diff --git a/store/driver.go b/store/driver.go index 6ad9cc04..f0f6fea3 100644 --- a/store/driver.go +++ b/store/driver.go @@ -14,6 +14,10 @@ type Driver interface { UpdateResource(ctx context.Context, update *UpdateResource) (*Resource, error) DeleteResource(ctx context.Context, delete *DeleteResource) error + UpsertMemoRelation(ctx context.Context, create *MemoRelation) (*MemoRelation, error) + ListMemoRelations(ctx context.Context, find *FindMemoRelation) ([]*MemoRelation, error) + DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelation) error + UpsertSystemSetting(ctx context.Context, upsert *SystemSetting) (*SystemSetting, error) ListSystemSettings(ctx context.Context, find *FindSystemSetting) ([]*SystemSetting, error) diff --git a/store/memo_relation.go b/store/memo_relation.go index a411a775..35106252 100644 --- a/store/memo_relation.go +++ b/store/memo_relation.go @@ -3,7 +3,6 @@ package store import ( "context" "database/sql" - "strings" ) type MemoRelationType string @@ -32,77 +31,11 @@ type DeleteMemoRelation struct { } func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (*MemoRelation, error) { - stmt := ` - INSERT INTO memo_relation ( - memo_id, - related_memo_id, - type - ) - VALUES (?, ?, ?) - ON CONFLICT (memo_id, related_memo_id, type) DO UPDATE SET - type = EXCLUDED.type - RETURNING memo_id, related_memo_id, type - ` - memoRelation := &MemoRelation{} - if err := s.db.QueryRowContext( - ctx, - stmt, - create.MemoID, - create.RelatedMemoID, - create.Type, - ).Scan( - &memoRelation.MemoID, - &memoRelation.RelatedMemoID, - &memoRelation.Type, - ); err != nil { - return nil, err - } - - return memoRelation, nil + return s.driver.UpsertMemoRelation(ctx, create) } func (s *Store) ListMemoRelations(ctx context.Context, find *FindMemoRelation) ([]*MemoRelation, error) { - where, args := []string{"TRUE"}, []any{} - if find.MemoID != nil { - where, args = append(where, "memo_id = ?"), append(args, find.MemoID) - } - if find.RelatedMemoID != nil { - where, args = append(where, "related_memo_id = ?"), append(args, find.RelatedMemoID) - } - if find.Type != nil { - where, args = append(where, "type = ?"), append(args, find.Type) - } - - rows, err := s.db.QueryContext(ctx, ` - SELECT - memo_id, - related_memo_id, - type - FROM memo_relation - WHERE `+strings.Join(where, " AND "), args...) - if err != nil { - return nil, err - } - defer rows.Close() - - list := []*MemoRelation{} - for rows.Next() { - memoRelation := &MemoRelation{} - if err := rows.Scan( - &memoRelation.MemoID, - &memoRelation.RelatedMemoID, - &memoRelation.Type, - ); err != nil { - return nil, err - } - list = append(list, memoRelation) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return list, nil + return s.driver.ListMemoRelations(ctx, find) } func (s *Store) GetMemoRelation(ctx context.Context, find *FindMemoRelation) (*MemoRelation, error) { @@ -119,27 +52,7 @@ func (s *Store) GetMemoRelation(ctx context.Context, find *FindMemoRelation) (*M } func (s *Store) DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelation) error { - where, args := []string{"TRUE"}, []any{} - if delete.MemoID != nil { - where, args = append(where, "memo_id = ?"), append(args, delete.MemoID) - } - if delete.RelatedMemoID != nil { - where, args = append(where, "related_memo_id = ?"), append(args, delete.RelatedMemoID) - } - if delete.Type != nil { - where, args = append(where, "type = ?"), append(args, delete.Type) - } - stmt := ` - DELETE FROM memo_relation - WHERE ` + strings.Join(where, " AND ") - result, err := s.db.ExecContext(ctx, stmt, args...) - if err != nil { - return err - } - if _, err = result.RowsAffected(); err != nil { - return err - } - return nil + return s.driver.DeleteMemoRelation(ctx, delete) } func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error { diff --git a/store/sqlite/memo_relation.go b/store/sqlite/memo_relation.go new file mode 100644 index 00000000..86f6de01 --- /dev/null +++ b/store/sqlite/memo_relation.go @@ -0,0 +1,106 @@ +package sqlite + +import ( + "context" + "strings" + + "github.com/usememos/memos/store" +) + +func (d *Driver) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) { + stmt := ` + INSERT INTO memo_relation ( + memo_id, + related_memo_id, + type + ) + VALUES (?, ?, ?) + ON CONFLICT (memo_id, related_memo_id, type) DO UPDATE SET + type = EXCLUDED.type + RETURNING memo_id, related_memo_id, type + ` + memoRelation := &store.MemoRelation{} + if err := d.db.QueryRowContext( + ctx, + stmt, + create.MemoID, + create.RelatedMemoID, + create.Type, + ).Scan( + &memoRelation.MemoID, + &memoRelation.RelatedMemoID, + &memoRelation.Type, + ); err != nil { + return nil, err + } + + return memoRelation, nil +} + +func (d *Driver) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) { + where, args := []string{"TRUE"}, []any{} + if find.MemoID != nil { + where, args = append(where, "memo_id = ?"), append(args, find.MemoID) + } + if find.RelatedMemoID != nil { + where, args = append(where, "related_memo_id = ?"), append(args, find.RelatedMemoID) + } + if find.Type != nil { + where, args = append(where, "type = ?"), append(args, find.Type) + } + + rows, err := d.db.QueryContext(ctx, ` + SELECT + memo_id, + related_memo_id, + type + FROM memo_relation + WHERE `+strings.Join(where, " AND "), args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*store.MemoRelation{} + for rows.Next() { + memoRelation := &store.MemoRelation{} + if err := rows.Scan( + &memoRelation.MemoID, + &memoRelation.RelatedMemoID, + &memoRelation.Type, + ); err != nil { + return nil, err + } + list = append(list, memoRelation) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *Driver) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error { + where, args := []string{"TRUE"}, []any{} + if delete.MemoID != nil { + where, args = append(where, "memo_id = ?"), append(args, delete.MemoID) + } + if delete.RelatedMemoID != nil { + where, args = append(where, "related_memo_id = ?"), append(args, delete.RelatedMemoID) + } + if delete.Type != nil { + where, args = append(where, "type = ?"), append(args, delete.Type) + } + stmt := ` + DELETE FROM memo_relation + WHERE ` + strings.Join(where, " AND ") + result, err := d.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + if _, err = result.RowsAffected(); err != nil { + return err + } + return nil +}