diff --git a/api/common.go b/api/common.go deleted file mode 100644 index 4f16292d..00000000 --- a/api/common.go +++ /dev/null @@ -1,24 +0,0 @@ -package api - -// UnknownID is the ID for unknowns. -const UnknownID = -1 - -// RowStatus is the status for a row. -type RowStatus string - -const ( - // Normal is the status for a normal row. - Normal RowStatus = "NORMAL" - // Archived is the status for an archived row. - Archived RowStatus = "ARCHIVED" -) - -func (e RowStatus) String() string { - switch e { - case Normal: - return "NORMAL" - case Archived: - return "ARCHIVED" - } - return "" -} diff --git a/api/memo.go b/api/memo.go deleted file mode 100644 index 71bd7258..00000000 --- a/api/memo.go +++ /dev/null @@ -1,94 +0,0 @@ -package api - -// Visibility is the type of a visibility. -type Visibility string - -const ( - // Public is the PUBLIC visibility. - Public Visibility = "PUBLIC" - // Protected is the PROTECTED visibility. - Protected Visibility = "PROTECTED" - // Private is the PRIVATE visibility. - Private Visibility = "PRIVATE" -) - -func (v Visibility) String() string { - switch v { - case Public: - return "PUBLIC" - case Protected: - return "PROTECTED" - case Private: - return "PRIVATE" - } - return "PRIVATE" -} - -type MemoResponse struct { - ID int `json:"id"` - - // Standard fields - RowStatus RowStatus `json:"rowStatus"` - CreatorID int `json:"creatorId"` - CreatedTs int64 `json:"createdTs"` - UpdatedTs int64 `json:"updatedTs"` - - // Domain specific fields - DisplayTs int64 `json:"displayTs"` - Content string `json:"content"` - Visibility Visibility `json:"visibility"` - Pinned bool `json:"pinned"` - - // Related fields - CreatorName string `json:"creatorName"` - ResourceList []*Resource `json:"resourceList"` - RelationList []*MemoRelation `json:"relationList"` -} - -type CreateMemoRequest struct { - // Standard fields - CreatorID int `json:"-"` - CreatedTs *int64 `json:"createdTs"` - - // Domain specific fields - Visibility Visibility `json:"visibility"` - Content string `json:"content"` - - // Related fields - ResourceIDList []int `json:"resourceIdList"` - RelationList []*MemoRelationUpsert `json:"relationList"` -} - -type PatchMemoRequest struct { - ID int `json:"-"` - - // Standard fields - CreatedTs *int64 `json:"createdTs"` - UpdatedTs *int64 - RowStatus *RowStatus `json:"rowStatus"` - - // Domain specific fields - Content *string `json:"content"` - Visibility *Visibility `json:"visibility"` - - // Related fields - ResourceIDList []int `json:"resourceIdList"` - RelationList []*MemoRelationUpsert `json:"relationList"` -} - -type FindMemoRequest struct { - ID *int - - // Standard fields - RowStatus *RowStatus - CreatorID *int - - // Domain specific fields - Pinned *bool - ContentSearch []string - VisibilityList []Visibility - - // Pagination - Limit *int - Offset *int -} diff --git a/api/memo_organizer.go b/api/memo_organizer.go deleted file mode 100644 index 664ff1f0..00000000 --- a/api/memo_organizer.go +++ /dev/null @@ -1,24 +0,0 @@ -package api - -type MemoOrganizer struct { - // Domain specific fields - MemoID int - UserID int - Pinned bool -} - -type MemoOrganizerUpsert struct { - MemoID int `json:"-"` - UserID int `json:"-"` - Pinned bool `json:"pinned"` -} - -type MemoOrganizerFind struct { - MemoID int - UserID int -} - -type MemoOrganizerDelete struct { - MemoID *int - UserID *int -} diff --git a/api/memo_relation.go b/api/memo_relation.go deleted file mode 100644 index 9e8c022b..00000000 --- a/api/memo_relation.go +++ /dev/null @@ -1,19 +0,0 @@ -package api - -type MemoRelationType string - -const ( - MemoRelationReference MemoRelationType = "REFERENCE" - MemoRelationAdditional MemoRelationType = "ADDITIONAL" -) - -type MemoRelation struct { - MemoID int `json:"memoId"` - RelatedMemoID int `json:"relatedMemoId"` - Type MemoRelationType `json:"type"` -} - -type MemoRelationUpsert struct { - RelatedMemoID int `json:"relatedMemoId"` - Type MemoRelationType `json:"type"` -} diff --git a/api/memo_resource.go b/api/memo_resource.go deleted file mode 100644 index 3c0b82ce..00000000 --- a/api/memo_resource.go +++ /dev/null @@ -1,24 +0,0 @@ -package api - -type MemoResource struct { - MemoID int - ResourceID int - CreatedTs int64 - UpdatedTs int64 -} - -type MemoResourceUpsert struct { - MemoID int `json:"-"` - ResourceID int - UpdatedTs *int64 -} - -type MemoResourceFind struct { - MemoID *int - ResourceID *int -} - -type MemoResourceDelete struct { - MemoID *int - ResourceID *int -} diff --git a/api/resource.go b/api/resource.go deleted file mode 100644 index 08485ebd..00000000 --- a/api/resource.go +++ /dev/null @@ -1,22 +0,0 @@ -package api - -type Resource struct { - ID int `json:"id"` - - // Standard fields - CreatorID int `json:"creatorId"` - CreatedTs int64 `json:"createdTs"` - UpdatedTs int64 `json:"updatedTs"` - - // Domain specific fields - Filename string `json:"filename"` - Blob []byte `json:"-"` - InternalPath string `json:"-"` - ExternalLink string `json:"externalLink"` - Type string `json:"type"` - Size int64 `json:"size"` - PublicID string `json:"publicId"` - - // Related fields - LinkedMemoAmount int `json:"linkedMemoAmount"` -} diff --git a/api/v1/auth.go b/api/v1/auth.go index 263623b2..53e5b41e 100644 --- a/api/v1/auth.go +++ b/api/v1/auth.go @@ -234,7 +234,7 @@ func (s *APIV1Service) createAuthSignInActivity(c echo.Context, user *store.User if err != nil { return errors.Wrap(err, "failed to marshal activity payload") } - activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{ + activity, err := s.Store.CreateActivity(ctx, &store.Activity{ CreatorID: user.ID, Type: string(ActivityUserAuthSignIn), Level: string(ActivityInfo), @@ -256,7 +256,7 @@ func (s *APIV1Service) createAuthSignUpActivity(c echo.Context, user *store.User if err != nil { return errors.Wrap(err, "failed to marshal activity payload") } - activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{ + activity, err := s.Store.CreateActivity(ctx, &store.Activity{ CreatorID: user.ID, Type: string(ActivityUserAuthSignUp), Level: string(ActivityInfo), diff --git a/server/http_getter.go b/api/v1/http_getter.go similarity index 92% rename from server/http_getter.go rename to api/v1/http_getter.go index 7146ce99..6d24887c 100644 --- a/server/http_getter.go +++ b/api/v1/http_getter.go @@ -1,4 +1,4 @@ -package server +package v1 import ( "fmt" @@ -9,7 +9,7 @@ import ( getter "github.com/usememos/memos/plugin/http-getter" ) -func registerGetterPublicRoutes(g *echo.Group) { +func (*APIV1Service) registerGetterPublicRoutes(g *echo.Group) { g.GET("/get/httpmeta", func(c echo.Context) error { urlStr := c.QueryParam("url") if urlStr == "" { @@ -23,7 +23,7 @@ func registerGetterPublicRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusNotAcceptable, fmt.Sprintf("Failed to get website meta with url: %s", urlStr)).SetInternal(err) } - return c.JSON(http.StatusOK, composeResponse(htmlMeta)) + return c.JSON(http.StatusOK, htmlMeta) }) g.GET("/get/image", func(c echo.Context) error { diff --git a/api/v1/jwt.go b/api/v1/jwt.go index fb30adc9..3c0b8742 100644 --- a/api/v1/jwt.go +++ b/api/v1/jwt.go @@ -82,7 +82,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e } // Skip validation for server status endpoints. - if common.HasPrefixes(path, "/api/v1/ping", "/api/v1/idp", "/api/v1/status", "/api/user/:id") && method == http.MethodGet { + if common.HasPrefixes(path, "/api/v1/ping", "/api/v1/idp", "/api/v1/status", "/api/v1/user/:id") && method == http.MethodGet { return next(c) } @@ -93,7 +93,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e return next(c) } // When the request is not authenticated, we allow the user to access the memo endpoints for those public memos. - if common.HasPrefixes(path, "/api/memo") && method == http.MethodGet { + if common.HasPrefixes(path, "/api/v1/memo") && method == http.MethodGet { return next(c) } return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") diff --git a/api/v1/memo.go b/api/v1/memo.go index 05c2717e..03dc7943 100644 --- a/api/v1/memo.go +++ b/api/v1/memo.go @@ -1,5 +1,20 @@ package v1 +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/labstack/echo/v4" + "github.com/pkg/errors" + "github.com/usememos/memos/common" + "github.com/usememos/memos/store" +) + // Visibility is the type of a visibility. type Visibility string @@ -13,5 +28,731 @@ const ( ) func (v Visibility) String() string { - return string(v) + switch v { + case Public: + return "PUBLIC" + case Protected: + return "PROTECTED" + case Private: + return "PRIVATE" + } + return "PRIVATE" +} + +type Memo struct { + ID int `json:"id"` + + // Standard fields + RowStatus RowStatus `json:"rowStatus"` + CreatorID int `json:"creatorId"` + CreatedTs int64 `json:"createdTs"` + UpdatedTs int64 `json:"updatedTs"` + + // Domain specific fields + DisplayTs int64 `json:"displayTs"` + Content string `json:"content"` + Visibility Visibility `json:"visibility"` + Pinned bool `json:"pinned"` + + // Related fields + CreatorName string `json:"creatorName"` + ResourceList []*Resource `json:"resourceList"` + RelationList []*MemoRelation `json:"relationList"` +} + +type CreateMemoRequest struct { + // Standard fields + CreatorID int `json:"-"` + CreatedTs *int64 `json:"createdTs"` + + // Domain specific fields + Visibility Visibility `json:"visibility"` + Content string `json:"content"` + + // Related fields + ResourceIDList []int `json:"resourceIdList"` + RelationList []*UpsertMemoRelationRequest `json:"relationList"` +} + +type PatchMemoRequest struct { + ID int `json:"-"` + + // Standard fields + CreatedTs *int64 `json:"createdTs"` + UpdatedTs *int64 + RowStatus *RowStatus `json:"rowStatus"` + + // Domain specific fields + Content *string `json:"content"` + Visibility *Visibility `json:"visibility"` + + // Related fields + ResourceIDList []int `json:"resourceIdList"` + RelationList []*UpsertMemoRelationRequest `json:"relationList"` +} + +type FindMemoRequest struct { + ID *int + + // Standard fields + RowStatus *RowStatus + CreatorID *int + + // Domain specific fields + Pinned *bool + ContentSearch []string + VisibilityList []Visibility + + // Pagination + Limit *int + Offset *int +} + +// maxContentLength means the max memo content bytes is 1MB. +const maxContentLength = 1 << 30 + +func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { + g.POST("/memo", func(c echo.Context) error { + ctx := c.Request().Context() + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") + } + + createMemoRequest := &CreateMemoRequest{} + if err := json.NewDecoder(c.Request().Body).Decode(createMemoRequest); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post memo request").SetInternal(err) + } + if len(createMemoRequest.Content) > maxContentLength { + return echo.NewHTTPError(http.StatusBadRequest, "Content size overflow, up to 1MB") + } + + if createMemoRequest.Visibility == "" { + userMemoVisibilitySetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{ + UserID: &userID, + Key: UserSettingMemoVisibilityKey.String(), + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user setting").SetInternal(err) + } + + if userMemoVisibilitySetting != nil { + memoVisibility := Private + err := json.Unmarshal([]byte(userMemoVisibilitySetting.Value), &memoVisibility) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal user setting value").SetInternal(err) + } + createMemoRequest.Visibility = memoVisibility + } else { + // Private is the default memo visibility. + createMemoRequest.Visibility = Private + } + } + + // Find disable public memos system setting. + disablePublicMemosSystemSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{ + Name: SystemSettingDisablePublicMemosName.String(), + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting").SetInternal(err) + } + if disablePublicMemosSystemSetting != nil { + disablePublicMemos := false + err = json.Unmarshal([]byte(disablePublicMemosSystemSetting.Value), &disablePublicMemos) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting").SetInternal(err) + } + if disablePublicMemos { + user, err := s.Store.GetUser(ctx, &store.FindUser{ + ID: &userID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) + } + // Enforce normal user to create private memo if public memos are disabled. + if user.Role == store.RoleUser { + createMemoRequest.Visibility = Private + } + } + } + + createMemoRequest.CreatorID = userID + memo, err := s.Store.CreateMemo(ctx, convertCreateMemoRequestToMemoMessage(createMemoRequest)) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create memo").SetInternal(err) + } + if err := s.createMemoCreateActivity(ctx, memo); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) + } + + for _, resourceID := range createMemoRequest.ResourceIDList { + if _, err := s.Store.UpsertMemoResource(ctx, &store.UpsertMemoResource{ + MemoID: memo.ID, + ResourceID: resourceID, + }); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo resource").SetInternal(err) + } + } + + for _, memoRelationUpsert := range createMemoRequest.RelationList { + if _, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: memo.ID, + RelatedMemoID: memoRelationUpsert.RelatedMemoID, + Type: store.MemoRelationType(memoRelationUpsert.Type), + }); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo relation").SetInternal(err) + } + } + + memo, err = s.Store.GetMemo(ctx, &store.FindMemo{ + ID: &memo.ID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo").SetInternal(err) + } + memoResponse, err := s.convertMemoFromStore(ctx, memo) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err) + } + return c.JSON(http.StatusOK, memoResponse) + }) + + g.PATCH("/memo/:memoId", func(c echo.Context) error { + ctx := c.Request().Context() + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") + } + + memoID, err := strconv.Atoi(c.Param("memoId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) + } + + memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ + ID: &memoID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) + } + if memo.CreatorID != userID { + return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") + } + + currentTs := time.Now().Unix() + patchMemoRequest := &PatchMemoRequest{ + ID: memoID, + UpdatedTs: ¤tTs, + } + if err := json.NewDecoder(c.Request().Body).Decode(patchMemoRequest); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch memo request").SetInternal(err) + } + + if patchMemoRequest.Content != nil && len(*patchMemoRequest.Content) > maxContentLength { + return echo.NewHTTPError(http.StatusBadRequest, "Content size overflow, up to 1MB").SetInternal(err) + } + + updateMemoMessage := &store.UpdateMemo{ + ID: memoID, + CreatedTs: patchMemoRequest.CreatedTs, + UpdatedTs: patchMemoRequest.UpdatedTs, + Content: patchMemoRequest.Content, + } + if patchMemoRequest.RowStatus != nil { + rowStatus := store.RowStatus(patchMemoRequest.RowStatus.String()) + updateMemoMessage.RowStatus = &rowStatus + } + if patchMemoRequest.Visibility != nil { + visibility := store.Visibility(patchMemoRequest.Visibility.String()) + updateMemoMessage.Visibility = &visibility + } + + err = s.Store.UpdateMemo(ctx, updateMemoMessage) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch memo").SetInternal(err) + } + memo, err = s.Store.GetMemo(ctx, &store.FindMemo{ID: &memoID}) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) + } + + if patchMemoRequest.ResourceIDList != nil { + addedResourceIDList, removedResourceIDList := getIDListDiff(memo.ResourceIDList, patchMemoRequest.ResourceIDList) + for _, resourceID := range addedResourceIDList { + if _, err := s.Store.UpsertMemoResource(ctx, &store.UpsertMemoResource{ + MemoID: memo.ID, + ResourceID: resourceID, + }); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo resource").SetInternal(err) + } + } + for _, resourceID := range removedResourceIDList { + if err := s.Store.DeleteMemoResource(ctx, &store.DeleteMemoResource{ + MemoID: &memo.ID, + ResourceID: &resourceID, + }); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete memo resource").SetInternal(err) + } + } + } + + if patchMemoRequest.RelationList != nil { + patchMemoRelationList := make([]*store.MemoRelation, 0) + for _, memoRelation := range patchMemoRequest.RelationList { + patchMemoRelationList = append(patchMemoRelationList, &store.MemoRelation{ + MemoID: memo.ID, + RelatedMemoID: memoRelation.RelatedMemoID, + Type: store.MemoRelationType(memoRelation.Type), + }) + } + addedMemoRelationList, removedMemoRelationList := getMemoRelationListDiff(memo.RelationList, patchMemoRelationList) + for _, memoRelation := range addedMemoRelationList { + if _, err := s.Store.UpsertMemoRelation(ctx, memoRelation); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo relation").SetInternal(err) + } + } + for _, memoRelation := range removedMemoRelationList { + if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{ + MemoID: &memo.ID, + RelatedMemoID: &memoRelation.RelatedMemoID, + Type: &memoRelation.Type, + }); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete memo relation").SetInternal(err) + } + } + } + + memo, err = s.Store.GetMemo(ctx, &store.FindMemo{ID: &memoID}) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) + } + memoResponse, err := s.convertMemoFromStore(ctx, memo) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err) + } + return c.JSON(http.StatusOK, memoResponse) + }) + + g.GET("/memo", func(c echo.Context) error { + ctx := c.Request().Context() + findMemoMessage := &store.FindMemo{} + if userID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil { + findMemoMessage.CreatorID = &userID + } + + currentUserID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + if findMemoMessage.CreatorID == nil { + return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo") + } + findMemoMessage.VisibilityList = []store.Visibility{store.Public} + } else { + if findMemoMessage.CreatorID == nil { + findMemoMessage.CreatorID = ¤tUserID + } else { + findMemoMessage.VisibilityList = []store.Visibility{store.Public, store.Protected} + } + } + + rowStatus := store.RowStatus(c.QueryParam("rowStatus")) + if rowStatus != "" { + findMemoMessage.RowStatus = &rowStatus + } + pinnedStr := c.QueryParam("pinned") + if pinnedStr != "" { + pinned := pinnedStr == "true" + findMemoMessage.Pinned = &pinned + } + + contentSearch := []string{} + tag := c.QueryParam("tag") + if tag != "" { + contentSearch = append(contentSearch, "#"+tag) + } + contentSlice := c.QueryParams()["content"] + if len(contentSlice) > 0 { + contentSearch = append(contentSearch, contentSlice...) + } + findMemoMessage.ContentSearch = contentSearch + + visibilityListStr := c.QueryParam("visibility") + if visibilityListStr != "" { + visibilityList := []store.Visibility{} + for _, visibility := range strings.Split(visibilityListStr, ",") { + visibilityList = append(visibilityList, store.Visibility(visibility)) + } + findMemoMessage.VisibilityList = visibilityList + } + if limit, err := strconv.Atoi(c.QueryParam("limit")); err == nil { + findMemoMessage.Limit = &limit + } + if offset, err := strconv.Atoi(c.QueryParam("offset")); err == nil { + findMemoMessage.Offset = &offset + } + + memoDisplayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get memo display with updated ts setting value").SetInternal(err) + } + if memoDisplayWithUpdatedTs { + findMemoMessage.OrderByUpdatedTs = true + } + + list, err := s.Store.ListMemos(ctx, findMemoMessage) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch memo list").SetInternal(err) + } + memoResponseList := []*Memo{} + for _, memo := range list { + memoResponse, err := s.convertMemoFromStore(ctx, memo) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err) + } + memoResponseList = append(memoResponseList, memoResponse) + } + return c.JSON(http.StatusOK, memoResponseList) + }) + + g.GET("/memo/:memoId", func(c echo.Context) error { + ctx := c.Request().Context() + memoID, err := strconv.Atoi(c.Param("memoId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) + } + + memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ + ID: &memoID, + }) + if err != nil { + if common.ErrorCode(err) == common.NotFound { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo ID not found: %d", memoID)).SetInternal(err) + } + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find memo by ID: %v", memoID)).SetInternal(err) + } + + userID, ok := c.Get(getUserIDContextKey()).(int) + if memo.Visibility == store.Private { + if !ok || memo.CreatorID != userID { + return echo.NewHTTPError(http.StatusForbidden, "this memo is private only") + } + } else if memo.Visibility == store.Protected { + if !ok { + return echo.NewHTTPError(http.StatusForbidden, "this memo is protected, missing user in session") + } + } + memoResponse, err := s.convertMemoFromStore(ctx, memo) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err) + } + return c.JSON(http.StatusOK, memoResponse) + }) + + g.GET("/memo/stats", func(c echo.Context) error { + ctx := c.Request().Context() + normalStatus := store.Normal + findMemoMessage := &store.FindMemo{ + RowStatus: &normalStatus, + } + if creatorID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil { + findMemoMessage.CreatorID = &creatorID + } + if findMemoMessage.CreatorID == nil { + return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo") + } + + currentUserID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + findMemoMessage.VisibilityList = []store.Visibility{store.Public} + } else { + if *findMemoMessage.CreatorID != currentUserID { + findMemoMessage.VisibilityList = []store.Visibility{store.Public, store.Protected} + } else { + findMemoMessage.VisibilityList = []store.Visibility{store.Public, store.Protected, store.Private} + } + } + + memoDisplayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get memo display with updated ts setting value").SetInternal(err) + } + if memoDisplayWithUpdatedTs { + findMemoMessage.OrderByUpdatedTs = true + } + + list, err := s.Store.ListMemos(ctx, findMemoMessage) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err) + } + memoResponseList := []*Memo{} + for _, memo := range list { + memoResponse, err := s.convertMemoFromStore(ctx, memo) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err) + } + memoResponseList = append(memoResponseList, memoResponse) + } + + displayTsList := []int64{} + for _, memo := range memoResponseList { + displayTsList = append(displayTsList, memo.DisplayTs) + } + return c.JSON(http.StatusOK, displayTsList) + }) + + g.GET("/memo/all", func(c echo.Context) error { + ctx := c.Request().Context() + findMemoMessage := &store.FindMemo{} + _, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + findMemoMessage.VisibilityList = []store.Visibility{store.Public} + } else { + findMemoMessage.VisibilityList = []store.Visibility{store.Public, store.Protected} + } + + pinnedStr := c.QueryParam("pinned") + if pinnedStr != "" { + pinned := pinnedStr == "true" + findMemoMessage.Pinned = &pinned + } + + contentSearch := []string{} + tag := c.QueryParam("tag") + if tag != "" { + contentSearch = append(contentSearch, "#"+tag+" ") + } + contentSlice := c.QueryParams()["content"] + if len(contentSlice) > 0 { + contentSearch = append(contentSearch, contentSlice...) + } + findMemoMessage.ContentSearch = contentSearch + + visibilityListStr := c.QueryParam("visibility") + if visibilityListStr != "" { + visibilityList := []store.Visibility{} + for _, visibility := range strings.Split(visibilityListStr, ",") { + visibilityList = append(visibilityList, store.Visibility(visibility)) + } + findMemoMessage.VisibilityList = visibilityList + } + if limit, err := strconv.Atoi(c.QueryParam("limit")); err == nil { + findMemoMessage.Limit = &limit + } + if offset, err := strconv.Atoi(c.QueryParam("offset")); err == nil { + findMemoMessage.Offset = &offset + } + + // Only fetch normal status memos. + normalStatus := store.Normal + findMemoMessage.RowStatus = &normalStatus + + memoDisplayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get memo display with updated ts setting value").SetInternal(err) + } + if memoDisplayWithUpdatedTs { + findMemoMessage.OrderByUpdatedTs = true + } + + list, err := s.Store.ListMemos(ctx, findMemoMessage) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch all memo list").SetInternal(err) + } + memoResponseList := []*Memo{} + for _, memo := range list { + memoResponse, err := s.convertMemoFromStore(ctx, memo) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err) + } + memoResponseList = append(memoResponseList, memoResponse) + } + return c.JSON(http.StatusOK, memoResponseList) + }) + + g.DELETE("/memo/:memoId", func(c echo.Context) error { + ctx := c.Request().Context() + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") + } + memoID, err := strconv.Atoi(c.Param("memoId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) + } + + memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ + ID: &memoID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) + } + if memo.CreatorID != userID { + return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") + } + + if err := s.Store.DeleteMemo(ctx, &store.DeleteMemo{ + ID: memoID, + }); err != nil { + if common.ErrorCode(err) == common.NotFound { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo ID not found: %d", memoID)) + } + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to delete memo ID: %v", memoID)).SetInternal(err) + } + return c.JSON(http.StatusOK, true) + }) +} + +func (s *APIV1Service) createMemoCreateActivity(ctx context.Context, memo *store.Memo) error { + payload := ActivityMemoCreatePayload{ + Content: memo.Content, + Visibility: memo.Visibility.String(), + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + return errors.Wrap(err, "failed to marshal activity payload") + } + activity, err := s.Store.CreateActivity(ctx, &store.Activity{ + CreatorID: memo.CreatorID, + Type: ActivityMemoCreate.String(), + Level: ActivityInfo.String(), + Payload: string(payloadBytes), + }) + if err != nil || activity == nil { + return errors.Wrap(err, "failed to create activity") + } + return err +} + +func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo) (*Memo, error) { + memoResponse := &Memo{ + ID: memo.ID, + RowStatus: RowStatus(memo.RowStatus.String()), + CreatorID: memo.CreatorID, + CreatedTs: memo.CreatedTs, + UpdatedTs: memo.UpdatedTs, + Content: memo.Content, + Visibility: Visibility(memo.Visibility.String()), + Pinned: memo.Pinned, + } + + // Compose creator name. + user, err := s.Store.GetUser(ctx, &store.FindUser{ + ID: &memoResponse.CreatorID, + }) + if err != nil { + return nil, err + } + if user.Nickname != "" { + memoResponse.CreatorName = user.Nickname + } else { + memoResponse.CreatorName = user.Username + } + + // Compose display ts. + memoResponse.DisplayTs = memoResponse.CreatedTs + // Find memo display with updated ts setting. + memoDisplayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx) + if err != nil { + return nil, err + } + if memoDisplayWithUpdatedTs { + memoResponse.DisplayTs = memoResponse.UpdatedTs + } + + relationList := []*MemoRelation{} + for _, relation := range memo.RelationList { + relationList = append(relationList, convertMemoRelationFromStore(relation)) + } + memoResponse.RelationList = relationList + + resourceList := []*Resource{} + for _, resourceID := range memo.ResourceIDList { + resource, err := s.Store.GetResource(ctx, &store.FindResource{ + ID: &resourceID, + }) + if err != nil { + return nil, err + } + if resource != nil { + resourceList = append(resourceList, convertResourceFromStore(resource)) + } + } + memoResponse.ResourceList = resourceList + + return memoResponse, nil +} + +func (s *APIV1Service) getMemoDisplayWithUpdatedTsSettingValue(ctx context.Context) (bool, error) { + memoDisplayWithUpdatedTsSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{ + Name: SystemSettingMemoDisplayWithUpdatedTsName.String(), + }) + if err != nil { + return false, errors.Wrap(err, "failed to find system setting") + } + memoDisplayWithUpdatedTs := false + if memoDisplayWithUpdatedTsSetting != nil { + err = json.Unmarshal([]byte(memoDisplayWithUpdatedTsSetting.Value), &memoDisplayWithUpdatedTs) + if err != nil { + return false, errors.Wrap(err, "failed to unmarshal system setting value") + } + } + return memoDisplayWithUpdatedTs, nil +} + +func convertCreateMemoRequestToMemoMessage(memoCreate *CreateMemoRequest) *store.Memo { + createdTs := time.Now().Unix() + if memoCreate.CreatedTs != nil { + createdTs = *memoCreate.CreatedTs + } + return &store.Memo{ + CreatorID: memoCreate.CreatorID, + CreatedTs: createdTs, + Content: memoCreate.Content, + Visibility: store.Visibility(memoCreate.Visibility), + } +} + +func getMemoRelationListDiff(oldList, newList []*store.MemoRelation) (addedList, removedList []*store.MemoRelation) { + oldMap := map[string]bool{} + for _, relation := range oldList { + oldMap[fmt.Sprintf("%d-%s", relation.RelatedMemoID, relation.Type)] = true + } + newMap := map[string]bool{} + for _, relation := range newList { + newMap[fmt.Sprintf("%d-%s", relation.RelatedMemoID, relation.Type)] = true + } + for _, relation := range oldList { + key := fmt.Sprintf("%d-%s", relation.RelatedMemoID, relation.Type) + if !newMap[key] { + removedList = append(removedList, relation) + } + } + for _, relation := range newList { + key := fmt.Sprintf("%d-%s", relation.RelatedMemoID, relation.Type) + if !oldMap[key] { + addedList = append(addedList, relation) + } + } + return addedList, removedList +} + +func getIDListDiff(oldList, newList []int) (addedList, removedList []int) { + oldMap := map[int]bool{} + for _, id := range oldList { + oldMap[id] = true + } + newMap := map[int]bool{} + for _, id := range newList { + newMap[id] = true + } + for id := range oldMap { + if !newMap[id] { + removedList = append(removedList, id) + } + } + for id := range newMap { + if !oldMap[id] { + addedList = append(addedList, id) + } + } + return addedList, removedList } diff --git a/api/v1/memo_organizer.go b/api/v1/memo_organizer.go new file mode 100644 index 00000000..91bf1609 --- /dev/null +++ b/api/v1/memo_organizer.go @@ -0,0 +1,74 @@ +package v1 + +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + + "github.com/labstack/echo/v4" + "github.com/usememos/memos/store" +) + +type MemoOrganizer struct { + MemoID int `json:"memoId"` + UserID int `json:"userId"` + Pinned bool `json:"pinned"` +} + +type UpsertMemoOrganizerRequest struct { + Pinned bool `json:"pinned"` +} + +func (s *APIV1Service) registerMemoOrganizerRoutes(g *echo.Group) { + g.POST("/memo/:memoId/organizer", func(c echo.Context) error { + ctx := c.Request().Context() + memoID, err := strconv.Atoi(c.Param("memoId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) + } + + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") + } + + memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ + ID: &memoID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) + } + if memo.CreatorID != userID { + return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") + } + + request := &UpsertMemoOrganizerRequest{} + if err := json.NewDecoder(c.Request().Body).Decode(request); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post memo organizer request").SetInternal(err) + } + + upsert := &store.MemoOrganizer{ + MemoID: memoID, + UserID: userID, + Pinned: request.Pinned, + } + _, err = s.Store.UpsertMemoOrganizerV1(ctx, upsert) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo organizer").SetInternal(err) + } + + memo, err = s.Store.GetMemo(ctx, &store.FindMemo{ + ID: &memoID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find memo by ID: %v", memoID)).SetInternal(err) + } + + memoResponse, err := s.convertMemoFromStore(ctx, memo) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err) + } + return c.JSON(http.StatusOK, memoResponse) + }) +} diff --git a/server/memo_relation.go b/api/v1/memo_relation.go similarity index 68% rename from server/memo_relation.go rename to api/v1/memo_relation.go index e472cb54..9ab8079a 100644 --- a/server/memo_relation.go +++ b/api/v1/memo_relation.go @@ -1,4 +1,4 @@ -package server +package v1 import ( "encoding/json" @@ -6,13 +6,29 @@ import ( "net/http" "strconv" - "github.com/usememos/memos/api" - "github.com/usememos/memos/store" - "github.com/labstack/echo/v4" + "github.com/usememos/memos/store" ) -func (s *Server) registerMemoRelationRoutes(g *echo.Group) { +type MemoRelationType string + +const ( + MemoRelationReference MemoRelationType = "REFERENCE" + MemoRelationAdditional MemoRelationType = "ADDITIONAL" +) + +type MemoRelation struct { + MemoID int `json:"memoId"` + RelatedMemoID int `json:"relatedMemoId"` + Type MemoRelationType `json:"type"` +} + +type UpsertMemoRelationRequest struct { + RelatedMemoID int `json:"relatedMemoId"` + Type MemoRelationType `json:"type"` +} + +func (s *APIV1Service) registerMemoRelationRoutes(g *echo.Group) { g.POST("/memo/:memoId/relation", func(c echo.Context) error { ctx := c.Request().Context() memoID, err := strconv.Atoi(c.Param("memoId")) @@ -20,20 +36,20 @@ func (s *Server) registerMemoRelationRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) } - memoRelationUpsert := &api.MemoRelationUpsert{} - if err := json.NewDecoder(c.Request().Body).Decode(memoRelationUpsert); err != nil { + request := &UpsertMemoRelationRequest{} + if err := json.NewDecoder(c.Request().Body).Decode(request); err != nil { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post memo relation request").SetInternal(err) } - memoRelation, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelationMessage{ + memoRelation, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{ MemoID: memoID, - RelatedMemoID: memoRelationUpsert.RelatedMemoID, - Type: store.MemoRelationType(memoRelationUpsert.Type), + RelatedMemoID: request.RelatedMemoID, + Type: store.MemoRelationType(request.Type), }) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo relation").SetInternal(err) } - return c.JSON(http.StatusOK, composeResponse(memoRelation)) + return c.JSON(http.StatusOK, memoRelation) }) g.GET("/memo/:memoId/relation", func(c echo.Context) error { @@ -43,13 +59,13 @@ func (s *Server) registerMemoRelationRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) } - memoRelationList, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelationMessage{ + memoRelationList, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{ MemoID: &memoID, }) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to list memo relations").SetInternal(err) } - return c.JSON(http.StatusOK, composeResponse(memoRelationList)) + return c.JSON(http.StatusOK, memoRelationList) }) g.DELETE("/memo/:memoId/relation/:relatedMemoId/type/:relationType", func(c echo.Context) error { @@ -64,7 +80,7 @@ func (s *Server) registerMemoRelationRoutes(g *echo.Group) { } relationType := store.MemoRelationType(c.Param("relationType")) - if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelationMessage{ + if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{ MemoID: &memoID, RelatedMemoID: &relatedMemoID, Type: &relationType, @@ -75,10 +91,10 @@ func (s *Server) registerMemoRelationRoutes(g *echo.Group) { }) } -func convertMemoRelationMessageToMemoRelation(memoRelation *store.MemoRelationMessage) *api.MemoRelation { - return &api.MemoRelation{ +func convertMemoRelationFromStore(memoRelation *store.MemoRelation) *MemoRelation { + return &MemoRelation{ MemoID: memoRelation.MemoID, RelatedMemoID: memoRelation.RelatedMemoID, - Type: api.MemoRelationType(memoRelation.Type), + Type: MemoRelationType(memoRelation.Type), } } diff --git a/api/v1/memo_resource.go b/api/v1/memo_resource.go index 551b36ce..fded1087 100644 --- a/api/v1/memo_resource.go +++ b/api/v1/memo_resource.go @@ -1,16 +1,26 @@ package v1 +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/labstack/echo/v4" + "github.com/usememos/memos/store" +) + type MemoResource struct { - MemoID int - ResourceID int - CreatedTs int64 - UpdatedTs int64 + MemoID int `json:"memoId"` + ResourceID int `json:"resourceId"` + CreatedTs int64 `json:"createdTs"` + UpdatedTs int64 `json:"updatedTs"` } -type MemoResourceUpsert struct { - MemoID int `json:"-"` - ResourceID int - UpdatedTs *int64 +type UpsertMemoResourceRequest struct { + ResourceID int `json:"resourceId"` + UpdatedTs *int64 `json:"updatedTs"` } type MemoResourceFind struct { @@ -22,3 +32,100 @@ type MemoResourceDelete struct { MemoID *int ResourceID *int } + +func (s *APIV1Service) registerMemoResourceRoutes(g *echo.Group) { + g.POST("/memo/:memoId/resource", func(c echo.Context) error { + ctx := c.Request().Context() + memoID, err := strconv.Atoi(c.Param("memoId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) + } + + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") + } + request := &UpsertMemoResourceRequest{} + if err := json.NewDecoder(c.Request().Body).Decode(request); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post memo resource request").SetInternal(err) + } + resource, err := s.Store.GetResource(ctx, &store.FindResource{ + ID: &request.ResourceID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource").SetInternal(err) + } + if resource == nil { + return echo.NewHTTPError(http.StatusBadRequest, "Resource not found").SetInternal(err) + } else if resource.CreatorID != userID { + return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized to bind this resource").SetInternal(err) + } + + upsert := &store.UpsertMemoResource{ + MemoID: memoID, + ResourceID: request.ResourceID, + CreatedTs: time.Now().Unix(), + } + if request.UpdatedTs != nil { + upsert.UpdatedTs = request.UpdatedTs + } + if _, err := s.Store.UpsertMemoResource(ctx, upsert); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo resource").SetInternal(err) + } + return c.JSON(http.StatusOK, true) + }) + + g.GET("/memo/:memoId/resource", func(c echo.Context) error { + ctx := c.Request().Context() + memoID, err := strconv.Atoi(c.Param("memoId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) + } + + list, err := s.Store.ListResources(ctx, &store.FindResource{ + MemoID: &memoID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource list").SetInternal(err) + } + resourceList := []*Resource{} + for _, resource := range list { + resourceList = append(resourceList, convertResourceFromStore(resource)) + } + return c.JSON(http.StatusOK, resourceList) + }) + + g.DELETE("/memo/:memoId/resource/:resourceId", func(c echo.Context) error { + ctx := c.Request().Context() + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") + } + memoID, err := strconv.Atoi(c.Param("memoId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Memo ID is not a number: %s", c.Param("memoId"))).SetInternal(err) + } + resourceID, err := strconv.Atoi(c.Param("resourceId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Resource ID is not a number: %s", c.Param("resourceId"))).SetInternal(err) + } + + memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ + ID: &memoID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) + } + if memo.CreatorID != userID { + return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") + } + + if err := s.Store.DeleteMemoResource(ctx, &store.DeleteMemoResource{ + MemoID: &memoID, + ResourceID: &resourceID, + }); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource list").SetInternal(err) + } + return c.JSON(http.StatusOK, true) + }) +} diff --git a/api/v1/resource.go b/api/v1/resource.go index 9450cb8e..046f1ef0 100644 --- a/api/v1/resource.go +++ b/api/v1/resource.go @@ -144,7 +144,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) { } } - resource, err := s.Store.CreateResourceV1(ctx, create) + resource, err := s.Store.CreateResource(ctx, create) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create resource").SetInternal(err) } @@ -311,7 +311,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) { } create.PublicID = publicID - resource, err := s.Store.CreateResourceV1(ctx, create) + resource, err := s.Store.CreateResource(ctx, create) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create resource").SetInternal(err) } @@ -430,7 +430,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) { log.Warn(fmt.Sprintf("failed to delete local thumbnail with path %s", thumbnailPath), zap.Error(err)) } - if err := s.Store.DeleteResourceV1(ctx, &store.DeleteResource{ + if err := s.Store.DeleteResource(ctx, &store.DeleteResource{ ID: resourceID, }); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete resource").SetInternal(err) @@ -522,7 +522,7 @@ func (s *APIV1Service) createResourceCreateActivity(ctx context.Context, resourc if err != nil { return errors.Wrap(err, "failed to marshal activity payload") } - activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{ + activity, err := s.Store.CreateActivity(ctx, &store.Activity{ CreatorID: resource.CreatorID, Type: ActivityResourceCreate.String(), Level: ActivityInfo.String(), diff --git a/api/v1/shortcut.go b/api/v1/shortcut.go index 04620414..435680b2 100644 --- a/api/v1/shortcut.go +++ b/api/v1/shortcut.go @@ -210,7 +210,7 @@ func (s *APIV1Service) createShortcutCreateActivity(c echo.Context, shortcut *Sh if err != nil { return errors.Wrap(err, "failed to marshal activity payload") } - activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{ + activity, err := s.Store.CreateActivity(ctx, &store.Activity{ CreatorID: shortcut.CreatorID, Type: ActivityShortcutCreate.String(), Level: ActivityInfo.String(), diff --git a/api/v1/tag.go b/api/v1/tag.go index 0530581b..c6188ef9 100644 --- a/api/v1/tag.go +++ b/api/v1/tag.go @@ -84,7 +84,7 @@ func (s *APIV1Service) registerTagRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "Missing user session") } normalRowStatus := store.Normal - memoFind := &store.FindMemoMessage{ + memoFind := &store.FindMemo{ CreatorID: &userID, ContentSearch: []string{"#"}, RowStatus: &normalRowStatus, @@ -157,7 +157,7 @@ func (s *APIV1Service) createTagCreateActivity(c echo.Context, tag *Tag) error { if err != nil { return errors.Wrap(err, "failed to marshal activity payload") } - activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{ + activity, err := s.Store.CreateActivity(ctx, &store.Activity{ CreatorID: tag.CreatorID, Type: ActivityTagCreate.String(), Level: ActivityInfo.String(), diff --git a/api/v1/user.go b/api/v1/user.go index 1caa4090..67daa950 100644 --- a/api/v1/user.go +++ b/api/v1/user.go @@ -377,7 +377,7 @@ func (s *APIV1Service) createUserCreateActivity(c echo.Context, user *User) erro if err != nil { return errors.Wrap(err, "failed to marshal activity payload") } - activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{ + activity, err := s.Store.CreateActivity(ctx, &store.Activity{ CreatorID: user.ID, Type: ActivityUserCreate.String(), Level: ActivityInfo.String(), diff --git a/api/v1/v1.go b/api/v1/v1.go index d4096efd..074a7ab5 100644 --- a/api/v1/v1.go +++ b/api/v1/v1.go @@ -35,10 +35,15 @@ func (s *APIV1Service) Register(rootGroup *echo.Group) { s.registerShortcutRoutes(apiV1Group) s.registerStorageRoutes(apiV1Group) s.registerResourceRoutes(apiV1Group) + s.registerMemoRoutes(apiV1Group) + s.registerMemoOrganizerRoutes(apiV1Group) + s.registerMemoResourceRoutes(apiV1Group) + s.registerMemoRelationRoutes(apiV1Group) publicGroup := rootGroup.Group("/o") publicGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { return JWTMiddleware(s, next, s.Secret) }) + s.registerGetterPublicRoutes(publicGroup) s.registerResourcePublicRoutes(publicGroup) } diff --git a/server/common.go b/server/common.go deleted file mode 100644 index 7f6fccdc..00000000 --- a/server/common.go +++ /dev/null @@ -1,56 +0,0 @@ -package server - -import ( - "net/http" - - "github.com/labstack/echo/v4" - "github.com/usememos/memos/common" - "github.com/usememos/memos/store" -) - -type response struct { - Data any `json:"data"` -} - -func composeResponse(data any) response { - return response{ - Data: data, - } -} - -func defaultGetRequestSkipper(c echo.Context) bool { - return c.Request().Method == http.MethodGet -} - -func defaultAPIRequestSkipper(c echo.Context) bool { - path := c.Path() - return common.HasPrefixes(path, "/api") -} - -func (s *Server) defaultAuthSkipper(c echo.Context) bool { - ctx := c.Request().Context() - path := c.Path() - - // Skip auth. - if common.HasPrefixes(path, "/api/v1/auth") { - return true - } - - // If there is openId in query string and related user is found, then skip auth. - openID := c.QueryParam("openId") - if openID != "" { - user, err := s.Store.GetUser(ctx, &store.FindUser{ - OpenID: &openID, - }) - if err != nil && common.ErrorCode(err) != common.NotFound { - return false - } - if user != nil { - // Stores userID into context. - c.Set(getUserIDContextKey(), user.ID) - return true - } - } - - return false -} diff --git a/server/jwt.go b/server/jwt.go deleted file mode 100644 index fc2fe255..00000000 --- a/server/jwt.go +++ /dev/null @@ -1,206 +0,0 @@ -package server - -import ( - "fmt" - "net/http" - "strconv" - "strings" - "time" - - "github.com/golang-jwt/jwt/v4" - "github.com/labstack/echo/v4" - "github.com/pkg/errors" - "github.com/usememos/memos/common" - "github.com/usememos/memos/server/auth" - "github.com/usememos/memos/store" -) - -const ( - // Context section - // The key name used to store user id in the context - // user id is extracted from the jwt token subject field. - userIDContextKey = "user-id" -) - -func getUserIDContextKey() string { - return userIDContextKey -} - -// Claims creates a struct that will be encoded to a JWT. -// We add jwt.RegisteredClaims as an embedded type, to provide fields such as name. -type Claims struct { - Name string `json:"name"` - jwt.RegisteredClaims -} - -func extractTokenFromHeader(c echo.Context) (string, error) { - authHeader := c.Request().Header.Get("Authorization") - if authHeader == "" { - return "", nil - } - - authHeaderParts := strings.Fields(authHeader) - if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "", errors.New("Authorization header format must be Bearer {token}") - } - - return authHeaderParts[1], nil -} - -func findAccessToken(c echo.Context) string { - accessToken := "" - cookie, _ := c.Cookie(auth.AccessTokenCookieName) - if cookie != nil { - accessToken = cookie.Value - } - if accessToken == "" { - accessToken, _ = extractTokenFromHeader(c) - } - - return accessToken -} - -func audienceContains(audience jwt.ClaimStrings, token string) bool { - for _, v := range audience { - if v == token { - return true - } - } - return false -} - -// JWTMiddleware validates the access token. -// If the access token is about to expire or has expired and the request has a valid refresh token, it -// will try to generate new access token and refresh token. -func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.HandlerFunc { - return func(c echo.Context) error { - path := c.Request().URL.Path - method := c.Request().Method - - if server.defaultAuthSkipper(c) { - return next(c) - } - - token := findAccessToken(c) - if token == "" { - // Allow the user to access the public endpoints. - if common.HasPrefixes(path, "/o") { - return next(c) - } - // When the request is not authenticated, we allow the user to access the memo endpoints for those public memos. - if common.HasPrefixes(path, "/api/memo") && method == http.MethodGet { - return next(c) - } - return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") - } - - claims := &Claims{} - accessToken, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) { - if t.Method.Alg() != jwt.SigningMethodHS256.Name { - return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) - } - if kid, ok := t.Header["kid"].(string); ok { - if kid == "v1" { - return []byte(secret), nil - } - } - return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) - }) - - if !accessToken.Valid { - return echo.NewHTTPError(http.StatusUnauthorized, "Invalid access token.") - } - - if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) { - return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q.", claims.Audience, auth.AccessTokenAudienceName)) - } - generateToken := time.Until(claims.ExpiresAt.Time) < auth.RefreshThresholdDuration - if err != nil { - var ve *jwt.ValidationError - if errors.As(err, &ve) { - // If expiration error is the only error, we will clear the err - // and generate new access token and refresh token - if ve.Errors == jwt.ValidationErrorExpired { - generateToken = true - } - } else { - return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token")) - } - } - - // We either have a valid access token or we will attempt to generate new access token and refresh token - ctx := c.Request().Context() - userID, err := strconv.Atoi(claims.Subject) - if err != nil { - return echo.NewHTTPError(http.StatusUnauthorized, "Malformed ID in the token.") - } - - // Even if there is no error, we still need to make sure the user still exists. - user, err := server.Store.GetUser(ctx, &store.FindUser{ - ID: &userID, - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to find user ID: %d", userID)).SetInternal(err) - } - if user == nil { - return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to find user ID: %d", userID)) - } - - if generateToken { - generateTokenFunc := func() error { - rc, err := c.Cookie(auth.RefreshTokenCookieName) - if err != nil { - return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Missing refresh token.") - } - - // Parses token and checks if it's valid. - refreshTokenClaims := &Claims{} - refreshToken, err := jwt.ParseWithClaims(rc.Value, refreshTokenClaims, func(t *jwt.Token) (any, error) { - if t.Method.Alg() != jwt.SigningMethodHS256.Name { - return nil, errors.Errorf("unexpected refresh token signing method=%v, expected %v", t.Header["alg"], jwt.SigningMethodHS256) - } - - if kid, ok := t.Header["kid"].(string); ok { - if kid == "v1" { - return []byte(secret), nil - } - } - return nil, errors.Errorf("unexpected refresh token kid=%v", t.Header["kid"]) - }) - if err != nil { - if err == jwt.ErrSignatureInvalid { - return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Invalid refresh token signature.") - } - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err) - } - - if !audienceContains(refreshTokenClaims.Audience, auth.RefreshTokenAudienceName) { - return echo.NewHTTPError(http.StatusUnauthorized, - fmt.Sprintf("Invalid refresh token, audience mismatch, got %q, expected %q. you may send request to the wrong environment", - refreshTokenClaims.Audience, - auth.RefreshTokenAudienceName, - )) - } - - // If we have a valid refresh token, we will generate new access token and refresh token - if refreshToken != nil && refreshToken.Valid { - if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err) - } - } - - return nil - } - - // It may happen that we still have a valid access token, but we encounter issue when trying to generate new token - // In such case, we won't return the error. - if err := generateTokenFunc(); err != nil && !accessToken.Valid { - return err - } - } - - // Stores userID into context. - c.Set(getUserIDContextKey(), userID) - return next(c) - } -} diff --git a/server/memo.go b/server/memo.go deleted file mode 100644 index d39fec4f..00000000 --- a/server/memo.go +++ /dev/null @@ -1,735 +0,0 @@ -package server - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "strconv" - "strings" - "time" - - "github.com/pkg/errors" - "github.com/usememos/memos/api" - apiv1 "github.com/usememos/memos/api/v1" - "github.com/usememos/memos/common" - "github.com/usememos/memos/store" - - "github.com/labstack/echo/v4" -) - -// maxContentLength means the max memo content bytes is 1MB. -const maxContentLength = 1 << 30 - -func (s *Server) registerMemoRoutes(g *echo.Group) { - g.POST("/memo", func(c echo.Context) error { - ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") - } - - createMemoRequest := &api.CreateMemoRequest{} - if err := json.NewDecoder(c.Request().Body).Decode(createMemoRequest); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post memo request").SetInternal(err) - } - if len(createMemoRequest.Content) > maxContentLength { - return echo.NewHTTPError(http.StatusBadRequest, "Content size overflow, up to 1MB") - } - - if createMemoRequest.Visibility == "" { - userMemoVisibilitySetting, err := s.Store.GetUserSetting(ctx, &store.FindUserSetting{ - UserID: &userID, - Key: apiv1.UserSettingMemoVisibilityKey.String(), - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user setting").SetInternal(err) - } - - if userMemoVisibilitySetting != nil { - memoVisibility := api.Private - err := json.Unmarshal([]byte(userMemoVisibilitySetting.Value), &memoVisibility) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal user setting value").SetInternal(err) - } - createMemoRequest.Visibility = memoVisibility - } else { - // Private is the default memo visibility. - createMemoRequest.Visibility = api.Private - } - } - - // Find disable public memos system setting. - disablePublicMemosSystemSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{ - Name: apiv1.SystemSettingDisablePublicMemosName.String(), - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting").SetInternal(err) - } - if disablePublicMemosSystemSetting != nil { - disablePublicMemos := false - err = json.Unmarshal([]byte(disablePublicMemosSystemSetting.Value), &disablePublicMemos) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting").SetInternal(err) - } - if disablePublicMemos { - user, err := s.Store.GetUser(ctx, &store.FindUser{ - ID: &userID, - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) - } - // Enforce normal user to create private memo if public memos are disabled. - if user.Role == store.RoleUser { - createMemoRequest.Visibility = api.Private - } - } - } - - createMemoRequest.CreatorID = userID - memoMessage, err := s.Store.CreateMemo(ctx, convertCreateMemoRequestToMemoMessage(createMemoRequest)) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create memo").SetInternal(err) - } - if err := s.createMemoCreateActivity(ctx, memoMessage); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) - } - - for _, resourceID := range createMemoRequest.ResourceIDList { - if _, err := s.Store.UpsertMemoResource(ctx, &api.MemoResourceUpsert{ - MemoID: memoMessage.ID, - ResourceID: resourceID, - }); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo resource").SetInternal(err) - } - } - - for _, memoRelationUpsert := range createMemoRequest.RelationList { - if _, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelationMessage{ - MemoID: memoMessage.ID, - RelatedMemoID: memoRelationUpsert.RelatedMemoID, - Type: store.MemoRelationType(memoRelationUpsert.Type), - }); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo relation").SetInternal(err) - } - } - - memoMessage, err = s.Store.GetMemo(ctx, &store.FindMemoMessage{ - ID: &memoMessage.ID, - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo").SetInternal(err) - } - memoResponse, err := s.composeMemoMessageToMemoResponse(ctx, memoMessage) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err) - } - return c.JSON(http.StatusOK, composeResponse(memoResponse)) - }) - - g.PATCH("/memo/:memoId", func(c echo.Context) error { - ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") - } - - memoID, err := strconv.Atoi(c.Param("memoId")) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) - } - - memoMessage, err := s.Store.GetMemo(ctx, &store.FindMemoMessage{ - ID: &memoID, - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) - } - if memoMessage.CreatorID != userID { - return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") - } - - currentTs := time.Now().Unix() - patchMemoRequest := &api.PatchMemoRequest{ - ID: memoID, - UpdatedTs: ¤tTs, - } - if err := json.NewDecoder(c.Request().Body).Decode(patchMemoRequest); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch memo request").SetInternal(err) - } - - if patchMemoRequest.Content != nil && len(*patchMemoRequest.Content) > maxContentLength { - return echo.NewHTTPError(http.StatusBadRequest, "Content size overflow, up to 1MB").SetInternal(err) - } - - updateMemoMessage := &store.UpdateMemoMessage{ - ID: memoID, - CreatedTs: patchMemoRequest.CreatedTs, - UpdatedTs: patchMemoRequest.UpdatedTs, - Content: patchMemoRequest.Content, - } - if patchMemoRequest.RowStatus != nil { - rowStatus := store.RowStatus(patchMemoRequest.RowStatus.String()) - updateMemoMessage.RowStatus = &rowStatus - } - if patchMemoRequest.Visibility != nil { - visibility := store.Visibility(patchMemoRequest.Visibility.String()) - updateMemoMessage.Visibility = &visibility - } - - err = s.Store.UpdateMemo(ctx, updateMemoMessage) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch memo").SetInternal(err) - } - memoMessage, err = s.Store.GetMemo(ctx, &store.FindMemoMessage{ID: &memoID}) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) - } - - if patchMemoRequest.ResourceIDList != nil { - addedResourceIDList, removedResourceIDList := getIDListDiff(memoMessage.ResourceIDList, patchMemoRequest.ResourceIDList) - for _, resourceID := range addedResourceIDList { - if _, err := s.Store.UpsertMemoResource(ctx, &api.MemoResourceUpsert{ - MemoID: memoMessage.ID, - ResourceID: resourceID, - }); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo resource").SetInternal(err) - } - } - for _, resourceID := range removedResourceIDList { - if err := s.Store.DeleteMemoResource(ctx, &api.MemoResourceDelete{ - MemoID: &memoMessage.ID, - ResourceID: &resourceID, - }); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete memo resource").SetInternal(err) - } - } - } - - if patchMemoRequest.RelationList != nil { - patchMemoRelationList := make([]*store.MemoRelationMessage, 0) - for _, memoRelation := range patchMemoRequest.RelationList { - patchMemoRelationList = append(patchMemoRelationList, &store.MemoRelationMessage{ - MemoID: memoMessage.ID, - RelatedMemoID: memoRelation.RelatedMemoID, - Type: store.MemoRelationType(memoRelation.Type), - }) - } - addedMemoRelationList, removedMemoRelationList := getMemoRelationListDiff(memoMessage.RelationList, patchMemoRelationList) - for _, memoRelation := range addedMemoRelationList { - if _, err := s.Store.UpsertMemoRelation(ctx, memoRelation); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo relation").SetInternal(err) - } - } - for _, memoRelation := range removedMemoRelationList { - if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelationMessage{ - MemoID: &memoMessage.ID, - RelatedMemoID: &memoRelation.RelatedMemoID, - Type: &memoRelation.Type, - }); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete memo relation").SetInternal(err) - } - } - } - - memoMessage, err = s.Store.GetMemo(ctx, &store.FindMemoMessage{ID: &memoID}) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) - } - memoResponse, err := s.composeMemoMessageToMemoResponse(ctx, memoMessage) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err) - } - return c.JSON(http.StatusOK, composeResponse(memoResponse)) - }) - - g.GET("/memo", func(c echo.Context) error { - ctx := c.Request().Context() - findMemoMessage := &store.FindMemoMessage{} - if userID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil { - findMemoMessage.CreatorID = &userID - } - - currentUserID, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - if findMemoMessage.CreatorID == nil { - return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo") - } - findMemoMessage.VisibilityList = []store.Visibility{store.Public} - } else { - if findMemoMessage.CreatorID == nil { - findMemoMessage.CreatorID = ¤tUserID - } else { - findMemoMessage.VisibilityList = []store.Visibility{store.Public, store.Protected} - } - } - - rowStatus := store.RowStatus(c.QueryParam("rowStatus")) - if rowStatus != "" { - findMemoMessage.RowStatus = &rowStatus - } - pinnedStr := c.QueryParam("pinned") - if pinnedStr != "" { - pinned := pinnedStr == "true" - findMemoMessage.Pinned = &pinned - } - - contentSearch := []string{} - tag := c.QueryParam("tag") - if tag != "" { - contentSearch = append(contentSearch, "#"+tag) - } - contentSlice := c.QueryParams()["content"] - if len(contentSlice) > 0 { - contentSearch = append(contentSearch, contentSlice...) - } - findMemoMessage.ContentSearch = contentSearch - - visibilityListStr := c.QueryParam("visibility") - if visibilityListStr != "" { - visibilityList := []store.Visibility{} - for _, visibility := range strings.Split(visibilityListStr, ",") { - visibilityList = append(visibilityList, store.Visibility(visibility)) - } - findMemoMessage.VisibilityList = visibilityList - } - if limit, err := strconv.Atoi(c.QueryParam("limit")); err == nil { - findMemoMessage.Limit = &limit - } - if offset, err := strconv.Atoi(c.QueryParam("offset")); err == nil { - findMemoMessage.Offset = &offset - } - - memoDisplayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get memo display with updated ts setting value").SetInternal(err) - } - if memoDisplayWithUpdatedTs { - findMemoMessage.OrderByUpdatedTs = true - } - - memoMessageList, err := s.Store.ListMemos(ctx, findMemoMessage) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch memo list").SetInternal(err) - } - memoResponseList := []*api.MemoResponse{} - for _, memoMessage := range memoMessageList { - memoResponse, err := s.composeMemoMessageToMemoResponse(ctx, memoMessage) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err) - } - memoResponseList = append(memoResponseList, memoResponse) - } - return c.JSON(http.StatusOK, composeResponse(memoResponseList)) - }) - - g.GET("/memo/:memoId", func(c echo.Context) error { - ctx := c.Request().Context() - memoID, err := strconv.Atoi(c.Param("memoId")) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) - } - - memoMessage, err := s.Store.GetMemo(ctx, &store.FindMemoMessage{ - ID: &memoID, - }) - if err != nil { - if common.ErrorCode(err) == common.NotFound { - return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo ID not found: %d", memoID)).SetInternal(err) - } - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find memo by ID: %v", memoID)).SetInternal(err) - } - - userID, ok := c.Get(getUserIDContextKey()).(int) - if memoMessage.Visibility == store.Private { - if !ok || memoMessage.CreatorID != userID { - return echo.NewHTTPError(http.StatusForbidden, "this memo is private only") - } - } else if memoMessage.Visibility == store.Protected { - if !ok { - return echo.NewHTTPError(http.StatusForbidden, "this memo is protected, missing user in session") - } - } - memoResponse, err := s.composeMemoMessageToMemoResponse(ctx, memoMessage) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err) - } - return c.JSON(http.StatusOK, composeResponse(memoResponse)) - }) - - g.POST("/memo/:memoId/organizer", func(c echo.Context) error { - ctx := c.Request().Context() - memoID, err := strconv.Atoi(c.Param("memoId")) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) - } - - userID, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") - } - - memo, err := s.Store.GetMemo(ctx, &store.FindMemoMessage{ - ID: &memoID, - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) - } - if memo.CreatorID != userID { - return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") - } - - memoOrganizerUpsert := &api.MemoOrganizerUpsert{} - if err := json.NewDecoder(c.Request().Body).Decode(memoOrganizerUpsert); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post memo organizer request").SetInternal(err) - } - memoOrganizerUpsert.MemoID = memoID - memoOrganizerUpsert.UserID = userID - - err = s.Store.UpsertMemoOrganizer(ctx, memoOrganizerUpsert) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo organizer").SetInternal(err) - } - - memoMessage, err := s.Store.GetMemo(ctx, &store.FindMemoMessage{ - ID: &memoID, - }) - if err != nil { - if common.ErrorCode(err) == common.NotFound { - return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo ID not found: %d", memoID)).SetInternal(err) - } - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find memo by ID: %v", memoID)).SetInternal(err) - } - memoResponse, err := s.composeMemoMessageToMemoResponse(ctx, memoMessage) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err) - } - return c.JSON(http.StatusOK, composeResponse(memoResponse)) - }) - - g.GET("/memo/stats", func(c echo.Context) error { - ctx := c.Request().Context() - normalStatus := store.Normal - findMemoMessage := &store.FindMemoMessage{ - RowStatus: &normalStatus, - } - if creatorID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil { - findMemoMessage.CreatorID = &creatorID - } - if findMemoMessage.CreatorID == nil { - return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo") - } - - currentUserID, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - findMemoMessage.VisibilityList = []store.Visibility{store.Public} - } else { - if *findMemoMessage.CreatorID != currentUserID { - findMemoMessage.VisibilityList = []store.Visibility{store.Public, store.Protected} - } else { - findMemoMessage.VisibilityList = []store.Visibility{store.Public, store.Protected, store.Private} - } - } - - memoDisplayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get memo display with updated ts setting value").SetInternal(err) - } - if memoDisplayWithUpdatedTs { - findMemoMessage.OrderByUpdatedTs = true - } - - memoMessageList, err := s.Store.ListMemos(ctx, findMemoMessage) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err) - } - memoResponseList := []*api.MemoResponse{} - for _, memoMessage := range memoMessageList { - memoResponse, err := s.composeMemoMessageToMemoResponse(ctx, memoMessage) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err) - } - memoResponseList = append(memoResponseList, memoResponse) - } - - displayTsList := []int64{} - for _, memo := range memoResponseList { - displayTsList = append(displayTsList, memo.DisplayTs) - } - return c.JSON(http.StatusOK, composeResponse(displayTsList)) - }) - - g.GET("/memo/all", func(c echo.Context) error { - ctx := c.Request().Context() - findMemoMessage := &store.FindMemoMessage{} - _, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - findMemoMessage.VisibilityList = []store.Visibility{store.Public} - } else { - findMemoMessage.VisibilityList = []store.Visibility{store.Public, store.Protected} - } - - pinnedStr := c.QueryParam("pinned") - if pinnedStr != "" { - pinned := pinnedStr == "true" - findMemoMessage.Pinned = &pinned - } - - contentSearch := []string{} - tag := c.QueryParam("tag") - if tag != "" { - contentSearch = append(contentSearch, "#"+tag+" ") - } - contentSlice := c.QueryParams()["content"] - if len(contentSlice) > 0 { - contentSearch = append(contentSearch, contentSlice...) - } - findMemoMessage.ContentSearch = contentSearch - - visibilityListStr := c.QueryParam("visibility") - if visibilityListStr != "" { - visibilityList := []store.Visibility{} - for _, visibility := range strings.Split(visibilityListStr, ",") { - visibilityList = append(visibilityList, store.Visibility(visibility)) - } - findMemoMessage.VisibilityList = visibilityList - } - if limit, err := strconv.Atoi(c.QueryParam("limit")); err == nil { - findMemoMessage.Limit = &limit - } - if offset, err := strconv.Atoi(c.QueryParam("offset")); err == nil { - findMemoMessage.Offset = &offset - } - - // Only fetch normal status memos. - normalStatus := store.Normal - findMemoMessage.RowStatus = &normalStatus - - memoDisplayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get memo display with updated ts setting value").SetInternal(err) - } - if memoDisplayWithUpdatedTs { - findMemoMessage.OrderByUpdatedTs = true - } - - memoMessageList, err := s.Store.ListMemos(ctx, findMemoMessage) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch all memo list").SetInternal(err) - } - memoResponseList := []*api.MemoResponse{} - for _, memoMessage := range memoMessageList { - memoResponse, err := s.composeMemoMessageToMemoResponse(ctx, memoMessage) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err) - } - memoResponseList = append(memoResponseList, memoResponse) - } - return c.JSON(http.StatusOK, composeResponse(memoResponseList)) - }) - - g.DELETE("/memo/:memoId", func(c echo.Context) error { - ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") - } - memoID, err := strconv.Atoi(c.Param("memoId")) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) - } - - memo, err := s.Store.GetMemo(ctx, &store.FindMemoMessage{ - ID: &memoID, - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) - } - if memo.CreatorID != userID { - return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") - } - - if err := s.Store.DeleteMemo(ctx, &store.DeleteMemoMessage{ - ID: memoID, - }); err != nil { - if common.ErrorCode(err) == common.NotFound { - return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo ID not found: %d", memoID)) - } - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to delete memo ID: %v", memoID)).SetInternal(err) - } - return c.JSON(http.StatusOK, true) - }) -} - -func (s *Server) createMemoCreateActivity(ctx context.Context, memo *store.MemoMessage) error { - payload := apiv1.ActivityMemoCreatePayload{ - Content: memo.Content, - Visibility: memo.Visibility.String(), - } - payloadBytes, err := json.Marshal(payload) - if err != nil { - return errors.Wrap(err, "failed to marshal activity payload") - } - activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{ - CreatorID: memo.CreatorID, - Type: apiv1.ActivityMemoCreate.String(), - Level: apiv1.ActivityInfo.String(), - Payload: string(payloadBytes), - }) - if err != nil || activity == nil { - return errors.Wrap(err, "failed to create activity") - } - return err -} - -func getIDListDiff(oldList, newList []int) (addedList, removedList []int) { - oldMap := map[int]bool{} - for _, id := range oldList { - oldMap[id] = true - } - newMap := map[int]bool{} - for _, id := range newList { - newMap[id] = true - } - for id := range oldMap { - if !newMap[id] { - removedList = append(removedList, id) - } - } - for id := range newMap { - if !oldMap[id] { - addedList = append(addedList, id) - } - } - return addedList, removedList -} - -func getMemoRelationListDiff(oldList, newList []*store.MemoRelationMessage) (addedList, removedList []*store.MemoRelationMessage) { - oldMap := map[string]bool{} - for _, relation := range oldList { - oldMap[fmt.Sprintf("%d-%s", relation.RelatedMemoID, relation.Type)] = true - } - newMap := map[string]bool{} - for _, relation := range newList { - newMap[fmt.Sprintf("%d-%s", relation.RelatedMemoID, relation.Type)] = true - } - for _, relation := range oldList { - key := fmt.Sprintf("%d-%s", relation.RelatedMemoID, relation.Type) - if !newMap[key] { - removedList = append(removedList, relation) - } - } - for _, relation := range newList { - key := fmt.Sprintf("%d-%s", relation.RelatedMemoID, relation.Type) - if !oldMap[key] { - addedList = append(addedList, relation) - } - } - return addedList, removedList -} - -func convertCreateMemoRequestToMemoMessage(memoCreate *api.CreateMemoRequest) *store.MemoMessage { - createdTs := time.Now().Unix() - if memoCreate.CreatedTs != nil { - createdTs = *memoCreate.CreatedTs - } - return &store.MemoMessage{ - CreatorID: memoCreate.CreatorID, - CreatedTs: createdTs, - Content: memoCreate.Content, - Visibility: store.Visibility(memoCreate.Visibility), - } -} - -func (s *Server) composeMemoMessageToMemoResponse(ctx context.Context, memoMessage *store.MemoMessage) (*api.MemoResponse, error) { - memoResponse := &api.MemoResponse{ - ID: memoMessage.ID, - RowStatus: api.RowStatus(memoMessage.RowStatus.String()), - CreatorID: memoMessage.CreatorID, - CreatedTs: memoMessage.CreatedTs, - UpdatedTs: memoMessage.UpdatedTs, - Content: memoMessage.Content, - Visibility: api.Visibility(memoMessage.Visibility.String()), - Pinned: memoMessage.Pinned, - } - - // Compose creator name. - user, err := s.Store.GetUser(ctx, &store.FindUser{ - ID: &memoResponse.CreatorID, - }) - if err != nil { - return nil, err - } - if user.Nickname != "" { - memoResponse.CreatorName = user.Nickname - } else { - memoResponse.CreatorName = user.Username - } - - // Compose display ts. - memoResponse.DisplayTs = memoResponse.CreatedTs - // Find memo display with updated ts setting. - memoDisplayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx) - if err != nil { - return nil, err - } - if memoDisplayWithUpdatedTs { - memoResponse.DisplayTs = memoResponse.UpdatedTs - } - - relationList := []*api.MemoRelation{} - for _, relation := range memoMessage.RelationList { - relationList = append(relationList, convertMemoRelationMessageToMemoRelation(relation)) - } - memoResponse.RelationList = relationList - - resourceList := []*api.Resource{} - for _, resourceID := range memoMessage.ResourceIDList { - resource, err := s.Store.GetResource(ctx, &store.FindResource{ - ID: &resourceID, - }) - if err != nil { - return nil, err - } - if resource != nil { - resourceList = append(resourceList, convertResourceFromStore(resource)) - } - } - memoResponse.ResourceList = resourceList - - return memoResponse, nil -} - -func (s *Server) getMemoDisplayWithUpdatedTsSettingValue(ctx context.Context) (bool, error) { - memoDisplayWithUpdatedTsSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{ - Name: apiv1.SystemSettingMemoDisplayWithUpdatedTsName.String(), - }) - if err != nil { - return false, errors.Wrap(err, "failed to find system setting") - } - memoDisplayWithUpdatedTs := false - if memoDisplayWithUpdatedTsSetting != nil { - err = json.Unmarshal([]byte(memoDisplayWithUpdatedTsSetting.Value), &memoDisplayWithUpdatedTs) - if err != nil { - return false, errors.Wrap(err, "failed to unmarshal system setting value") - } - } - return memoDisplayWithUpdatedTs, nil -} - -func convertResourceFromStore(resource *store.Resource) *api.Resource { - return &api.Resource{ - ID: resource.ID, - CreatorID: resource.CreatorID, - CreatedTs: resource.CreatedTs, - UpdatedTs: resource.UpdatedTs, - Filename: resource.Filename, - Blob: resource.Blob, - InternalPath: resource.InternalPath, - ExternalLink: resource.ExternalLink, - Type: resource.Type, - Size: resource.Size, - PublicID: resource.PublicID, - LinkedMemoAmount: resource.LinkedMemoAmount, - } -} diff --git a/server/memo_resource.go b/server/memo_resource.go deleted file mode 100644 index 3a3e9060..00000000 --- a/server/memo_resource.go +++ /dev/null @@ -1,107 +0,0 @@ -package server - -import ( - "encoding/json" - "fmt" - "net/http" - "strconv" - "time" - - "github.com/usememos/memos/api" - "github.com/usememos/memos/store" - - "github.com/labstack/echo/v4" -) - -func (s *Server) registerMemoResourceRoutes(g *echo.Group) { - g.POST("/memo/:memoId/resource", func(c echo.Context) error { - ctx := c.Request().Context() - memoID, err := strconv.Atoi(c.Param("memoId")) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) - } - - userID, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") - } - memoResourceUpsert := &api.MemoResourceUpsert{} - if err := json.NewDecoder(c.Request().Body).Decode(memoResourceUpsert); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post memo resource request").SetInternal(err) - } - resource, err := s.Store.GetResource(ctx, &store.FindResource{ - ID: &memoResourceUpsert.ResourceID, - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource").SetInternal(err) - } - if resource == nil { - return echo.NewHTTPError(http.StatusBadRequest, "Resource not found").SetInternal(err) - } else if resource.CreatorID != userID { - return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized to bind this resource").SetInternal(err) - } - - memoResourceUpsert.MemoID = memoID - currentTs := time.Now().Unix() - memoResourceUpsert.UpdatedTs = ¤tTs - if _, err := s.Store.UpsertMemoResource(ctx, memoResourceUpsert); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo resource").SetInternal(err) - } - return c.JSON(http.StatusOK, true) - }) - - g.GET("/memo/:memoId/resource", func(c echo.Context) error { - ctx := c.Request().Context() - memoID, err := strconv.Atoi(c.Param("memoId")) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) - } - - list, err := s.Store.ListResources(ctx, &store.FindResource{ - MemoID: &memoID, - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource list").SetInternal(err) - } - resourceList := []*api.Resource{} - for _, resource := range list { - resourceList = append(resourceList, convertResourceFromStore(resource)) - } - return c.JSON(http.StatusOK, composeResponse(resourceList)) - }) - - g.DELETE("/memo/:memoId/resource/:resourceId", func(c echo.Context) error { - ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") - } - memoID, err := strconv.Atoi(c.Param("memoId")) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Memo ID is not a number: %s", c.Param("memoId"))).SetInternal(err) - } - resourceID, err := strconv.Atoi(c.Param("resourceId")) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Resource ID is not a number: %s", c.Param("resourceId"))).SetInternal(err) - } - - memo, err := s.Store.GetMemo(ctx, &store.FindMemoMessage{ - ID: &memoID, - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) - } - if memo.CreatorID != userID { - return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") - } - - memoResourceDelete := &api.MemoResourceDelete{ - MemoID: &memoID, - ResourceID: &resourceID, - } - if err := s.Store.DeleteMemoResource(ctx, memoResourceDelete); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource list").SetInternal(err) - } - return c.JSON(http.StatusOK, true) - }) -} diff --git a/server/rss.go b/server/rss.go index 31343306..81a47c7d 100644 --- a/server/rss.go +++ b/server/rss.go @@ -26,7 +26,7 @@ func (s *Server) registerRSSRoutes(g *echo.Group) { } normalStatus := store.Normal - memoFind := store.FindMemoMessage{ + memoFind := store.FindMemo{ RowStatus: &normalStatus, VisibilityList: []store.Visibility{store.Public}, } @@ -57,7 +57,7 @@ func (s *Server) registerRSSRoutes(g *echo.Group) { } normalStatus := store.Normal - memoFind := store.FindMemoMessage{ + memoFind := store.FindMemo{ CreatorID: &id, RowStatus: &normalStatus, VisibilityList: []store.Visibility{store.Public}, @@ -80,7 +80,7 @@ func (s *Server) registerRSSRoutes(g *echo.Group) { const MaxRSSItemCount = 100 const MaxRSSItemTitleLength = 100 -func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store.MemoMessage, baseURL string, profile *apiv1.CustomizedProfile) (string, error) { +func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store.Memo, baseURL string, profile *apiv1.CustomizedProfile) (string, error) { feed := &feeds.Feed{ Title: profile.Name, Link: &feeds.Link{Href: baseURL}, diff --git a/server/server.go b/server/server.go index 3c3c24f7..ba543610 100644 --- a/server/server.go +++ b/server/server.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "time" "github.com/google/uuid" @@ -87,20 +88,6 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store rootGroup := e.Group("") s.registerRSSRoutes(rootGroup) - publicGroup := e.Group("/o") - publicGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { - return JWTMiddleware(s, next, s.Secret) - }) - registerGetterPublicRoutes(publicGroup) - - apiGroup := e.Group("/api") - apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { - return JWTMiddleware(s, next, s.Secret) - }) - s.registerMemoRoutes(apiGroup) - s.registerMemoResourceRoutes(apiGroup) - s.registerMemoRelationRoutes(apiGroup) - apiV1Service := apiv1.NewAPIV1Service(s.Secret, profile, store) apiV1Service.Register(rootGroup) @@ -185,7 +172,7 @@ func (s *Server) createServerStartActivity(ctx context.Context) error { if err != nil { return errors.Wrap(err, "failed to marshal activity payload") } - activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{ + activity, err := s.Store.CreateActivity(ctx, &store.Activity{ CreatorID: apiv1.UnknownID, Type: apiv1.ActivityServerStart.String(), Level: apiv1.ActivityInfo.String(), @@ -196,3 +183,12 @@ func (s *Server) createServerStartActivity(ctx context.Context) error { } return err } + +func defaultGetRequestSkipper(c echo.Context) bool { + return c.Request().Method == http.MethodGet +} + +func defaultAPIRequestSkipper(c echo.Context) bool { + path := c.Path() + return common.HasPrefixes(path, "/api", "/api/v1") +} diff --git a/server/telegram.go b/server/telegram.go index 884ee3a7..1e072e46 100644 --- a/server/telegram.go +++ b/server/telegram.go @@ -8,7 +8,6 @@ import ( "strconv" "github.com/pkg/errors" - "github.com/usememos/memos/api" apiv1 "github.com/usememos/memos/api/v1" "github.com/usememos/memos/common" "github.com/usememos/memos/plugin/telegram" @@ -61,20 +60,19 @@ func (t *telegramHandler) MessageHandle(ctx context.Context, bot *telegram.Bot, return err } - // create memo - memoCreate := api.CreateMemoRequest{ + create := &store.Memo{ CreatorID: creatorID, - Visibility: api.Private, + Visibility: store.Private, } if message.Text != nil { - memoCreate.Content = *message.Text + create.Content = *message.Text } if blobs != nil && message.Caption != nil { - memoCreate.Content = *message.Caption + create.Content = *message.Caption } - memoMessage, err := t.store.CreateMemo(ctx, convertCreateMemoRequestToMemoMessage(&memoCreate)) + memoMessage, err := t.store.CreateMemo(ctx, create) if err != nil { _, err := bot.EditMessage(ctx, message.Chat.ID, reply.MessageID, fmt.Sprintf("failed to CreateMemo: %s", err), nil) return err @@ -90,7 +88,7 @@ func (t *telegramHandler) MessageHandle(ctx context.Context, bot *telegram.Bot, case ".png": mime = "image/png" } - resource, err := t.store.CreateResourceV1(ctx, &store.Resource{ + resource, err := t.store.CreateResource(ctx, &store.Resource{ CreatorID: creatorID, Filename: filename, Type: mime, @@ -103,7 +101,7 @@ func (t *telegramHandler) MessageHandle(ctx context.Context, bot *telegram.Bot, return err } - _, err = t.store.UpsertMemoResource(ctx, &api.MemoResourceUpsert{ + _, err = t.store.UpsertMemoResource(ctx, &store.UpsertMemoResource{ MemoID: memoMessage.ID, ResourceID: resource.ID, }) @@ -126,7 +124,7 @@ func (t *telegramHandler) CallbackQueryHandle(ctx context.Context, bot *telegram return bot.AnswerCallbackQuery(ctx, callbackQuery.ID, fmt.Sprintf("fail to parse callbackQuery.Data %s", callbackQuery.Data)) } - update := store.UpdateMemoMessage{ + update := store.UpdateMemo{ ID: memoID, Visibility: &visibility, } diff --git a/store/activity.go b/store/activity.go index 57b01986..c59fc4db 100644 --- a/store/activity.go +++ b/store/activity.go @@ -4,7 +4,7 @@ import ( "context" ) -type ActivityMessage struct { +type Activity struct { ID int // Standard fields @@ -18,7 +18,7 @@ type ActivityMessage struct { } // CreateActivity creates an instance of Activity. -func (s *Store) CreateActivity(ctx context.Context, create *ActivityMessage) (*ActivityMessage, error) { +func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return nil, err @@ -45,6 +45,7 @@ func (s *Store) CreateActivity(ctx context.Context, create *ActivityMessage) (*A if err := tx.Commit(); err != nil { return nil, err } - activityMessage := create - return activityMessage, nil + + activity := create + return activity, nil } diff --git a/store/idp.go b/store/idp.go index f4e81fc1..954f0294 100644 --- a/store/idp.go +++ b/store/idp.go @@ -123,6 +123,10 @@ func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityPro return nil, err } + if err := tx.Commit(); err != nil { + return nil, err + } + for _, item := range list { s.idpCache.Store(item.ID, item) } @@ -150,6 +154,10 @@ func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvi return nil, nil } + if err := tx.Commit(); err != nil { + return nil, err + } + identityProvider := list[0] s.idpCache.Store(identityProvider.ID, identityProvider) return identityProvider, nil diff --git a/store/memo.go b/store/memo.go index acb43c97..fba079c3 100644 --- a/store/memo.go +++ b/store/memo.go @@ -7,8 +7,6 @@ import ( "strconv" "strings" "time" - - "github.com/usememos/memos/common" ) // Visibility is the type of a visibility. @@ -35,7 +33,7 @@ func (v Visibility) String() string { return "PRIVATE" } -type MemoMessage struct { +type Memo struct { ID int // Standard fields @@ -51,10 +49,10 @@ type MemoMessage struct { // Composed fields Pinned bool ResourceIDList []int - RelationList []*MemoRelationMessage + RelationList []*MemoRelation } -type FindMemoMessage struct { +type FindMemo struct { ID *int // Standard fields @@ -72,7 +70,7 @@ type FindMemoMessage struct { OrderByUpdatedTs bool } -type UpdateMemoMessage struct { +type UpdateMemo struct { ID int CreatedTs *int64 UpdatedTs *int64 @@ -81,14 +79,14 @@ type UpdateMemoMessage struct { Visibility *Visibility } -type DeleteMemoMessage struct { +type DeleteMemo struct { ID int } -func (s *Store) CreateMemo(ctx context.Context, create *MemoMessage) (*MemoMessage, error) { +func (s *Store) CreateMemo(ctx context.Context, create *Memo) (*Memo, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -119,19 +117,20 @@ func (s *Store) CreateMemo(ctx context.Context, create *MemoMessage) (*MemoMessa &create.UpdatedTs, &create.RowStatus, ); err != nil { - return nil, FormatError(err) + return nil, err } if err := tx.Commit(); err != nil { - return nil, FormatError(err) + return nil, err } - memoMessage := create - return memoMessage, nil + + memo := create + return memo, nil } -func (s *Store) ListMemos(ctx context.Context, find *FindMemoMessage) ([]*MemoMessage, error) { +func (s *Store) ListMemos(ctx context.Context, find *FindMemo) ([]*Memo, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -140,13 +139,17 @@ func (s *Store) ListMemos(ctx context.Context, find *FindMemoMessage) ([]*MemoMe return nil, err } + if err := tx.Commit(); err != nil { + return nil, err + } + return list, nil } -func (s *Store) GetMemo(ctx context.Context, find *FindMemoMessage) (*MemoMessage, error) { +func (s *Store) GetMemo(ctx context.Context, find *FindMemo) (*Memo, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -155,14 +158,18 @@ func (s *Store) GetMemo(ctx context.Context, find *FindMemoMessage) (*MemoMessag return nil, err } if len(list) == 0 { - return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("memo not found")} + return nil, nil } - memoMessage := list[0] - return memoMessage, nil + if err := tx.Commit(); err != nil { + return nil, err + } + + memo := list[0] + return memo, nil } -func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemoMessage) error { +func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemo) error { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return err @@ -199,27 +206,20 @@ func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemoMessage) error return err } -func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemoMessage) error { +func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemo) error { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return FormatError(err) + return err } defer tx.Rollback() where, args := []string{"id = ?"}, []any{delete.ID} stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ") - result, err := tx.ExecContext(ctx, stmt, args...) - if err != nil { - return FormatError(err) - } - - rows, err := result.RowsAffected() + _, err = tx.ExecContext(ctx, stmt, args...) if err != nil { return err } - if rows == 0 { - return &common.Error{Code: common.NotFound, Err: fmt.Errorf("idp not found")} - } + if err := s.vacuumImpl(ctx, tx); err != nil { return err } @@ -230,7 +230,7 @@ func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemoMessage) error func (s *Store) FindMemosVisibilityList(ctx context.Context, memoIDs []int) ([]Visibility, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -247,7 +247,7 @@ func (s *Store) FindMemosVisibilityList(ctx context.Context, memoIDs []int) ([]V rows, err := tx.QueryContext(ctx, query, args...) if err != nil { - return nil, FormatError(err) + return nil, err } defer rows.Close() @@ -255,19 +255,19 @@ func (s *Store) FindMemosVisibilityList(ctx context.Context, memoIDs []int) ([]V for rows.Next() { var visibility Visibility if err := rows.Scan(&visibility); err != nil { - return nil, FormatError(err) + return nil, err } visibilityList = append(visibilityList, visibility) } if err := rows.Err(); err != nil { - return nil, FormatError(err) + return nil, err } return visibilityList, nil } -func listMemos(ctx context.Context, tx *sql.Tx, find *FindMemoMessage) ([]*MemoMessage, error) { +func listMemos(ctx context.Context, tx *sql.Tx, find *FindMemo) ([]*Memo, error) { where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { @@ -343,68 +343,68 @@ func listMemos(ctx context.Context, tx *sql.Tx, find *FindMemoMessage) ([]*MemoM rows, err := tx.QueryContext(ctx, query, args...) if err != nil { - return nil, FormatError(err) + return nil, err } defer rows.Close() - memoMessageList := make([]*MemoMessage, 0) + list := make([]*Memo, 0) for rows.Next() { - var memoMessage MemoMessage + var memo Memo var memoResourceIDList sql.NullString var memoRelationList sql.NullString if err := rows.Scan( - &memoMessage.ID, - &memoMessage.CreatorID, - &memoMessage.CreatedTs, - &memoMessage.UpdatedTs, - &memoMessage.RowStatus, - &memoMessage.Content, - &memoMessage.Visibility, - &memoMessage.Pinned, + &memo.ID, + &memo.CreatorID, + &memo.CreatedTs, + &memo.UpdatedTs, + &memo.RowStatus, + &memo.Content, + &memo.Visibility, + &memo.Pinned, &memoResourceIDList, &memoRelationList, ); err != nil { - return nil, FormatError(err) + return nil, err } if memoResourceIDList.Valid { idStringList := strings.Split(memoResourceIDList.String, ",") - memoMessage.ResourceIDList = make([]int, 0, len(idStringList)) + memo.ResourceIDList = make([]int, 0, len(idStringList)) for _, idString := range idStringList { id, err := strconv.Atoi(idString) if err != nil { - return nil, FormatError(err) + return nil, err } - memoMessage.ResourceIDList = append(memoMessage.ResourceIDList, id) + memo.ResourceIDList = append(memo.ResourceIDList, id) } } if memoRelationList.Valid { - memoMessage.RelationList = make([]*MemoRelationMessage, 0) + memo.RelationList = make([]*MemoRelation, 0) relatedMemoTypeList := strings.Split(memoRelationList.String, ",") for _, relatedMemoType := range relatedMemoTypeList { relatedMemoTypeList := strings.Split(relatedMemoType, ":") if len(relatedMemoTypeList) != 2 { - return nil, &common.Error{Code: common.Invalid, Err: fmt.Errorf("invalid relation format")} + return nil, fmt.Errorf("invalid relation format") } relatedMemoID, err := strconv.Atoi(relatedMemoTypeList[0]) if err != nil { - return nil, FormatError(err) + return nil, err } - memoMessage.RelationList = append(memoMessage.RelationList, &MemoRelationMessage{ - MemoID: memoMessage.ID, + memo.RelationList = append(memo.RelationList, &MemoRelation{ + MemoID: memo.ID, RelatedMemoID: relatedMemoID, Type: MemoRelationType(relatedMemoTypeList[1]), }) } } - memoMessageList = append(memoMessageList, &memoMessage) + list = append(list, &memo) } if err := rows.Err(); err != nil { - return nil, FormatError(err) + return nil, err } - return memoMessageList, nil + return list, nil } func vacuumMemo(ctx context.Context, tx *sql.Tx) error { @@ -420,7 +420,7 @@ func vacuumMemo(ctx context.Context, tx *sql.Tx) error { )` _, err := tx.ExecContext(ctx, stmt) if err != nil { - return FormatError(err) + return err } return nil diff --git a/store/memo_organizer.go b/store/memo_organizer.go index 1456e33d..2e10a5da 100644 --- a/store/memo_organizer.go +++ b/store/memo_organizer.go @@ -5,117 +5,31 @@ import ( "database/sql" "fmt" "strings" - - "github.com/usememos/memos/api" - "github.com/usememos/memos/common" ) -// memoOrganizerRaw is the store model for an MemoOrganizer. -// Fields have exactly the same meanings as MemoOrganizer. -type memoOrganizerRaw struct { - // Domain specific fields +type MemoOrganizer struct { MemoID int UserID int Pinned bool } -func (raw *memoOrganizerRaw) toMemoOrganizer() *api.MemoOrganizer { - return &api.MemoOrganizer{ - MemoID: raw.MemoID, - UserID: raw.UserID, - Pinned: raw.Pinned, - } +type FindMemoOrganizer struct { + MemoID int + UserID int } -func (s *Store) FindMemoOrganizer(ctx context.Context, find *api.MemoOrganizerFind) (*api.MemoOrganizer, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, FormatError(err) - } - defer tx.Rollback() +type DeleteMemoOrganizer struct { + MemoID *int + UserID *int +} - memoOrganizerRaw, err := findMemoOrganizer(ctx, tx, find) +func (s *Store) UpsertMemoOrganizerV1(ctx context.Context, upsert *MemoOrganizer) (*MemoOrganizer, error) { + tx, err := s.db.BeginTx(ctx, nil) if err != nil { return nil, err } - - memoOrganizer := memoOrganizerRaw.toMemoOrganizer() - - return memoOrganizer, nil -} - -func (s *Store) UpsertMemoOrganizer(ctx context.Context, upsert *api.MemoOrganizerUpsert) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return FormatError(err) - } defer tx.Rollback() - if err := upsertMemoOrganizer(ctx, tx, upsert); err != nil { - return err - } - - if err := tx.Commit(); err != nil { - return FormatError(err) - } - - return nil -} - -func (s *Store) DeleteMemoOrganizer(ctx context.Context, delete *api.MemoOrganizerDelete) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return FormatError(err) - } - defer tx.Rollback() - - if err := deleteMemoOrganizer(ctx, tx, delete); err != nil { - return err - } - - if err := tx.Commit(); err != nil { - return FormatError(err) - } - - return nil -} - -func findMemoOrganizer(ctx context.Context, tx *sql.Tx, find *api.MemoOrganizerFind) (*memoOrganizerRaw, error) { - query := ` - SELECT - memo_id, - user_id, - pinned - FROM memo_organizer - WHERE memo_id = ? AND user_id = ? - ` - row, err := tx.QueryContext(ctx, query, find.MemoID, find.UserID) - if err != nil { - return nil, FormatError(err) - } - defer row.Close() - - if !row.Next() { - return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")} - } - - var memoOrganizerRaw memoOrganizerRaw - if err := row.Scan( - &memoOrganizerRaw.MemoID, - &memoOrganizerRaw.UserID, - &memoOrganizerRaw.Pinned, - ); err != nil { - return nil, FormatError(err) - } - - if err := row.Err(); err != nil { - return nil, err - } - - return &memoOrganizerRaw, nil -} - -func upsertMemoOrganizer(ctx context.Context, tx *sql.Tx, upsert *api.MemoOrganizerUpsert) error { query := ` INSERT INTO memo_organizer ( memo_id, @@ -126,21 +40,64 @@ func upsertMemoOrganizer(ctx context.Context, tx *sql.Tx, upsert *api.MemoOrgani ON CONFLICT(memo_id, user_id) DO UPDATE SET pinned = EXCLUDED.pinned - RETURNING memo_id, user_id, pinned ` - var memoOrganizer api.MemoOrganizer - if err := tx.QueryRowContext(ctx, query, upsert.MemoID, upsert.UserID, upsert.Pinned).Scan( + if _, err := tx.ExecContext(ctx, query, upsert.MemoID, upsert.UserID, upsert.Pinned); err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + memoOrganizer := upsert + return memoOrganizer, nil +} + +func (s *Store) GetMemoOrganizerV1(ctx context.Context, find *FindMemoOrganizer) (*MemoOrganizer, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + where, args := []string{}, []any{} + if find.MemoID != 0 { + where = append(where, "memo_id = ?") + args = append(args, find.MemoID) + } + if find.UserID != 0 { + where = append(where, "user_id = ?") + args = append(args, find.UserID) + } + query := fmt.Sprintf(` + SELECT + memo_id, + user_id, + pinned + FROM memo_organizer + WHERE %s + `, strings.Join(where, " AND ")) + row := tx.QueryRowContext(ctx, query, args...) + + memoOrganizer := &MemoOrganizer{} + if err := row.Scan( &memoOrganizer.MemoID, &memoOrganizer.UserID, &memoOrganizer.Pinned, ); err != nil { - return FormatError(err) + return nil, err } - return nil + return memoOrganizer, nil } -func deleteMemoOrganizer(ctx context.Context, tx *sql.Tx, delete *api.MemoOrganizerDelete) error { +func (s *Store) DeleteMemoOrganizerV1(ctx context.Context, delete *DeleteMemoOrganizer) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return FormatError(err) + } + defer tx.Rollback() + where, args := []string{}, []any{} if v := delete.MemoID; v != nil { @@ -151,14 +108,13 @@ func deleteMemoOrganizer(ctx context.Context, tx *sql.Tx, delete *api.MemoOrgani } stmt := `DELETE FROM memo_organizer WHERE ` + strings.Join(where, " AND ") - result, err := tx.ExecContext(ctx, stmt, args...) + _, err = tx.ExecContext(ctx, stmt, args...) if err != nil { return FormatError(err) } - rows, _ := result.RowsAffected() - if rows == 0 { - return &common.Error{Code: common.NotFound, Err: fmt.Errorf("memo organizer not found")} + if err := tx.Commit(); err != nil { + return FormatError(err) } return nil diff --git a/store/memo_relation.go b/store/memo_relation.go index 6fe5c081..d0d1c8c7 100644 --- a/store/memo_relation.go +++ b/store/memo_relation.go @@ -3,10 +3,7 @@ package store import ( "context" "database/sql" - "fmt" "strings" - - "github.com/usememos/memos/common" ) type MemoRelationType string @@ -16,28 +13,28 @@ const ( MemoRelationAdditional MemoRelationType = "ADDITIONAL" ) -type MemoRelationMessage struct { +type MemoRelation struct { MemoID int RelatedMemoID int Type MemoRelationType } -type FindMemoRelationMessage struct { +type FindMemoRelation struct { MemoID *int RelatedMemoID *int Type *MemoRelationType } -type DeleteMemoRelationMessage struct { +type DeleteMemoRelation struct { MemoID *int RelatedMemoID *int Type *MemoRelationType } -func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelationMessage) (*MemoRelationMessage, error) { +func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (*MemoRelation, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -52,7 +49,7 @@ func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelationMess type = EXCLUDED.type RETURNING memo_id, related_memo_id, type ` - memoRelationMessage := &MemoRelationMessage{} + memoRelationMessage := &MemoRelation{} if err := tx.QueryRowContext( ctx, query, @@ -64,18 +61,18 @@ func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelationMess &memoRelationMessage.RelatedMemoID, &memoRelationMessage.Type, ); err != nil { - return nil, FormatError(err) + return nil, err } if err := tx.Commit(); err != nil { - return nil, FormatError(err) + return nil, err } return memoRelationMessage, nil } -func (s *Store) ListMemoRelations(ctx context.Context, find *FindMemoRelationMessage) ([]*MemoRelationMessage, error) { +func (s *Store) ListMemoRelations(ctx context.Context, find *FindMemoRelation) ([]*MemoRelation, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -84,13 +81,17 @@ func (s *Store) ListMemoRelations(ctx context.Context, find *FindMemoRelationMes return nil, err } + if err := tx.Commit(); err != nil { + return nil, err + } + return list, nil } -func (s *Store) GetMemoRelation(ctx context.Context, find *FindMemoRelationMessage) (*MemoRelationMessage, error) { +func (s *Store) GetMemoRelation(ctx context.Context, find *FindMemoRelation) (*MemoRelation, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -100,15 +101,20 @@ func (s *Store) GetMemoRelation(ctx context.Context, find *FindMemoRelationMessa } if len(list) == 0 { - return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")} + return nil, nil } + + if err := tx.Commit(); err != nil { + return nil, err + } + return list[0], nil } -func (s *Store) DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelationMessage) error { +func (s *Store) DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelation) error { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return FormatError(err) + return err } defer tx.Rollback() @@ -127,16 +133,17 @@ func (s *Store) DeleteMemoRelation(ctx context.Context, delete *DeleteMemoRelati DELETE FROM memo_relation WHERE ` + strings.Join(where, " AND ") if _, err := tx.ExecContext(ctx, query, args...); err != nil { - return FormatError(err) + return err } if err := tx.Commit(); err != nil { - return FormatError(err) + // Prevent lint warning. + return err } return nil } -func listMemoRelations(ctx context.Context, tx *sql.Tx, find *FindMemoRelationMessage) ([]*MemoRelationMessage, error) { +func listMemoRelations(ctx context.Context, tx *sql.Tx, find *FindMemoRelation) ([]*MemoRelation, error) { where, args := []string{"TRUE"}, []any{} if find.MemoID != nil { where, args = append(where, "memo_id = ?"), append(args, find.MemoID) @@ -156,24 +163,24 @@ func listMemoRelations(ctx context.Context, tx *sql.Tx, find *FindMemoRelationMe FROM memo_relation WHERE `+strings.Join(where, " AND "), args...) if err != nil { - return nil, FormatError(err) + return nil, err } defer rows.Close() - memoRelationMessages := []*MemoRelationMessage{} + memoRelationMessages := []*MemoRelation{} for rows.Next() { - memoRelationMessage := &MemoRelationMessage{} + memoRelationMessage := &MemoRelation{} if err := rows.Scan( &memoRelationMessage.MemoID, &memoRelationMessage.RelatedMemoID, &memoRelationMessage.Type, ); err != nil { - return nil, FormatError(err) + return nil, err } memoRelationMessages = append(memoRelationMessages, memoRelationMessage) } if err := rows.Err(); err != nil { - return nil, FormatError(err) + return nil, err } return memoRelationMessages, nil } @@ -183,7 +190,7 @@ func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error { DELETE FROM memo_relation WHERE memo_id NOT IN (SELECT id FROM memo) OR related_memo_id NOT IN (SELECT id FROM memo) `); err != nil { - return FormatError(err) + return err } return nil } diff --git a/store/memo_resource.go b/store/memo_resource.go index 804d5961..7796d088 100644 --- a/store/memo_resource.go +++ b/store/memo_resource.go @@ -3,11 +3,7 @@ package store import ( "context" "database/sql" - "fmt" "strings" - - "github.com/usememos/memos/api" - "github.com/usememos/memos/common" ) type MemoResource struct { @@ -17,11 +13,65 @@ type MemoResource struct { UpdatedTs int64 } +type UpsertMemoResource struct { + MemoID int + ResourceID int + CreatedTs int64 + UpdatedTs *int64 +} + type FindMemoResource struct { MemoID *int ResourceID *int } +type DeleteMemoResource struct { + MemoID *int + ResourceID *int +} + +func (s *Store) UpsertMemoResource(ctx context.Context, upsert *UpsertMemoResource) (*MemoResource, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, FormatError(err) + } + defer tx.Rollback() + + set := []string{"memo_id", "resource_id"} + args := []any{upsert.MemoID, upsert.ResourceID} + placeholder := []string{"?", "?"} + + if v := upsert.UpdatedTs; v != nil { + set, args, placeholder = append(set, "updated_ts"), append(args, v), append(placeholder, "?") + } + + query := ` + INSERT INTO memo_resource ( + ` + strings.Join(set, ", ") + ` + ) + VALUES (` + strings.Join(placeholder, ",") + `) + ON CONFLICT(memo_id, resource_id) DO UPDATE + SET + updated_ts = EXCLUDED.updated_ts + RETURNING memo_id, resource_id, created_ts, updated_ts + ` + memoResource := &MemoResource{} + if err := tx.QueryRowContext(ctx, query, args...).Scan( + &memoResource.MemoID, + &memoResource.ResourceID, + &memoResource.CreatedTs, + &memoResource.UpdatedTs, + ); err != nil { + return nil, FormatError(err) + } + + if err := tx.Commit(); err != nil { + return nil, FormatError(err) + } + + return memoResource, nil +} + func (s *Store) ListMemoResources(ctx context.Context, find *FindMemoResource) ([]*MemoResource, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { @@ -41,6 +91,58 @@ func (s *Store) ListMemoResources(ctx context.Context, find *FindMemoResource) ( return list, nil } +func (s *Store) GetMemoResource(ctx context.Context, find *FindMemoResource) (*MemoResource, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + list, err := listMemoResources(ctx, tx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + memoResource := list[0] + return memoResource, nil +} + +func (s *Store) DeleteMemoResource(ctx context.Context, delete *DeleteMemoResource) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return FormatError(err) + } + defer tx.Rollback() + + where, args := []string{}, []any{} + + if v := delete.MemoID; v != nil { + where, args = append(where, "memo_id = ?"), append(args, *v) + } + if v := delete.ResourceID; v != nil { + where, args = append(where, "resource_id = ?"), append(args, *v) + } + + stmt := `DELETE FROM memo_resource WHERE ` + strings.Join(where, " AND ") + _, err = tx.ExecContext(ctx, stmt, args...) + if err != nil { + return FormatError(err) + } + + if err := tx.Commit(); err != nil { + return FormatError(err) + } + + return nil +} + func listMemoResources(ctx context.Context, tx *sql.Tx, find *FindMemoResource) ([]*MemoResource, error) { where, args := []string{"1 = 1"}, []any{} @@ -89,207 +191,6 @@ func listMemoResources(ctx context.Context, tx *sql.Tx, find *FindMemoResource) return list, nil } -// memoResourceRaw is the store model for an MemoResource. -// Fields have exactly the same meanings as MemoResource. -type memoResourceRaw struct { - MemoID int - ResourceID int - CreatedTs int64 - UpdatedTs int64 -} - -func (raw *memoResourceRaw) toMemoResource() *api.MemoResource { - return &api.MemoResource{ - MemoID: raw.MemoID, - ResourceID: raw.ResourceID, - CreatedTs: raw.CreatedTs, - UpdatedTs: raw.UpdatedTs, - } -} - -func (s *Store) FindMemoResourceList(ctx context.Context, find *api.MemoResourceFind) ([]*api.MemoResource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, FormatError(err) - } - defer tx.Rollback() - - memoResourceRawList, err := findMemoResourceList(ctx, tx, find) - if err != nil { - return nil, err - } - - list := []*api.MemoResource{} - for _, raw := range memoResourceRawList { - memoResource := raw.toMemoResource() - list = append(list, memoResource) - } - - return list, nil -} - -func (s *Store) FindMemoResource(ctx context.Context, find *api.MemoResourceFind) (*api.MemoResource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, FormatError(err) - } - defer tx.Rollback() - - list, err := findMemoResourceList(ctx, tx, find) - if err != nil { - return nil, err - } - - if len(list) == 0 { - return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")} - } - - memoResourceRaw := list[0] - - return memoResourceRaw.toMemoResource(), nil -} - -func (s *Store) UpsertMemoResource(ctx context.Context, upsert *api.MemoResourceUpsert) (*api.MemoResource, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, FormatError(err) - } - defer tx.Rollback() - - memoResourceRaw, err := upsertMemoResource(ctx, tx, upsert) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, FormatError(err) - } - - return memoResourceRaw.toMemoResource(), nil -} - -func (s *Store) DeleteMemoResource(ctx context.Context, delete *api.MemoResourceDelete) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return FormatError(err) - } - defer tx.Rollback() - - if err := deleteMemoResource(ctx, tx, delete); err != nil { - return FormatError(err) - } - - if err := tx.Commit(); err != nil { - return FormatError(err) - } - - return nil -} - -func findMemoResourceList(ctx context.Context, tx *sql.Tx, find *api.MemoResourceFind) ([]*memoResourceRaw, error) { - where, args := []string{"1 = 1"}, []any{} - - if v := find.MemoID; v != nil { - where, args = append(where, "memo_id = ?"), append(args, *v) - } - if v := find.ResourceID; v != nil { - where, args = append(where, "resource_id = ?"), append(args, *v) - } - - query := ` - SELECT - memo_id, - resource_id, - created_ts, - updated_ts - FROM memo_resource - WHERE ` + strings.Join(where, " AND ") + ` - ORDER BY updated_ts DESC - ` - rows, err := tx.QueryContext(ctx, query, args...) - if err != nil { - return nil, FormatError(err) - } - defer rows.Close() - - memoResourceRawList := make([]*memoResourceRaw, 0) - for rows.Next() { - var memoResourceRaw memoResourceRaw - if err := rows.Scan( - &memoResourceRaw.MemoID, - &memoResourceRaw.ResourceID, - &memoResourceRaw.CreatedTs, - &memoResourceRaw.UpdatedTs, - ); err != nil { - return nil, FormatError(err) - } - - memoResourceRawList = append(memoResourceRawList, &memoResourceRaw) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return memoResourceRawList, nil -} - -func upsertMemoResource(ctx context.Context, tx *sql.Tx, upsert *api.MemoResourceUpsert) (*memoResourceRaw, error) { - set := []string{"memo_id", "resource_id"} - args := []any{upsert.MemoID, upsert.ResourceID} - placeholder := []string{"?", "?"} - - if v := upsert.UpdatedTs; v != nil { - set, args, placeholder = append(set, "updated_ts"), append(args, v), append(placeholder, "?") - } - - query := ` - INSERT INTO memo_resource ( - ` + strings.Join(set, ", ") + ` - ) - VALUES (` + strings.Join(placeholder, ",") + `) - ON CONFLICT(memo_id, resource_id) DO UPDATE - SET - updated_ts = EXCLUDED.updated_ts - RETURNING memo_id, resource_id, created_ts, updated_ts - ` - var memoResourceRaw memoResourceRaw - if err := tx.QueryRowContext(ctx, query, args...).Scan( - &memoResourceRaw.MemoID, - &memoResourceRaw.ResourceID, - &memoResourceRaw.CreatedTs, - &memoResourceRaw.UpdatedTs, - ); err != nil { - return nil, FormatError(err) - } - - return &memoResourceRaw, nil -} - -func deleteMemoResource(ctx context.Context, tx *sql.Tx, delete *api.MemoResourceDelete) error { - where, args := []string{}, []any{} - - if v := delete.MemoID; v != nil { - where, args = append(where, "memo_id = ?"), append(args, *v) - } - if v := delete.ResourceID; v != nil { - where, args = append(where, "resource_id = ?"), append(args, *v) - } - - stmt := `DELETE FROM memo_resource WHERE ` + strings.Join(where, " AND ") - result, err := tx.ExecContext(ctx, stmt, args...) - if err != nil { - return FormatError(err) - } - - rows, _ := result.RowsAffected() - if rows == 0 { - return &common.Error{Code: common.NotFound, Err: fmt.Errorf("memo resource not found")} - } - - return nil -} - func vacuumMemoResource(ctx context.Context, tx *sql.Tx) error { stmt := ` DELETE FROM diff --git a/store/resource.go b/store/resource.go index 95375f0f..e51e6fe8 100644 --- a/store/resource.go +++ b/store/resource.go @@ -48,7 +48,7 @@ type DeleteResource struct { ID int } -func (s *Store) CreateResourceV1(ctx context.Context, create *Resource) (*Resource, error) { +func (s *Store) CreateResource(ctx context.Context, create *Resource) (*Resource, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return nil, FormatError(err) @@ -98,6 +98,10 @@ func (s *Store) ListResources(ctx context.Context, find *FindResource) ([]*Resou return nil, err } + if err := tx.Commit(); err != nil { + return nil, err + } + return resources, nil } @@ -113,14 +117,14 @@ func (s *Store) GetResource(ctx context.Context, find *FindResource) (*Resource, return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - if len(resources) == 0 { return nil, nil } + if err := tx.Commit(); err != nil { + return nil, err + } + return resources[0], nil } @@ -174,7 +178,7 @@ func (s *Store) UpdateResource(ctx context.Context, update *UpdateResource) (*Re return &resource, nil } -func (s *Store) DeleteResourceV1(ctx context.Context, delete *DeleteResource) error { +func (s *Store) DeleteResource(ctx context.Context, delete *DeleteResource) error { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return FormatError(err) diff --git a/store/shortcut.go b/store/shortcut.go index de6e25ca..7fb4047f 100644 --- a/store/shortcut.go +++ b/store/shortcut.go @@ -85,6 +85,10 @@ func (s *Store) ListShortcuts(ctx context.Context, find *FindShortcut) ([]*Short return nil, err } + if err := tx.Commit(); err != nil { + return nil, err + } + return list, nil } @@ -104,6 +108,10 @@ func (s *Store) GetShortcut(ctx context.Context, find *FindShortcut) (*Shortcut, return nil, nil } + if err := tx.Commit(); err != nil { + return nil, err + } + shortcut := list[0] return shortcut, nil } diff --git a/store/storage.go b/store/storage.go index cceee950..d043c3eb 100644 --- a/store/storage.go +++ b/store/storage.go @@ -69,6 +69,10 @@ func (s *Store) ListStorages(ctx context.Context, find *FindStorage) ([]*Storage return nil, err } + if err := tx.Commit(); err != nil { + return nil, err + } + return list, nil } @@ -87,6 +91,10 @@ func (s *Store) GetStorage(ctx context.Context, find *FindStorage) (*Storage, er return nil, nil } + if err := tx.Commit(); err != nil { + return nil, err + } + return list[0], nil } diff --git a/store/system_setting.go b/store/system_setting.go index e72430a7..6c06ff5b 100644 --- a/store/system_setting.go +++ b/store/system_setting.go @@ -57,6 +57,10 @@ func (s *Store) ListSystemSettings(ctx context.Context, find *FindSystemSetting) return nil, err } + if err := tx.Commit(); err != nil { + return nil, err + } + for _, systemSettingMessage := range list { s.systemSettingCache.Store(systemSettingMessage.Name, systemSettingMessage) } @@ -85,6 +89,10 @@ func (s *Store) GetSystemSetting(ctx context.Context, find *FindSystemSetting) ( return nil, nil } + if err := tx.Commit(); err != nil { + return nil, err + } + systemSettingMessage := list[0] s.systemSettingCache.Store(systemSettingMessage.Name, systemSettingMessage) return systemSettingMessage, nil diff --git a/store/tag.go b/store/tag.go index 9de47807..c6295291 100644 --- a/store/tag.go +++ b/store/tag.go @@ -88,6 +88,10 @@ func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) { return nil, err } + if err := tx.Commit(); err != nil { + return nil, err + } + return list, nil } diff --git a/store/user.go b/store/user.go index a2b8ddb9..ea747d1f 100644 --- a/store/user.go +++ b/store/user.go @@ -120,6 +120,7 @@ func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { if err := tx.Commit(); err != nil { return nil, err } + user := create s.userCache.Store(user.ID, user) return user, nil @@ -202,6 +203,10 @@ func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) return nil, err } + if err := tx.Commit(); err != nil { + return nil, err + } + for _, user := range list { s.userCache.Store(user.ID, user) } @@ -228,6 +233,11 @@ func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) { if len(list) == 0 { return nil, nil } + + if err := tx.Commit(); err != nil { + return nil, err + } + user := list[0] s.userCache.Store(user.ID, user) return user, nil diff --git a/store/user_setting.go b/store/user_setting.go index 55415d37..8fd6c284 100644 --- a/store/user_setting.go +++ b/store/user_setting.go @@ -57,6 +57,10 @@ func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([] return nil, err } + if err := tx.Commit(); err != nil { + return nil, err + } + for _, userSetting := range userSettingList { s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting) } @@ -85,6 +89,10 @@ func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*Use return nil, nil } + if err := tx.Commit(); err != nil { + return nil, err + } + userSetting := list[0] s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting) return userSetting, nil diff --git a/test/server/memo_relation_test.go b/test/server/memo_relation_test.go index 28e3a16a..a7d705c5 100644 --- a/test/server/memo_relation_test.go +++ b/test/server/memo_relation_test.go @@ -9,7 +9,6 @@ import ( "github.com/pkg/errors" "github.com/stretchr/testify/require" - "github.com/usememos/memos/api" apiv1 "github.com/usememos/memos/api/v1" ) @@ -26,17 +25,17 @@ func TestMemoRelationServer(t *testing.T) { user, err := s.postAuthSignup(signup) require.NoError(t, err) require.Equal(t, signup.Username, user.Username) - memo, err := s.postMemoCreate(&api.CreateMemoRequest{ + memo, err := s.postMemoCreate(&apiv1.CreateMemoRequest{ Content: "test memo", }) require.NoError(t, err) require.Equal(t, "test memo", memo.Content) - memo2, err := s.postMemoCreate(&api.CreateMemoRequest{ + memo2, err := s.postMemoCreate(&apiv1.CreateMemoRequest{ Content: "test memo2", - RelationList: []*api.MemoRelationUpsert{ + RelationList: []*apiv1.UpsertMemoRelationRequest{ { RelatedMemoID: memo.ID, - Type: api.MemoRelationReference, + Type: apiv1.MemoRelationReference, }, }, }) @@ -46,14 +45,14 @@ func TestMemoRelationServer(t *testing.T) { require.NoError(t, err) require.Len(t, memoList, 2) require.Len(t, memo2.RelationList, 1) - err = s.deleteMemoRelation(memo2.ID, memo.ID, api.MemoRelationReference) + err = s.deleteMemoRelation(memo2.ID, memo.ID, apiv1.MemoRelationReference) require.NoError(t, err) memo2, err = s.getMemo(memo2.ID) require.NoError(t, err) require.Len(t, memo2.RelationList, 0) - memoRelation, err := s.postMemoRelationUpsert(memo2.ID, &api.MemoRelationUpsert{ + memoRelation, err := s.postMemoRelationUpsert(memo2.ID, &apiv1.UpsertMemoRelationRequest{ RelatedMemoID: memo.ID, - Type: api.MemoRelationReference, + Type: apiv1.MemoRelationReference, }) require.NoError(t, err) require.Equal(t, memo.ID, memoRelation.RelatedMemoID) @@ -62,13 +61,13 @@ func TestMemoRelationServer(t *testing.T) { require.Len(t, memo2.RelationList, 1) } -func (s *TestingServer) postMemoRelationUpsert(memoID int, memoRelationUpsert *api.MemoRelationUpsert) (*api.MemoRelation, error) { +func (s *TestingServer) postMemoRelationUpsert(memoID int, memoRelationUpsert *apiv1.UpsertMemoRelationRequest) (*apiv1.MemoRelation, error) { rawData, err := json.Marshal(&memoRelationUpsert) if err != nil { return nil, errors.Wrap(err, "failed to marshal memo relation upsert") } reader := bytes.NewReader(rawData) - body, err := s.post(fmt.Sprintf("/api/memo/%d/relation", memoID), reader, nil) + body, err := s.post(fmt.Sprintf("/api/v1/memo/%d/relation", memoID), reader, nil) if err != nil { return nil, err } @@ -79,17 +78,14 @@ func (s *TestingServer) postMemoRelationUpsert(memoID int, memoRelationUpsert *a return nil, errors.Wrap(err, "fail to read response body") } - type MemoCreateResponse struct { - Data *api.MemoRelation `json:"data"` - } - res := new(MemoCreateResponse) - if err = json.Unmarshal(buf.Bytes(), res); err != nil { + memoRelation := &apiv1.MemoRelation{} + if err = json.Unmarshal(buf.Bytes(), memoRelation); err != nil { return nil, errors.Wrap(err, "fail to unmarshal post memo relation upsert response") } - return res.Data, nil + return memoRelation, nil } -func (s *TestingServer) deleteMemoRelation(memoID int, relatedMemoID int, relationType api.MemoRelationType) error { - _, err := s.delete(fmt.Sprintf("/api/memo/%d/relation/%d/type/%s", memoID, relatedMemoID, relationType), nil) +func (s *TestingServer) deleteMemoRelation(memoID int, relatedMemoID int, relationType apiv1.MemoRelationType) error { + _, err := s.delete(fmt.Sprintf("/api/v1/memo/%d/relation/%d/type/%s", memoID, relatedMemoID, relationType), nil) return err } diff --git a/test/server/memo_test.go b/test/server/memo_test.go index 30fa4734..571163f7 100644 --- a/test/server/memo_test.go +++ b/test/server/memo_test.go @@ -9,7 +9,6 @@ import ( "github.com/pkg/errors" "github.com/stretchr/testify/require" - "github.com/usememos/memos/api" apiv1 "github.com/usememos/memos/api/v1" ) @@ -26,7 +25,7 @@ func TestMemoServer(t *testing.T) { user, err := s.postAuthSignup(signup) require.NoError(t, err) require.Equal(t, signup.Username, user.Username) - memo, err := s.postMemoCreate(&api.CreateMemoRequest{ + memo, err := s.postMemoCreate(&apiv1.CreateMemoRequest{ Content: "test memo", }) require.NoError(t, err) @@ -35,20 +34,18 @@ func TestMemoServer(t *testing.T) { require.NoError(t, err) require.Len(t, memoList, 1) updatedContent := "updated memo" - memo, err = s.patchMemo(&api.PatchMemoRequest{ + memo, err = s.patchMemo(&apiv1.PatchMemoRequest{ ID: memo.ID, Content: &updatedContent, }) require.NoError(t, err) require.Equal(t, updatedContent, memo.Content) require.Equal(t, false, memo.Pinned) - memo, err = s.postMemosOrganizer(&api.MemoOrganizerUpsert{ - MemoID: memo.ID, - UserID: user.ID, + _, err = s.postMemoOrganizer(memo.ID, &apiv1.UpsertMemoOrganizerRequest{ Pinned: true, }) require.NoError(t, err) - memo, err = s.patchMemo(&api.PatchMemoRequest{ + memo, err = s.patchMemo(&apiv1.PatchMemoRequest{ ID: memo.ID, Content: &updatedContent, }) @@ -62,8 +59,8 @@ func TestMemoServer(t *testing.T) { require.Len(t, memoList, 0) } -func (s *TestingServer) getMemo(memoID int) (*api.MemoResponse, error) { - body, err := s.get(fmt.Sprintf("/api/memo/%d", memoID), nil) +func (s *TestingServer) getMemo(memoID int) (*apiv1.Memo, error) { + body, err := s.get(fmt.Sprintf("/api/v1/memo/%d", memoID), nil) if err != nil { return nil, err } @@ -74,18 +71,15 @@ func (s *TestingServer) getMemo(memoID int) (*api.MemoResponse, error) { return nil, errors.Wrap(err, "fail to read response body") } - type MemoCreateResponse struct { - Data *api.MemoResponse `json:"data"` - } - res := new(MemoCreateResponse) - if err = json.Unmarshal(buf.Bytes(), res); err != nil { + memo := &apiv1.Memo{} + if err = json.Unmarshal(buf.Bytes(), memo); err != nil { return nil, errors.Wrap(err, "fail to unmarshal get memo response") } - return res.Data, nil + return memo, nil } -func (s *TestingServer) getMemoList() ([]*api.MemoResponse, error) { - body, err := s.get("/api/memo", nil) +func (s *TestingServer) getMemoList() ([]*apiv1.Memo, error) { + body, err := s.get("/api/v1/memo", nil) if err != nil { return nil, err } @@ -96,23 +90,20 @@ func (s *TestingServer) getMemoList() ([]*api.MemoResponse, error) { return nil, errors.Wrap(err, "fail to read response body") } - type MemoCreateResponse struct { - Data []*api.MemoResponse `json:"data"` - } - res := new(MemoCreateResponse) - if err = json.Unmarshal(buf.Bytes(), res); err != nil { + memoList := []*apiv1.Memo{} + if err = json.Unmarshal(buf.Bytes(), &memoList); err != nil { return nil, errors.Wrap(err, "fail to unmarshal get memo list response") } - return res.Data, nil + return memoList, nil } -func (s *TestingServer) postMemoCreate(memoCreate *api.CreateMemoRequest) (*api.MemoResponse, error) { +func (s *TestingServer) postMemoCreate(memoCreate *apiv1.CreateMemoRequest) (*apiv1.Memo, error) { rawData, err := json.Marshal(&memoCreate) if err != nil { return nil, errors.Wrap(err, "failed to marshal memo create") } reader := bytes.NewReader(rawData) - body, err := s.post("/api/memo", reader, nil) + body, err := s.post("/api/v1/memo", reader, nil) if err != nil { return nil, err } @@ -123,23 +114,20 @@ func (s *TestingServer) postMemoCreate(memoCreate *api.CreateMemoRequest) (*api. return nil, errors.Wrap(err, "fail to read response body") } - type MemoCreateResponse struct { - Data *api.MemoResponse `json:"data"` - } - res := new(MemoCreateResponse) - if err = json.Unmarshal(buf.Bytes(), res); err != nil { + memo := &apiv1.Memo{} + if err = json.Unmarshal(buf.Bytes(), memo); err != nil { return nil, errors.Wrap(err, "fail to unmarshal post memo create response") } - return res.Data, nil + return memo, nil } -func (s *TestingServer) patchMemo(memoPatch *api.PatchMemoRequest) (*api.MemoResponse, error) { +func (s *TestingServer) patchMemo(memoPatch *apiv1.PatchMemoRequest) (*apiv1.Memo, error) { rawData, err := json.Marshal(&memoPatch) if err != nil { return nil, errors.Wrap(err, "failed to marshal memo patch") } reader := bytes.NewReader(rawData) - body, err := s.patch(fmt.Sprintf("/api/memo/%d", memoPatch.ID), reader, nil) + body, err := s.patch(fmt.Sprintf("/api/v1/memo/%d", memoPatch.ID), reader, nil) if err != nil { return nil, err } @@ -150,28 +138,25 @@ func (s *TestingServer) patchMemo(memoPatch *api.PatchMemoRequest) (*api.MemoRes return nil, errors.Wrap(err, "fail to read response body") } - type MemoPatchResponse struct { - Data *api.MemoResponse `json:"data"` - } - res := new(MemoPatchResponse) - if err = json.Unmarshal(buf.Bytes(), res); err != nil { + memo := &apiv1.Memo{} + if err = json.Unmarshal(buf.Bytes(), memo); err != nil { return nil, errors.Wrap(err, "fail to unmarshal patch memo response") } - return res.Data, nil + return memo, nil } func (s *TestingServer) deleteMemo(memoID int) error { - _, err := s.delete(fmt.Sprintf("/api/memo/%d", memoID), nil) + _, err := s.delete(fmt.Sprintf("/api/v1/memo/%d", memoID), nil) return err } -func (s *TestingServer) postMemosOrganizer(memosOrganizer *api.MemoOrganizerUpsert) (*api.MemoResponse, error) { +func (s *TestingServer) postMemoOrganizer(memoID int, memosOrganizer *apiv1.UpsertMemoOrganizerRequest) (*apiv1.Memo, error) { rawData, err := json.Marshal(&memosOrganizer) if err != nil { return nil, errors.Wrap(err, "failed to marshal memos organizer") } reader := bytes.NewReader(rawData) - body, err := s.post(fmt.Sprintf("/api/memo/%d/organizer", memosOrganizer.MemoID), reader, nil) + body, err := s.post(fmt.Sprintf("/api/v1/memo/%d/organizer", memoID), reader, nil) if err != nil { return nil, err } @@ -182,12 +167,9 @@ func (s *TestingServer) postMemosOrganizer(memosOrganizer *api.MemoOrganizerUpse return nil, errors.Wrap(err, "fail to read response body") } - type MemoOrganizerResponse struct { - Data *api.MemoResponse `json:"data"` - } - res := new(MemoOrganizerResponse) - if err = json.Unmarshal(buf.Bytes(), res); err != nil { + memo := &apiv1.Memo{} + if err = json.Unmarshal(buf.Bytes(), memo); err != nil { return nil, errors.Wrap(err, "fail to unmarshal organizer memo create response") } - return res.Data, err + return memo, err } diff --git a/test/store/memo_relation_test.go b/test/store/memo_relation_test.go index c852ac46..c6918671 100644 --- a/test/store/memo_relation_test.go +++ b/test/store/memo_relation_test.go @@ -13,7 +13,7 @@ func TestMemoRelationStore(t *testing.T) { ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) require.NoError(t, err) - memoCreate := &store.MemoMessage{ + memoCreate := &store.Memo{ CreatorID: user.ID, Content: "test_content", Visibility: store.Public, @@ -21,7 +21,7 @@ func TestMemoRelationStore(t *testing.T) { memo, err := ts.CreateMemo(ctx, memoCreate) require.NoError(t, err) require.Equal(t, memoCreate.Content, memo.Content) - memo2Create := &store.MemoMessage{ + memo2Create := &store.Memo{ CreatorID: user.ID, Content: "test_content_2", Visibility: store.Public, @@ -29,14 +29,14 @@ func TestMemoRelationStore(t *testing.T) { memo2, err := ts.CreateMemo(ctx, memo2Create) require.NoError(t, err) require.Equal(t, memo2Create.Content, memo2.Content) - memoRelationMessage := &store.MemoRelationMessage{ + memoRelationMessage := &store.MemoRelation{ MemoID: memo.ID, RelatedMemoID: memo2.ID, Type: store.MemoRelationReference, } _, err = ts.UpsertMemoRelation(ctx, memoRelationMessage) require.NoError(t, err) - memoRelation, err := ts.ListMemoRelations(ctx, &store.FindMemoRelationMessage{ + memoRelation, err := ts.ListMemoRelations(ctx, &store.FindMemoRelation{ MemoID: &memo.ID, }) require.NoError(t, err) @@ -44,11 +44,11 @@ func TestMemoRelationStore(t *testing.T) { require.Equal(t, memo2.ID, memoRelation[0].RelatedMemoID) require.Equal(t, memo.ID, memoRelation[0].MemoID) require.Equal(t, store.MemoRelationReference, memoRelation[0].Type) - err = ts.DeleteMemo(ctx, &store.DeleteMemoMessage{ + err = ts.DeleteMemo(ctx, &store.DeleteMemo{ ID: memo2.ID, }) require.NoError(t, err) - memoRelation, err = ts.ListMemoRelations(ctx, &store.FindMemoRelationMessage{ + memoRelation, err = ts.ListMemoRelations(ctx, &store.FindMemoRelation{ MemoID: &memo.ID, }) require.NoError(t, err) diff --git a/test/store/memo_test.go b/test/store/memo_test.go index d4294585..e32ed682 100644 --- a/test/store/memo_test.go +++ b/test/store/memo_test.go @@ -13,7 +13,7 @@ func TestMemoStore(t *testing.T) { ts := NewTestingStore(ctx, t) user, err := createTestingHostUser(ctx, ts) require.NoError(t, err) - memoCreate := &store.MemoMessage{ + memoCreate := &store.Memo{ CreatorID: user.ID, Content: "test_content", Visibility: store.Public, @@ -22,23 +22,23 @@ func TestMemoStore(t *testing.T) { require.NoError(t, err) require.Equal(t, memoCreate.Content, memo.Content) memoPatchContent := "test_content_2" - memoPatch := &store.UpdateMemoMessage{ + memoPatch := &store.UpdateMemo{ ID: memo.ID, Content: &memoPatchContent, } err = ts.UpdateMemo(ctx, memoPatch) require.NoError(t, err) - memo, err = ts.GetMemo(ctx, &store.FindMemoMessage{ + memo, err = ts.GetMemo(ctx, &store.FindMemo{ ID: &memo.ID, }) require.NoError(t, err) - memoList, err := ts.ListMemos(ctx, &store.FindMemoMessage{ + memoList, err := ts.ListMemos(ctx, &store.FindMemo{ CreatorID: &user.ID, }) require.NoError(t, err) require.Equal(t, 1, len(memoList)) require.Equal(t, memo, memoList[0]) - err = ts.DeleteMemo(ctx, &store.DeleteMemoMessage{ + err = ts.DeleteMemo(ctx, &store.DeleteMemo{ ID: memo.ID, }) require.NoError(t, err) diff --git a/test/store/resource_test.go b/test/store/resource_test.go index def65860..d525b9cf 100644 --- a/test/store/resource_test.go +++ b/test/store/resource_test.go @@ -11,7 +11,7 @@ import ( func TestResourceStore(t *testing.T) { ctx := context.Background() ts := NewTestingStore(ctx, t) - _, err := ts.CreateResourceV1(ctx, &store.Resource{ + _, err := ts.CreateResource(ctx, &store.Resource{ CreatorID: 101, Filename: "test.epub", Blob: []byte("test"), @@ -49,11 +49,11 @@ func TestResourceStore(t *testing.T) { require.NoError(t, err) require.Nil(t, notFoundResource) - err = ts.DeleteResourceV1(ctx, &store.DeleteResource{ + err = ts.DeleteResource(ctx, &store.DeleteResource{ ID: 1, }) require.NoError(t, err) - err = ts.DeleteResourceV1(ctx, &store.DeleteResource{ + err = ts.DeleteResource(ctx, &store.DeleteResource{ ID: 2, }) require.NoError(t, err) diff --git a/web/src/components/Settings/MyAccountSection.tsx b/web/src/components/Settings/MyAccountSection.tsx index 7bb13581..d9115440 100644 --- a/web/src/components/Settings/MyAccountSection.tsx +++ b/web/src/components/Settings/MyAccountSection.tsx @@ -11,7 +11,7 @@ const MyAccountSection = () => { const { t } = useTranslation(); const userStore = useUserStore(); const user = userStore.state.user as User; - const openAPIRoute = `${window.location.origin}/api/memo?openId=${user.openId}`; + const openAPIRoute = `${window.location.origin}/api/v1/memo?openId=${user.openId}`; const handleResetOpenIdBtnClick = async () => { showCommonDialog({ diff --git a/web/src/components/ShareMemoDialog.tsx b/web/src/components/ShareMemoDialog.tsx index 0355515c..4924506d 100644 --- a/web/src/components/ShareMemoDialog.tsx +++ b/web/src/components/ShareMemoDialog.tsx @@ -52,7 +52,7 @@ const ShareMemoDialog: React.FC = (props: Props) => { useEffect(() => { getMemoStats(user.id) - .then(({ data: { data } }) => { + .then(({ data }) => { setPartialState({ memoAmount: data.length, }); diff --git a/web/src/components/UsageHeatMap.tsx b/web/src/components/UsageHeatMap.tsx index 7ad9164e..6f7416ae 100644 --- a/web/src/components/UsageHeatMap.tsx +++ b/web/src/components/UsageHeatMap.tsx @@ -57,7 +57,7 @@ const UsageHeatMap = () => { useEffect(() => { getMemoStats(currentUserId) - .then(({ data: { data } }) => { + .then(({ data }) => { setMemoAmount(data.length); const newStat: DailyUsageStat[] = getInitialUsageStat(usedDaysAmount, beginDayTimestamp); for (const record of data) { diff --git a/web/src/components/kit/DatePicker.tsx b/web/src/components/kit/DatePicker.tsx index 1cf7c200..b3192a35 100644 --- a/web/src/components/kit/DatePicker.tsx +++ b/web/src/components/kit/DatePicker.tsx @@ -26,7 +26,7 @@ const DatePicker: React.FC = (props: DatePickerProps) => { }, [datestamp]); useEffect(() => { - getMemoStats(currentUserId).then(({ data: { data } }) => { + getMemoStats(currentUserId).then(({ data }) => { const m = new Map(); for (const record of data) { const date = getDateStampByDate(record * 1000); diff --git a/web/src/helpers/api.ts b/web/src/helpers/api.ts index 939f6f11..09c973b4 100644 --- a/web/src/helpers/api.ts +++ b/web/src/helpers/api.ts @@ -1,11 +1,5 @@ import axios from "axios"; -type ResponseObject = { - data: T; - error?: string; - message?: string; -}; - export function getSystemStatus() { return axios.get("/api/v1/status"); } @@ -85,7 +79,7 @@ export function getAllMemos(memoFind?: MemoFind) { queryList.push(`limit=${memoFind.limit}`); } - return axios.get>(`/api/memo/all?${queryList.join("&")}`); + return axios.get(`/api/v1/memo/all?${queryList.join("&")}`); } export function getMemoList(memoFind?: MemoFind) { @@ -105,39 +99,39 @@ export function getMemoList(memoFind?: MemoFind) { if (memoFind?.limit) { queryList.push(`limit=${memoFind.limit}`); } - return axios.get>(`/api/memo?${queryList.join("&")}`); + return axios.get(`/api/v1/memo?${queryList.join("&")}`); } export function getMemoStats(userId: UserId) { - return axios.get>(`/api/memo/stats?creatorId=${userId}`); + return axios.get(`/api/v1/memo/stats?creatorId=${userId}`); } export function getMemoById(id: MemoId) { - return axios.get>(`/api/memo/${id}`); + return axios.get(`/api/v1/memo/${id}`); } export function createMemo(memoCreate: MemoCreate) { - return axios.post>("/api/memo", memoCreate); + return axios.post("/api/v1/memo", memoCreate); } export function patchMemo(memoPatch: MemoPatch) { - return axios.patch>(`/api/memo/${memoPatch.id}`, memoPatch); + return axios.patch(`/api/v1/memo/${memoPatch.id}`, memoPatch); } export function pinMemo(memoId: MemoId) { - return axios.post(`/api/memo/${memoId}/organizer`, { + return axios.post(`/api/v1/memo/${memoId}/organizer`, { pinned: true, }); } export function unpinMemo(memoId: MemoId) { - return axios.post(`/api/memo/${memoId}/organizer`, { + return axios.post(`/api/v1/memo/${memoId}/organizer`, { pinned: false, }); } export function deleteMemo(memoId: MemoId) { - return axios.delete(`/api/memo/${memoId}`); + return axios.delete(`/api/v1/memo/${memoId}`); } export function getShortcutList(shortcutFind?: ShortcutFind) { @@ -192,17 +186,17 @@ export function deleteResourceById(id: ResourceId) { } export function getMemoResourceList(memoId: MemoId) { - return axios.get>(`/api/memo/${memoId}/resource`); + return axios.get(`/api/v1/memo/${memoId}/resource`); } export function upsertMemoResource(memoId: MemoId, resourceId: ResourceId) { - return axios.post(`/api/memo/${memoId}/resource`, { + return axios.post(`/api/v1/memo/${memoId}/resource`, { resourceId, }); } export function deleteMemoResource(memoId: MemoId, resourceId: ResourceId) { - return axios.delete(`/api/memo/${memoId}/resource/${resourceId}`); + return axios.delete(`/api/v1/memo/${memoId}/resource/${resourceId}`); } export function getTagList(tagFind?: TagFind) { diff --git a/web/src/store/module/memo.ts b/web/src/store/module/memo.ts index c37153d3..a0df71b0 100644 --- a/web/src/store/module/memo.ts +++ b/web/src/store/module/memo.ts @@ -21,7 +21,7 @@ export const useMemoStore = () => { const memoCacheStore = useMemoCacheStore(); const fetchMemoById = async (memoId: MemoId) => { - const { data } = (await api.getMemoById(memoId)).data; + const { data } = await api.getMemoById(memoId); const memo = convertResponseModelMemo(data); return memo; @@ -42,7 +42,7 @@ export const useMemoStore = () => { if (userStore.isVisitorMode()) { memoFind.creatorId = userStore.getUserIdFromPath(); } - const { data } = (await api.getMemoList(memoFind)).data; + const { data } = await api.getMemoList(memoFind); const fetchedMemos = data.map((m) => convertResponseModelMemo(m)); store.dispatch(upsertMemos(fetchedMemos)); store.dispatch(setIsFetching(false)); @@ -60,7 +60,7 @@ export const useMemoStore = () => { offset, }; - const { data } = (await api.getAllMemos(memoFind)).data; + const { data } = await api.getAllMemos(memoFind); const fetchedMemos = data.map((m) => convertResponseModelMemo(m)); for (const m of fetchedMemos) { @@ -76,7 +76,7 @@ export const useMemoStore = () => { if (userStore.isVisitorMode()) { memoFind.creatorId = userStore.getUserIdFromPath(); } - const { data } = (await api.getMemoList(memoFind)).data; + const { data } = await api.getMemoList(memoFind); const archivedMemos = data.map((m) => { return convertResponseModelMemo(m); }); @@ -97,14 +97,14 @@ export const useMemoStore = () => { return state.memos.filter((m) => m.content.match(regex)); }, createMemo: async (memoCreate: MemoCreate) => { - const { data } = (await api.createMemo(memoCreate)).data; + const { data } = await api.createMemo(memoCreate); const memo = convertResponseModelMemo(data); store.dispatch(createMemo(memo)); memoCacheStore.setMemoCache(memo); return memo; }, patchMemo: async (memoPatch: MemoPatch): Promise => { - const { data } = (await api.patchMemo(memoPatch)).data; + const { data } = await api.patchMemo(memoPatch); const memo = convertResponseModelMemo(data); store.dispatch(patchMemo(omit(memo, "pinned"))); memoCacheStore.setMemoCache(memo); diff --git a/web/src/store/zustand/memo.ts b/web/src/store/zustand/memo.ts index 09b15490..669d69ac 100644 --- a/web/src/store/zustand/memo.ts +++ b/web/src/store/zustand/memo.ts @@ -12,7 +12,7 @@ export const useMemoCacheStore = create( return memo; } - const { data } = (await api.getMemoById(memoId)).data; + const { data } = await api.getMemoById(memoId); const formatedMemo = convertResponseModelMemo(data); set((state) => {