chore: update common utils (#1908)

This commit is contained in:
boojack 2023-07-06 22:53:38 +08:00 committed by GitHub
parent a7573d5705
commit 0e05c62a3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 168 additions and 204 deletions

View File

@ -8,7 +8,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/common"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/plugin/idp"
"github.com/usememos/memos/plugin/idp/oauth2"
"github.com/usememos/memos/server/auth"
@ -43,7 +43,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &signin.Username,
})
if err != nil && common.ErrorCode(err) != common.NotFound {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Incorrect login credentials, please try again")
}
if user == nil {
@ -114,7 +114,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &userInfo.Identifier,
})
if err != nil && common.ErrorCode(err) != common.NotFound {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Incorrect login credentials, please try again")
}
if user == nil {
@ -124,9 +124,9 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
Role: store.RoleUser,
Nickname: userInfo.DisplayName,
Email: userInfo.Email,
OpenID: common.GenUUID(),
OpenID: util.GenUUID(),
}
password, err := common.RandomString(20)
password, err := util.RandomString(20)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate random password").SetInternal(err)
}
@ -173,7 +173,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
// The new signup user should be normal user by default.
Role: store.RoleUser,
Nickname: signup.Username,
OpenID: common.GenUUID(),
OpenID: util.GenUUID(),
}
if len(existedHostUsers) == 0 {
// Change the default role to host if there is no host user.
@ -182,7 +182,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
allowSignUpSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
Name: SystemSettingAllowSignUpName.String(),
})
if err != nil && common.ErrorCode(err) != common.NotFound {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting").SetInternal(err)
}

View File

@ -7,7 +7,6 @@ import (
"strconv"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/common"
"github.com/usememos/memos/store"
)
@ -231,9 +230,6 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
}
if err = s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: identityProviderID}); err != nil {
if common.ErrorCode(err) == common.NotFound {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Identity provider ID not found: %d", identityProviderID))
}
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete identity provider").SetInternal(err)
}
return c.JSON(http.StatusOK, true)

View File

