From b6564bcd770b22c329aaf8968c5837ac3997d3cd Mon Sep 17 00:00:00 2001 From: boojack Date: Mon, 1 May 2023 16:09:41 +0800 Subject: [PATCH] feat: implement memo relation server (#1618) --- api/memo.go | 11 ++-- api/memo_relation.go | 19 ++++++ server/memo.go | 25 ++++++++ server/memo_relation.go | 76 ++++++++++++++++++++++++ server/server.go | 1 + store/memo.go | 5 ++ store/memo_relation.go | 20 +++++++ test/server/memo_relation_test.go | 97 +++++++++++++++++++++++++++++++ test/server/memo_test.go | 30 ++++++++-- 9 files changed, 276 insertions(+), 8 deletions(-) create mode 100644 api/memo_relation.go create mode 100644 server/memo_relation.go create mode 100644 test/server/memo_relation_test.go diff --git a/api/memo.go b/api/memo.go index 80ce1b52..3cbb528b 100644 --- a/api/memo.go +++ b/api/memo.go @@ -42,8 +42,9 @@ type Memo struct { Pinned bool `json:"pinned"` // Related fields - CreatorName string `json:"creatorName"` - ResourceList []*Resource `json:"resourceList"` + CreatorName string `json:"creatorName"` + ResourceList []*Resource `json:"resourceList"` + RelationList []*MemoRelation `json:"relationList"` } type MemoCreate struct { @@ -56,7 +57,8 @@ type MemoCreate struct { Content string `json:"content"` // Related fields - ResourceIDList []int `json:"resourceIdList"` + ResourceIDList []int `json:"resourceIdList"` + MemoRelationList []*MemoRelationUpsert `json:"memoRelationList"` } type MemoPatch struct { @@ -72,7 +74,8 @@ type MemoPatch struct { Visibility *Visibility `json:"visibility"` // Related fields - ResourceIDList []int `json:"resourceIdList"` + ResourceIDList []int `json:"resourceIdList"` + MemoRelationList []*MemoRelationUpsert `json:"memoRelationList"` } type MemoFind struct { diff --git a/api/memo_relation.go b/api/memo_relation.go new file mode 100644 index 00000000..8e5df41d --- /dev/null +++ b/api/memo_relation.go @@ -0,0 +1,19 @@ +package api + +type MemoRelationType string + +const ( + MemoRelationReference MemoRelationType = "REFERENCE" + MemoRelationAdditional MemoRelationType = "ADDITIONAL" +) + +type MemoRelation struct { + MemoID int + RelatedMemoID int + Type MemoRelationType +} + +type MemoRelationUpsert struct { + RelatedMemoID int + Type MemoRelationType +} diff --git a/server/memo.go b/server/memo.go index e9aea9e6..fae1a843 100644 --- a/server/memo.go +++ b/server/memo.go @@ -11,6 +11,7 @@ import ( "github.com/pkg/errors" "github.com/usememos/memos/api" "github.com/usememos/memos/common" + "github.com/usememos/memos/store" "github.com/labstack/echo/v4" ) @@ -101,6 +102,18 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { } } + if s.Profile.IsDev() { + for _, memoRelationUpsert := range memoCreate.MemoRelationList { + if _, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelationMessage{ + MemoID: memo.ID, + RelatedMemoID: memoRelationUpsert.RelatedMemoID, + Type: store.MemoRelationType(memoRelationUpsert.Type), + }); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo relation").SetInternal(err) + } + } + } + memo, err = s.Store.ComposeMemo(ctx, memo) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo").SetInternal(err) @@ -157,6 +170,18 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { } } + if s.Profile.IsDev() { + for _, memoRelationUpsert := range memoPatch.MemoRelationList { + if _, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelationMessage{ + MemoID: memo.ID, + RelatedMemoID: memoRelationUpsert.RelatedMemoID, + Type: store.MemoRelationType(memoRelationUpsert.Type), + }); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo relation").SetInternal(err) + } + } + } + memo, err = s.Store.ComposeMemo(ctx, memo) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo").SetInternal(err) diff --git a/server/memo_relation.go b/server/memo_relation.go new file mode 100644 index 00000000..0e5136ae --- /dev/null +++ b/server/memo_relation.go @@ -0,0 +1,76 @@ +package server + +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + + "github.com/usememos/memos/api" + "github.com/usememos/memos/store" + + "github.com/labstack/echo/v4" +) + +func (s *Server) registerMemoRelationRoutes(g *echo.Group) { + g.POST("/memo/:memoId/relation", func(c echo.Context) error { + ctx := c.Request().Context() + memoID, err := strconv.Atoi(c.Param("memoId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) + } + + memoRelationUpsert := &api.MemoRelationUpsert{} + if err := json.NewDecoder(c.Request().Body).Decode(memoRelationUpsert); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post memo relation request").SetInternal(err) + } + + memoRelation, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelationMessage{ + MemoID: memoID, + RelatedMemoID: memoRelationUpsert.RelatedMemoID, + Type: store.MemoRelationType(memoRelationUpsert.Type), + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo relation").SetInternal(err) + } + return c.JSON(http.StatusOK, composeResponse(memoRelation)) + }) + + g.GET("/memo/:memoId/relation", func(c echo.Context) error { + ctx := c.Request().Context() + memoID, err := strconv.Atoi(c.Param("memoId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) + } + + memoRelationList, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelationMessage{ + MemoID: &memoID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to list memo relations").SetInternal(err) + } + return c.JSON(http.StatusOK, composeResponse(memoRelationList)) + }) + + g.DELETE("/memo/:memoId/relation/:relatedMemoId/type/:relationType", func(c echo.Context) error { + ctx := c.Request().Context() + memoID, err := strconv.Atoi(c.Param("memoId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Memo ID is not a number: %s", c.Param("memoId"))).SetInternal(err) + } + relatedMemoID, err := strconv.Atoi(c.Param("relatedMemoId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Related memo ID is not a number: %s", c.Param("resourceId"))).SetInternal(err) + } + relationType := store.MemoRelationType(c.Param("relationType")) + + if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelationMessage{ + MemoID: &memoID, + RelatedMemoID: &relatedMemoID, + Type: &relationType, + }); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete memo relation").SetInternal(err) + } + return c.JSON(http.StatusOK, true) + }) +} diff --git a/server/server.go b/server/server.go index 3c82e8e2..8ffc3cac 100644 --- a/server/server.go +++ b/server/server.go @@ -109,6 +109,7 @@ func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) { s.registerStorageRoutes(apiGroup) s.registerIdentityProviderRoutes(apiGroup) s.registerOpenAIRoutes(apiGroup) + s.registerMemoRelationRoutes(apiGroup) return s, nil } diff --git a/store/memo.go b/store/memo.go index 448cecde..46a0d770 100644 --- a/store/memo.go +++ b/store/memo.go @@ -53,6 +53,11 @@ func (s *Store) ComposeMemo(ctx context.Context, memo *api.Memo) (*api.Memo, err if err := s.ComposeMemoResourceList(ctx, memo); err != nil { return nil, err } + if s.profile.IsDev() { + if err := s.ComposeMemoRelationList(ctx, memo); err != nil { + return nil, err + } + } return memo, nil } diff --git a/store/memo_relation.go b/store/memo_relation.go index 6fe5c081..06cd6397 100644 --- a/store/memo_relation.go +++ b/store/memo_relation.go @@ -6,9 +6,29 @@ import ( "fmt" "strings" + "github.com/usememos/memos/api" "github.com/usememos/memos/common" ) +func (s *Store) ComposeMemoRelationList(ctx context.Context, memo *api.Memo) error { + memoRelationList, err := s.ListMemoRelations(ctx, &FindMemoRelationMessage{ + MemoID: &memo.ID, + }) + if err != nil { + return err + } + + for _, memoRelation := range memoRelationList { + memo.RelationList = append(memo.RelationList, &api.MemoRelation{ + MemoID: memoRelation.MemoID, + RelatedMemoID: memoRelation.RelatedMemoID, + Type: api.MemoRelationType(memoRelation.Type), + }) + } + + return nil +} + type MemoRelationType string const ( diff --git a/test/server/memo_relation_test.go b/test/server/memo_relation_test.go new file mode 100644 index 00000000..8be8a000 --- /dev/null +++ b/test/server/memo_relation_test.go @@ -0,0 +1,97 @@ +package testserver + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + "github.com/usememos/memos/api" +) + +func TestMemoRelationServer(t *testing.T) { + ctx := context.Background() + s, err := NewTestingServer(ctx, t) + require.NoError(t, err) + defer s.Shutdown(ctx) + + signup := &api.SignUp{ + Username: "testuser", + Password: "testpassword", + } + user, err := s.postAuthSignup(signup) + require.NoError(t, err) + require.Equal(t, signup.Username, user.Username) + memoList, err := s.getMemoList() + require.NoError(t, err) + require.Len(t, memoList, 0) + memo, err := s.postMemoCreate(&api.MemoCreate{ + Content: "test memo", + }) + require.NoError(t, err) + require.Equal(t, "test memo", memo.Content) + memo2, err := s.postMemoCreate(&api.MemoCreate{ + Content: "test memo2", + MemoRelationList: []*api.MemoRelationUpsert{ + { + RelatedMemoID: memo.ID, + Type: api.MemoRelationReference, + }, + }, + }) + require.NoError(t, err) + require.Equal(t, "test memo2", memo2.Content) + memoList, err = s.getMemoList() + require.NoError(t, err) + require.Len(t, memoList, 2) + require.Len(t, memo2.RelationList, 1) + err = s.deleteMemoRelation(memo2.ID, memo.ID, api.MemoRelationReference) + require.NoError(t, err) + memo2, err = s.getMemo(memo2.ID) + require.NoError(t, err) + require.Len(t, memo2.RelationList, 0) + memoRelation, err := s.postMemoRelationUpsert(memo2.ID, &api.MemoRelationUpsert{ + RelatedMemoID: memo.ID, + Type: api.MemoRelationReference, + }) + require.NoError(t, err) + require.Equal(t, memo.ID, memoRelation.RelatedMemoID) + memo2, err = s.getMemo(memo2.ID) + require.NoError(t, err) + require.Len(t, memo2.RelationList, 1) +} + +func (s *TestingServer) postMemoRelationUpsert(memoID int, memoRelationUpsert *api.MemoRelationUpsert) (*api.MemoRelation, error) { + rawData, err := json.Marshal(&memoRelationUpsert) + if err != nil { + return nil, errors.Wrap(err, "failed to marshal memo relation upsert") + } + reader := bytes.NewReader(rawData) + body, err := s.post(fmt.Sprintf("/api/memo/%d/relation", memoID), reader, nil) + if err != nil { + return nil, err + } + + buf := &bytes.Buffer{} + _, err = buf.ReadFrom(body) + if err != nil { + return nil, errors.Wrap(err, "fail to read response body") + } + + type MemoCreateResponse struct { + Data *api.MemoRelation `json:"data"` + } + res := new(MemoCreateResponse) + if err = json.Unmarshal(buf.Bytes(), res); err != nil { + return nil, errors.Wrap(err, "fail to unmarshal post memo relation upsert response") + } + return res.Data, nil +} + +func (s *TestingServer) deleteMemoRelation(memoID int, relatedMemoID int, relationType api.MemoRelationType) error { + _, err := s.delete(fmt.Sprintf("/api/memo/%d/relation/%d/type/%s", memoID, relatedMemoID, relationType), nil) + return err +} diff --git a/test/server/memo_test.go b/test/server/memo_test.go index 7ee1174d..a8562b95 100644 --- a/test/server/memo_test.go +++ b/test/server/memo_test.go @@ -37,13 +37,13 @@ func TestMemoServer(t *testing.T) { require.NoError(t, err) require.Len(t, memoList, 1) updatedContent := "updated memo" - memo, err = s.patchMemoPatch(&api.MemoPatch{ + memo, err = s.patchMemo(&api.MemoPatch{ ID: memo.ID, Content: &updatedContent, }) require.NoError(t, err) require.Equal(t, updatedContent, memo.Content) - err = s.postMemoDelete(&api.MemoDelete{ + err = s.deleteMemo(&api.MemoDelete{ ID: memo.ID, }) require.NoError(t, err) @@ -52,6 +52,28 @@ func TestMemoServer(t *testing.T) { require.Len(t, memoList, 0) } +func (s *TestingServer) getMemo(memoID int) (*api.Memo, error) { + body, err := s.get(fmt.Sprintf("/api/memo/%d", memoID), nil) + if err != nil { + return nil, err + } + + buf := &bytes.Buffer{} + _, err = buf.ReadFrom(body) + if err != nil { + return nil, errors.Wrap(err, "fail to read response body") + } + + type MemoCreateResponse struct { + Data *api.Memo `json:"data"` + } + res := new(MemoCreateResponse) + if err = json.Unmarshal(buf.Bytes(), res); err != nil { + return nil, errors.Wrap(err, "fail to unmarshal get memo response") + } + return res.Data, nil +} + func (s *TestingServer) getMemoList() ([]*api.Memo, error) { body, err := s.get("/api/memo", nil) if err != nil { @@ -101,7 +123,7 @@ func (s *TestingServer) postMemoCreate(memoCreate *api.MemoCreate) (*api.Memo, e return res.Data, nil } -func (s *TestingServer) patchMemoPatch(memoPatch *api.MemoPatch) (*api.Memo, error) { +func (s *TestingServer) patchMemo(memoPatch *api.MemoPatch) (*api.Memo, error) { rawData, err := json.Marshal(&memoPatch) if err != nil { return nil, errors.Wrap(err, "failed to marshal memo patch") @@ -128,7 +150,7 @@ func (s *TestingServer) patchMemoPatch(memoPatch *api.MemoPatch) (*api.Memo, err return res.Data, nil } -func (s *TestingServer) postMemoDelete(memoDelete *api.MemoDelete) error { +func (s *TestingServer) deleteMemo(memoDelete *api.MemoDelete) error { _, err := s.delete(fmt.Sprintf("/api/memo/%d", memoDelete.ID), nil) return err }