From fab8a71fd27263af2aa5300d8301c3512d04a722 Mon Sep 17 00:00:00 2001 From: boojack Date: Tue, 25 Apr 2023 23:27:38 +0800 Subject: [PATCH] feat: implement memo relation store (#1598) * feat: implement memo relation store * chore: update --- store/memo_relation.go | 189 +++++++++++++++++++++++++++++++ store/store.go | 3 + test/store/memo_relation_test.go | 57 ++++++++++ test/test.go | 2 +- 4 files changed, 250 insertions(+), 1 deletion(-) create mode 100644 store/memo_relation.go create mode 100644 test/store/memo_relation_test.go diff --git a/store/memo_relation.go b/store/memo_relation.go new file mode 100644 index 00000000..6fe5c081 --- /dev/null +++ b/store/memo_relation.go @@ -0,0 +1,189 @@ +package store + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/usememos/memos/common" +) + +type MemoRelationType string + +const ( + MemoRelationReference MemoRelationType = "REFERENCE" + MemoRelationAdditional MemoRelationType = "ADDITIONAL" +) + +type MemoRelationMessage struct { + MemoID int + RelatedMemoID int + Type MemoRelationType +} + +type FindMemoRelationMessage struct { + MemoID *int + RelatedMemoID *int + Type *MemoRelationType +} + +type DeleteMemoRelationMessage struct { + MemoID *int + RelatedMemoID *int + Type *MemoRelationType +} + +func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelationMessage) (*MemoRelationMessage, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + query := ` + INSERT INTO memo_relation ( + memo_id, + related_memo_id, + type + ) + VALUES (?, ?, ?) + ON CONFLICT (memo_id, related_memo_id, type) DO UPDATE SET + type = EXCLUDED.type + RETURNING memo_id, related_memo_id, type + ` + memoRelationMessage := &MemoRelationMessage{} + if err := tx.QueryRowContext( + ctx, + query, + create.MemoID, + create.RelatedMemoID, + create.Type, + ).Scan( + &memoRelationMessage.MemoID, + &memoRelationMessage.RelatedMemoID, + &memoRelationMessage.Type, + ); err != nil { + return nil, FormatError(err) + } + if err := tx.Commit(); err != nil { + return nil, FormatError(err) + } + return memoRelationMessage, nil +} + +func (s *Store) ListMemoRelations(ctx context.Context, find *FindMemoRelationMessage) ([]*MemoRelationMessage, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + list, err := listMemoRelations(ctx, tx, find) + if err != nil { + return nil, err + } + + return list, nil +} + +func (s *Store) GetMemoRelation(ctx context.Context, find *FindMemoRelationMessage) (*MemoRelationMessage, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + list, err := listMemoRelations(ctx, tx, find) + if err != nil { + return nil, err + } + + if len(list) == 0 { + return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")} + } + return list[0], nil +} + +func (s *Store) DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelationMessage) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return FormatError(err) + } + defer tx.Rollback() + + where, args := []string{"TRUE"}, []any{} + if delete.MemoID != nil { + where, args = append(where, "memo_id = ?"), append(args, delete.MemoID) + } + if delete.RelatedMemoID != nil { + where, args = append(where, "related_memo_id = ?"), append(args, delete.RelatedMemoID) + } + if delete.Type != nil { + where, args = append(where, "type = ?"), append(args, delete.Type) + } + + query := ` + DELETE FROM memo_relation + WHERE ` + strings.Join(where, " AND ") + if _, err := tx.ExecContext(ctx, query, args...); err != nil { + return FormatError(err) + } + + if err := tx.Commit(); err != nil { + return FormatError(err) + } + return nil +} + +func listMemoRelations(ctx context.Context, tx *sql.Tx, find *FindMemoRelationMessage) ([]*MemoRelationMessage, error) { + where, args := []string{"TRUE"}, []any{} + if find.MemoID != nil { + where, args = append(where, "memo_id = ?"), append(args, find.MemoID) + } + if find.RelatedMemoID != nil { + where, args = append(where, "related_memo_id = ?"), append(args, find.RelatedMemoID) + } + if find.Type != nil { + where, args = append(where, "type = ?"), append(args, find.Type) + } + + rows, err := tx.QueryContext(ctx, ` + SELECT + memo_id, + related_memo_id, + type + FROM memo_relation + WHERE `+strings.Join(where, " AND "), args...) + if err != nil { + return nil, FormatError(err) + } + defer rows.Close() + + memoRelationMessages := []*MemoRelationMessage{} + for rows.Next() { + memoRelationMessage := &MemoRelationMessage{} + if err := rows.Scan( + &memoRelationMessage.MemoID, + &memoRelationMessage.RelatedMemoID, + &memoRelationMessage.Type, + ); err != nil { + return nil, FormatError(err) + } + memoRelationMessages = append(memoRelationMessages, memoRelationMessage) + } + if err := rows.Err(); err != nil { + return nil, FormatError(err) + } + return memoRelationMessages, nil +} + +func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error { + if _, err := tx.ExecContext(ctx, ` + DELETE FROM memo_relation + WHERE memo_id NOT IN (SELECT id FROM memo) OR related_memo_id NOT IN (SELECT id FROM memo) + `); err != nil { + return FormatError(err) + } + return nil +} diff --git a/store/store.go b/store/store.go index 50f3afc3..b26f8bf8 100644 --- a/store/store.go +++ b/store/store.go @@ -71,6 +71,9 @@ func vacuum(ctx context.Context, tx *sql.Tx) error { if err := vacuumMemoResource(ctx, tx); err != nil { return err } + if err := vacuumMemoRelations(ctx, tx); err != nil { + return err + } if err := vacuumTag(ctx, tx); err != nil { // Prevent revive warning. return err diff --git a/test/store/memo_relation_test.go b/test/store/memo_relation_test.go new file mode 100644 index 00000000..4db7f9c4 --- /dev/null +++ b/test/store/memo_relation_test.go @@ -0,0 +1,57 @@ +package teststore + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/usememos/memos/api" + "github.com/usememos/memos/store" +) + +func TestMemoRelationStore(t *testing.T) { + ctx := context.Background() + ts := NewTestingStore(ctx, t) + user, err := createTestingHostUser(ctx, ts) + require.NoError(t, err) + memoCreate := &api.MemoCreate{ + CreatorID: user.ID, + Content: "test_content", + Visibility: api.Public, + } + memo, err := ts.CreateMemo(ctx, memoCreate) + require.NoError(t, err) + require.Equal(t, memoCreate.Content, memo.Content) + memoCreate = &api.MemoCreate{ + CreatorID: user.ID, + Content: "test_content_2", + Visibility: api.Public, + } + memo2, err := ts.CreateMemo(ctx, memoCreate) + require.NoError(t, err) + require.Equal(t, memoCreate.Content, memo2.Content) + memoRelationMessage := &store.MemoRelationMessage{ + MemoID: memo.ID, + RelatedMemoID: memo2.ID, + Type: store.MemoRelationReference, + } + _, err = ts.UpsertMemoRelation(ctx, memoRelationMessage) + require.NoError(t, err) + memoRelation, err := ts.ListMemoRelations(ctx, &store.FindMemoRelationMessage{ + MemoID: &memo.ID, + }) + require.NoError(t, err) + require.Equal(t, 1, len(memoRelation)) + require.Equal(t, memo2.ID, memoRelation[0].RelatedMemoID) + require.Equal(t, memo.ID, memoRelation[0].MemoID) + require.Equal(t, store.MemoRelationReference, memoRelation[0].Type) + err = ts.DeleteMemo(ctx, &api.MemoDelete{ + ID: memo2.ID, + }) + require.NoError(t, err) + memoRelation, err = ts.ListMemoRelations(ctx, &store.FindMemoRelationMessage{ + MemoID: &memo.ID, + }) + require.NoError(t, err) + require.Equal(t, 0, len(memoRelation)) +} diff --git a/test/test.go b/test/test.go index a70c2b85..64afe786 100644 --- a/test/test.go +++ b/test/test.go @@ -25,7 +25,7 @@ func getUnusedPort() int { func GetTestingProfile(t *testing.T) *profile.Profile { // Get a temporary directory for the test data. dir := t.TempDir() - mode := "prod" + mode := "dev" port := getUnusedPort() return &profile.Profile{ Mode: mode,