@ -10,7 +10,7 @@ import (
"github.com/golang-jwt/jwt/v4"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/common"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
@ -82,18 +82,18 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
}
// Skip validation for server status endpoints.
if common.HasPrefixes(path, "/api/v1/ping", "/api/v1/idp", "/api/v1/status", "/api/v1/user/:id") && method == http.MethodGet {
if util.HasPrefixes(path, "/api/v1/ping", "/api/v1/idp", "/api/v1/status", "/api/v1/user/:id") && method == http.MethodGet {
return next(c)
}
token := findAccessToken(c)
if token == "" {
// Allow the user to access the public endpoints.
if common.HasPrefixes(path, "/o") {
if util.HasPrefixes(path, "/o") {
return next(c)
}
// When the request is not authenticated, we allow the user to access the memo endpoints for those public memos.
if common.HasPrefixes(path, "/api/v1/memo") && method == http.MethodGet {
if util.HasPrefixes(path, "/api/v1/memo") && method == http.MethodGet {
return next(c)
}
return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token")
@ -215,7 +215,7 @@ func (s *APIV1Service) defaultAuthSkipper(c echo.Context) bool {
path := c.Path()
// Skip auth.
if common.HasPrefixes(path, "/api/v1/auth") {
if util.HasPrefixes(path, "/api/v1/auth") {
return true
}
@ -225,7 +225,7 @@ func (s *APIV1Service) defaultAuthSkipper(c echo.Context) bool {
user, err := s.Store.GetUser(ctx, &store.FindUser{
OpenID: &openID,
})
if err != nil && common.ErrorCode(err) != common.NotFound {
if err != nil {
return false
}
if user != nil {

View File

@ -11,7 +11,6 @@ import (
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/common"
"github.com/usememos/memos/store"
)
@ -135,7 +134,6 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user setting").SetInternal(err)
}
if userMemoVisibilitySetting != nil {
memoVisibility := Private
err := json.Unmarshal([]byte(userMemoVisibilitySetting.Value), &memoVisibility)
@ -169,6 +167,9 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil {
return echo.NewHTTPError(http.StatusNotFound, "User not found")
}
// Enforce normal user to create private memo if public memos are disabled.
if user.Role == store.RoleUser {
createMemoRequest.Visibility = Private
@ -210,6 +211,10 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo").SetInternal(err)
}
if memo == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %d", memo.ID))
}
memoResponse, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err)
@ -235,6 +240,9 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err)
}
if memo == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %d", memoID))
}
if memo.CreatorID != userID {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
@ -275,6 +283,9 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err)
}
if memo == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %d", memoID))
}
if patchMemoRequest.ResourceIDList != nil {
addedResourceIDList, removedResourceIDList := getIDListDiff(memo.ResourceIDList, patchMemoRequest.ResourceIDList)
@ -326,6 +337,10 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err)
}
if memo == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %d", memoID))
}
memoResponse, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err)
@ -424,11 +439,11 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) {
ID: &memoID,
})
if err != nil {
if common.ErrorCode(err) == common.NotFound {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo ID not found: %d", memoID)).SetInternal(err)
}
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find memo by ID: %v", memoID)).SetInternal(err)
}
if memo == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %d", memoID))
}
userID, ok := c.Get(getUserIDContextKey()).(int)
if memo.Visibility == store.Private {
@ -585,6 +600,9 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err)
}
if memo == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %d", memoID))
}
if memo.CreatorID != userID {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
@ -592,9 +610,6 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) {
if err := s.Store.DeleteMemo(ctx, &store.DeleteMemo{
ID: memoID,
}); err != nil {
if common.ErrorCode(err) == common.NotFound {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo ID not found: %d", memoID))
}
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to delete memo ID: %v", memoID)).SetInternal(err)
}
return c.JSON(http.StatusOK, true)

View File

@ -39,6 +39,9 @@ func (s *APIV1Service) registerMemoOrganizerRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err)
}
if memo == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %v", memoID))
}
if memo.CreatorID != userID {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
@ -53,7 +56,7 @@ func (s *APIV1Service) registerMemoOrganizerRoutes(g *echo.Group) {
UserID: userID,
Pinned: request.Pinned,
}
_, err = s.Store.UpsertMemoOrganizerV1(ctx, upsert)
_, err = s.Store.UpsertMemoOrganizer(ctx, upsert)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo organizer").SetInternal(err)
}
@ -64,6 +67,9 @@ func (s *APIV1Service) registerMemoOrganizerRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find memo by ID: %v", memoID)).SetInternal(err)
}
if memo == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %v", memoID))
}
memoResponse, err := s.convertMemoFromStore(ctx, memo)
if err != nil {

View File

@ -116,6 +116,9 @@ func (s *APIV1Service) registerMemoResourceRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err)
}
if memo == nil {
return echo.NewHTTPError(http.StatusBadRequest, "Memo not found")
}
if memo.CreatorID != userID {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}

View File

