diff --git a/api/v1/auth.go b/api/v1/auth.go index 53e5b41e..58e64fcc 100644 --- a/api/v1/auth.go +++ b/api/v1/auth.go @@ -8,7 +8,7 @@ import ( "github.com/labstack/echo/v4" "github.com/pkg/errors" - "github.com/usememos/memos/common" + "github.com/usememos/memos/common/util" "github.com/usememos/memos/plugin/idp" "github.com/usememos/memos/plugin/idp/oauth2" "github.com/usememos/memos/server/auth" @@ -43,7 +43,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { user, err := s.Store.GetUser(ctx, &store.FindUser{ Username: &signin.Username, }) - if err != nil && common.ErrorCode(err) != common.NotFound { + if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Incorrect login credentials, please try again") } if user == nil { @@ -114,7 +114,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { user, err := s.Store.GetUser(ctx, &store.FindUser{ Username: &userInfo.Identifier, }) - if err != nil && common.ErrorCode(err) != common.NotFound { + if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Incorrect login credentials, please try again") } if user == nil { @@ -124,9 +124,9 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { Role: store.RoleUser, Nickname: userInfo.DisplayName, Email: userInfo.Email, - OpenID: common.GenUUID(), + OpenID: util.GenUUID(), } - password, err := common.RandomString(20) + password, err := util.RandomString(20) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate random password").SetInternal(err) } @@ -173,7 +173,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { // The new signup user should be normal user by default. Role: store.RoleUser, Nickname: signup.Username, - OpenID: common.GenUUID(), + OpenID: util.GenUUID(), } if len(existedHostUsers) == 0 { // Change the default role to host if there is no host user. @@ -182,7 +182,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { allowSignUpSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{ Name: SystemSettingAllowSignUpName.String(), }) - if err != nil && common.ErrorCode(err) != common.NotFound { + if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting").SetInternal(err) } diff --git a/api/v1/idp.go b/api/v1/idp.go index 941f036a..4ceb9678 100644 --- a/api/v1/idp.go +++ b/api/v1/idp.go @@ -7,7 +7,6 @@ import ( "strconv" "github.com/labstack/echo/v4" - "github.com/usememos/memos/common" "github.com/usememos/memos/store" ) @@ -231,9 +230,6 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { } if err = s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: identityProviderID}); err != nil { - if common.ErrorCode(err) == common.NotFound { - return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Identity provider ID not found: %d", identityProviderID)) - } return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete identity provider").SetInternal(err) } return c.JSON(http.StatusOK, true) diff --git a/api/v1/jwt.go b/api/v1/jwt.go index 3c0b8742..f8dd3879 100644 --- a/api/v1/jwt.go +++ b/api/v1/jwt.go @@ -10,7 +10,7 @@ import ( "github.com/golang-jwt/jwt/v4" "github.com/labstack/echo/v4" "github.com/pkg/errors" - "github.com/usememos/memos/common" + "github.com/usememos/memos/common/util" "github.com/usememos/memos/server/auth" "github.com/usememos/memos/store" ) @@ -82,18 +82,18 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e } // Skip validation for server status endpoints. - if common.HasPrefixes(path, "/api/v1/ping", "/api/v1/idp", "/api/v1/status", "/api/v1/user/:id") && method == http.MethodGet { + if util.HasPrefixes(path, "/api/v1/ping", "/api/v1/idp", "/api/v1/status", "/api/v1/user/:id") && method == http.MethodGet { return next(c) } token := findAccessToken(c) if token == "" { // Allow the user to access the public endpoints. - if common.HasPrefixes(path, "/o") { + if util.HasPrefixes(path, "/o") { return next(c) } // When the request is not authenticated, we allow the user to access the memo endpoints for those public memos. - if common.HasPrefixes(path, "/api/v1/memo") && method == http.MethodGet { + if util.HasPrefixes(path, "/api/v1/memo") && method == http.MethodGet { return next(c) } return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") @@ -215,7 +215,7 @@ func (s *APIV1Service) defaultAuthSkipper(c echo.Context) bool { path := c.Path() // Skip auth. - if common.HasPrefixes(path, "/api/v1/auth") { + if util.HasPrefixes(path, "/api/v1/auth") { return true } @@ -225,7 +225,7 @@ func (s *APIV1Service) defaultAuthSkipper(c echo.Context) bool { user, err := s.Store.GetUser(ctx, &store.FindUser{ OpenID: &openID, }) - if err != nil && common.ErrorCode(err) != common.NotFound { + if err != nil { return false } if user != nil { diff --git a/api/v1/memo.go b/api/v1/memo.go index 03dc7943..045bf396 100644 --- a/api/v1/memo.go +++ b/api/v1/memo.go @@ -11,7 +11,6 @@ import ( "github.com/labstack/echo/v4" "github.com/pkg/errors" - "github.com/usememos/memos/common" "github.com/usememos/memos/store" ) @@ -135,7 +134,6 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user setting").SetInternal(err) } - if userMemoVisibilitySetting != nil { memoVisibility := Private err := json.Unmarshal([]byte(userMemoVisibilitySetting.Value), &memoVisibility) @@ -169,6 +167,9 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) } + if user == nil { + return echo.NewHTTPError(http.StatusNotFound, "User not found") + } // Enforce normal user to create private memo if public memos are disabled. if user.Role == store.RoleUser { createMemoRequest.Visibility = Private @@ -210,6 +211,10 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo").SetInternal(err) } + if memo == nil { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %d", memo.ID)) + } + memoResponse, err := s.convertMemoFromStore(ctx, memo) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err) @@ -235,6 +240,9 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) } + if memo == nil { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %d", memoID)) + } if memo.CreatorID != userID { return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") } @@ -275,6 +283,9 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) } + if memo == nil { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %d", memoID)) + } if patchMemoRequest.ResourceIDList != nil { addedResourceIDList, removedResourceIDList := getIDListDiff(memo.ResourceIDList, patchMemoRequest.ResourceIDList) @@ -326,6 +337,10 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) } + if memo == nil { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %d", memoID)) + } + memoResponse, err := s.convertMemoFromStore(ctx, memo) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err) @@ -424,11 +439,11 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { ID: &memoID, }) if err != nil { - if common.ErrorCode(err) == common.NotFound { - return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo ID not found: %d", memoID)).SetInternal(err) - } return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find memo by ID: %v", memoID)).SetInternal(err) } + if memo == nil { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %d", memoID)) + } userID, ok := c.Get(getUserIDContextKey()).(int) if memo.Visibility == store.Private { @@ -585,6 +600,9 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) } + if memo == nil { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %d", memoID)) + } if memo.CreatorID != userID { return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") } @@ -592,9 +610,6 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { if err := s.Store.DeleteMemo(ctx, &store.DeleteMemo{ ID: memoID, }); err != nil { - if common.ErrorCode(err) == common.NotFound { - return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo ID not found: %d", memoID)) - } return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to delete memo ID: %v", memoID)).SetInternal(err) } return c.JSON(http.StatusOK, true) diff --git a/api/v1/memo_organizer.go b/api/v1/memo_organizer.go index 91bf1609..9adc9f66 100644 --- a/api/v1/memo_organizer.go +++ b/api/v1/memo_organizer.go @@ -39,6 +39,9 @@ func (s *APIV1Service) registerMemoOrganizerRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) } + if memo == nil { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %v", memoID)) + } if memo.CreatorID != userID { return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") } @@ -53,7 +56,7 @@ func (s *APIV1Service) registerMemoOrganizerRoutes(g *echo.Group) { UserID: userID, Pinned: request.Pinned, } - _, err = s.Store.UpsertMemoOrganizerV1(ctx, upsert) + _, err = s.Store.UpsertMemoOrganizer(ctx, upsert) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo organizer").SetInternal(err) } @@ -64,6 +67,9 @@ func (s *APIV1Service) registerMemoOrganizerRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find memo by ID: %v", memoID)).SetInternal(err) } + if memo == nil { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %v", memoID)) + } memoResponse, err := s.convertMemoFromStore(ctx, memo) if err != nil { diff --git a/api/v1/memo_resource.go b/api/v1/memo_resource.go index fded1087..eaba21ff 100644 --- a/api/v1/memo_resource.go +++ b/api/v1/memo_resource.go @@ -116,6 +116,9 @@ func (s *APIV1Service) registerMemoResourceRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err) } + if memo == nil { + return echo.NewHTTPError(http.StatusBadRequest, "Memo not found") + } if memo.CreatorID != userID { return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") } diff --git a/api/v1/resource.go b/api/v1/resource.go index 046f1ef0..05853d19 100644 --- a/api/v1/resource.go +++ b/api/v1/resource.go @@ -21,8 +21,8 @@ import ( "github.com/disintegration/imaging" "github.com/labstack/echo/v4" "github.com/pkg/errors" - "github.com/usememos/memos/common" "github.com/usememos/memos/common/log" + "github.com/usememos/memos/common/util" "github.com/usememos/memos/plugin/storage/s3" "github.com/usememos/memos/store" "go.uber.org/zap" @@ -101,7 +101,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) { Filename: request.Filename, ExternalLink: request.ExternalLink, Type: request.Type, - PublicID: common.GenUUID(), + PublicID: util.GenUUID(), } if request.ExternalLink != "" { // Only allow those external links scheme with http/https @@ -208,7 +208,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) { } } - publicID := common.GenUUID() + publicID := util.GenUUID() if storageServiceID == DatabaseStorage { fileBytes, err := io.ReadAll(sourceFile) if err != nil { @@ -226,7 +226,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) { // as it handles the os-specific path separator automatically. // path.Join() always uses '/' as path separator. systemSettingLocalStoragePath, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{Name: SystemSettingLocalStoragePathName.String()}) - if err != nil && common.ErrorCode(err) != common.NotFound { + if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find local storage path setting").SetInternal(err) } localStoragePath := "assets/{publicid}" @@ -268,6 +268,9 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage").SetInternal(err) } + if storage == nil { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Storage %d not found", storageServiceID)) + } storageMessage, err := ConvertStorageFromStore(storage) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to convert storage").SetInternal(err) @@ -366,6 +369,9 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find resource").SetInternal(err) } + if resource == nil { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Resource not found: %d", resourceID)) + } if resource.CreatorID != userID { return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") } @@ -384,7 +390,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) { update.Filename = request.Filename } if request.ResetPublicID != nil && *request.ResetPublicID { - publicID := common.GenUUID() + publicID := util.GenUUID() update.PublicID = &publicID } @@ -415,7 +421,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find resource").SetInternal(err) } if resource == nil { - return echo.NewHTTPError(http.StatusNotFound, "Resource not found") + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Resource not found: %d", resourceID)) } if resource.InternalPath != "" { @@ -465,6 +471,9 @@ func (s *APIV1Service) registerResourcePublicRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find resource by ID: %v", resourceID)).SetInternal(err) } + if resource == nil { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Resource not found: %d", resourceID)) + } // Private resource require logined user is the creator if resourceVisibility == store.Private && (!ok || userID != resource.CreatorID) { @@ -485,7 +494,7 @@ func (s *APIV1Service) registerResourcePublicRoutes(g *echo.Group) { } } - if c.QueryParam("thumbnail") == "1" && common.HasPrefixes(resource.Type, "image/png", "image/jpeg") { + if c.QueryParam("thumbnail") == "1" && util.HasPrefixes(resource.Type, "image/png", "image/jpeg") { ext := filepath.Ext(resource.Filename) thumbnailPath := path.Join(s.Profile.Data, thumbnailImagePath, fmt.Sprintf("%d-%s%s", resource.ID, resource.PublicID, ext)) thumbnailBlob, err := getOrGenerateThumbnailImage(blob, thumbnailPath) diff --git a/server/rss.go b/api/v1/rss.go similarity index 87% rename from server/rss.go rename to api/v1/rss.go index 81a47c7d..2f9b4fa7 100644 --- a/server/rss.go +++ b/api/v1/rss.go @@ -1,9 +1,10 @@ -package server +package v1 import ( "bytes" "context" "encoding/json" + "fmt" "net/http" "strconv" "strings" @@ -11,13 +12,15 @@ import ( "github.com/gorilla/feeds" "github.com/labstack/echo/v4" - apiv1 "github.com/usememos/memos/api/v1" - "github.com/usememos/memos/common" + "github.com/usememos/memos/common/util" "github.com/usememos/memos/store" "github.com/yuin/goldmark" ) -func (s *Server) registerRSSRoutes(g *echo.Group) { +const maxRSSItemCount = 100 +const maxRSSItemTitleLength = 100 + +func (s *APIV1Service) registerRSSRoutes(g *echo.Group) { g.GET("/explore/rss.xml", func(c echo.Context) error { ctx := c.Request().Context() systemCustomizedProfile, err := s.getSystemCustomizedProfile(ctx) @@ -77,10 +80,7 @@ func (s *Server) registerRSSRoutes(g *echo.Group) { }) } -const MaxRSSItemCount = 100 -const MaxRSSItemTitleLength = 100 - -func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store.Memo, baseURL string, profile *apiv1.CustomizedProfile) (string, error) { +func (s *APIV1Service) generateRSSFromMemoList(ctx context.Context, memoList []*store.Memo, baseURL string, profile *CustomizedProfile) (string, error) { feed := &feeds.Feed{ Title: profile.Name, Link: &feeds.Link{Href: baseURL}, @@ -88,7 +88,7 @@ func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store. Created: time.Now(), } - var itemCountLimit = common.Min(len(memoList), MaxRSSItemCount) + var itemCountLimit = util.Min(len(memoList), maxRSSItemCount) feed.Items = make([]*feeds.Item, itemCountLimit) for i := 0; i < itemCountLimit; i++ { memo := memoList[i] @@ -107,6 +107,9 @@ func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store. if err != nil { return "", err } + if resource == nil { + return "", fmt.Errorf("Resource not found: %d", resourceID) + } enclosure := feeds.Enclosure{} if resource.ExternalLink != "" { enclosure.Url = resource.ExternalLink @@ -126,14 +129,14 @@ func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store. return rss, nil } -func (s *Server) getSystemCustomizedProfile(ctx context.Context) (*apiv1.CustomizedProfile, error) { +func (s *APIV1Service) getSystemCustomizedProfile(ctx context.Context) (*CustomizedProfile, error) { systemSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{ - Name: apiv1.SystemSettingCustomizedProfileName.String(), + Name: SystemSettingCustomizedProfileName.String(), }) if err != nil { return nil, err } - customizedProfile := &apiv1.CustomizedProfile{ + customizedProfile := &CustomizedProfile{ Name: "memos", LogoURL: "", Description: "", @@ -155,7 +158,7 @@ func getRSSItemTitle(content string) string { title = strings.Split(content, "\n")[0][2:] } else { title = strings.Split(content, "\n")[0] - var titleLengthLimit = common.Min(len(title), MaxRSSItemTitleLength) + var titleLengthLimit = util.Min(len(title), maxRSSItemTitleLength) if titleLengthLimit < len(title) { title = title[:titleLengthLimit] + "..." } diff --git a/api/v1/shortcut.go b/api/v1/shortcut.go index 435680b2..8e99b690 100644 --- a/api/v1/shortcut.go +++ b/api/v1/shortcut.go @@ -99,6 +99,9 @@ func (s *APIV1Service) registerShortcutRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find shortcut").SetInternal(err) } + if shortcut == nil { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Shortcut not found: %d", shortcutID)) + } if shortcut.CreatorID != userID { return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") } @@ -165,7 +168,7 @@ func (s *APIV1Service) registerShortcutRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to fetch shortcut by ID %d", shortcutID)).SetInternal(err) } if shortcut == nil { - return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Shortcut by ID %d not found", shortcutID)) + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Shortcut not found: %d", shortcutID)) } return c.JSON(http.StatusOK, convertShortcutFromStore(shortcut)) }) @@ -187,6 +190,9 @@ func (s *APIV1Service) registerShortcutRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find shortcut").SetInternal(err) } + if shortcut == nil { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Shortcut not found: %d", shortcutID)) + } if shortcut.CreatorID != userID { return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") } diff --git a/api/v1/user.go b/api/v1/user.go index 67daa950..623be1c5 100644 --- a/api/v1/user.go +++ b/api/v1/user.go @@ -9,7 +9,7 @@ import ( "github.com/labstack/echo/v4" "github.com/pkg/errors" - "github.com/usememos/memos/common" + "github.com/usememos/memos/common/util" "github.com/usememos/memos/store" "golang.org/x/crypto/bcrypt" ) @@ -77,7 +77,7 @@ func (create CreateUserRequest) Validate() error { if len(create.Email) > 256 { return fmt.Errorf("email is too long, maximum length is 256") } - if !common.ValidateEmail(create.Email) { + if !util.ValidateEmail(create.Email) { return fmt.Errorf("invalid email format") } } @@ -120,7 +120,7 @@ func (update UpdateUserRequest) Validate() error { if len(*update.Email) > 256 { return fmt.Errorf("email is too long, maximum length is 256") } - if !common.ValidateEmail(*update.Email) { + if !util.ValidateEmail(*update.Email) { return fmt.Errorf("invalid email format") } } @@ -141,6 +141,9 @@ func (s *APIV1Service) registerUserRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user by id").SetInternal(err) } + if currentUser == nil { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session") + } if currentUser.Role != store.RoleHost { return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized to create user") } @@ -168,7 +171,7 @@ func (s *APIV1Service) registerUserRoutes(g *echo.Group) { Email: userCreate.Email, Nickname: userCreate.Nickname, PasswordHash: string(passwordHash), - OpenID: common.GenUUID(), + OpenID: util.GenUUID(), }) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err) @@ -211,6 +214,9 @@ func (s *APIV1Service) registerUserRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) } + if user == nil { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session") + } list, err := s.Store.ListUserSettings(ctx, &store.FindUserSetting{ UserID: &userID, @@ -306,7 +312,7 @@ func (s *APIV1Service) registerUserRoutes(g *echo.Group) { userUpdate.PasswordHash = &passwordHashStr } if request.ResetOpenID != nil && *request.ResetOpenID { - openID := common.GenUUID() + openID := util.GenUUID() userUpdate.OpenID = &openID } if request.AvatarURL != nil { diff --git a/api/v1/v1.go b/api/v1/v1.go index 074a7ab5..fcb605cc 100644 --- a/api/v1/v1.go +++ b/api/v1/v1.go @@ -21,6 +21,10 @@ func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store } func (s *APIV1Service) Register(rootGroup *echo.Group) { + // Register RSS routes. + s.registerRSSRoutes(rootGroup) + + // Register API v1 routes. apiV1Group := rootGroup.Group("/api/v1") apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc { return JWTMiddleware(s, next, s.Secret) @@ -40,6 +44,7 @@ func (s *APIV1Service) Register(rootGroup *echo.Group) { s.registerMemoResourceRoutes(apiV1Group) s.registerMemoRelationRoutes(apiV1Group) + // Register public routes. publicGroup := rootGroup.Group("/o") publicGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { return JWTMiddleware(s, next, s.Secret) diff --git a/assets/tech-stack.png b/assets/tech-stack.png deleted file mode 100644 index 9144b026..00000000 Binary files a/assets/tech-stack.png and /dev/null differ diff --git a/common/error.go b/common/error.go deleted file mode 100644 index 3545a9bb..00000000 --- a/common/error.go +++ /dev/null @@ -1,72 +0,0 @@ -package common - -import ( - "errors" -) - -// Code is the error code. -type Code int - -// Application error codes. -const ( - // 0 ~ 99 general error. - Ok Code = 0 - Internal Code = 1 - NotAuthorized Code = 2 - Invalid Code = 3 - NotFound Code = 4 - Conflict Code = 5 - NotImplemented Code = 6 -) - -// Error represents an application-specific error. Application errors can be -// unwrapped by the caller to extract out the code & message. -// -// Any non-application error (such as a disk error) should be reported as an -// Internal error and the human user should only see "Internal error" as the -// message. These low-level internal error details should only be logged and -// reported to the operator of the application (not the end user). -type Error struct { - // Machine-readable error code. - Code Code - - // Embedded error. - Err error -} - -// Error implements the error interface. Not used by the application otherwise. -func (e *Error) Error() string { - return e.Err.Error() -} - -// ErrorCode unwraps an application error and returns its code. -// Non-application errors always return EINTERNAL. -func ErrorCode(err error) Code { - var e *Error - if err == nil { - return Ok - } else if errors.As(err, &e) { - return e.Code - } - return Internal -} - -// ErrorMessage unwraps an application error and returns its message. -// Non-application errors always return "Internal error". -func ErrorMessage(err error) string { - var e *Error - if err == nil { - return "" - } else if errors.As(err, &e) { - return e.Err.Error() - } - return "Internal error." -} - -// Errorf is a helper function to return an Error with a given code and error. -func Errorf(code Code, err error) *Error { - return &Error{ - Code: code, - Err: err, - } -} diff --git a/common/util.go b/common/util/util.go similarity index 98% rename from common/util.go rename to common/util/util.go index 571dbf44..4ae025d6 100644 --- a/common/util.go +++ b/common/util/util.go @@ -1,4 +1,4 @@ -package common +package util import ( "crypto/rand" diff --git a/common/util_test.go b/common/util/util_test.go similarity index 96% rename from common/util_test.go rename to common/util/util_test.go index 9cd7af82..7a80416b 100644 --- a/common/util_test.go +++ b/common/util/util_test.go @@ -1,4 +1,4 @@ -package common +package util import ( "testing" diff --git a/docs/development.md b/docs/development.md index 97db8ed9..32d8e073 100644 --- a/docs/development.md +++ b/docs/development.md @@ -6,10 +6,6 @@ Memos is built with a curated tech stack. It is optimized for developer experien 2. It requires zero config. 3. 1 command to start backend and 1 command to start frontend, both with live reload support. -## Tech Stack - -![tech-stack](https://raw.githubusercontent.com/usememos/memos/main/assets/tech-stack.png) - ## Prerequisites - [Go](https://golang.org/doc/install) diff --git a/server/server.go b/server/server.go index ba543610..de51bc43 100644 --- a/server/server.go +++ b/server/server.go @@ -10,7 +10,7 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" apiv1 "github.com/usememos/memos/api/v1" - "github.com/usememos/memos/common" + "github.com/usememos/memos/common/util" "github.com/usememos/memos/plugin/telegram" "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store" @@ -86,8 +86,6 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store s.Secret = secret rootGroup := e.Group("") - s.registerRSSRoutes(rootGroup) - apiV1Service := apiv1.NewAPIV1Service(s.Secret, profile, store) apiV1Service.Register(rootGroup) @@ -129,7 +127,7 @@ func (s *Server) getSystemServerID(ctx context.Context) (string, error) { serverIDSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{ Name: apiv1.SystemSettingServerIDName.String(), }) - if err != nil && common.ErrorCode(err) != common.NotFound { + if err != nil { return "", err } if serverIDSetting == nil || serverIDSetting.Value == "" { @@ -148,7 +146,7 @@ func (s *Server) getSystemSecretSessionName(ctx context.Context) (string, error) secretSessionNameValue, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{ Name: apiv1.SystemSettingSecretSessionName.String(), }) - if err != nil && common.ErrorCode(err) != common.NotFound { + if err != nil { return "", err } if secretSessionNameValue == nil || secretSessionNameValue.Value == "" { @@ -190,5 +188,5 @@ func defaultGetRequestSkipper(c echo.Context) bool { func defaultAPIRequestSkipper(c echo.Context) bool { path := c.Path() - return common.HasPrefixes(path, "/api", "/api/v1") + return util.HasPrefixes(path, "/api", "/api/v1") } diff --git a/server/telegram.go b/server/telegram.go index 1e072e46..a93ad32b 100644 --- a/server/telegram.go +++ b/server/telegram.go @@ -9,7 +9,7 @@ import ( "github.com/pkg/errors" apiv1 "github.com/usememos/memos/api/v1" - "github.com/usememos/memos/common" + "github.com/usememos/memos/common/util" "github.com/usememos/memos/plugin/telegram" "github.com/usememos/memos/store" ) @@ -94,7 +94,7 @@ func (t *telegramHandler) MessageHandle(ctx context.Context, bot *telegram.Bot, Type: mime, Size: int64(len(blob)), Blob: blob, - PublicID: common.GenUUID(), + PublicID: util.GenUUID(), }) if err != nil { _, err := bot.EditMessage(ctx, message.Chat.ID, reply.MessageID, fmt.Sprintf("failed to CreateResource: %s", err), nil) diff --git a/setup/setup.go b/setup/setup.go index 1eefb25b..0ed27e79 100644 --- a/setup/setup.go +++ b/setup/setup.go @@ -7,7 +7,7 @@ import ( "golang.org/x/crypto/bcrypt" - "github.com/usememos/memos/common" + "github.com/usememos/memos/common/util" "github.com/usememos/memos/store" ) @@ -51,7 +51,7 @@ func (s setupService) createUser(ctx context.Context, hostUsername, hostPassword // The new signup user should be normal user by default. Role: store.RoleHost, Nickname: hostUsername, - OpenID: common.GenUUID(), + OpenID: util.GenUUID(), } if len(userCreate.Username) < 3 { @@ -73,7 +73,7 @@ func (s setupService) createUser(ctx context.Context, hostUsername, hostPassword if len(userCreate.Email) > 256 { return fmt.Errorf("email is too long, maximum length is 256") } - if !common.ValidateEmail(userCreate.Email) { + if !util.ValidateEmail(userCreate.Email) { return fmt.Errorf("invalid email format") } } diff --git a/store/activity.go b/store/activity.go index c59fc4db..1b7a12a5 100644 --- a/store/activity.go +++ b/store/activity.go @@ -17,7 +17,6 @@ type Activity struct { Payload string } -// CreateActivity creates an instance of Activity. func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { diff --git a/store/error.go b/store/error.go deleted file mode 100644 index 2bbef297..00000000 --- a/store/error.go +++ /dev/null @@ -1,19 +0,0 @@ -package store - -import ( - "database/sql" - "errors" -) - -func FormatError(err error) error { - if err == nil { - return nil - } - - switch err { - case sql.ErrNoRows: - return errors.New("data not found") - default: - return err - } -} diff --git a/store/idp.go b/store/idp.go index 954f0294..4f22a871 100644 --- a/store/idp.go +++ b/store/idp.go @@ -256,7 +256,7 @@ func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdenti } func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityProvider) ([]*IdentityProvider, error) { - where, args := []string{"TRUE"}, []any{} + where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v) } diff --git a/store/memo_organizer.go b/store/memo_organizer.go index 2e10a5da..1047f30c 100644 --- a/store/memo_organizer.go +++ b/store/memo_organizer.go @@ -23,7 +23,7 @@ type DeleteMemoOrganizer struct { UserID *int } -func (s *Store) UpsertMemoOrganizerV1(ctx context.Context, upsert *MemoOrganizer) (*MemoOrganizer, error) { +func (s *Store) UpsertMemoOrganizer(ctx context.Context, upsert *MemoOrganizer) (*MemoOrganizer, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return nil, err @@ -53,7 +53,7 @@ func (s *Store) UpsertMemoOrganizerV1(ctx context.Context, upsert *MemoOrganizer return memoOrganizer, nil } -func (s *Store) GetMemoOrganizerV1(ctx context.Context, find *FindMemoOrganizer) (*MemoOrganizer, error) { +func (s *Store) GetMemoOrganizer(ctx context.Context, find *FindMemoOrganizer) (*MemoOrganizer, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return nil, err @@ -69,6 +69,7 @@ func (s *Store) GetMemoOrganizerV1(ctx context.Context, find *FindMemoOrganizer) where = append(where, "user_id = ?") args = append(args, find.UserID) } + query := fmt.Sprintf(` SELECT memo_id, @@ -78,6 +79,12 @@ func (s *Store) GetMemoOrganizerV1(ctx context.Context, find *FindMemoOrganizer) WHERE %s `, strings.Join(where, " AND ")) row := tx.QueryRowContext(ctx, query, args...) + if err := row.Err(); err != nil { + return nil, err + } + if row == nil { + return nil, nil + } memoOrganizer := &MemoOrganizer{} if err := row.Scan( @@ -88,13 +95,17 @@ func (s *Store) GetMemoOrganizerV1(ctx context.Context, find *FindMemoOrganizer) return nil, err } + if err := tx.Commit(); err != nil { + return nil, err + } + return memoOrganizer, nil } -func (s *Store) DeleteMemoOrganizerV1(ctx context.Context, delete *DeleteMemoOrganizer) error { +func (s *Store) DeleteMemoOrganizer(ctx context.Context, delete *DeleteMemoOrganizer) error { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return FormatError(err) + return err } defer tx.Rollback() @@ -110,11 +121,12 @@ func (s *Store) DeleteMemoOrganizerV1(ctx context.Context, delete *DeleteMemoOrg stmt := `DELETE FROM memo_organizer WHERE ` + strings.Join(where, " AND ") _, err = tx.ExecContext(ctx, stmt, args...) if err != nil { - return FormatError(err) + return err } if err := tx.Commit(); err != nil { - return FormatError(err) + // Prevent linter warning. + return err } return nil @@ -139,7 +151,7 @@ func vacuumMemoOrganizer(ctx context.Context, tx *sql.Tx) error { )` _, err := tx.ExecContext(ctx, stmt) if err != nil { - return FormatError(err) + return err } return nil diff --git a/store/memo_relation.go b/store/memo_relation.go index d0d1c8c7..3230b54a 100644 --- a/store/memo_relation.go +++ b/store/memo_relation.go @@ -49,7 +49,7 @@ func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (* type = EXCLUDED.type RETURNING memo_id, related_memo_id, type ` - memoRelationMessage := &MemoRelation{} + memoRelation := &MemoRelation{} if err := tx.QueryRowContext( ctx, query, @@ -57,16 +57,18 @@ func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (* create.RelatedMemoID, create.Type, ).Scan( - &memoRelationMessage.MemoID, - &memoRelationMessage.RelatedMemoID, - &memoRelationMessage.Type, + &memoRelation.MemoID, + &memoRelation.RelatedMemoID, + &memoRelation.Type, ); err != nil { return nil, err } + if err := tx.Commit(); err != nil { return nil, err } - return memoRelationMessage, nil + + return memoRelation, nil } func (s *Store) ListMemoRelations(ctx context.Context, find *FindMemoRelation) ([]*MemoRelation, error) { diff --git a/store/memo_resource.go b/store/memo_resource.go index 7796d088..41024b37 100644 --- a/store/memo_resource.go +++ b/store/memo_resource.go @@ -33,7 +33,7 @@ type DeleteMemoResource struct { func (s *Store) UpsertMemoResource(ctx context.Context, upsert *UpsertMemoResource) (*MemoResource, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -62,11 +62,11 @@ func (s *Store) UpsertMemoResource(ctx context.Context, upsert *UpsertMemoResour &memoResource.CreatedTs, &memoResource.UpdatedTs, ); err != nil { - return nil, FormatError(err) + return nil, err } if err := tx.Commit(); err != nil { - return nil, FormatError(err) + return nil, err } return memoResource, nil @@ -117,7 +117,7 @@ func (s *Store) GetMemoResource(ctx context.Context, find *FindMemoResource) (*M func (s *Store) DeleteMemoResource(ctx context.Context, delete *DeleteMemoResource) error { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return FormatError(err) + return err } defer tx.Rollback() @@ -133,11 +133,12 @@ func (s *Store) DeleteMemoResource(ctx context.Context, delete *DeleteMemoResour stmt := `DELETE FROM memo_resource WHERE ` + strings.Join(where, " AND ") _, err = tx.ExecContext(ctx, stmt, args...) if err != nil { - return FormatError(err) + return err } if err := tx.Commit(); err != nil { - return FormatError(err) + // Prevent linter warning. + return err } return nil @@ -165,7 +166,7 @@ func listMemoResources(ctx context.Context, tx *sql.Tx, find *FindMemoResource) ` rows, err := tx.QueryContext(ctx, query, args...) if err != nil { - return nil, FormatError(err) + return nil, err } defer rows.Close() @@ -178,7 +179,7 @@ func listMemoResources(ctx context.Context, tx *sql.Tx, find *FindMemoResource) &memoResource.CreatedTs, &memoResource.UpdatedTs, ); err != nil { - return nil, FormatError(err) + return nil, err } list = append(list, &memoResource) @@ -210,7 +211,7 @@ func vacuumMemoResource(ctx context.Context, tx *sql.Tx) error { )` _, err := tx.ExecContext(ctx, stmt) if err != nil { - return FormatError(err) + return err } return nil diff --git a/store/resource.go b/store/resource.go index e51e6fe8..e4f75685 100644 --- a/store/resource.go +++ b/store/resource.go @@ -51,7 +51,7 @@ type DeleteResource struct { func (s *Store) CreateResource(ctx context.Context, create *Resource) (*Resource, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -85,7 +85,7 @@ func (s *Store) CreateResource(ctx context.Context, create *Resource) (*Resource func (s *Store) ListResources(ctx context.Context, find *FindResource) ([]*Resource, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -108,7 +108,7 @@ func (s *Store) ListResources(ctx context.Context, find *FindResource) ([]*Resou func (s *Store) GetResource(ctx context.Context, find *FindResource) (*Resource, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -131,7 +131,7 @@ func (s *Store) GetResource(ctx context.Context, find *FindResource) (*Resource, func (s *Store) UpdateResource(ctx context.Context, update *UpdateResource) (*Resource, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -168,7 +168,7 @@ func (s *Store) UpdateResource(ctx context.Context, update *UpdateResource) (*Re &resource.PublicID, } if err := tx.QueryRowContext(ctx, query, args...).Scan(dests...); err != nil { - return nil, FormatError(err) + return nil, err } if err := tx.Commit(); err != nil { @@ -181,7 +181,7 @@ func (s *Store) UpdateResource(ctx context.Context, update *UpdateResource) (*Re func (s *Store) DeleteResource(ctx context.Context, delete *DeleteResource) error { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return FormatError(err) + return err } defer tx.Rollback() @@ -243,7 +243,7 @@ func listResources(ctx context.Context, tx *sql.Tx, find *FindResource) ([]*Reso rows, err := tx.QueryContext(ctx, query, args...) if err != nil { - return nil, FormatError(err) + return nil, err } defer rows.Close() @@ -267,13 +267,13 @@ func listResources(ctx context.Context, tx *sql.Tx, find *FindResource) ([]*Reso dests = append(dests, &resource.Blob) } if err := rows.Scan(dests...); err != nil { - return nil, FormatError(err) + return nil, err } list = append(list, &resource) } if err := rows.Err(); err != nil { - return nil, FormatError(err) + return nil, err } return list, nil @@ -292,7 +292,7 @@ func vacuumResource(ctx context.Context, tx *sql.Tx) error { )` _, err := tx.ExecContext(ctx, stmt) if err != nil { - return FormatError(err) + return err } return nil diff --git a/store/store.go b/store/store.go index 661ed08e..6002ec2f 100644 --- a/store/store.go +++ b/store/store.go @@ -15,7 +15,7 @@ type Store struct { systemSettingCache sync.Map // map[string]*SystemSetting userCache sync.Map // map[int]*User userSettingCache sync.Map // map[string]*UserSetting - shortcutCache sync.Map // map[int]*shortcutRaw + shortcutCache sync.Map // map[int]*Shortcut idpCache sync.Map // map[int]*IdentityProvider } diff --git a/store/user.go b/store/user.go index ea747d1f..13b71ef4 100644 --- a/store/user.go +++ b/store/user.go @@ -65,17 +65,13 @@ type UpdateUser struct { } type FindUser struct { - ID *int - - // Standard fields + ID *int RowStatus *RowStatus - - // Domain specific fields - Username *string - Role *Role - Email *string - Nickname *string - OpenID *string + Username *string + Role *Role + Email *string + Nickname *string + OpenID *string } type DeleteUser struct { diff --git a/test/store/idp_test.go b/test/store/idp_test.go index 0d8276c4..847db471 100644 --- a/test/store/idp_test.go +++ b/test/store/idp_test.go @@ -37,6 +37,7 @@ func TestIdentityProviderStore(t *testing.T) { ID: &createdIDP.ID, }) require.NoError(t, err) + require.NotNil(t, idp) require.Equal(t, createdIDP, idp) newName := "My GitHub OAuth" updatedIdp, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProvider{ diff --git a/test/store/memo_test.go b/test/store/memo_test.go index e32ed682..16e9393d 100644 --- a/test/store/memo_test.go +++ b/test/store/memo_test.go @@ -32,6 +32,7 @@ func TestMemoStore(t *testing.T) { ID: &memo.ID, }) require.NoError(t, err) + require.NotNil(t, memo) memoList, err := ts.ListMemos(ctx, &store.FindMemo{ CreatorID: &user.ID, })