@ -21,8 +21,8 @@ import (
"github.com/disintegration/imaging"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/common"
"github.com/usememos/memos/common/log"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/plugin/storage/s3"
"github.com/usememos/memos/store"
"go.uber.org/zap"
@ -101,7 +101,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) {
Filename: request.Filename,
ExternalLink: request.ExternalLink,
Type: request.Type,
PublicID: common.GenUUID(),
PublicID: util.GenUUID(),
}
if request.ExternalLink != "" {
// Only allow those external links scheme with http/https
@ -208,7 +208,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) {
}
}
publicID := common.GenUUID()
publicID := util.GenUUID()
if storageServiceID == DatabaseStorage {
fileBytes, err := io.ReadAll(sourceFile)
if err != nil {
@ -226,7 +226,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) {
// as it handles the os-specific path separator automatically.
// path.Join() always uses '/' as path separator.
systemSettingLocalStoragePath, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{Name: SystemSettingLocalStoragePathName.String()})
if err != nil && common.ErrorCode(err) != common.NotFound {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find local storage path setting").SetInternal(err)
}
localStoragePath := "assets/{publicid}"
@ -268,6 +268,9 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage").SetInternal(err)
}
if storage == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Storage %d not found", storageServiceID))
}
storageMessage, err := ConvertStorageFromStore(storage)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to convert storage").SetInternal(err)
@ -366,6 +369,9 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find resource").SetInternal(err)
}
if resource == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Resource not found: %d", resourceID))
}
if resource.CreatorID != userID {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
@ -384,7 +390,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) {
update.Filename = request.Filename
}
if request.ResetPublicID != nil && *request.ResetPublicID {
publicID := common.GenUUID()
publicID := util.GenUUID()
update.PublicID = &publicID
}
@ -415,7 +421,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find resource").SetInternal(err)
}
if resource == nil {
return echo.NewHTTPError(http.StatusNotFound, "Resource not found")
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Resource not found: %d", resourceID))
}
if resource.InternalPath != "" {
@ -465,6 +471,9 @@ func (s *APIV1Service) registerResourcePublicRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find resource by ID: %v", resourceID)).SetInternal(err)
}
if resource == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Resource not found: %d", resourceID))
}
// Private resource require logined user is the creator
if resourceVisibility == store.Private && (!ok || userID != resource.CreatorID) {
@ -485,7 +494,7 @@ func (s *APIV1Service) registerResourcePublicRoutes(g *echo.Group) {
}
}
if c.QueryParam("thumbnail") == "1" && common.HasPrefixes(resource.Type, "image/png", "image/jpeg") {
if c.QueryParam("thumbnail") == "1" && util.HasPrefixes(resource.Type, "image/png", "image/jpeg") {
ext := filepath.Ext(resource.Filename)
thumbnailPath := path.Join(s.Profile.Data, thumbnailImagePath, fmt.Sprintf("%d-%s%s", resource.ID, resource.PublicID, ext))
thumbnailBlob, err := getOrGenerateThumbnailImage(blob, thumbnailPath)

View File

@ -1,9 +1,10 @@
package server
package v1
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
@ -11,13 +12,15 @@ import (
"github.com/gorilla/feeds"
"github.com/labstack/echo/v4"
apiv1 "github.com/usememos/memos/api/v1"
"github.com/usememos/memos/common"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/store"
"github.com/yuin/goldmark"
)
func (s *Server) registerRSSRoutes(g *echo.Group) {
const maxRSSItemCount = 100
const maxRSSItemTitleLength = 100
func (s *APIV1Service) registerRSSRoutes(g *echo.Group) {
g.GET("/explore/rss.xml", func(c echo.Context) error {
ctx := c.Request().Context()
systemCustomizedProfile, err := s.getSystemCustomizedProfile(ctx)
@ -77,10 +80,7 @@ func (s *Server) registerRSSRoutes(g *echo.Group) {
})
}
const MaxRSSItemCount = 100
const MaxRSSItemTitleLength = 100
func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store.Memo, baseURL string, profile *apiv1.CustomizedProfile) (string, error) {
func (s *APIV1Service) generateRSSFromMemoList(ctx context.Context, memoList []*store.Memo, baseURL string, profile *CustomizedProfile) (string, error) {
feed := &feeds.Feed{
Title: profile.Name,
Link: &feeds.Link{Href: baseURL},
@ -88,7 +88,7 @@ func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store.
Created: time.Now(),
}
var itemCountLimit = common.Min(len(memoList), MaxRSSItemCount)
var itemCountLimit = util.Min(len(memoList), maxRSSItemCount)
feed.Items = make([]*feeds.Item, itemCountLimit)
for i := 0; i < itemCountLimit; i++ {
memo := memoList[i]
@ -107,6 +107,9 @@ func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store.
if err != nil {
return "", err
}
if resource == nil {
return "", fmt.Errorf("Resource not found: %d", resourceID)
}
enclosure := feeds.Enclosure{}
if resource.ExternalLink != "" {
enclosure.Url = resource.ExternalLink
@ -126,14 +129,14 @@ func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store.
return rss, nil
}
func (s *Server) getSystemCustomizedProfile(ctx context.Context) (*apiv1.CustomizedProfile, error) {
func (s *APIV1Service) getSystemCustomizedProfile(ctx context.Context) (*CustomizedProfile, error) {
systemSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
Name: apiv1.SystemSettingCustomizedProfileName.String(),
Name: SystemSettingCustomizedProfileName.String(),
})
if err != nil {
return nil, err
}
customizedProfile := &apiv1.CustomizedProfile{
customizedProfile := &CustomizedProfile{
Name: "memos",
LogoURL: "",
Description: "",
@ -155,7 +158,7 @@ func getRSSItemTitle(content string) string {
title = strings.Split(content, "\n")[0][2:]
} else {
title = strings.Split(content, "\n")[0]
var titleLengthLimit = common.Min(len(title), MaxRSSItemTitleLength)
var titleLengthLimit = util.Min(len(title), maxRSSItemTitleLength)
if titleLengthLimit < len(title) {
title = title[:titleLengthLimit] + "..."
}

View File

@ -99,6 +99,9 @@ func (s *APIV1Service) registerShortcutRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find shortcut").SetInternal(err)
}
if shortcut == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Shortcut not found: %d", shortcutID))
}
if shortcut.CreatorID != userID {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
@ -165,7 +168,7 @@ func (s *APIV1Service) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to fetch shortcut by ID %d", shortcutID)).SetInternal(err)
}
if shortcut == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Shortcut by ID %d not found", shortcutID))
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Shortcut not found: %d", shortcutID))
}
return c.JSON(http.StatusOK, convertShortcutFromStore(shortcut))
})
@ -187,6 +190,9 @@ func (s *APIV1Service) registerShortcutRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find shortcut").SetInternal(err)
}
if shortcut == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Shortcut not found: %d", shortcutID))
}
if shortcut.CreatorID != userID {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}

View File

@ -9,7 +9,7 @@ import (
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/common"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/store"
"golang.org/x/crypto/bcrypt"
)
@ -77,7 +77,7 @@ func (create CreateUserRequest) Validate() error {
if len(create.Email) > 256 {
return fmt.Errorf("email is too long, maximum length is 256")
}
if !common.ValidateEmail(create.Email) {
if !util.ValidateEmail(create.Email) {
return fmt.Errorf("invalid email format")
}
}
@ -120,7 +120,7 @@ func (update UpdateUserRequest) Validate() error {
if len(*update.Email) > 256 {
return fmt.Errorf("email is too long, maximum length is 256")
}
if !common.ValidateEmail(*update.Email) {
if !util.ValidateEmail(*update.Email) {
return fmt.Errorf("invalid email format")
}
}
@ -141,6 +141,9 @@ func (s *APIV1Service) registerUserRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user by id").SetInternal(err)
}
if currentUser == nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
if currentUser.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized to create user")
}
@ -168,7 +171,7 @@ func (s *APIV1Service) registerUserRoutes(g *echo.Group) {
Email: userCreate.Email,
Nickname: userCreate.Nickname,
PasswordHash: string(passwordHash),
OpenID: common.GenUUID(),
OpenID: util.GenUUID(),
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
@ -211,6 +214,9 @@ func (s *APIV1Service) registerUserRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
list, err := s.Store.ListUserSettings(ctx, &store.FindUserSetting{
UserID: &userID,
@ -306,7 +312,7 @@ func (s *APIV1Service) registerUserRoutes(g *echo.Group) {
userUpdate.PasswordHash = &passwordHashStr
}
if request.ResetOpenID != nil && *request.ResetOpenID {
openID := common.GenUUID()
openID := util.GenUUID()
userUpdate.OpenID = &openID
}
if request.AvatarURL != nil {

View File

@ -21,6 +21,10 @@ func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store
}
func (s *APIV1Service) Register(rootGroup *echo.Group) {
// Register RSS routes.
s.registerRSSRoutes(rootGroup)
// Register API v1 routes.
apiV1Group := rootGroup.Group("/api/v1")
apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, s.Secret)
@ -40,6 +44,7 @@ func (s *APIV1Service) Register(rootGroup *echo.Group) {
s.registerMemoResourceRoutes(apiV1Group)
s.registerMemoRelationRoutes(apiV1Group)
// Register public routes.
publicGroup := rootGroup.Group("/o")
publicGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, s.Secret)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 374 KiB

View File

@ -1,72 +0,0 @@
package common
import (
"errors"
)
// Code is the error code.
type Code int
// Application error codes.
const (
// 0 ~ 99 general error.
Ok Code = 0
Internal Code = 1
NotAuthorized Code = 2
Invalid Code = 3
NotFound Code = 4
Conflict Code = 5
NotImplemented Code = 6
)
// Error represents an application-specific error. Application errors can be
// unwrapped by the caller to extract out the code & message.
//
// Any non-application error (such as a disk error) should be reported as an
// Internal error and the human user should only see "Internal error" as the
// message. These low-level internal error details should only be logged and
// reported to the operator of the application (not the end user).
type Error struct {
// Machine-readable error code.
Code Code
// Embedded error.
Err error
}
// Error implements the error interface. Not used by the application otherwise.
func (e *Error) Error() string {
return e.Err.Error()
}
// ErrorCode unwraps an application error and returns its code.
// Non-application errors always return EINTERNAL.
func ErrorCode(err error) Code {
var e *Error
if err == nil {
return Ok
} else if errors.As(err, &e) {
return e.Code
}
return Internal
}
// ErrorMessage unwraps an application error and returns its message.
// Non-application errors always return "Internal error".
func ErrorMessage(err error) string {
var e *Error
if err == nil {
return ""
} else if errors.As(err, &e) {
return e.Err.Error()
}
return "Internal error."
}
// Errorf is a helper function to return an Error with a given code and error.
func Errorf(code Code, err error) *Error {
return &Error{
Code: code,
Err: err,
}
}

View File

@ -1,4 +1,4 @@
package common
package util
import (
"crypto/rand"

View File

@ -1,4 +1,4 @@
package common
package util
import (
"testing"

View File

@ -6,10 +6,6 @@ Memos is built with a curated tech stack. It is optimized for developer experien
2. It requires zero config.
3. 1 command to start backend and 1 command to start frontend, both with live reload support.
## Tech Stack
![tech-stack](https://raw.githubusercontent.com/usememos/memos/main/assets/tech-stack.png)
## Prerequisites
- [Go](https://golang.org/doc/install)

View File

@ -10,7 +10,7 @@ import (
"github.com/google/uuid"
"github.com/pkg/errors"
apiv1 "github.com/usememos/memos/api/v1"
"github.com/usememos/memos/common"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/plugin/telegram"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store"
@ -86,8 +86,6 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
s.Secret = secret
rootGroup := e.Group("")
s.registerRSSRoutes(rootGroup)
apiV1Service := apiv1.NewAPIV1Service(s.Secret, profile, store)
apiV1Service.Register(rootGroup)
@ -129,7 +127,7 @@ func (s *Server) getSystemServerID(ctx context.Context) (string, error) {
serverIDSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
Name: apiv1.SystemSettingServerIDName.String(),
})
if err != nil && common.ErrorCode(err) != common.NotFound {
if err != nil {
return "", err
}
if serverIDSetting == nil || serverIDSetting.Value == "" {
@ -148,7 +146,7 @@ func (s *Server) getSystemSecretSessionName(ctx context.Context) (string, error)
secretSessionNameValue, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
Name: apiv1.SystemSettingSecretSessionName.String(),
})
if err != nil && common.ErrorCode(err) != common.NotFound {
if err != nil {
return "", err
}
if secretSessionNameValue == nil || secretSessionNameValue.Value == "" {
@ -190,5 +188,5 @@ func defaultGetRequestSkipper(c echo.Context) bool {
func defaultAPIRequestSkipper(c echo.Context) bool {
path := c.Path()
return common.HasPrefixes(path, "/api", "/api/v1")
return util.HasPrefixes(path, "/api", "/api/v1")
}

View File

@ -9,7 +9,7 @@ import (
"github.com/pkg/errors"
apiv1 "github.com/usememos/memos/api/v1"
"github.com/usememos/memos/common"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/plugin/telegram"
"github.com/usememos/memos/store"
)
@ -94,7 +94,7 @@ func (t *telegramHandler) MessageHandle(ctx context.Context, bot *telegram.Bot,
Type: mime,
Size: int64(len(blob)),
Blob: blob,
PublicID: common.GenUUID(),
PublicID: util.GenUUID(),
})
if err != nil {
_, err := bot.EditMessage(ctx, message.Chat.ID, reply.MessageID, fmt.Sprintf("failed to CreateResource: %s", err), nil)

View File

@ -7,7 +7,7 @@ import (
"golang.org/x/crypto/bcrypt"
"github.com/usememos/memos/common"
"github.com/usememos/memos/common/util"
"github.com/usememos/memos/store"
)
@ -51,7 +51,7 @@ func (s setupService) createUser(ctx context.Context, hostUsername, hostPassword
// The new signup user should be normal user by default.
Role: store.RoleHost,
Nickname: hostUsername,
OpenID: common.GenUUID(),
OpenID: util.GenUUID(),
}
if len(userCreate.Username) < 3 {
@ -73,7 +73,7 @@ func (s setupService) createUser(ctx context.Context, hostUsername, hostPassword
if len(userCreate.Email) > 256 {
return fmt.Errorf("email is too long, maximum length is 256")
}
if !common.ValidateEmail(userCreate.Email) {
if !util.ValidateEmail(userCreate.Email) {
return fmt.Errorf("invalid email format")
}
}

View File

@ -17,7 +17,6 @@ type Activity struct {
Payload string
}
// CreateActivity creates an instance of Activity.
func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {

View File

@ -1,19 +0,0 @@
package store
import (
"database/sql"
"errors"
)
func FormatError(err error) error {
if err == nil {
return nil
}
switch err {
case sql.ErrNoRows:
return errors.New("data not found")
default:
return err
}
}

View File

@ -256,7 +256,7 @@ func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdenti
}
func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityProvider) ([]*IdentityProvider, error) {
where, args := []string{"TRUE"}, []any{}
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
}

View File

@ -23,7 +23,7 @@ type DeleteMemoOrganizer struct {
UserID *int
}
func (s *Store) UpsertMemoOrganizerV1(ctx context.Context, upsert *MemoOrganizer) (*MemoOrganizer, error) {
func (s *Store) UpsertMemoOrganizer(ctx context.Context, upsert *MemoOrganizer) (*MemoOrganizer, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
@ -53,7 +53,7 @@ func (s *Store) UpsertMemoOrganizerV1(ctx context.Context, upsert *MemoOrganizer
return memoOrganizer, nil
}
func (s *Store) GetMemoOrganizerV1(ctx context.Context, find *FindMemoOrganizer) (*MemoOrganizer, error) {
func (s *Store) GetMemoOrganizer(ctx context.Context, find *FindMemoOrganizer) (*MemoOrganizer, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
@ -69,6 +69,7 @@ func (s *Store) GetMemoOrganizerV1(ctx context.Context, find *FindMemoOrganizer)
where = append(where, "user_id = ?")
args = append(args, find.UserID)
}
query := fmt.Sprintf(`
SELECT
memo_id,
@ -78,6 +79,12 @@ func (s *Store) GetMemoOrganizerV1(ctx context.Context, find *FindMemoOrganizer)
WHERE %s
`, strings.Join(where, " AND "))
row := tx.QueryRowContext(ctx, query, args...)
if err := row.Err(); err != nil {
return nil, err
}
if row == nil {
return nil, nil
}
memoOrganizer := &MemoOrganizer{}
if err := row.Scan(
@ -88,13 +95,17 @@ func (s *Store) GetMemoOrganizerV1(ctx context.Context, find *FindMemoOrganizer)
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return memoOrganizer, nil
}
func (s *Store) DeleteMemoOrganizerV1(ctx context.Context, delete *DeleteMemoOrganizer) error {
func (s *Store) DeleteMemoOrganizer(ctx context.Context, delete *DeleteMemoOrganizer) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return FormatError(err)
return err
}
defer tx.Rollback()
@ -110,11 +121,12 @@ func (s *Store) DeleteMemoOrganizerV1(ctx context.Context, delete *DeleteMemoOrg
stmt := `DELETE FROM memo_organizer WHERE ` + strings.Join(where, " AND ")
_, err = tx.ExecContext(ctx, stmt, args...)
if err != nil {
return FormatError(err)
return err
}
if err := tx.Commit(); err != nil {
return FormatError(err)
// Prevent linter warning.
return err
}
return nil
@ -139,7 +151,7 @@ func vacuumMemoOrganizer(ctx context.Context, tx *sql.Tx) error {
)`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return FormatError(err)
return err
}
return nil

View File

@ -49,7 +49,7 @@ func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (*
type = EXCLUDED.type
RETURNING memo_id, related_memo_id, type
`
memoRelationMessage := &MemoRelation{}
memoRelation := &MemoRelation{}
if err := tx.QueryRowContext(
ctx,
query,
@ -57,16 +57,18 @@ func (s *Store) UpsertMemoRelation(ctx context.Context, create *MemoRelation) (*
create.RelatedMemoID,
create.Type,
).Scan(
&memoRelationMessage.MemoID,
&memoRelationMessage.RelatedMemoID,
&memoRelationMessage.Type,
&memoRelation.MemoID,
&memoRelation.RelatedMemoID,
&memoRelation.Type,
); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return memoRelationMessage, nil
return memoRelation, nil
}
func (s *Store) ListMemoRelations(ctx context.Context, find *FindMemoRelation) ([]*MemoRelation, error) {

View File

@ -33,7 +33,7 @@ type DeleteMemoResource struct {
func (s *Store) UpsertMemoResource(ctx context.Context, upsert *UpsertMemoResource) (*MemoResource, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
return nil, err
}
defer tx.Rollback()
@ -62,11 +62,11 @@ func (s *Store) UpsertMemoResource(ctx context.Context, upsert *UpsertMemoResour
&memoResource.CreatedTs,
&memoResource.UpdatedTs,
); err != nil {
return nil, FormatError(err)
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
return nil, err
}
return memoResource, nil
@ -117,7 +117,7 @@ func (s *Store) GetMemoResource(ctx context.Context, find *FindMemoResource) (*M
func (s *Store) DeleteMemoResource(ctx context.Context, delete *DeleteMemoResource) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return FormatError(err)
return err
}
defer tx.Rollback()
@ -133,11 +133,12 @@ func (s *Store) DeleteMemoResource(ctx context.Context, delete *DeleteMemoResour
stmt := `DELETE FROM memo_resource WHERE ` + strings.Join(where, " AND ")
_, err = tx.ExecContext(ctx, stmt, args...)
if err != nil {
return FormatError(err)
return err
}
if err := tx.Commit(); err != nil {
return FormatError(err)
// Prevent linter warning.
return err
}
return nil
@ -165,7 +166,7 @@ func listMemoResources(ctx context.Context, tx *sql.Tx, find *FindMemoResource)
`
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, FormatError(err)
return nil, err
}
defer rows.Close()
@ -178,7 +179,7 @@ func listMemoResources(ctx context.Context, tx *sql.Tx, find *FindMemoResource)
&memoResource.CreatedTs,
&memoResource.UpdatedTs,
); err != nil {
return nil, FormatError(err)
return nil, err
}
list = append(list, &memoResource)
@ -210,7 +211,7 @@ func vacuumMemoResource(ctx context.Context, tx *sql.Tx) error {
)`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return FormatError(err)
return err
}
return nil

View File

@ -51,7 +51,7 @@ type DeleteResource struct {
func (s *Store) CreateResource(ctx context.Context, create *Resource) (*Resource, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
return nil, err
}
defer tx.Rollback()
@ -85,7 +85,7 @@ func (s *Store) CreateResource(ctx context.Context, create *Resource) (*Resource
func (s *Store) ListResources(ctx context.Context, find *FindResource) ([]*Resource, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
return nil, err
}
defer tx.Rollback()
@ -108,7 +108,7 @@ func (s *Store) ListResources(ctx context.Context, find *FindResource) ([]*Resou
func (s *Store) GetResource(ctx context.Context, find *FindResource) (*Resource, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
return nil, err
}
defer tx.Rollback()
@ -131,7 +131,7 @@ func (s *Store) GetResource(ctx context.Context, find *FindResource) (*Resource,
func (s *Store) UpdateResource(ctx context.Context, update *UpdateResource) (*Resource, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
return nil, err
}
defer tx.Rollback()
@ -168,7 +168,7 @@ func (s *Store) UpdateResource(ctx context.Context, update *UpdateResource) (*Re
&resource.PublicID,
}
if err := tx.QueryRowContext(ctx, query, args...).Scan(dests...); err != nil {
return nil, FormatError(err)
return nil, err
}
if err := tx.Commit(); err != nil {
@ -181,7 +181,7 @@ func (s *Store) UpdateResource(ctx context.Context, update *UpdateResource) (*Re
func (s *Store) DeleteResource(ctx context.Context, delete *DeleteResource) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return FormatError(err)
return err
}
defer tx.Rollback()
@ -243,7 +243,7 @@ func listResources(ctx context.Context, tx *sql.Tx, find *FindResource) ([]*Reso
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, FormatError(err)
return nil, err
}
defer rows.Close()
@ -267,13 +267,13 @@ func listResources(ctx context.Context, tx *sql.Tx, find *FindResource) ([]*Reso
dests = append(dests, &resource.Blob)
}
if err := rows.Scan(dests...); err != nil {
return nil, FormatError(err)
return nil, err
}
list = append(list, &resource)
}
if err := rows.Err(); err != nil {
return nil, FormatError(err)
return nil, err
}
return list, nil
@ -292,7 +292,7 @@ func vacuumResource(ctx context.Context, tx *sql.Tx) error {
)`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return FormatError(err)
return err
}
return nil

View File

@ -15,7 +15,7 @@ type Store struct {
systemSettingCache sync.Map // map[string]*SystemSetting
userCache sync.Map // map[int]*User
userSettingCache sync.Map // map[string]*UserSetting
shortcutCache sync.Map // map[int]*shortcutRaw
shortcutCache sync.Map // map[int]*Shortcut
idpCache sync.Map // map[int]*IdentityProvider
}

View File

@ -65,17 +65,13 @@ type UpdateUser struct {
}
type FindUser struct {
ID *int
// Standard fields
ID *int
RowStatus *RowStatus
// Domain specific fields
Username *string
Role *Role
Email *string
Nickname *string
OpenID *string
Username *string
Role *Role
Email *string
Nickname *string
OpenID *string
}
type DeleteUser struct {

View File

@ -37,6 +37,7 @@ func TestIdentityProviderStore(t *testing.T) {
ID: &createdIDP.ID,
})
require.NoError(t, err)
require.NotNil(t, idp)
require.Equal(t, createdIDP, idp)
newName := "My GitHub OAuth"
updatedIdp, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProvider{

View File

@ -32,6 +32,7 @@ func TestMemoStore(t *testing.T) {
ID: &memo.ID,
})
require.NoError(t, err)
require.NotNil(t, memo)
memoList, err := ts.ListMemos(ctx, &store.FindMemo{
CreatorID: &user.ID,
})