refactor: migrate definition to api v1 (#1879)

* refactor: user api v1

* refactor: system setting to apiv1

* chore: remove unused definition

* chore: update

* chore: refactor: system setting

* chore: update

* refactor: migrate tag

* feat: migrate activity store

* refactor: migrate shortcut apiv1

* chore: update
This commit is contained in:
boojack
2023-07-02 18:56:25 +08:00
committed by GitHub
parent b84ecc4574
commit 66e65e4dc1
59 changed files with 1387 additions and 2608 deletions

View File

@ -1,137 +0,0 @@
package api
import "github.com/usememos/memos/server/profile"
// ActivityType is the type for an activity.
type ActivityType string
const (
// User related.
// ActivityUserCreate is the type for creating users.
ActivityUserCreate ActivityType = "user.create"
// ActivityUserUpdate is the type for updating users.
ActivityUserUpdate ActivityType = "user.update"
// ActivityUserDelete is the type for deleting users.
ActivityUserDelete ActivityType = "user.delete"
// ActivityUserAuthSignIn is the type for user signin.
ActivityUserAuthSignIn ActivityType = "user.auth.signin"
// ActivityUserAuthSignUp is the type for user signup.
ActivityUserAuthSignUp ActivityType = "user.auth.signup"
// ActivityUserSettingUpdate is the type for updating user settings.
ActivityUserSettingUpdate ActivityType = "user.setting.update"
// Memo related.
// ActivityMemoCreate is the type for creating memos.
ActivityMemoCreate ActivityType = "memo.create"
// ActivityMemoUpdate is the type for updating memos.
ActivityMemoUpdate ActivityType = "memo.update"
// ActivityMemoDelete is the type for deleting memos.
ActivityMemoDelete ActivityType = "memo.delete"
// Shortcut related.
// ActivityShortcutCreate is the type for creating shortcuts.
ActivityShortcutCreate ActivityType = "shortcut.create"
// ActivityShortcutUpdate is the type for updating shortcuts.
ActivityShortcutUpdate ActivityType = "shortcut.update"
// ActivityShortcutDelete is the type for deleting shortcuts.
ActivityShortcutDelete ActivityType = "shortcut.delete"
// Resource related.
// ActivityResourceCreate is the type for creating resources.
ActivityResourceCreate ActivityType = "resource.create"
// ActivityResourceDelete is the type for deleting resources.
ActivityResourceDelete ActivityType = "resource.delete"
// Tag related.
// ActivityTagCreate is the type for creating tags.
ActivityTagCreate ActivityType = "tag.create"
// ActivityTagDelete is the type for deleting tags.
ActivityTagDelete ActivityType = "tag.delete"
// Server related.
// ActivityServerStart is the type for starting server.
ActivityServerStart ActivityType = "server.start"
)
// ActivityLevel is the level of activities.
type ActivityLevel string
const (
// ActivityInfo is the INFO level of activities.
ActivityInfo ActivityLevel = "INFO"
// ActivityWarn is the WARN level of activities.
ActivityWarn ActivityLevel = "WARN"
// ActivityError is the ERROR level of activities.
ActivityError ActivityLevel = "ERROR"
)
type ActivityUserCreatePayload struct {
UserID int `json:"userId"`
Username string `json:"username"`
Role Role `json:"role"`
}
type ActivityUserAuthSignInPayload struct {
UserID int `json:"userId"`
IP string `json:"ip"`
}
type ActivityUserAuthSignUpPayload struct {
Username string `json:"username"`
IP string `json:"ip"`
}
type ActivityMemoCreatePayload struct {
Content string `json:"content"`
Visibility string `json:"visibility"`
}
type ActivityShortcutCreatePayload struct {
Title string `json:"title"`
Payload string `json:"payload"`
}
type ActivityResourceCreatePayload struct {
Filename string `json:"filename"`
Type string `json:"type"`
Size int64 `json:"size"`
}
type ActivityTagCreatePayload struct {
TagName string `json:"tagName"`
}
type ActivityServerStartPayload struct {
ServerID string `json:"serverId"`
Profile *profile.Profile `json:"profile"`
}
type Activity struct {
ID int `json:"id"`
// Standard fields
CreatorID int `json:"creatorId"`
CreatedTs int64 `json:"createdTs"`
// Domain specific fields
Type ActivityType `json:"type"`
Level ActivityLevel `json:"level"`
Payload string `json:"payload"`
}
// ActivityCreate is the API message for creating an activity.
type ActivityCreate struct {
// Standard fields
CreatorID int
// Domain specific fields
Type ActivityType `json:"type"`
Level ActivityLevel
Payload string `json:"payload"`
}

View File

@ -1,53 +0,0 @@
package api
type Shortcut struct {
ID int `json:"id"`
// Standard fields
RowStatus RowStatus `json:"rowStatus"`
CreatorID int `json:"creatorId"`
CreatedTs int64 `json:"createdTs"`
UpdatedTs int64 `json:"updatedTs"`
// Domain specific fields
Title string `json:"title"`
Payload string `json:"payload"`
}
type ShortcutCreate struct {
// Standard fields
CreatorID int `json:"-"`
// Domain specific fields
Title string `json:"title"`
Payload string `json:"payload"`
}
type ShortcutPatch struct {
ID int `json:"-"`
// Standard fields
UpdatedTs *int64
RowStatus *RowStatus `json:"rowStatus"`
// Domain specific fields
Title *string `json:"title"`
Payload *string `json:"payload"`
}
type ShortcutFind struct {
ID *int
// Standard fields
CreatorID *int
// Domain specific fields
Title *string `json:"title"`
}
type ShortcutDelete struct {
ID *int
// Standard fields
CreatorID *int
}

View File

@ -1,29 +0,0 @@
package api
import "github.com/usememos/memos/server/profile"
type SystemStatus struct {
Host *User `json:"host"`
Profile profile.Profile `json:"profile"`
DBSize int64 `json:"dbSize"`
// System settings
// Allow sign up.
AllowSignUp bool `json:"allowSignUp"`
// Disable public memos.
DisablePublicMemos bool `json:"disablePublicMemos"`
// Max upload size.
MaxUploadSizeMiB int `json:"maxUploadSizeMiB"`
// Additional style.
AdditionalStyle string `json:"additionalStyle"`
// Additional script.
AdditionalScript string `json:"additionalScript"`
// Customized server profile, including server name and external url.
CustomizedProfile CustomizedProfile `json:"customizedProfile"`
// Storage service ID.
StorageServiceID int `json:"storageServiceId"`
// Local storage path.
LocalStoragePath string `json:"localStoragePath"`
// Memo display with updated timestamp.
MemoDisplayWithUpdatedTs bool `json:"memoDisplayWithUpdatedTs"`
}

View File

@ -1,201 +0,0 @@
package api
import (
"encoding/json"
"fmt"
"strings"
"golang.org/x/exp/slices"
)
type SystemSettingName string
const (
// SystemSettingServerIDName is the name of server id.
SystemSettingServerIDName SystemSettingName = "server-id"
// SystemSettingSecretSessionName is the name of secret session.
SystemSettingSecretSessionName SystemSettingName = "secret-session"
// SystemSettingAllowSignUpName is the name of allow signup setting.
SystemSettingAllowSignUpName SystemSettingName = "allow-signup"
// SystemSettingDisablePublicMemosName is the name of disable public memos setting.
SystemSettingDisablePublicMemosName SystemSettingName = "disable-public-memos"
// SystemSettingMaxUploadSizeMiBName is the name of max upload size setting.
SystemSettingMaxUploadSizeMiBName SystemSettingName = "max-upload-size-mib"
// SystemSettingAdditionalStyleName is the name of additional style.
SystemSettingAdditionalStyleName SystemSettingName = "additional-style"
// SystemSettingAdditionalScriptName is the name of additional script.
SystemSettingAdditionalScriptName SystemSettingName = "additional-script"
// SystemSettingCustomizedProfileName is the name of customized server profile.
SystemSettingCustomizedProfileName SystemSettingName = "customized-profile"
// SystemSettingStorageServiceIDName is the name of storage service ID.
SystemSettingStorageServiceIDName SystemSettingName = "storage-service-id"
// SystemSettingLocalStoragePathName is the name of local storage path.
SystemSettingLocalStoragePathName SystemSettingName = "local-storage-path"
// SystemSettingOpenAIConfigName is the name of OpenAI config.
SystemSettingOpenAIConfigName SystemSettingName = "openai-config"
// SystemSettingTelegramBotToken is the name of Telegram Bot Token.
SystemSettingTelegramBotTokenName SystemSettingName = "telegram-bot-token"
SystemSettingMemoDisplayWithUpdatedTsName SystemSettingName = "memo-display-with-updated-ts"
)
// CustomizedProfile is the struct definition for SystemSettingCustomizedProfileName system setting item.
type CustomizedProfile struct {
// Name is the server name, default is `memos`
Name string `json:"name"`
// LogoURL is the url of logo image.
LogoURL string `json:"logoUrl"`
// Description is the server description.
Description string `json:"description"`
// Locale is the server default locale.
Locale string `json:"locale"`
// Appearance is the server default appearance.
Appearance string `json:"appearance"`
// ExternalURL is the external url of server. e.g. https://usermemos.com
ExternalURL string `json:"externalUrl"`
}
type OpenAIConfig struct {
Key string `json:"key"`
Host string `json:"host"`
}
func (key SystemSettingName) String() string {
switch key {
case SystemSettingServerIDName:
return "server-id"
case SystemSettingSecretSessionName:
return "secret-session"
case SystemSettingAllowSignUpName:
return "allow-signup"
case SystemSettingDisablePublicMemosName:
return "disable-public-memos"
case SystemSettingMaxUploadSizeMiBName:
return "max-upload-size-mib"
case SystemSettingAdditionalStyleName:
return "additional-style"
case SystemSettingAdditionalScriptName:
return "additional-script"
case SystemSettingCustomizedProfileName:
return "customized-profile"
case SystemSettingStorageServiceIDName:
return "storage-service-id"
case SystemSettingLocalStoragePathName:
return "local-storage-path"
case SystemSettingOpenAIConfigName:
return "openai-config"
case SystemSettingTelegramBotTokenName:
return "telegram-bot-token"
case SystemSettingMemoDisplayWithUpdatedTsName:
return "memo-display-with-updated-ts"
}
return ""
}
type SystemSetting struct {
Name SystemSettingName `json:"name"`
// Value is a JSON string with basic value.
Value string `json:"value"`
Description string `json:"description"`
}
type SystemSettingUpsert struct {
Name SystemSettingName `json:"name"`
Value string `json:"value"`
Description string `json:"description"`
}
const systemSettingUnmarshalError = `failed to unmarshal value from system setting "%v"`
func (upsert SystemSettingUpsert) Validate() error {
switch settingName := upsert.Name; settingName {
case SystemSettingServerIDName:
return fmt.Errorf("updating %v is not allowed", settingName)
case SystemSettingAllowSignUpName:
var value bool
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingDisablePublicMemosName:
var value bool
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingMaxUploadSizeMiBName:
var value int
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingAdditionalStyleName:
var value string
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingAdditionalScriptName:
var value string
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingCustomizedProfileName:
customizedProfile := CustomizedProfile{
Name: "memos",
LogoURL: "",
Description: "",
Locale: "en",
Appearance: "system",
ExternalURL: "",
}
if err := json.Unmarshal([]byte(upsert.Value), &customizedProfile); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
if !slices.Contains(UserSettingLocaleValue, customizedProfile.Locale) {
return fmt.Errorf(`invalid locale value for system setting "%v"`, settingName)
}
if !slices.Contains(UserSettingAppearanceValue, customizedProfile.Appearance) {
return fmt.Errorf(`invalid appearance value for system setting "%v"`, settingName)
}
case SystemSettingStorageServiceIDName:
value := DatabaseStorage
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
return nil
case SystemSettingLocalStoragePathName:
value := ""
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingOpenAIConfigName:
value := OpenAIConfig{}
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingTelegramBotTokenName:
if upsert.Value == "" {
return nil
}
// Bot Token with Reverse Proxy shoule like `http.../bot<token>`
if strings.HasPrefix(upsert.Value, "http") {
slashIndex := strings.LastIndexAny(upsert.Value, "/")
if strings.HasPrefix(upsert.Value[slashIndex:], "/bot") {
return nil
}
return fmt.Errorf("token start with `http` must end with `/bot<token>`")
}
fragments := strings.Split(upsert.Value, ":")
if len(fragments) != 2 {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingMemoDisplayWithUpdatedTsName:
var value bool
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
default:
return fmt.Errorf("invalid system setting name")
}
return nil
}
type SystemSettingFind struct {
Name SystemSettingName `json:"name"`
}

View File

@ -1,20 +0,0 @@
package api
type Tag struct {
Name string
CreatorID int
}
type TagUpsert struct {
Name string
CreatorID int `json:"-"`
}
type TagFind struct {
CreatorID int
}
type TagDelete struct {
Name string `json:"name"`
CreatorID int
}

View File

@ -1,158 +0,0 @@
package api
import (
"fmt"
"github.com/usememos/memos/common"
)
// Role is the type of a role.
type Role string
const (
// Host is the HOST role.
Host Role = "HOST"
// Admin is the ADMIN role.
Admin Role = "ADMIN"
// NormalUser is the USER role.
NormalUser Role = "USER"
)
func (e Role) String() string {
switch e {
case Host:
return "HOST"
case Admin:
return "ADMIN"
case NormalUser:
return "USER"
}
return "USER"
}
type User struct {
ID int `json:"id"`
// Standard fields
RowStatus RowStatus `json:"rowStatus"`
CreatedTs int64 `json:"createdTs"`
UpdatedTs int64 `json:"updatedTs"`
// Domain specific fields
Username string `json:"username"`
Role Role `json:"role"`
Email string `json:"email"`
Nickname string `json:"nickname"`
PasswordHash string `json:"-"`
OpenID string `json:"openId"`
AvatarURL string `json:"avatarUrl"`
UserSettingList []*UserSetting `json:"userSettingList"`
}
type UserCreate struct {
// Domain specific fields
Username string `json:"username"`
Role Role `json:"role"`
Email string `json:"email"`
Nickname string `json:"nickname"`
Password string `json:"password"`
PasswordHash string
OpenID string
}
func (create UserCreate) Validate() error {
if len(create.Username) < 3 {
return fmt.Errorf("username is too short, minimum length is 3")
}
if len(create.Username) > 32 {
return fmt.Errorf("username is too long, maximum length is 32")
}
if len(create.Password) < 3 {
return fmt.Errorf("password is too short, minimum length is 3")
}
if len(create.Password) > 512 {
return fmt.Errorf("password is too long, maximum length is 512")
}
if len(create.Nickname) > 64 {
return fmt.Errorf("nickname is too long, maximum length is 64")
}
if create.Email != "" {
if len(create.Email) > 256 {
return fmt.Errorf("email is too long, maximum length is 256")
}
if !common.ValidateEmail(create.Email) {
return fmt.Errorf("invalid email format")
}
}
return nil
}
type UserPatch struct {
ID int `json:"-"`
// Standard fields
UpdatedTs *int64
RowStatus *RowStatus `json:"rowStatus"`
// Domain specific fields
Username *string `json:"username"`
Email *string `json:"email"`
Nickname *string `json:"nickname"`
Password *string `json:"password"`
ResetOpenID *bool `json:"resetOpenId"`
AvatarURL *string `json:"avatarUrl"`
PasswordHash *string
OpenID *string
}
func (patch UserPatch) Validate() error {
if patch.Username != nil && len(*patch.Username) < 3 {
return fmt.Errorf("username is too short, minimum length is 3")
}
if patch.Username != nil && len(*patch.Username) > 32 {
return fmt.Errorf("username is too long, maximum length is 32")
}
if patch.Password != nil && len(*patch.Password) < 3 {
return fmt.Errorf("password is too short, minimum length is 3")
}
if patch.Password != nil && len(*patch.Password) > 512 {
return fmt.Errorf("password is too long, maximum length is 512")
}
if patch.Nickname != nil && len(*patch.Nickname) > 64 {
return fmt.Errorf("nickname is too long, maximum length is 64")
}
if patch.AvatarURL != nil {
if len(*patch.AvatarURL) > 2<<20 {
return fmt.Errorf("avatar is too large, maximum is 2MB")
}
}
if patch.Email != nil && *patch.Email != "" {
if len(*patch.Email) > 256 {
return fmt.Errorf("email is too long, maximum length is 256")
}
if !common.ValidateEmail(*patch.Email) {
return fmt.Errorf("invalid email format")
}
}
return nil
}
type UserFind struct {
ID *int `json:"id"`
// Standard fields
RowStatus *RowStatus `json:"rowStatus"`
// Domain specific fields
Username *string `json:"username"`
Role *Role
Email *string `json:"email"`
Nickname *string `json:"nickname"`
OpenID *string
}
type UserDelete struct {
ID int
}

View File

@ -1,134 +0,0 @@
package api
import (
"encoding/json"
"fmt"
"strconv"
"golang.org/x/exp/slices"
)
type UserSettingKey string
const (
// UserSettingLocaleKey is the key type for user locale.
UserSettingLocaleKey UserSettingKey = "locale"
// UserSettingAppearanceKey is the key type for user appearance.
UserSettingAppearanceKey UserSettingKey = "appearance"
// UserSettingMemoVisibilityKey is the key type for user preference memo default visibility.
UserSettingMemoVisibilityKey UserSettingKey = "memo-visibility"
// UserSettingTelegramUserID is the key type for telegram UserID of memos user.
UserSettingTelegramUserIDKey UserSettingKey = "telegram-user-id"
)
// String returns the string format of UserSettingKey type.
func (key UserSettingKey) String() string {
switch key {
case UserSettingLocaleKey:
return "locale"
case UserSettingAppearanceKey:
return "appearance"
case UserSettingMemoVisibilityKey:
return "memo-visibility"
case UserSettingTelegramUserIDKey:
return "telegram-user-id"
}
return ""
}
var (
UserSettingLocaleValue = []string{
"de",
"en",
"es",
"fr",
"hr",
"it",
"ja",
"ko",
"nl",
"pl",
"pt-BR",
"ru",
"sl",
"sv",
"tr",
"uk",
"vi",
"zh-Hans",
"zh-Hant",
}
UserSettingAppearanceValue = []string{"system", "light", "dark"}
UserSettingMemoVisibilityValue = []Visibility{Private, Protected, Public}
)
type UserSetting struct {
UserID int
Key UserSettingKey `json:"key"`
// Value is a JSON string with basic value
Value string `json:"value"`
}
type UserSettingUpsert struct {
UserID int `json:"-"`
Key UserSettingKey `json:"key"`
Value string `json:"value"`
}
func (upsert UserSettingUpsert) Validate() error {
if upsert.Key == UserSettingLocaleKey {
localeValue := "en"
err := json.Unmarshal([]byte(upsert.Value), &localeValue)
if err != nil {
return fmt.Errorf("failed to unmarshal user setting locale value")
}
if !slices.Contains(UserSettingLocaleValue, localeValue) {
return fmt.Errorf("invalid user setting locale value")
}
} else if upsert.Key == UserSettingAppearanceKey {
appearanceValue := "system"
err := json.Unmarshal([]byte(upsert.Value), &appearanceValue)
if err != nil {
return fmt.Errorf("failed to unmarshal user setting appearance value")
}
if !slices.Contains(UserSettingAppearanceValue, appearanceValue) {
return fmt.Errorf("invalid user setting appearance value")
}
} else if upsert.Key == UserSettingMemoVisibilityKey {
memoVisibilityValue := Private
err := json.Unmarshal([]byte(upsert.Value), &memoVisibilityValue)
if err != nil {
return fmt.Errorf("failed to unmarshal user setting memo visibility value")
}
if !slices.Contains(UserSettingMemoVisibilityValue, memoVisibilityValue) {
return fmt.Errorf("invalid user setting memo visibility value")
}
} else if upsert.Key == UserSettingTelegramUserIDKey {
var s string
err := json.Unmarshal([]byte(upsert.Value), &s)
if err != nil {
return fmt.Errorf("invalid user setting telegram user id value")
}
if s == "" {
return nil
}
if _, err := strconv.Atoi(s); err != nil {
return fmt.Errorf("invalid user setting telegram user id value")
}
} else {
return fmt.Errorf("invalid user setting key")
}
return nil
}
type UserSettingFind struct {
UserID *int
Key UserSettingKey `json:"key"`
}
type UserSettingDelete struct {
UserID int
}

View File

@ -59,6 +59,10 @@ const (
ActivityServerStart ActivityType = "server.start" ActivityServerStart ActivityType = "server.start"
) )
func (t ActivityType) String() string {
return string(t)
}
// ActivityLevel is the level of activities. // ActivityLevel is the level of activities.
type ActivityLevel string type ActivityLevel string
@ -71,6 +75,10 @@ const (
ActivityError ActivityLevel = "ERROR" ActivityError ActivityLevel = "ERROR"
) )
func (l ActivityLevel) String() string {
return string(l)
}
type ActivityUserCreatePayload struct { type ActivityUserCreatePayload struct {
UserID int `json:"userId"` UserID int `json:"userId"`
Username string `json:"username"` Username string `json:"username"`

View File

@ -85,7 +85,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
} }
var userInfo *idp.IdentityProviderUserInfo var userInfo *idp.IdentityProviderUserInfo
if identityProvider.Type == store.IdentityProviderOAuth2 { if identityProvider.Type == store.IdentityProviderOAuth2Type {
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.OAuth2Config) oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.OAuth2Config)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider instance").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider instance").SetInternal(err)
@ -121,7 +121,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
userCreate := &store.User{ userCreate := &store.User{
Username: userInfo.Identifier, Username: userInfo.Identifier,
// The new signup user should be normal user by default. // The new signup user should be normal user by default.
Role: store.NormalUser, Role: store.RoleUser,
Nickname: userInfo.DisplayName, Nickname: userInfo.DisplayName,
Email: userInfo.Email, Email: userInfo.Email,
OpenID: common.GenUUID(), OpenID: common.GenUUID(),
@ -135,7 +135,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err)
} }
userCreate.PasswordHash = string(passwordHash) userCreate.PasswordHash = string(passwordHash)
user, err = s.Store.CreateUserV1(ctx, userCreate) user, err = s.Store.CreateUser(ctx, userCreate)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
} }
@ -160,7 +160,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signup request").SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signup request").SetInternal(err)
} }
hostUserType := store.Host hostUserType := store.RoleHost
existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{ existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{
Role: &hostUserType, Role: &hostUserType,
}) })
@ -171,13 +171,13 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
userCreate := &store.User{ userCreate := &store.User{
Username: signup.Username, Username: signup.Username,
// The new signup user should be normal user by default. // The new signup user should be normal user by default.
Role: store.NormalUser, Role: store.RoleUser,
Nickname: signup.Username, Nickname: signup.Username,
OpenID: common.GenUUID(), OpenID: common.GenUUID(),
} }
if len(existedHostUsers) == 0 { if len(existedHostUsers) == 0 {
// Change the default role to host if there is no host user. // Change the default role to host if there is no host user.
userCreate.Role = store.Host userCreate.Role = store.RoleHost
} else { } else {
allowSignUpSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{ allowSignUpSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
Name: SystemSettingAllowSignUpName.String(), Name: SystemSettingAllowSignUpName.String(),
@ -204,7 +204,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
} }
userCreate.PasswordHash = string(passwordHash) userCreate.PasswordHash = string(passwordHash)
user, err := s.Store.CreateUserV1(ctx, userCreate) user, err := s.Store.CreateUser(ctx, userCreate)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
} }
@ -234,7 +234,7 @@ func (s *APIV1Service) createAuthSignInActivity(c echo.Context, user *store.User
if err != nil { if err != nil {
return errors.Wrap(err, "failed to marshal activity payload") return errors.Wrap(err, "failed to marshal activity payload")
} }
activity, err := s.Store.CreateActivityV1(ctx, &store.ActivityMessage{ activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{
CreatorID: user.ID, CreatorID: user.ID,
Type: string(ActivityUserAuthSignIn), Type: string(ActivityUserAuthSignIn),
Level: string(ActivityInfo), Level: string(ActivityInfo),
@ -256,7 +256,7 @@ func (s *APIV1Service) createAuthSignUpActivity(c echo.Context, user *store.User
if err != nil { if err != nil {
return errors.Wrap(err, "failed to marshal activity payload") return errors.Wrap(err, "failed to marshal activity payload")
} }
activity, err := s.Store.CreateActivityV1(ctx, &store.ActivityMessage{ activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{
CreatorID: user.ID, CreatorID: user.ID,
Type: string(ActivityUserAuthSignUp), Type: string(ActivityUserAuthSignUp),
Level: string(ActivityInfo), Level: string(ActivityInfo),

View File

@ -13,12 +13,6 @@ const (
Archived RowStatus = "ARCHIVED" Archived RowStatus = "ARCHIVED"
) )
func (e RowStatus) String() string { func (r RowStatus) String() string {
switch e { return string(r)
case Normal:
return "NORMAL"
case Archived:
return "ARCHIVED"
}
return ""
} }

View File

@ -14,9 +14,13 @@ import (
type IdentityProviderType string type IdentityProviderType string
const ( const (
IdentityProviderOAuth2 IdentityProviderType = "OAUTH2" IdentityProviderOAuth2Type IdentityProviderType = "OAUTH2"
) )
func (t IdentityProviderType) String() string {
return string(t)
}
type IdentityProviderConfig struct { type IdentityProviderConfig struct {
OAuth2Config *IdentityProviderOAuth2Config `json:"oauth2Config"` OAuth2Config *IdentityProviderOAuth2Config `json:"oauth2Config"`
} }
@ -53,7 +57,7 @@ type CreateIdentityProviderRequest struct {
} }
type UpdateIdentityProviderRequest struct { type UpdateIdentityProviderRequest struct {
ID int ID int `json:"-"`
Type IdentityProviderType `json:"type"` Type IdentityProviderType `json:"type"`
Name *string `json:"name"` Name *string `json:"name"`
IdentifierFilter *string `json:"identifierFilter"` IdentifierFilter *string `json:"identifierFilter"`
@ -74,7 +78,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != store.Host { if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
@ -108,7 +112,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != store.Host { if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
@ -153,7 +157,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role == store.Host { if user == nil || user.Role == store.RoleHost {
isHostUser = true isHostUser = true
} }
} }
@ -183,7 +187,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != store.Host { if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
@ -217,7 +221,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != store.Host { if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }

View File

@ -82,7 +82,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
} }
// Skip validation for server status endpoints. // Skip validation for server status endpoints.
if common.HasPrefixes(path, "/api/ping", "/api/v1/idp", "/api/user/:id") && method == http.MethodGet { if common.HasPrefixes(path, "/api/v1/ping", "/api/v1/idp", "/api/user/:id") && method == http.MethodGet {
return next(c) return next(c)
} }
@ -93,7 +93,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
return next(c) return next(c)
} }
// When the request is not authenticated, we allow the user to access the memo endpoints for those public memos. // When the request is not authenticated, we allow the user to access the memo endpoints for those public memos.
if common.HasPrefixes(path, "/api/status", "/api/memo") && method == http.MethodGet { if common.HasPrefixes(path, "/api/v1/status", "/api/memo") && method == http.MethodGet {
return next(c) return next(c)
} }
return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token")

View File

@ -13,13 +13,5 @@ const (
) )
func (v Visibility) String() string { func (v Visibility) String() string {
switch v { return string(v)
case Public:
return "PUBLIC"
case Protected:
return "PROTECTED"
case Private:
return "PRIVATE"
}
return "PRIVATE"
} }

View File

@ -1,4 +1,4 @@
package server package v1
import ( import (
"encoding/json" "encoding/json"
@ -7,34 +7,79 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/pkg/errors"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/store"
) )
func (s *Server) registerShortcutRoutes(g *echo.Group) { type Shortcut struct {
ID int `json:"id"`
// Standard fields
RowStatus RowStatus `json:"rowStatus"`
CreatorID int `json:"creatorId"`
CreatedTs int64 `json:"createdTs"`
UpdatedTs int64 `json:"updatedTs"`
// Domain specific fields
Title string `json:"title"`
Payload string `json:"payload"`
}
type CreateShortcutRequest struct {
Title string `json:"title"`
Payload string `json:"payload"`
}
type UpdateShortcutRequest struct {
RowStatus *RowStatus `json:"rowStatus"`
Title *string `json:"title"`
Payload *string `json:"payload"`
}
type ShortcutFind struct {
ID *int
// Standard fields
CreatorID *int
// Domain specific fields
Title *string `json:"title"`
}
type ShortcutDelete struct {
ID *int
// Standard fields
CreatorID *int
}
func (s *APIV1Service) registerShortcutRoutes(g *echo.Group) {
g.POST("/shortcut", func(c echo.Context) error { g.POST("/shortcut", func(c echo.Context) error {
ctx := c.Request().Context() ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int) userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok { if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
shortcutCreate := &api.ShortcutCreate{} shortcutCreate := &CreateShortcutRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(shortcutCreate); err != nil { if err := json.NewDecoder(c.Request().Body).Decode(shortcutCreate); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post shortcut request").SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post shortcut request").SetInternal(err)
} }
shortcutCreate.CreatorID = userID shortcut, err := s.Store.CreateShortcut(ctx, &store.Shortcut{
shortcut, err := s.Store.CreateShortcut(ctx, shortcutCreate) CreatorID: userID,
Title: shortcutCreate.Title,
Payload: shortcutCreate.Payload,
})
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create shortcut").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create shortcut").SetInternal(err)
} }
if err := s.createShortcutCreateActivity(c, shortcut); err != nil {
shortcutMessage := convertShortcutFromStore(shortcut)
if err := s.createShortcutCreateActivity(c, shortcutMessage); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
} }
return c.JSON(http.StatusOK, composeResponse(shortcut)) return c.JSON(http.StatusOK, shortcutMessage)
}) })
g.PATCH("/shortcut/:shortcutId", func(c echo.Context) error { g.PATCH("/shortcut/:shortcutId", func(c echo.Context) error {
@ -48,10 +93,9 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err)
} }
shortcutFind := &api.ShortcutFind{ shortcut, err := s.Store.GetShortcut(ctx, &store.FindShortcut{
ID: &shortcutID, ID: &shortcutID,
} })
shortcut, err := s.Store.FindShortcut(ctx, shortcutFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find shortcut").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find shortcut").SetInternal(err)
} }
@ -59,20 +103,32 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
currentTs := time.Now().Unix() request := &UpdateShortcutRequest{}
shortcutPatch := &api.ShortcutPatch{ if err := json.NewDecoder(c.Request().Body).Decode(request); err != nil {
UpdatedTs: &currentTs,
}
if err := json.NewDecoder(c.Request().Body).Decode(shortcutPatch); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch shortcut request").SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch shortcut request").SetInternal(err)
} }
shortcutPatch.ID = shortcutID currentTs := time.Now().Unix()
shortcut, err = s.Store.PatchShortcut(ctx, shortcutPatch) shortcutUpdate := &store.UpdateShortcut{
ID: shortcutID,
UpdatedTs: &currentTs,
}
if request.RowStatus != nil {
rowStatus := store.RowStatus(*request.RowStatus)
shortcutUpdate.RowStatus = &rowStatus
}
if request.Title != nil {
shortcutUpdate.Title = request.Title
}
if request.Payload != nil {
shortcutUpdate.Payload = request.Payload
}
shortcut, err = s.Store.UpdateShortcut(ctx, shortcutUpdate)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch shortcut").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch shortcut").SetInternal(err)
} }
return c.JSON(http.StatusOK, composeResponse(shortcut)) return c.JSON(http.StatusOK, convertShortcutFromStore(shortcut))
}) })
g.GET("/shortcut", func(c echo.Context) error { g.GET("/shortcut", func(c echo.Context) error {
@ -82,14 +138,17 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find shortcut") return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find shortcut")
} }
shortcutFind := &api.ShortcutFind{ list, err := s.Store.ListShortcuts(ctx, &store.FindShortcut{
CreatorID: &userID, CreatorID: &userID,
} })
list, err := s.Store.FindShortcutList(ctx, shortcutFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch shortcut list").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get shortcut list").SetInternal(err)
} }
return c.JSON(http.StatusOK, composeResponse(list)) shortcutMessageList := make([]*Shortcut, 0, len(list))
for _, shortcut := range list {
shortcutMessageList = append(shortcutMessageList, convertShortcutFromStore(shortcut))
}
return c.JSON(http.StatusOK, shortcutMessageList)
}) })
g.GET("/shortcut/:shortcutId", func(c echo.Context) error { g.GET("/shortcut/:shortcutId", func(c echo.Context) error {
@ -99,14 +158,16 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err)
} }
shortcutFind := &api.ShortcutFind{ shortcut, err := s.Store.GetShortcut(ctx, &store.FindShortcut{
ID: &shortcutID, ID: &shortcutID,
} })
shortcut, err := s.Store.FindShortcut(ctx, shortcutFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to fetch shortcut by ID %d", *shortcutFind.ID)).SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to fetch shortcut by ID %d", shortcutID)).SetInternal(err)
} }
return c.JSON(http.StatusOK, composeResponse(shortcut)) if shortcut == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Shortcut by ID %d not found", shortcutID))
}
return c.JSON(http.StatusOK, convertShortcutFromStore(shortcut))
}) })
g.DELETE("/shortcut/:shortcutId", func(c echo.Context) error { g.DELETE("/shortcut/:shortcutId", func(c echo.Context) error {
@ -120,10 +181,9 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err)
} }
shortcutFind := &api.ShortcutFind{ shortcut, err := s.Store.GetShortcut(ctx, &store.FindShortcut{
ID: &shortcutID, ID: &shortcutID,
} })
shortcut, err := s.Store.FindShortcut(ctx, shortcutFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find shortcut").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find shortcut").SetInternal(err)
} }
@ -131,22 +191,18 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
shortcutDelete := &api.ShortcutDelete{ if err := s.Store.DeleteShortcut(ctx, &store.DeleteShortcut{
ID: &shortcutID, ID: &shortcutID,
} }); err != nil {
if err := s.Store.DeleteShortcut(ctx, shortcutDelete); err != nil {
if common.ErrorCode(err) == common.NotFound {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Shortcut ID not found: %d", shortcutID))
}
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete shortcut").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete shortcut").SetInternal(err)
} }
return c.JSON(http.StatusOK, true) return c.JSON(http.StatusOK, true)
}) })
} }
func (s *Server) createShortcutCreateActivity(c echo.Context, shortcut *api.Shortcut) error { func (s *APIV1Service) createShortcutCreateActivity(c echo.Context, shortcut *Shortcut) error {
ctx := c.Request().Context() ctx := c.Request().Context()
payload := api.ActivityShortcutCreatePayload{ payload := ActivityShortcutCreatePayload{
Title: shortcut.Title, Title: shortcut.Title,
Payload: shortcut.Payload, Payload: shortcut.Payload,
} }
@ -154,10 +210,10 @@ func (s *Server) createShortcutCreateActivity(c echo.Context, shortcut *api.Shor
if err != nil { if err != nil {
return errors.Wrap(err, "failed to marshal activity payload") return errors.Wrap(err, "failed to marshal activity payload")
} }
activity, err := s.Store.CreateActivity(ctx, &api.ActivityCreate{ activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{
CreatorID: shortcut.CreatorID, CreatorID: shortcut.CreatorID,
Type: api.ActivityShortcutCreate, Type: ActivityShortcutCreate.String(),
Level: api.ActivityInfo, Level: ActivityInfo.String(),
Payload: string(payloadBytes), Payload: string(payloadBytes),
}) })
if err != nil || activity == nil { if err != nil || activity == nil {
@ -165,3 +221,15 @@ func (s *Server) createShortcutCreateActivity(c echo.Context, shortcut *api.Shor
} }
return err return err
} }
func convertShortcutFromStore(shortcut *store.Shortcut) *Shortcut {
return &Shortcut{
ID: shortcut.ID,
RowStatus: RowStatus(shortcut.RowStatus),
CreatorID: shortcut.CreatorID,
Title: shortcut.Title,
Payload: shortcut.Payload,
CreatedTs: shortcut.CreatedTs,
UpdatedTs: shortcut.UpdatedTs,
}
}

8
api/v1/storage.go Normal file
View File

@ -0,0 +1,8 @@
package v1
const (
// LocalStorage means the storage service is local file system.
LocalStorage = -1
// DatabaseStorage means the storage service is database.
DatabaseStorage = 0
)

169
api/v1/system.go Normal file
View File

@ -0,0 +1,169 @@
package v1
import (
"encoding/json"
"net/http"
"os"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/common/log"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store"
"go.uber.org/zap"
)
type SystemStatus struct {
Host *User `json:"host"`
Profile profile.Profile `json:"profile"`
DBSize int64 `json:"dbSize"`
// System settings
// Allow sign up.
AllowSignUp bool `json:"allowSignUp"`
// Disable public memos.
DisablePublicMemos bool `json:"disablePublicMemos"`
// Max upload size.
MaxUploadSizeMiB int `json:"maxUploadSizeMiB"`
// Additional style.
AdditionalStyle string `json:"additionalStyle"`
// Additional script.
AdditionalScript string `json:"additionalScript"`
// Customized server profile, including server name and external url.
CustomizedProfile CustomizedProfile `json:"customizedProfile"`
// Storage service ID.
StorageServiceID int `json:"storageServiceId"`
// Local storage path.
LocalStoragePath string `json:"localStoragePath"`
// Memo display with updated timestamp.
MemoDisplayWithUpdatedTs bool `json:"memoDisplayWithUpdatedTs"`
}
func (s *APIV1Service) registerSystemRoutes(g *echo.Group) {
g.GET("/ping", func(c echo.Context) error {
return c.JSON(http.StatusOK, s.Profile)
})
g.GET("/status", func(c echo.Context) error {
ctx := c.Request().Context()
systemStatus := SystemStatus{
Profile: *s.Profile,
DBSize: 0,
AllowSignUp: false,
DisablePublicMemos: false,
MaxUploadSizeMiB: 32,
AdditionalStyle: "",
AdditionalScript: "",
CustomizedProfile: CustomizedProfile{
Name: "memos",
LogoURL: "",
Description: "",
Locale: "en",
Appearance: "system",
ExternalURL: "",
},
StorageServiceID: DatabaseStorage,
LocalStoragePath: "assets/{timestamp}_{filename}",
MemoDisplayWithUpdatedTs: false,
}
hostUserType := store.RoleHost
hostUser, err := s.Store.GetUser(ctx, &store.FindUser{
Role: &hostUserType,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find host user").SetInternal(err)
}
if hostUser != nil {
// data desensitize
hostUser.OpenID = ""
hostUser.Email = ""
systemStatus.Host = converUserFromStore(hostUser)
}
systemSettingList, err := s.Store.ListSystemSettings(ctx, &store.FindSystemSetting{})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting list").SetInternal(err)
}
for _, systemSetting := range systemSettingList {
if systemSetting.Name == SystemSettingServerIDName.String() || systemSetting.Name == SystemSettingSecretSessionName.String() || systemSetting.Name == SystemSettingTelegramBotTokenName.String() {
continue
}
var baseValue any
err := json.Unmarshal([]byte(systemSetting.Value), &baseValue)
if err != nil {
log.Warn("Failed to unmarshal system setting value", zap.String("setting name", systemSetting.Name))
continue
}
switch systemSetting.Name {
case SystemSettingAllowSignUpName.String():
systemStatus.AllowSignUp = baseValue.(bool)
case SystemSettingDisablePublicMemosName.String():
systemStatus.DisablePublicMemos = baseValue.(bool)
case SystemSettingMaxUploadSizeMiBName.String():
systemStatus.MaxUploadSizeMiB = int(baseValue.(float64))
case SystemSettingAdditionalStyleName.String():
systemStatus.AdditionalStyle = baseValue.(string)
case SystemSettingAdditionalScriptName.String():
systemStatus.AdditionalScript = baseValue.(string)
case SystemSettingCustomizedProfileName.String():
customizedProfile := CustomizedProfile{}
if err := json.Unmarshal([]byte(systemSetting.Value), &customizedProfile); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting customized profile value").SetInternal(err)
}
systemStatus.CustomizedProfile = customizedProfile
case SystemSettingStorageServiceIDName.String():
systemStatus.StorageServiceID = int(baseValue.(float64))
case SystemSettingLocalStoragePathName.String():
systemStatus.LocalStoragePath = baseValue.(string)
case SystemSettingMemoDisplayWithUpdatedTsName.String():
systemStatus.MemoDisplayWithUpdatedTs = baseValue.(bool)
default:
log.Warn("Unknown system setting name", zap.String("setting name", systemSetting.Name))
}
}
userID, ok := c.Get(getUserIDContextKey()).(int)
// Get database size for host user.
if ok {
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user != nil && user.Role == store.RoleHost {
fi, err := os.Stat(s.Profile.DSN)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to read database fileinfo").SetInternal(err)
}
systemStatus.DBSize = fi.Size()
}
}
return c.JSON(http.StatusOK, systemStatus)
})
g.POST("/system/vacuum", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
if err := s.Store.Vacuum(ctx); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to vacuum database").SetInternal(err)
}
return c.JSON(http.StatusOK, true)
})
}

View File

@ -3,7 +3,11 @@ package v1
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"strings" "strings"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/store"
) )
type SystemSettingName string type SystemSettingName string
@ -29,10 +33,9 @@ const (
SystemSettingStorageServiceIDName SystemSettingName = "storage-service-id" SystemSettingStorageServiceIDName SystemSettingName = "storage-service-id"
// SystemSettingLocalStoragePathName is the name of local storage path. // SystemSettingLocalStoragePathName is the name of local storage path.
SystemSettingLocalStoragePathName SystemSettingName = "local-storage-path" SystemSettingLocalStoragePathName SystemSettingName = "local-storage-path"
// SystemSettingOpenAIConfigName is the name of OpenAI config.
SystemSettingOpenAIConfigName SystemSettingName = "openai-config"
// SystemSettingTelegramBotToken is the name of Telegram Bot Token. // SystemSettingTelegramBotToken is the name of Telegram Bot Token.
SystemSettingTelegramBotTokenName SystemSettingName = "telegram-bot-token" SystemSettingTelegramBotTokenName SystemSettingName = "telegram-bot-token"
// SystemSettingMemoDisplayWithUpdatedTsName is the name of memo display with updated ts.
SystemSettingMemoDisplayWithUpdatedTsName SystemSettingName = "memo-display-with-updated-ts" SystemSettingMemoDisplayWithUpdatedTsName SystemSettingName = "memo-display-with-updated-ts"
) )
@ -52,41 +55,8 @@ type CustomizedProfile struct {
ExternalURL string `json:"externalUrl"` ExternalURL string `json:"externalUrl"`
} }
type OpenAIConfig struct {
Key string `json:"key"`
Host string `json:"host"`
}
func (key SystemSettingName) String() string { func (key SystemSettingName) String() string {
switch key { return string(key)
case SystemSettingServerIDName:
return "server-id"
case SystemSettingSecretSessionName:
return "secret-session"
case SystemSettingAllowSignUpName:
return "allow-signup"
case SystemSettingDisablePublicMemosName:
return "disable-public-memos"
case SystemSettingMaxUploadSizeMiBName:
return "max-upload-size-mib"
case SystemSettingAdditionalStyleName:
return "additional-style"
case SystemSettingAdditionalScriptName:
return "additional-script"
case SystemSettingCustomizedProfileName:
return "customized-profile"
case SystemSettingStorageServiceIDName:
return "storage-service-id"
case SystemSettingLocalStoragePathName:
return "local-storage-path"
case SystemSettingOpenAIConfigName:
return "openai-config"
case SystemSettingTelegramBotTokenName:
return "telegram-bot-token"
case SystemSettingMemoDisplayWithUpdatedTsName:
return "memo-display-with-updated-ts"
}
return ""
} }
type SystemSetting struct { type SystemSetting struct {
@ -96,7 +66,7 @@ type SystemSetting struct {
Description string `json:"description"` Description string `json:"description"`
} }
type SystemSettingUpsert struct { type UpsertSystemSettingRequest struct {
Name SystemSettingName `json:"name"` Name SystemSettingName `json:"name"`
Value string `json:"value"` Value string `json:"value"`
Description string `json:"description"` Description string `json:"description"`
@ -104,7 +74,7 @@ type SystemSettingUpsert struct {
const systemSettingUnmarshalError = `failed to unmarshal value from system setting "%v"` const systemSettingUnmarshalError = `failed to unmarshal value from system setting "%v"`
func (upsert SystemSettingUpsert) Validate() error { func (upsert UpsertSystemSettingRequest) Validate() error {
switch settingName := upsert.Name; settingName { switch settingName := upsert.Name; settingName {
case SystemSettingServerIDName: case SystemSettingServerIDName:
return fmt.Errorf("updating %v is not allowed", settingName) return fmt.Errorf("updating %v is not allowed", settingName)
@ -157,11 +127,6 @@ func (upsert SystemSettingUpsert) Validate() error {
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil { if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName) return fmt.Errorf(systemSettingUnmarshalError, settingName)
} }
case SystemSettingOpenAIConfigName:
value := OpenAIConfig{}
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingTelegramBotTokenName: case SystemSettingTelegramBotTokenName:
if upsert.Value == "" { if upsert.Value == "" {
return nil return nil
@ -189,6 +154,77 @@ func (upsert SystemSettingUpsert) Validate() error {
return nil return nil
} }
type SystemSettingFind struct { func (s *APIV1Service) registerSystemSettingRoutes(g *echo.Group) {
Name SystemSettingName `json:"name"` g.POST("/system/setting", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
systemSettingUpsert := &UpsertSystemSettingRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(systemSettingUpsert); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post system setting request").SetInternal(err)
}
if err := systemSettingUpsert.Validate(); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "invalid system setting").SetInternal(err)
}
systemSetting, err := s.Store.UpsertSystemSetting(ctx, &store.SystemSetting{
Name: systemSettingUpsert.Name.String(),
Value: systemSettingUpsert.Value,
Description: systemSettingUpsert.Description,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert system setting").SetInternal(err)
}
return c.JSON(http.StatusOK, convertSystemSettingFromStore(systemSetting))
})
g.GET("/system/setting", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
list, err := s.Store.ListSystemSettings(ctx, &store.FindSystemSetting{})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting list").SetInternal(err)
}
systemSettingList := make([]*SystemSetting, 0, len(list))
for _, systemSetting := range list {
systemSettingList = append(systemSettingList, convertSystemSettingFromStore(systemSetting))
}
return c.JSON(http.StatusOK, systemSettingList)
})
}
func convertSystemSettingFromStore(systemSetting *store.SystemSetting) *SystemSetting {
return &SystemSetting{
Name: SystemSettingName(systemSetting.Name),
Value: systemSetting.Value,
Description: systemSetting.Description,
}
} }

View File

@ -1,4 +1,4 @@
package server package v1
import ( import (
"encoding/json" "encoding/json"
@ -7,16 +7,26 @@ import (
"regexp" "regexp"
"sort" "sort"
"github.com/labstack/echo/v4"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"github.com/labstack/echo/v4"
) )
func (s *Server) registerTagRoutes(g *echo.Group) { type Tag struct {
Name string
CreatorID int
}
type UpsertTagRequest struct {
Name string `json:"name"`
}
type DeleteTagRequest struct {
Name string `json:"name"`
}
func (s *APIV1Service) registerTagRoutes(g *echo.Group) {
g.POST("/tag", func(c echo.Context) error { g.POST("/tag", func(c echo.Context) error {
ctx := c.Request().Context() ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int) userID, ok := c.Get(getUserIDContextKey()).(int)
@ -24,7 +34,7 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
tagUpsert := &api.TagUpsert{} tagUpsert := &UpsertTagRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(tagUpsert); err != nil { if err := json.NewDecoder(c.Request().Body).Decode(tagUpsert); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post tag request").SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post tag request").SetInternal(err)
} }
@ -32,15 +42,18 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Tag name shouldn't be empty") return echo.NewHTTPError(http.StatusBadRequest, "Tag name shouldn't be empty")
} }
tagUpsert.CreatorID = userID tag, err := s.Store.UpsertTagV1(ctx, &store.Tag{
tag, err := s.Store.UpsertTag(ctx, tagUpsert) Name: tagUpsert.Name,
CreatorID: userID,
})
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert tag").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert tag").SetInternal(err)
} }
if err := s.createTagCreateActivity(c, tag); err != nil { tagMessage := convertTagFromStore(tag)
if err := s.createTagCreateActivity(c, tagMessage); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
} }
return c.JSON(http.StatusOK, composeResponse(tag.Name)) return c.JSON(http.StatusOK, tagMessage.Name)
}) })
g.GET("/tag", func(c echo.Context) error { g.GET("/tag", func(c echo.Context) error {
@ -50,19 +63,18 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find tag") return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find tag")
} }
tagFind := &api.TagFind{ list, err := s.Store.ListTags(ctx, &store.FindTag{
CreatorID: userID, CreatorID: userID,
} })
tagList, err := s.Store.FindTagList(ctx, tagFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find tag list").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find tag list").SetInternal(err)
} }
tagNameList := []string{} tagNameList := []string{}
for _, tag := range tagList { for _, tag := range list {
tagNameList = append(tagNameList, tag.Name) tagNameList = append(tagNameList, tag.Name)
} }
return c.JSON(http.StatusOK, composeResponse(tagNameList)) return c.JSON(http.StatusOK, tagNameList)
}) })
g.GET("/tag/suggestion", func(c echo.Context) error { g.GET("/tag/suggestion", func(c echo.Context) error {
@ -83,15 +95,14 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err)
} }
tagFind := &api.TagFind{ list, err := s.Store.ListTags(ctx, &store.FindTag{
CreatorID: userID, CreatorID: userID,
} })
existTagList, err := s.Store.FindTagList(ctx, tagFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find tag list").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find tag list").SetInternal(err)
} }
tagNameList := []string{} tagNameList := []string{}
for _, tag := range existTagList { for _, tag := range list {
tagNameList = append(tagNameList, tag.Name) tagNameList = append(tagNameList, tag.Name)
} }
@ -108,7 +119,7 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
tagList = append(tagList, tag) tagList = append(tagList, tag)
} }
sort.Strings(tagList) sort.Strings(tagList)
return c.JSON(http.StatusOK, composeResponse(tagList)) return c.JSON(http.StatusOK, tagList)
}) })
g.POST("/tag/delete", func(c echo.Context) error { g.POST("/tag/delete", func(c echo.Context) error {
@ -118,7 +129,7 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
tagDelete := &api.TagDelete{} tagDelete := &DeleteTagRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(tagDelete); err != nil { if err := json.NewDecoder(c.Request().Body).Decode(tagDelete); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post tag request").SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post tag request").SetInternal(err)
} }
@ -126,17 +137,45 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Tag name shouldn't be empty") return echo.NewHTTPError(http.StatusBadRequest, "Tag name shouldn't be empty")
} }
tagDelete.CreatorID = userID err := s.Store.DeleteTag(ctx, &store.DeleteTag{
if err := s.Store.DeleteTag(ctx, tagDelete); err != nil { Name: tagDelete.Name,
if common.ErrorCode(err) == common.NotFound { CreatorID: userID,
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Tag name not found: %s", tagDelete.Name)) })
} if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to delete tag name: %v", tagDelete.Name)).SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to delete tag name: %v", tagDelete.Name)).SetInternal(err)
} }
return c.JSON(http.StatusOK, true) return c.JSON(http.StatusOK, true)
}) })
} }
func (s *APIV1Service) createTagCreateActivity(c echo.Context, tag *Tag) error {
ctx := c.Request().Context()
payload := ActivityTagCreatePayload{
TagName: tag.Name,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return errors.Wrap(err, "failed to marshal activity payload")
}
activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{
CreatorID: tag.CreatorID,
Type: ActivityTagCreate.String(),
Level: ActivityInfo.String(),
Payload: string(payloadBytes),
})
if err != nil || activity == nil {
return errors.Wrap(err, "failed to create activity")
}
return err
}
func convertTagFromStore(tag *store.Tag) *Tag {
return &Tag{
Name: tag.Name,
CreatorID: tag.CreatorID,
}
}
var tagRegexp = regexp.MustCompile(`#([^\s#]+)`) var tagRegexp = regexp.MustCompile(`#([^\s#]+)`)
func findTagListFromMemoContent(memoContent string) []string { func findTagListFromMemoContent(memoContent string) []string {
@ -154,24 +193,3 @@ func findTagListFromMemoContent(memoContent string) []string {
sort.Strings(tagList) sort.Strings(tagList)
return tagList return tagList
} }
func (s *Server) createTagCreateActivity(c echo.Context, tag *api.Tag) error {
ctx := c.Request().Context()
payload := api.ActivityTagCreatePayload{
TagName: tag.Name,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return errors.Wrap(err, "failed to marshal activity payload")
}
activity, err := s.Store.CreateActivity(ctx, &api.ActivityCreate{
CreatorID: tag.CreatorID,
Type: api.ActivityTagCreate,
Level: api.ActivityInfo,
Payload: string(payloadBytes),
})
if err != nil || activity == nil {
return errors.Wrap(err, "failed to create activity")
}
return err
}

View File

@ -1,4 +1,4 @@
package server package v1
import ( import (
"testing" "testing"

View File

@ -1,9 +0,0 @@
package v1
import "github.com/labstack/echo/v4"
func (*APIV1Service) registerTestRoutes(g *echo.Group) {
g.GET("/test", func(c echo.Context) error {
return c.String(200, "Hello World")
})
}

View File

@ -1,25 +1,406 @@
package v1 package v1
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"time"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/common"
"github.com/usememos/memos/store"
"golang.org/x/crypto/bcrypt"
)
// Role is the type of a role. // Role is the type of a role.
type Role string type Role string
const ( const (
// Host is the HOST role. // RoleHost is the HOST role.
Host Role = "HOST" RoleHost Role = "HOST"
// Admin is the ADMIN role. // RoleAdmin is the ADMIN role.
Admin Role = "ADMIN" RoleAdmin Role = "ADMIN"
// NormalUser is the USER role. // RoleUser is the USER role.
NormalUser Role = "USER" RoleUser Role = "USER"
) )
func (e Role) String() string { func (role Role) String() string {
switch e { return string(role)
case Host: }
return "HOST"
case Admin: type User struct {
return "ADMIN" ID int `json:"id"`
case NormalUser:
return "USER" // Standard fields
} RowStatus RowStatus `json:"rowStatus"`
return "USER" CreatedTs int64 `json:"createdTs"`
UpdatedTs int64 `json:"updatedTs"`
// Domain specific fields
Username string `json:"username"`
Role Role `json:"role"`
Email string `json:"email"`
Nickname string `json:"nickname"`
PasswordHash string `json:"-"`
OpenID string `json:"openId"`
AvatarURL string `json:"avatarUrl"`
UserSettingList []*UserSetting `json:"userSettingList"`
}
type CreateUserRequest struct {
Username string `json:"username"`
Role Role `json:"role"`
Email string `json:"email"`
Nickname string `json:"nickname"`
Password string `json:"password"`
}
func (create CreateUserRequest) Validate() error {
if len(create.Username) < 3 {
return fmt.Errorf("username is too short, minimum length is 3")
}
if len(create.Username) > 32 {
return fmt.Errorf("username is too long, maximum length is 32")
}
if len(create.Password) < 3 {
return fmt.Errorf("password is too short, minimum length is 3")
}
if len(create.Password) > 512 {
return fmt.Errorf("password is too long, maximum length is 512")
}
if len(create.Nickname) > 64 {
return fmt.Errorf("nickname is too long, maximum length is 64")
}
if create.Email != "" {
if len(create.Email) > 256 {
return fmt.Errorf("email is too long, maximum length is 256")
}
if !common.ValidateEmail(create.Email) {
return fmt.Errorf("invalid email format")
}
}
return nil
}
type UpdateUserRequest struct {
RowStatus *RowStatus `json:"rowStatus"`
Username *string `json:"username"`
Email *string `json:"email"`
Nickname *string `json:"nickname"`
Password *string `json:"password"`
ResetOpenID *bool `json:"resetOpenId"`
AvatarURL *string `json:"avatarUrl"`
}
func (update UpdateUserRequest) Validate() error {
if update.Username != nil && len(*update.Username) < 3 {
return fmt.Errorf("username is too short, minimum length is 3")
}
if update.Username != nil && len(*update.Username) > 32 {
return fmt.Errorf("username is too long, maximum length is 32")
}
if update.Password != nil && len(*update.Password) < 3 {
return fmt.Errorf("password is too short, minimum length is 3")
}
if update.Password != nil && len(*update.Password) > 512 {
return fmt.Errorf("password is too long, maximum length is 512")
}
if update.Nickname != nil && len(*update.Nickname) > 64 {
return fmt.Errorf("nickname is too long, maximum length is 64")
}
if update.AvatarURL != nil {
if len(*update.AvatarURL) > 2<<20 {
return fmt.Errorf("avatar is too large, maximum is 2MB")
}
}
if update.Email != nil && *update.Email != "" {
if len(*update.Email) > 256 {
return fmt.Errorf("email is too long, maximum length is 256")
}
if !common.ValidateEmail(*update.Email) {
return fmt.Errorf("invalid email format")
}
}
return nil
}
func (s *APIV1Service) registerUserRoutes(g *echo.Group) {
g.POST("/user", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
currentUser, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user by id").SetInternal(err)
}
if currentUser.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized to create user")
}
userCreate := &CreateUserRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(userCreate); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post user request").SetInternal(err)
}
if err := userCreate.Validate(); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid user create format").SetInternal(err)
}
// Disallow host user to be created.
if userCreate.Role == RoleHost {
return echo.NewHTTPError(http.StatusForbidden, "Could not create host user")
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(userCreate.Password), bcrypt.DefaultCost)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err)
}
user, err := s.Store.CreateUser(ctx, &store.User{
Username: userCreate.Username,
Role: store.Role(userCreate.Role),
Email: userCreate.Email,
Nickname: userCreate.Nickname,
PasswordHash: string(passwordHash),
OpenID: common.GenUUID(),
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
}
userMessage := converUserFromStore(user)
if err := s.createUserCreateActivity(c, userMessage); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
}
return c.JSON(http.StatusOK, userMessage)
})
g.GET("/user", func(c echo.Context) error {
ctx := c.Request().Context()
list, err := s.Store.ListUsers(ctx, &store.FindUser{})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch user list").SetInternal(err)
}
userMessageList := make([]*User, 0, len(list))
for _, user := range list {
userMessage := converUserFromStore(user)
// data desensitize
userMessage.OpenID = ""
userMessage.Email = ""
userMessageList = append(userMessageList, userMessage)
}
return c.JSON(http.StatusOK, userMessageList)
})
// GET /api/user/me is used to check if the user is logged in.
g.GET("/user/me", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
list, err := s.Store.ListUserSettings(ctx, &store.FindUserSetting{
UserID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find userSettingList").SetInternal(err)
}
userSettingList := []*UserSetting{}
for _, userSetting := range list {
userSettingList = append(userSettingList, convertUserSettingFromStore(userSetting))
}
userMessage := converUserFromStore(user)
userMessage.UserSettingList = userSettingList
return c.JSON(http.StatusOK, userMessage)
})
g.GET("/user/:id", func(c echo.Context) error {
ctx := c.Request().Context()
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted user id").SetInternal(err)
}
user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &id})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil {
return echo.NewHTTPError(http.StatusNotFound, "User not found")
}
userMessage := converUserFromStore(user)
// data desensitize
userMessage.OpenID = ""
userMessage.Email = ""
return c.JSON(http.StatusOK, userMessage)
})
g.PATCH("/user/:id", func(c echo.Context) error {
ctx := c.Request().Context()
userID, err := strconv.Atoi(c.Param("id"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("id"))).SetInternal(err)
}
currentUserID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
currentUser, err := s.Store.GetUser(ctx, &store.FindUser{ID: &currentUserID})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if currentUser == nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Current session user not found with ID: %d", currentUserID)).SetInternal(err)
} else if currentUser.Role != store.RoleHost && currentUserID != userID {
return echo.NewHTTPError(http.StatusForbidden, "Unauthorized to update user").SetInternal(err)
}
request := &UpdateUserRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(request); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch user request").SetInternal(err)
}
if err := request.Validate(); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid update user request").SetInternal(err)
}
currentTs := time.Now().Unix()
userUpdate := &store.UpdateUser{
ID: userID,
UpdatedTs: &currentTs,
}
if request.RowStatus != nil {
rowStatus := store.RowStatus(request.RowStatus.String())
userUpdate.RowStatus = &rowStatus
}
if request.Username != nil {
userUpdate.Username = request.Username
}
if request.Email != nil {
userUpdate.Email = request.Email
}
if request.Nickname != nil {
userUpdate.Nickname = request.Nickname
}
if request.Password != nil {
passwordHash, err := bcrypt.GenerateFromPassword([]byte(*request.Password), bcrypt.DefaultCost)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err)
}
passwordHashStr := string(passwordHash)
userUpdate.PasswordHash = &passwordHashStr
}
if request.ResetOpenID != nil && *request.ResetOpenID {
openID := common.GenUUID()
userUpdate.OpenID = &openID
}
if request.AvatarURL != nil {
userUpdate.AvatarURL = request.AvatarURL
}
user, err := s.Store.UpdateUser(ctx, userUpdate)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch user").SetInternal(err)
}
list, err := s.Store.ListUserSettings(ctx, &store.FindUserSetting{
UserID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find userSettingList").SetInternal(err)
}
userSettingList := []*UserSetting{}
for _, userSetting := range list {
userSettingList = append(userSettingList, convertUserSettingFromStore(userSetting))
}
userMessage := converUserFromStore(user)
userMessage.UserSettingList = userSettingList
return c.JSON(http.StatusOK, userMessage)
})
g.DELETE("/user/:id", func(c echo.Context) error {
ctx := c.Request().Context()
currentUserID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
currentUser, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &currentUserID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if currentUser == nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Current session user not found with ID: %d", currentUserID)).SetInternal(err)
} else if currentUser.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusForbidden, "Unauthorized to delete user").SetInternal(err)
}
userID, err := strconv.Atoi(c.Param("id"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("id"))).SetInternal(err)
}
userDelete := &store.DeleteUser{
ID: userID,
}
if err := s.Store.DeleteUser(ctx, userDelete); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete user").SetInternal(err)
}
return c.JSON(http.StatusOK, true)
})
}
func (s *APIV1Service) createUserCreateActivity(c echo.Context, user *User) error {
ctx := c.Request().Context()
payload := ActivityUserCreatePayload{
UserID: user.ID,
Username: user.Username,
Role: user.Role,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return errors.Wrap(err, "failed to marshal activity payload")
}
activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{
CreatorID: user.ID,
Type: ActivityUserCreate.String(),
Level: ActivityInfo.String(),
Payload: string(payloadBytes),
})
if err != nil || activity == nil {
return errors.Wrap(err, "failed to create activity")
}
return err
}
func converUserFromStore(user *store.User) *User {
return &User{
ID: user.ID,
RowStatus: RowStatus(user.RowStatus),
CreatedTs: user.CreatedTs,
UpdatedTs: user.UpdatedTs,
Username: user.Username,
Role: Role(user.Role),
Email: user.Email,
Nickname: user.Nickname,
PasswordHash: user.PasswordHash,
OpenID: user.OpenID,
AvatarURL: user.AvatarURL,
}
} }

View File

@ -3,8 +3,10 @@ package v1
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv" "net/http"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/store"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
@ -63,19 +65,18 @@ var (
) )
type UserSetting struct { type UserSetting struct {
UserID int UserID int `json:"userId"`
Key UserSettingKey `json:"key"` Key UserSettingKey `json:"key"`
// Value is a JSON string with basic value Value string `json:"value"`
Value string `json:"value"`
} }
type UserSettingUpsert struct { type UpsertUserSettingRequest struct {
UserID int `json:"-"` UserID int `json:"-"`
Key UserSettingKey `json:"key"` Key UserSettingKey `json:"key"`
Value string `json:"value"` Value string `json:"value"`
} }
func (upsert UserSettingUpsert) Validate() error { func (upsert UpsertUserSettingRequest) Validate() error {
if upsert.Key == UserSettingLocaleKey { if upsert.Key == UserSettingLocaleKey {
localeValue := "en" localeValue := "en"
err := json.Unmarshal([]byte(upsert.Value), &localeValue) err := json.Unmarshal([]byte(upsert.Value), &localeValue)
@ -104,18 +105,11 @@ func (upsert UserSettingUpsert) Validate() error {
return fmt.Errorf("invalid user setting memo visibility value") return fmt.Errorf("invalid user setting memo visibility value")
} }
} else if upsert.Key == UserSettingTelegramUserIDKey { } else if upsert.Key == UserSettingTelegramUserIDKey {
var s string var key string
err := json.Unmarshal([]byte(upsert.Value), &s) err := json.Unmarshal([]byte(upsert.Value), &key)
if err != nil { if err != nil {
return fmt.Errorf("invalid user setting telegram user id value") return fmt.Errorf("invalid user setting telegram user id value")
} }
if s == "" {
return nil
}
if _, err := strconv.Atoi(s); err != nil {
return fmt.Errorf("invalid user setting telegram user id value")
}
} else { } else {
return fmt.Errorf("invalid user setting key") return fmt.Errorf("invalid user setting key")
} }
@ -123,12 +117,41 @@ func (upsert UserSettingUpsert) Validate() error {
return nil return nil
} }
type UserSettingFind struct { func (s *APIV1Service) registerUserSettingRoutes(g *echo.Group) {
UserID *int g.POST("/user/setting", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
Key UserSettingKey `json:"key"` userSettingUpsert := &UpsertUserSettingRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(userSettingUpsert); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post user setting upsert request").SetInternal(err)
}
if err := userSettingUpsert.Validate(); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid user setting format").SetInternal(err)
}
userSettingUpsert.UserID = userID
userSetting, err := s.Store.UpsertUserSetting(ctx, &store.UserSetting{
UserID: userID,
Key: userSettingUpsert.Key.String(),
Value: userSettingUpsert.Value,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert user setting").SetInternal(err)
}
userSettingMessage := convertUserSettingFromStore(userSetting)
return c.JSON(http.StatusOK, userSettingMessage)
})
} }
type UserSettingDelete struct { func convertUserSettingFromStore(userSetting *store.UserSetting) *UserSetting {
UserID int return &UserSetting{
UserID: userSetting.UserID,
Key: UserSettingKey(userSetting.Key),
Value: userSetting.Value,
}
} }

View File

@ -25,7 +25,12 @@ func (s *APIV1Service) Register(rootGroup *echo.Group) {
apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc { apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, s.Secret) return JWTMiddleware(s, next, s.Secret)
}) })
s.registerTestRoutes(apiV1Group) s.registerSystemRoutes(apiV1Group)
s.registerSystemSettingRoutes(apiV1Group)
s.registerAuthRoutes(apiV1Group) s.registerAuthRoutes(apiV1Group)
s.registerIdentityProviderRoutes(apiV1Group) s.registerIdentityProviderRoutes(apiV1Group)
s.registerUserRoutes(apiV1Group)
s.registerUserSettingRoutes(apiV1Group)
s.registerTagRoutes(apiV1Group)
s.registerShortcutRoutes(apiV1Group)
} }

View File

@ -4,8 +4,8 @@ import (
"net/http" "net/http"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common" "github.com/usememos/memos/common"
"github.com/usememos/memos/store"
) )
type response struct { type response struct {
@ -39,10 +39,9 @@ func (s *Server) defaultAuthSkipper(c echo.Context) bool {
// If there is openId in query string and related user is found, then skip auth. // If there is openId in query string and related user is found, then skip auth.
openID := c.QueryParam("openId") openID := c.QueryParam("openId")
if openID != "" { if openID != "" {
userFind := &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUser{
OpenID: &openID, OpenID: &openID,
} })
user, err := s.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil && common.ErrorCode(err) != common.NotFound {
return false return false
} }

View File

@ -81,11 +81,6 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
return next(c) return next(c)
} }
// Skip validation for server status endpoints.
if common.HasPrefixes(path, "/api/ping", "/api/v1/idp", "/api/user/:id") && method == http.MethodGet {
return next(c)
}
token := findAccessToken(c) token := findAccessToken(c)
if token == "" { if token == "" {
// Allow the user to access the public endpoints. // Allow the user to access the public endpoints.
@ -93,7 +88,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
return next(c) return next(c)
} }
// When the request is not authenticated, we allow the user to access the memo endpoints for those public memos. // When the request is not authenticated, we allow the user to access the memo endpoints for those public memos.
if common.HasPrefixes(path, "/api/status", "/api/memo") && method == http.MethodGet { if common.HasPrefixes(path, "/api/memo") && method == http.MethodGet {
return next(c) return next(c)
} }
return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token")

View File

@ -60,10 +60,10 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
} }
// Find disable public memos system setting. // Find disable public memos system setting.
disablePublicMemosSystemSetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ disablePublicMemosSystemSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
Name: api.SystemSettingDisablePublicMemosName, Name: apiv1.SystemSettingDisablePublicMemosName.String(),
}) })
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting").SetInternal(err)
} }
if disablePublicMemosSystemSetting != nil { if disablePublicMemosSystemSetting != nil {
@ -73,14 +73,14 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting").SetInternal(err)
} }
if disablePublicMemos { if disablePublicMemos {
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
// Enforce normal user to create private memo if public memos are disabled. // Enforce normal user to create private memo if public memos are disabled.
if user.Role == "USER" { if user.Role == store.RoleUser {
createMemoRequest.Visibility = api.Private createMemoRequest.Visibility = api.Private
} }
} }
@ -91,7 +91,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create memo").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create memo").SetInternal(err)
} }
if err := createMemoCreateActivity(c.Request().Context(), s.Store, memoMessage); err != nil { if err := s.createMemoCreateActivity(ctx, memoMessage); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
} }
@ -561,8 +561,8 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
}) })
} }
func createMemoCreateActivity(ctx context.Context, store *store.Store, memo *store.MemoMessage) error { func (s *Server) createMemoCreateActivity(ctx context.Context, memo *store.MemoMessage) error {
payload := api.ActivityMemoCreatePayload{ payload := apiv1.ActivityMemoCreatePayload{
Content: memo.Content, Content: memo.Content,
Visibility: memo.Visibility.String(), Visibility: memo.Visibility.String(),
} }
@ -570,10 +570,10 @@ func createMemoCreateActivity(ctx context.Context, store *store.Store, memo *sto
if err != nil { if err != nil {
return errors.Wrap(err, "failed to marshal activity payload") return errors.Wrap(err, "failed to marshal activity payload")
} }
activity, err := store.CreateActivity(ctx, &api.ActivityCreate{ activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{
CreatorID: memo.CreatorID, CreatorID: memo.CreatorID,
Type: api.ActivityMemoCreate, Type: apiv1.ActivityMemoCreate.String(),
Level: api.ActivityInfo, Level: apiv1.ActivityInfo.String(),
Payload: string(payloadBytes), Payload: string(payloadBytes),
}) })
if err != nil || activity == nil { if err != nil || activity == nil {
@ -654,7 +654,7 @@ func (s *Server) composeMemoMessageToMemoResponse(ctx context.Context, memoMessa
} }
// Compose creator name. // Compose creator name.
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &memoResponse.CreatorID, ID: &memoResponse.CreatorID,
}) })
if err != nil { if err != nil {
@ -699,10 +699,10 @@ func (s *Server) composeMemoMessageToMemoResponse(ctx context.Context, memoMessa
} }
func (s *Server) getMemoDisplayWithUpdatedTsSettingValue(ctx context.Context) (bool, error) { func (s *Server) getMemoDisplayWithUpdatedTsSettingValue(ctx context.Context) (bool, error) {
memoDisplayWithUpdatedTsSetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ memoDisplayWithUpdatedTsSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
Name: api.SystemSettingMemoDisplayWithUpdatedTsName, Name: apiv1.SystemSettingMemoDisplayWithUpdatedTsName.String(),
}) })
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil {
return false, errors.Wrap(err, "failed to find system setting") return false, errors.Wrap(err, "failed to find system setting")
} }
memoDisplayWithUpdatedTs := false memoDisplayWithUpdatedTs := false

View File

@ -1,49 +0,0 @@
package server
import (
"encoding/json"
"net/http"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
"github.com/usememos/memos/plugin/openai"
)
func (s *Server) registerOpenAIRoutes(g *echo.Group) {
g.POST("/openai/chat-completion", func(c echo.Context) error {
ctx := c.Request().Context()
openAIConfigSetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{
Name: api.SystemSettingOpenAIConfigName,
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai key").SetInternal(err)
}
openAIConfig := api.OpenAIConfig{}
if openAIConfigSetting != nil {
err = json.Unmarshal([]byte(openAIConfigSetting.Value), &openAIConfig)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal openai system setting value").SetInternal(err)
}
}
if openAIConfig.Key == "" {
return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set")
}
messages := []openai.ChatCompletionMessage{}
if err := json.NewDecoder(c.Request().Body).Decode(&messages); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post chat completion request").SetInternal(err)
}
if len(messages) == 0 {
return echo.NewHTTPError(http.StatusBadRequest, "No messages provided")
}
result, err := openai.PostChatCompletion(messages, openAIConfig.Key, openAIConfig.Host)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to post chat completion").SetInternal(err)
}
return c.JSON(http.StatusOK, composeResponse(result))
})
}

View File

@ -22,6 +22,7 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/usememos/memos/api" "github.com/usememos/memos/api"
apiv1 "github.com/usememos/memos/api/v1"
"github.com/usememos/memos/common" "github.com/usememos/memos/common"
"github.com/usememos/memos/common/log" "github.com/usememos/memos/common/log"
"github.com/usememos/memos/plugin/storage/s3" "github.com/usememos/memos/plugin/storage/s3"
@ -102,7 +103,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create resource").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create resource").SetInternal(err)
} }
if err := createResourceCreateActivity(c.Request().Context(), s.Store, resource); err != nil { if err := s.createResourceCreateActivity(ctx, resource); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
} }
return c.JSON(http.StatusOK, composeResponse(resource)) return c.JSON(http.StatusOK, composeResponse(resource))
@ -116,7 +117,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
} }
// This is the backend default max upload size limit. // This is the backend default max upload size limit.
maxUploadSetting := s.Store.GetSystemSettingValueOrDefault(&ctx, api.SystemSettingMaxUploadSizeMiBName, "32") maxUploadSetting := s.Store.GetSystemSettingValueWithDefault(&ctx, apiv1.SystemSettingMaxUploadSizeMiBName.String(), "32")
var settingMaxUploadSizeBytes int var settingMaxUploadSizeBytes int
if settingMaxUploadSizeMiB, err := strconv.Atoi(maxUploadSetting); err == nil { if settingMaxUploadSizeMiB, err := strconv.Atoi(maxUploadSetting); err == nil {
settingMaxUploadSizeBytes = settingMaxUploadSizeMiB * MebiByte settingMaxUploadSizeBytes = settingMaxUploadSizeMiB * MebiByte
@ -150,8 +151,8 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
defer sourceFile.Close() defer sourceFile.Close()
var resourceCreate *api.ResourceCreate var resourceCreate *api.ResourceCreate
systemSettingStorageServiceID, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{Name: api.SystemSettingStorageServiceIDName}) systemSettingStorageServiceID, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{Name: apiv1.SystemSettingStorageServiceIDName.String()})
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage").SetInternal(err)
} }
storageServiceID := api.DatabaseStorage storageServiceID := api.DatabaseStorage
@ -179,7 +180,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
// filepath.Join() should be used for local file paths, // filepath.Join() should be used for local file paths,
// as it handles the os-specific path separator automatically. // as it handles the os-specific path separator automatically.
// path.Join() always uses '/' as path separator. // path.Join() always uses '/' as path separator.
systemSettingLocalStoragePath, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{Name: api.SystemSettingLocalStoragePathName}) systemSettingLocalStoragePath, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{Name: apiv1.SystemSettingLocalStoragePathName.String()})
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find local storage path setting").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find local storage path setting").SetInternal(err)
} }
@ -265,7 +266,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create resource").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create resource").SetInternal(err)
} }
if err := createResourceCreateActivity(c.Request().Context(), s.Store, resource); err != nil { if err := s.createResourceCreateActivity(ctx, resource); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
} }
return c.JSON(http.StatusOK, composeResponse(resource)) return c.JSON(http.StatusOK, composeResponse(resource))
@ -530,8 +531,8 @@ func (s *Server) registerResourcePublicRoutes(g *echo.Group) {
}) })
} }
func createResourceCreateActivity(ctx context.Context, store *store.Store, resource *api.Resource) error { func (s *Server) createResourceCreateActivity(ctx context.Context, resource *api.Resource) error {
payload := api.ActivityResourceCreatePayload{ payload := apiv1.ActivityResourceCreatePayload{
Filename: resource.Filename, Filename: resource.Filename,
Type: resource.Type, Type: resource.Type,
Size: resource.Size, Size: resource.Size,
@ -540,10 +541,10 @@ func createResourceCreateActivity(ctx context.Context, store *store.Store, resou
if err != nil { if err != nil {
return errors.Wrap(err, "failed to marshal activity payload") return errors.Wrap(err, "failed to marshal activity payload")
} }
activity, err := store.CreateActivity(ctx, &api.ActivityCreate{ activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{
CreatorID: resource.CreatorID, CreatorID: resource.CreatorID,
Type: api.ActivityResourceCreate, Type: apiv1.ActivityResourceCreate.String(),
Level: api.ActivityInfo, Level: apiv1.ActivityInfo.String(),
Payload: string(payloadBytes), Payload: string(payloadBytes),
}) })
if err != nil || activity == nil { if err != nil || activity == nil {

View File

@ -12,6 +12,7 @@ import (
"github.com/gorilla/feeds" "github.com/gorilla/feeds"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/usememos/memos/api" "github.com/usememos/memos/api"
apiv1 "github.com/usememos/memos/api/v1"
"github.com/usememos/memos/common" "github.com/usememos/memos/common"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
"github.com/yuin/goldmark" "github.com/yuin/goldmark"
@ -80,7 +81,7 @@ func (s *Server) registerRSSRoutes(g *echo.Group) {
const MaxRSSItemCount = 100 const MaxRSSItemCount = 100
const MaxRSSItemTitleLength = 100 const MaxRSSItemTitleLength = 100
func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store.MemoMessage, baseURL string, profile *api.CustomizedProfile) (string, error) { func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store.MemoMessage, baseURL string, profile *apiv1.CustomizedProfile) (string, error) {
feed := &feeds.Feed{ feed := &feeds.Feed{
Title: profile.Name, Title: profile.Name,
Link: &feeds.Link{Href: baseURL}, Link: &feeds.Link{Href: baseURL},
@ -126,15 +127,14 @@ func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store.
return rss, nil return rss, nil
} }
func (s *Server) getSystemCustomizedProfile(ctx context.Context) (*api.CustomizedProfile, error) { func (s *Server) getSystemCustomizedProfile(ctx context.Context) (*apiv1.CustomizedProfile, error) {
systemSetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ systemSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
Name: api.SystemSettingCustomizedProfileName, Name: apiv1.SystemSettingCustomizedProfileName.String(),
}) })
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil {
return nil, err return nil, err
} }
customizedProfile := &apiv1.CustomizedProfile{
customizedProfile := &api.CustomizedProfile{
Name: "memos", Name: "memos",
LogoURL: "", LogoURL: "",
Description: "", Description: "",

View File

@ -6,9 +6,10 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/google/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/usememos/memos/api" apiv1 "github.com/usememos/memos/api/v1"
apiV1 "github.com/usememos/memos/api/v1" "github.com/usememos/memos/common"
"github.com/usememos/memos/plugin/telegram" "github.com/usememos/memos/plugin/telegram"
"github.com/usememos/memos/server/profile" "github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
@ -97,18 +98,13 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, s.Secret) return JWTMiddleware(s, next, s.Secret)
}) })
s.registerSystemRoutes(apiGroup)
s.registerUserRoutes(apiGroup)
s.registerMemoRoutes(apiGroup) s.registerMemoRoutes(apiGroup)
s.registerMemoResourceRoutes(apiGroup) s.registerMemoResourceRoutes(apiGroup)
s.registerShortcutRoutes(apiGroup)
s.registerResourceRoutes(apiGroup) s.registerResourceRoutes(apiGroup)
s.registerTagRoutes(apiGroup)
s.registerStorageRoutes(apiGroup) s.registerStorageRoutes(apiGroup)
s.registerOpenAIRoutes(apiGroup)
s.registerMemoRelationRoutes(apiGroup) s.registerMemoRelationRoutes(apiGroup)
apiV1Service := apiV1.NewAPIV1Service(s.Secret, profile, store) apiV1Service := apiv1.NewAPIV1Service(s.Secret, profile, store)
apiV1Service.Register(rootGroup) apiV1Service.Register(rootGroup)
return s, nil return s, nil
@ -145,8 +141,46 @@ func (s *Server) GetEcho() *echo.Echo {
return s.e return s.e
} }
func (s *Server) getSystemServerID(ctx context.Context) (string, error) {
serverIDSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
Name: apiv1.SystemSettingServerIDName.String(),
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return "", err
}
if serverIDSetting == nil || serverIDSetting.Value == "" {
serverIDSetting, err = s.Store.UpsertSystemSetting(ctx, &store.SystemSetting{
Name: apiv1.SystemSettingServerIDName.String(),
Value: uuid.NewString(),
})
if err != nil {
return "", err
}
}
return serverIDSetting.Value, nil
}
func (s *Server) getSystemSecretSessionName(ctx context.Context) (string, error) {
secretSessionNameValue, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
Name: apiv1.SystemSettingSecretSessionName.String(),
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return "", err
}
if secretSessionNameValue == nil || secretSessionNameValue.Value == "" {
secretSessionNameValue, err = s.Store.UpsertSystemSetting(ctx, &store.SystemSetting{
Name: apiv1.SystemSettingSecretSessionName.String(),
Value: uuid.NewString(),
})
if err != nil {
return "", err
}
}
return secretSessionNameValue.Value, nil
}
func (s *Server) createServerStartActivity(ctx context.Context) error { func (s *Server) createServerStartActivity(ctx context.Context) error {
payload := api.ActivityServerStartPayload{ payload := apiv1.ActivityServerStartPayload{
ServerID: s.ID, ServerID: s.ID,
Profile: s.Profile, Profile: s.Profile,
} }
@ -154,10 +188,10 @@ func (s *Server) createServerStartActivity(ctx context.Context) error {
if err != nil { if err != nil {
return errors.Wrap(err, "failed to marshal activity payload") return errors.Wrap(err, "failed to marshal activity payload")
} }
activity, err := s.Store.CreateActivity(ctx, &api.ActivityCreate{ activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{
CreatorID: api.UnknownID, CreatorID: apiv1.UnknownID,
Type: api.ActivityServerStart, Type: apiv1.ActivityServerStart.String(),
Level: api.ActivityInfo, Level: apiv1.ActivityInfo.String(),
Payload: string(payloadBytes), Payload: string(payloadBytes),
}) })
if err != nil || activity == nil { if err != nil || activity == nil {

View File

@ -8,7 +8,9 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/usememos/memos/api" "github.com/usememos/memos/api"
apiv1 "github.com/usememos/memos/api/v1"
"github.com/usememos/memos/common" "github.com/usememos/memos/common"
"github.com/usememos/memos/store"
) )
func (s *Server) registerStorageRoutes(g *echo.Group) { func (s *Server) registerStorageRoutes(g *echo.Group) {
@ -19,13 +21,13 @@ func (s *Server) registerStorageRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != api.Host { if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
@ -48,13 +50,13 @@ func (s *Server) registerStorageRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != api.Host { if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
@ -84,14 +86,14 @@ func (s *Server) registerStorageRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
// We should only show storage list to host user. // We should only show storage list to host user.
if user == nil || user.Role != api.Host { if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
@ -109,13 +111,13 @@ func (s *Server) registerStorageRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != api.Host { if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
@ -124,8 +126,8 @@ func (s *Server) registerStorageRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("storageId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("storageId"))).SetInternal(err)
} }
systemSetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{Name: api.SystemSettingStorageServiceIDName}) systemSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{Name: apiv1.SystemSettingStorageServiceIDName.String()})
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage").SetInternal(err)
} }
if systemSetting != nil { if systemSetting != nil {

View File

@ -1,242 +0,0 @@
package server
import (
"context"
"encoding/json"
"net/http"
"os"
"github.com/google/uuid"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
"github.com/usememos/memos/common/log"
"go.uber.org/zap"
"github.com/labstack/echo/v4"
)
func (s *Server) registerSystemRoutes(g *echo.Group) {
g.GET("/ping", func(c echo.Context) error {
return c.JSON(http.StatusOK, composeResponse(s.Profile))
})
g.GET("/status", func(c echo.Context) error {
ctx := c.Request().Context()
hostUserType := api.Host
hostUserFind := api.UserFind{
Role: &hostUserType,
}
hostUser, err := s.Store.FindUser(ctx, &hostUserFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find host user").SetInternal(err)
}
if hostUser != nil {
// data desensitize
hostUser.OpenID = ""
hostUser.Email = ""
}
systemStatus := api.SystemStatus{
Host: hostUser,
Profile: *s.Profile,
DBSize: 0,
AllowSignUp: false,
DisablePublicMemos: false,
MaxUploadSizeMiB: 32,
AdditionalStyle: "",
AdditionalScript: "",
CustomizedProfile: api.CustomizedProfile{
Name: "memos",
LogoURL: "",
Description: "",
Locale: "en",
Appearance: "system",
ExternalURL: "",
},
StorageServiceID: api.DatabaseStorage,
LocalStoragePath: "assets/{timestamp}_{filename}",
MemoDisplayWithUpdatedTs: false,
}
systemSettingList, err := s.Store.FindSystemSettingList(ctx, &api.SystemSettingFind{})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting list").SetInternal(err)
}
for _, systemSetting := range systemSettingList {
if systemSetting.Name == api.SystemSettingServerIDName || systemSetting.Name == api.SystemSettingSecretSessionName || systemSetting.Name == api.SystemSettingOpenAIConfigName || systemSetting.Name == api.SystemSettingTelegramBotTokenName {
continue
}
var baseValue any
err := json.Unmarshal([]byte(systemSetting.Value), &baseValue)
if err != nil {
log.Warn("Failed to unmarshal system setting value", zap.String("setting name", systemSetting.Name.String()))
continue
}
switch systemSetting.Name {
case api.SystemSettingAllowSignUpName:
systemStatus.AllowSignUp = baseValue.(bool)
case api.SystemSettingDisablePublicMemosName:
systemStatus.DisablePublicMemos = baseValue.(bool)
case api.SystemSettingMaxUploadSizeMiBName:
systemStatus.MaxUploadSizeMiB = int(baseValue.(float64))
case api.SystemSettingAdditionalStyleName:
systemStatus.AdditionalStyle = baseValue.(string)
case api.SystemSettingAdditionalScriptName:
systemStatus.AdditionalScript = baseValue.(string)
case api.SystemSettingCustomizedProfileName:
customizedProfile := api.CustomizedProfile{}
if err := json.Unmarshal([]byte(systemSetting.Value), &customizedProfile); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting customized profile value").SetInternal(err)
}
systemStatus.CustomizedProfile = customizedProfile
case api.SystemSettingStorageServiceIDName:
systemStatus.StorageServiceID = int(baseValue.(float64))
case api.SystemSettingLocalStoragePathName:
systemStatus.LocalStoragePath = baseValue.(string)
case api.SystemSettingMemoDisplayWithUpdatedTsName:
systemStatus.MemoDisplayWithUpdatedTs = baseValue.(bool)
default:
log.Warn("Unknown system setting name", zap.String("setting name", systemSetting.Name.String()))
}
}
userID, ok := c.Get(getUserIDContextKey()).(int)
// Get database size for host user.
if ok {
user, err := s.Store.FindUser(ctx, &api.UserFind{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user != nil && user.Role == api.Host {
fi, err := os.Stat(s.Profile.DSN)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to read database fileinfo").SetInternal(err)
}
systemStatus.DBSize = fi.Size()
}
}
return c.JSON(http.StatusOK, composeResponse(systemStatus))
})
g.POST("/system/setting", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.FindUser(ctx, &api.UserFind{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != api.Host {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
systemSettingUpsert := &api.SystemSettingUpsert{}
if err := json.NewDecoder(c.Request().Body).Decode(systemSettingUpsert); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post system setting request").SetInternal(err)
}
if err := systemSettingUpsert.Validate(); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "system setting invalidate").SetInternal(err)
}
systemSetting, err := s.Store.UpsertSystemSetting(ctx, systemSettingUpsert)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert system setting").SetInternal(err)
}
return c.JSON(http.StatusOK, composeResponse(systemSetting))
})
g.GET("/system/setting", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.FindUser(ctx, &api.UserFind{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != api.Host {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
systemSettingList, err := s.Store.FindSystemSettingList(ctx, &api.SystemSettingFind{})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting list").SetInternal(err)
}
return c.JSON(http.StatusOK, composeResponse(systemSettingList))
})
g.POST("/system/vacuum", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.FindUser(ctx, &api.UserFind{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != api.Host {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
if err := s.Store.Vacuum(ctx); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to vacuum database").SetInternal(err)
}
return c.JSON(http.StatusOK, true)
})
}
func (s *Server) getSystemServerID(ctx context.Context) (string, error) {
serverIDValue, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{
Name: api.SystemSettingServerIDName,
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return "", err
}
if serverIDValue == nil || serverIDValue.Value == "" {
serverIDValue, err = s.Store.UpsertSystemSetting(ctx, &api.SystemSettingUpsert{
Name: api.SystemSettingServerIDName,
Value: uuid.NewString(),
})
if err != nil {
return "", err
}
}
return serverIDValue.Value, nil
}
func (s *Server) getSystemSecretSessionName(ctx context.Context) (string, error) {
secretSessionNameValue, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{
Name: api.SystemSettingSecretSessionName,
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return "", err
}
if secretSessionNameValue == nil || secretSessionNameValue.Value == "" {
secretSessionNameValue, err = s.Store.UpsertSystemSetting(ctx, &api.SystemSettingUpsert{
Name: api.SystemSettingSecretSessionName,
Value: uuid.NewString(),
})
if err != nil {
return "", err
}
}
return secretSessionNameValue.Value, nil
}

View File

@ -24,7 +24,7 @@ func newTelegramHandler(store *store.Store) *telegramHandler {
} }
func (t *telegramHandler) BotToken(ctx context.Context) string { func (t *telegramHandler) BotToken(ctx context.Context) string {
return t.store.GetSystemSettingValueOrDefault(&ctx, api.SystemSettingTelegramBotTokenName, "") return t.store.GetSystemSettingValueWithDefault(&ctx, apiv1.SystemSettingTelegramBotTokenName.String(), "")
} }
const ( const (
@ -80,11 +80,6 @@ func (t *telegramHandler) MessageHandle(ctx context.Context, bot *telegram.Bot,
return err return err
} }
if err := createMemoCreateActivity(ctx, t.store, memoMessage); err != nil {
_, err := bot.EditMessage(ctx, message.Chat.ID, reply.MessageID, fmt.Sprintf("failed to createMemoCreateActivity: %s", err), nil)
return err
}
// create resources // create resources
for filename, blob := range blobs { for filename, blob := range blobs {
// TODO support more // TODO support more
@ -108,10 +103,6 @@ func (t *telegramHandler) MessageHandle(ctx context.Context, bot *telegram.Bot,
_, err := bot.EditMessage(ctx, message.Chat.ID, reply.MessageID, fmt.Sprintf("failed to CreateResource: %s", err), nil) _, err := bot.EditMessage(ctx, message.Chat.ID, reply.MessageID, fmt.Sprintf("failed to CreateResource: %s", err), nil)
return err return err
} }
if err := createResourceCreateActivity(ctx, t.store, resource); err != nil {
_, err := bot.EditMessage(ctx, message.Chat.ID, reply.MessageID, fmt.Sprintf("failed to createResourceCreateActivity: %s", err), nil)
return err
}
_, err = t.store.UpsertMemoResource(ctx, &api.MemoResourceUpsert{ _, err = t.store.UpsertMemoResource(ctx, &api.MemoResourceUpsert{
MemoID: memoMessage.ID, MemoID: memoMessage.ID,

View File

@ -1,306 +0,0 @@
package server
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"time"
"github.com/pkg/errors"
"github.com/usememos/memos/api"
apiv1 "github.com/usememos/memos/api/v1"
"github.com/usememos/memos/common"
"github.com/usememos/memos/store"
"github.com/labstack/echo/v4"
"golang.org/x/crypto/bcrypt"
)
func (s *Server) registerUserRoutes(g *echo.Group) {
g.POST("/user", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
currentUser, err := s.Store.FindUser(ctx, &api.UserFind{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user by id").SetInternal(err)
}
if currentUser.Role != api.Host {
return echo.NewHTTPError(http.StatusUnauthorized, "Only Host user can create member")
}
userCreate := &api.UserCreate{}
if err := json.NewDecoder(c.Request().Body).Decode(userCreate); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post user request").SetInternal(err)
}
if userCreate.Role == api.Host {
return echo.NewHTTPError(http.StatusForbidden, "Could not create host user")
}
userCreate.OpenID = common.GenUUID()
if err := userCreate.Validate(); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid user create format").SetInternal(err)
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(userCreate.Password), bcrypt.DefaultCost)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err)
}
userCreate.PasswordHash = string(passwordHash)
user, err := s.Store.CreateUser(ctx, userCreate)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
}
if err := s.createUserCreateActivity(c, user); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
}
return c.JSON(http.StatusOK, composeResponse(user))
})
g.GET("/user", func(c echo.Context) error {
ctx := c.Request().Context()
userList, err := s.Store.FindUserList(ctx, &api.UserFind{})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch user list").SetInternal(err)
}
for _, user := range userList {
// data desensitize
user.OpenID = ""
user.Email = ""
}
return c.JSON(http.StatusOK, composeResponse(userList))
})
g.POST("/user/setting", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
userSettingUpsert := &apiv1.UserSettingUpsert{}
if err := json.NewDecoder(c.Request().Body).Decode(userSettingUpsert); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post user setting upsert request").SetInternal(err)
}
if err := userSettingUpsert.Validate(); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid user setting format").SetInternal(err)
}
userSettingUpsert.UserID = userID
userSetting, err := s.Store.UpsertUserSetting(ctx, &store.UserSetting{
UserID: userID,
Key: userSettingUpsert.Key.String(),
Value: userSettingUpsert.Value,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert user setting").SetInternal(err)
}
userSettingMessage := convertUserSettingFromStore(userSetting)
return c.JSON(http.StatusOK, composeResponse(userSettingMessage))
})
// GET /api/user/me is used to check if the user is logged in.
g.GET("/user/me", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
userFind := &api.UserFind{
ID: &userID,
}
user, err := s.Store.FindUser(ctx, userFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
list, err := s.Store.ListUserSettings(ctx, &store.FindUserSetting{
UserID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find userSettingList").SetInternal(err)
}
userSettingList := []*api.UserSetting{}
for _, item := range list {
userSetting := convertUserSettingFromStore(item)
userSettingList = append(userSettingList, &api.UserSetting{
UserID: userSetting.UserID,
Key: api.UserSettingKey(userSetting.Key),
Value: userSetting.Value,
})
}
user.UserSettingList = userSettingList
return c.JSON(http.StatusOK, composeResponse(user))
})
g.GET("/user/:id", func(c echo.Context) error {
ctx := c.Request().Context()
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted user id").SetInternal(err)
}
user, err := s.Store.FindUser(ctx, &api.UserFind{
ID: &id,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch user").SetInternal(err)
}
if user != nil {
// data desensitize
user.OpenID = ""
user.Email = ""
}
return c.JSON(http.StatusOK, composeResponse(user))
})
g.PATCH("/user/:id", func(c echo.Context) error {
ctx := c.Request().Context()
userID, err := strconv.Atoi(c.Param("id"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("id"))).SetInternal(err)
}
currentUserID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
currentUser, err := s.Store.FindUser(ctx, &api.UserFind{
ID: &currentUserID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if currentUser == nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Current session user not found with ID: %d", currentUserID)).SetInternal(err)
} else if currentUser.Role != api.Host && currentUserID != userID {
return echo.NewHTTPError(http.StatusForbidden, "Access forbidden for current session user").SetInternal(err)
}
currentTs := time.Now().Unix()
userPatch := &api.UserPatch{
UpdatedTs: &currentTs,
}
if err := json.NewDecoder(c.Request().Body).Decode(userPatch); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch user request").SetInternal(err)
}
userPatch.ID = userID
if userPatch.Password != nil && *userPatch.Password != "" {
passwordHash, err := bcrypt.GenerateFromPassword([]byte(*userPatch.Password), bcrypt.DefaultCost)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err)
}
passwordHashStr := string(passwordHash)
userPatch.PasswordHash = &passwordHashStr
}
if userPatch.ResetOpenID != nil && *userPatch.ResetOpenID {
openID := common.GenUUID()
userPatch.OpenID = &openID
}
if err := userPatch.Validate(); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid user patch format").SetInternal(err)
}
user, err := s.Store.PatchUser(ctx, userPatch)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch user").SetInternal(err)
}
list, err := s.Store.ListUserSettings(ctx, &store.FindUserSetting{
UserID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find userSettingList").SetInternal(err)
}
userSettingList := []*api.UserSetting{}
for _, item := range list {
userSetting := convertUserSettingFromStore(item)
userSettingList = append(userSettingList, &api.UserSetting{
UserID: userSetting.UserID,
Key: api.UserSettingKey(userSetting.Key),
Value: userSetting.Value,
})
}
user.UserSettingList = userSettingList
return c.JSON(http.StatusOK, composeResponse(user))
})
g.DELETE("/user/:id", func(c echo.Context) error {
ctx := c.Request().Context()
currentUserID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
currentUser, err := s.Store.FindUser(ctx, &api.UserFind{
ID: &currentUserID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if currentUser == nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Current session user not found with ID: %d", currentUserID)).SetInternal(err)
} else if currentUser.Role != api.Host {
return echo.NewHTTPError(http.StatusForbidden, "Access forbidden for current session user").SetInternal(err)
}
userID, err := strconv.Atoi(c.Param("id"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("id"))).SetInternal(err)
}
userDelete := &api.UserDelete{
ID: userID,
}
if err := s.Store.DeleteUser(ctx, userDelete); err != nil {
if common.ErrorCode(err) == common.NotFound {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("User ID not found: %d", userID))
}
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete user").SetInternal(err)
}
return c.JSON(http.StatusOK, true)
})
}
func (s *Server) createUserCreateActivity(c echo.Context, user *api.User) error {
ctx := c.Request().Context()
payload := api.ActivityUserCreatePayload{
UserID: user.ID,
Username: user.Username,
Role: user.Role,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return errors.Wrap(err, "failed to marshal activity payload")
}
activity, err := s.Store.CreateActivity(ctx, &api.ActivityCreate{
CreatorID: user.ID,
Type: api.ActivityUserCreate,
Level: api.ActivityInfo,
Payload: string(payloadBytes),
})
if err != nil || activity == nil {
return errors.Wrap(err, "failed to create activity")
}
return err
}
func convertUserSettingFromStore(userSetting *store.UserSetting) *apiv1.UserSetting {
return &apiv1.UserSetting{
UserID: userSetting.UserID,
Key: apiv1.UserSettingKey(userSetting.Key),
Value: userSetting.Value,
}
}

View File

@ -7,7 +7,6 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common" "github.com/usememos/memos/common"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
@ -33,10 +32,8 @@ func (s setupService) Setup(ctx context.Context, hostUsername, hostPassword stri
} }
func (s setupService) makeSureHostUserNotExists(ctx context.Context) error { func (s setupService) makeSureHostUserNotExists(ctx context.Context) error {
hostUserType := api.Host hostUserType := store.RoleHost
existedHostUsers, err := s.store.FindUserList(ctx, &api.UserFind{ existedHostUsers, err := s.store.ListUsers(ctx, &store.FindUser{Role: &hostUserType})
Role: &hostUserType,
})
if err != nil { if err != nil {
return fmt.Errorf("find user list: %w", err) return fmt.Errorf("find user list: %w", err)
} }
@ -52,7 +49,7 @@ func (s setupService) createUser(ctx context.Context, hostUsername, hostPassword
userCreate := &store.User{ userCreate := &store.User{
Username: hostUsername, Username: hostUsername,
// The new signup user should be normal user by default. // The new signup user should be normal user by default.
Role: store.Host, Role: store.RoleHost,
Nickname: hostUsername, Nickname: hostUsername,
OpenID: common.GenUUID(), OpenID: common.GenUUID(),
} }
@ -87,7 +84,7 @@ func (s setupService) createUser(ctx context.Context, hostUsername, hostPassword
} }
userCreate.PasswordHash = string(passwordHash) userCreate.PasswordHash = string(passwordHash)
if _, err := s.store.CreateUserV1(ctx, userCreate); err != nil { if _, err := s.store.CreateUser(ctx, userCreate); err != nil {
return fmt.Errorf("failed to create user: %w", err) return fmt.Errorf("failed to create user: %w", err)
} }

View File

@ -2,9 +2,6 @@ package store
import ( import (
"context" "context"
"database/sql"
"github.com/usememos/memos/api"
) )
type ActivityMessage struct { type ActivityMessage struct {
@ -20,8 +17,8 @@ type ActivityMessage struct {
Payload string Payload string
} }
// CreateActivityV1 creates an instance of Activity. // CreateActivity creates an instance of Activity.
func (s *Store) CreateActivityV1(ctx context.Context, create *ActivityMessage) (*ActivityMessage, error) { func (s *Store) CreateActivity(ctx context.Context, create *ActivityMessage) (*ActivityMessage, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, FormatError(err)
@ -51,80 +48,3 @@ func (s *Store) CreateActivityV1(ctx context.Context, create *ActivityMessage) (
activityMessage := create activityMessage := create
return activityMessage, nil return activityMessage, nil
} }
// activityRaw is the store model for an Activity.
// Fields have exactly the same meanings as Activity.
type activityRaw struct {
ID int
// Standard fields
CreatorID int
CreatedTs int64
// Domain specific fields
Type api.ActivityType
Level api.ActivityLevel
Payload string
}
// toActivity creates an instance of Activity based on the ActivityRaw.
func (raw *activityRaw) toActivity() *api.Activity {
return &api.Activity{
ID: raw.ID,
CreatorID: raw.CreatorID,
CreatedTs: raw.CreatedTs,
Type: raw.Type,
Level: raw.Level,
Payload: raw.Payload,
}
}
// CreateActivity creates an instance of Activity.
func (s *Store) CreateActivity(ctx context.Context, create *api.ActivityCreate) (*api.Activity, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
activityRaw, err := createActivity(ctx, tx, create)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
activity := activityRaw.toActivity()
return activity, nil
}
// createActivity creates a new activity.
func createActivity(ctx context.Context, tx *sql.Tx, create *api.ActivityCreate) (*activityRaw, error) {
query := `
INSERT INTO activity (
creator_id,
type,
level,
payload
)
VALUES (?, ?, ?, ?)
RETURNING id, type, level, payload, creator_id, created_ts
`
var activityRaw activityRaw
if err := tx.QueryRowContext(ctx, query, create.CreatorID, create.Type, create.Level, create.Payload).Scan(
&activityRaw.ID,
&activityRaw.Type,
&activityRaw.Level,
&activityRaw.Payload,
&activityRaw.CreatorID,
&activityRaw.CreatedTs,
); err != nil {
return nil, FormatError(err)
}
return &activityRaw, nil
}

View File

@ -4,6 +4,6 @@ import (
"fmt" "fmt"
) )
func getUserSettingCacheKeyV1(userID int, key string) string { func getUserSettingCacheKey(userID int, key string) string {
return fmt.Sprintf("%d-%s", userID, key) return fmt.Sprintf("%d-%s", userID, key)
} }

View File

@ -11,9 +11,13 @@ import (
type IdentityProviderType string type IdentityProviderType string
const ( const (
IdentityProviderOAuth2 IdentityProviderType = "OAUTH2" IdentityProviderOAuth2Type IdentityProviderType = "OAUTH2"
) )
func (t IdentityProviderType) String() string {
return string(t)
}
type IdentityProviderConfig struct { type IdentityProviderConfig struct {
OAuth2Config *IdentityProviderOAuth2Config OAuth2Config *IdentityProviderOAuth2Config
} }
@ -66,7 +70,7 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProv
defer tx.Rollback() defer tx.Rollback()
var configBytes []byte var configBytes []byte
if create.Type == IdentityProviderOAuth2 { if create.Type == IdentityProviderOAuth2Type {
configBytes, err = json.Marshal(create.Config.OAuth2Config) configBytes, err = json.Marshal(create.Config.OAuth2Config)
if err != nil { if err != nil {
return nil, err return nil, err
@ -167,7 +171,7 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti
} }
if v := update.Config; v != nil { if v := update.Config; v != nil {
var configBytes []byte var configBytes []byte
if update.Type == IdentityProviderOAuth2 { if update.Type == IdentityProviderOAuth2Type {
configBytes, err = json.Marshal(update.Config.OAuth2Config) configBytes, err = json.Marshal(update.Config.OAuth2Config)
if err != nil { if err != nil {
return nil, err return nil, err
@ -197,7 +201,7 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti
return nil, err return nil, err
} }
if identityProvider.Type == IdentityProviderOAuth2 { if identityProvider.Type == IdentityProviderOAuth2Type {
oauth2Config := &IdentityProviderOAuth2Config{} oauth2Config := &IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err return nil, err
@ -279,7 +283,7 @@ func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityPr
return nil, err return nil, err
} }
if identityProvider.Type == IdentityProviderOAuth2 { if identityProvider.Type == IdentityProviderOAuth2Type {
oauth2Config := &IdentityProviderOAuth2Config{} oauth2Config := &IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err return nil, err

View File

@ -3,20 +3,14 @@ package store
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strings" "strings"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
) )
// shortcutRaw is the store model for an Shortcut. type Shortcut struct {
// Fields have exactly the same meanings as Shortcut.
type shortcutRaw struct {
ID int ID int
// Standard fields // Standard fields
RowStatus api.RowStatus RowStatus RowStatus
CreatorID int CreatorID int
CreatedTs int64 CreatedTs int64
UpdatedTs int64 UpdatedTs int64
@ -26,134 +20,33 @@ type shortcutRaw struct {
Payload string Payload string
} }
func (raw *shortcutRaw) toShortcut() *api.Shortcut { type UpdateShortcut struct {
return &api.Shortcut{ ID int
ID: raw.ID,
RowStatus: raw.RowStatus, UpdatedTs *int64
CreatorID: raw.CreatorID, RowStatus *RowStatus
CreatedTs: raw.CreatedTs, Title *string
UpdatedTs: raw.UpdatedTs, Payload *string
Title: raw.Title,
Payload: raw.Payload,
}
} }
func (s *Store) CreateShortcut(ctx context.Context, create *api.ShortcutCreate) (*api.Shortcut, error) { type FindShortcut struct {
tx, err := s.db.BeginTx(ctx, nil) ID *int
if err != nil { CreatorID *int
return nil, FormatError(err) Title *string
} }
defer tx.Rollback()
shortcutRaw, err := createShortcut(ctx, tx, create) type DeleteShortcut struct {
ID *int
CreatorID *int
}
func (s *Store) CreateShortcut(ctx context.Context, create *Shortcut) (*Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
s.shortcutCache.Store(shortcutRaw.ID, shortcutRaw)
shortcut := shortcutRaw.toShortcut()
return shortcut, nil
}
func (s *Store) PatchShortcut(ctx context.Context, patch *api.ShortcutPatch) (*api.Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback() defer tx.Rollback()
shortcutRaw, err := patchShortcut(ctx, tx, patch)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
s.shortcutCache.Store(shortcutRaw.ID, shortcutRaw)
shortcut := shortcutRaw.toShortcut()
return shortcut, nil
}
func (s *Store) FindShortcutList(ctx context.Context, find *api.ShortcutFind) ([]*api.Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
shortcutRawList, err := findShortcutList(ctx, tx, find)
if err != nil {
return nil, err
}
list := []*api.Shortcut{}
for _, raw := range shortcutRawList {
list = append(list, raw.toShortcut())
}
return list, nil
}
func (s *Store) FindShortcut(ctx context.Context, find *api.ShortcutFind) (*api.Shortcut, error) {
if find.ID != nil {
if shortcut, ok := s.shortcutCache.Load(*find.ID); ok {
return shortcut.(*shortcutRaw).toShortcut(), nil
}
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
list, err := findShortcutList(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")}
}
shortcutRaw := list[0]
s.shortcutCache.Store(shortcutRaw.ID, shortcutRaw)
shortcut := shortcutRaw.toShortcut()
return shortcut, nil
}
func (s *Store) DeleteShortcut(ctx context.Context, delete *api.ShortcutDelete) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return FormatError(err)
}
defer tx.Rollback()
err = deleteShortcut(ctx, tx, delete)
if err != nil {
return FormatError(err)
}
if err := tx.Commit(); err != nil {
return FormatError(err)
}
s.shortcutCache.Delete(*delete.ID)
return nil
}
func createShortcut(ctx context.Context, tx *sql.Tx, create *api.ShortcutCreate) (*shortcutRaw, error) {
query := ` query := `
INSERT INTO shortcut ( INSERT INTO shortcut (
title, title,
@ -161,41 +54,81 @@ func createShortcut(ctx context.Context, tx *sql.Tx, create *api.ShortcutCreate)
creator_id creator_id
) )
VALUES (?, ?, ?) VALUES (?, ?, ?)
RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status RETURNING id, created_ts, updated_ts, row_status
` `
var shortcutRaw shortcutRaw
if err := tx.QueryRowContext(ctx, query, create.Title, create.Payload, create.CreatorID).Scan( if err := tx.QueryRowContext(ctx, query, create.Title, create.Payload, create.CreatorID).Scan(
&shortcutRaw.ID, &create.ID,
&shortcutRaw.Title, &create.CreatedTs,
&shortcutRaw.Payload, &create.UpdatedTs,
&shortcutRaw.CreatorID, &create.RowStatus,
&shortcutRaw.CreatedTs,
&shortcutRaw.UpdatedTs,
&shortcutRaw.RowStatus,
); err != nil { ); err != nil {
return nil, FormatError(err) return nil, err
} }
return &shortcutRaw, nil if err := tx.Commit(); err != nil {
return nil, err
}
shortcut := create
return shortcut, nil
} }
func patchShortcut(ctx context.Context, tx *sql.Tx, patch *api.ShortcutPatch) (*shortcutRaw, error) { func (s *Store) ListShortcuts(ctx context.Context, find *FindShortcut) ([]*Shortcut, error) {
set, args := []string{}, []any{} tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
if v := patch.UpdatedTs; v != nil { list, err := listShortcuts(ctx, tx, find)
if err != nil {
return nil, err
}
return list, nil
}
func (s *Store) GetShortcut(ctx context.Context, find *FindShortcut) (*Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listShortcuts(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
shortcut := list[0]
return shortcut, nil
}
func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
set, args := []string{}, []any{}
if v := update.UpdatedTs; v != nil {
set, args = append(set, "updated_ts = ?"), append(args, *v) set, args = append(set, "updated_ts = ?"), append(args, *v)
} }
if v := patch.Title; v != nil { if v := update.Title; v != nil {
set, args = append(set, "title = ?"), append(args, *v) set, args = append(set, "title = ?"), append(args, *v)
} }
if v := patch.Payload; v != nil { if v := update.Payload; v != nil {
set, args = append(set, "payload = ?"), append(args, *v) set, args = append(set, "payload = ?"), append(args, *v)
} }
if v := patch.RowStatus; v != nil { if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = ?"), append(args, *v) set, args = append(set, "row_status = ?"), append(args, *v)
} }
args = append(args, update.ID)
args = append(args, patch.ID)
query := ` query := `
UPDATE shortcut UPDATE shortcut
@ -203,23 +136,55 @@ func patchShortcut(ctx context.Context, tx *sql.Tx, patch *api.ShortcutPatch) (*
WHERE id = ? WHERE id = ?
RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status
` `
var shortcutRaw shortcutRaw shortcut := &Shortcut{}
if err := tx.QueryRowContext(ctx, query, args...).Scan( if err := tx.QueryRowContext(ctx, query, args...).Scan(
&shortcutRaw.ID, &shortcut.ID,
&shortcutRaw.Title, &shortcut.Title,
&shortcutRaw.Payload, &shortcut.Payload,
&shortcutRaw.CreatorID, &shortcut.CreatorID,
&shortcutRaw.CreatedTs, &shortcut.CreatedTs,
&shortcutRaw.UpdatedTs, &shortcut.UpdatedTs,
&shortcutRaw.RowStatus, &shortcut.RowStatus,
); err != nil { ); err != nil {
return nil, FormatError(err) return nil, err
} }
return &shortcutRaw, nil if err := tx.Commit(); err != nil {
return nil, err
}
return shortcut, nil
} }
func findShortcutList(ctx context.Context, tx *sql.Tx, find *api.ShortcutFind) ([]*shortcutRaw, error) { func (s *Store) DeleteShortcut(ctx context.Context, delete *DeleteShortcut) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
where, args := []string{}, []any{}
if v := delete.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := delete.CreatorID; v != nil {
where, args = append(where, "creator_id = ?"), append(args, *v)
}
stmt := `DELETE FROM shortcut WHERE ` + strings.Join(where, " AND ")
if _, err := tx.ExecContext(ctx, stmt, args...); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
s.shortcutCache.Delete(*delete.ID)
return nil
}
func listShortcuts(ctx context.Context, tx *sql.Tx, find *FindShortcut) ([]*Shortcut, error) {
where, args := []string{"1 = 1"}, []any{} where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil { if v := find.ID; v != nil {
@ -251,53 +216,28 @@ func findShortcutList(ctx context.Context, tx *sql.Tx, find *api.ShortcutFind) (
} }
defer rows.Close() defer rows.Close()
shortcutRawList := make([]*shortcutRaw, 0) list := make([]*Shortcut, 0)
for rows.Next() { for rows.Next() {
var shortcutRaw shortcutRaw var shortcut Shortcut
if err := rows.Scan( if err := rows.Scan(
&shortcutRaw.ID, &shortcut.ID,
&shortcutRaw.Title, &shortcut.Title,
&shortcutRaw.Payload, &shortcut.Payload,
&shortcutRaw.CreatorID, &shortcut.CreatorID,
&shortcutRaw.CreatedTs, &shortcut.CreatedTs,
&shortcutRaw.UpdatedTs, &shortcut.UpdatedTs,
&shortcutRaw.RowStatus, &shortcut.RowStatus,
); err != nil { ); err != nil {
return nil, FormatError(err) return nil, FormatError(err)
} }
list = append(list, &shortcut)
shortcutRawList = append(shortcutRawList, &shortcutRaw)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, FormatError(err) return nil, FormatError(err)
} }
return shortcutRawList, nil return list, nil
}
func deleteShortcut(ctx context.Context, tx *sql.Tx, delete *api.ShortcutDelete) error {
where, args := []string{}, []any{}
if v := delete.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := delete.CreatorID; v != nil {
where, args = append(where, "creator_id = ?"), append(args, *v)
}
stmt := `DELETE FROM shortcut WHERE ` + strings.Join(where, " AND ")
result, err := tx.ExecContext(ctx, stmt, args...)
if err != nil {
return FormatError(err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("shortcut not found")}
}
return nil
} }
func vacuumShortcut(ctx context.Context, tx *sql.Tx) error { func vacuumShortcut(ctx context.Context, tx *sql.Tx) error {

View File

@ -12,9 +12,8 @@ import (
type Store struct { type Store struct {
Profile *profile.Profile Profile *profile.Profile
db *sql.DB db *sql.DB
systemSettingCache sync.Map // map[string]*systemSettingRaw systemSettingCache sync.Map // map[string]*SystemSetting
userCache sync.Map // map[int]*userRaw userCache sync.Map // map[int]*User
userV1Cache sync.Map // map[string]*User
userSettingCache sync.Map // map[string]*UserSetting userSettingCache sync.Map // map[string]*UserSetting
shortcutCache sync.Map // map[int]*shortcutRaw shortcutCache sync.Map // map[int]*shortcutRaw
idpCache sync.Map // map[int]*IdentityProvider idpCache sync.Map // map[int]*IdentityProvider
@ -36,7 +35,7 @@ func (s *Store) GetDB() *sql.DB {
func (s *Store) Vacuum(ctx context.Context) error { func (s *Store) Vacuum(ctx context.Context) error {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return FormatError(err) return err
} }
defer tx.Rollback() defer tx.Rollback()
@ -45,7 +44,7 @@ func (s *Store) Vacuum(ctx context.Context) error {
} }
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return FormatError(err) return err
} }
// Vacuum sqlite database file size after deleting resource. // Vacuum sqlite database file size after deleting resource.

View File

@ -3,11 +3,7 @@ package store
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strings" "strings"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
) )
type SystemSetting struct { type SystemSetting struct {
@ -20,10 +16,39 @@ type FindSystemSetting struct {
Name string Name string
} }
func (s *Store) UpsertSystemSetting(ctx context.Context, upsert *SystemSetting) (*SystemSetting, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
query := `
INSERT INTO system_setting (
name, value, description
)
VALUES (?, ?, ?)
ON CONFLICT(name) DO UPDATE
SET
value = EXCLUDED.value,
description = EXCLUDED.description
`
if _, err := tx.ExecContext(ctx, query, upsert.Name, upsert.Value, upsert.Description); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
systemSetting := upsert
return systemSetting, nil
}
func (s *Store) ListSystemSettings(ctx context.Context, find *FindSystemSetting) ([]*SystemSetting, error) { func (s *Store) ListSystemSettings(ctx context.Context, find *FindSystemSetting) ([]*SystemSetting, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
@ -47,7 +72,7 @@ func (s *Store) GetSystemSetting(ctx context.Context, find *FindSystemSetting) (
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
@ -65,6 +90,15 @@ func (s *Store) GetSystemSetting(ctx context.Context, find *FindSystemSetting) (
return systemSettingMessage, nil return systemSettingMessage, nil
} }
func (s *Store) GetSystemSettingValueWithDefault(ctx *context.Context, settingName string, defaultValue string) string {
if setting, err := s.GetSystemSetting(*ctx, &FindSystemSetting{
Name: settingName,
}); err == nil && setting != nil {
return setting.Value
}
return defaultValue
}
func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting) ([]*SystemSetting, error) { func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting) ([]*SystemSetting, error) {
where, args := []string{"1 = 1"}, []any{} where, args := []string{"1 = 1"}, []any{}
if find.Name != "" { if find.Name != "" {
@ -81,7 +115,7 @@ func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting
rows, err := tx.QueryContext(ctx, query, args...) rows, err := tx.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer rows.Close() defer rows.Close()
@ -93,7 +127,7 @@ func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting
&systemSettingMessage.Value, &systemSettingMessage.Value,
&systemSettingMessage.Description, &systemSettingMessage.Description,
); err != nil { ); err != nil {
return nil, FormatError(err) return nil, err
} }
list = append(list, systemSettingMessage) list = append(list, systemSettingMessage)
} }
@ -104,160 +138,3 @@ func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting
return list, nil return list, nil
} }
type systemSettingRaw struct {
Name api.SystemSettingName
Value string
Description string
}
func (raw *systemSettingRaw) toSystemSetting() *api.SystemSetting {
return &api.SystemSetting{
Name: raw.Name,
Value: raw.Value,
Description: raw.Description,
}
}
func (s *Store) UpsertSystemSetting(ctx context.Context, upsert *api.SystemSettingUpsert) (*api.SystemSetting, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
systemSettingRaw, err := upsertSystemSetting(ctx, tx, upsert)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
systemSetting := systemSettingRaw.toSystemSetting()
s.systemSettingCache.Store(systemSettingRaw.Name, systemSettingRaw)
return systemSetting, nil
}
func (s *Store) FindSystemSettingList(ctx context.Context, find *api.SystemSettingFind) ([]*api.SystemSetting, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
systemSettingRawList, err := findSystemSettingList(ctx, tx, find)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
list := []*api.SystemSetting{}
for _, raw := range systemSettingRawList {
s.systemSettingCache.Store(raw.Name, raw)
list = append(list, raw.toSystemSetting())
}
return list, nil
}
func (s *Store) FindSystemSetting(ctx context.Context, find *api.SystemSettingFind) (*api.SystemSetting, error) {
if systemSetting, ok := s.systemSettingCache.Load(find.Name); ok {
systemSettingRaw := systemSetting.(*systemSettingRaw)
return systemSettingRaw.toSystemSetting(), nil
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
systemSettingRawList, err := findSystemSettingList(ctx, tx, find)
if err != nil {
return nil, err
}
if len(systemSettingRawList) == 0 {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")}
}
systemSettingRaw := systemSettingRawList[0]
s.systemSettingCache.Store(systemSettingRaw.Name, systemSettingRaw)
return systemSettingRaw.toSystemSetting(), nil
}
func (s *Store) GetSystemSettingValueOrDefault(ctx *context.Context, find api.SystemSettingName, defaultValue string) string {
if setting, err := s.FindSystemSetting(*ctx, &api.SystemSettingFind{
Name: find,
}); err == nil {
return setting.Value
}
return defaultValue
}
func upsertSystemSetting(ctx context.Context, tx *sql.Tx, upsert *api.SystemSettingUpsert) (*systemSettingRaw, error) {
query := `
INSERT INTO system_setting (
name, value, description
)
VALUES (?, ?, ?)
ON CONFLICT(name) DO UPDATE
SET
value = EXCLUDED.value,
description = EXCLUDED.description
RETURNING name, value, description
`
var systemSettingRaw systemSettingRaw
if err := tx.QueryRowContext(ctx, query, upsert.Name, upsert.Value, upsert.Description).Scan(
&systemSettingRaw.Name,
&systemSettingRaw.Value,
&systemSettingRaw.Description,
); err != nil {
return nil, FormatError(err)
}
return &systemSettingRaw, nil
}
func findSystemSettingList(ctx context.Context, tx *sql.Tx, find *api.SystemSettingFind) ([]*systemSettingRaw, error) {
where, args := []string{"1 = 1"}, []any{}
if find.Name.String() != "" {
where, args = append(where, "name = ?"), append(args, find.Name.String())
}
query := `
SELECT
name,
value,
description
FROM system_setting
WHERE ` + strings.Join(where, " AND ")
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, FormatError(err)
}
defer rows.Close()
systemSettingRawList := make([]*systemSettingRaw, 0)
for rows.Next() {
var systemSettingRaw systemSettingRaw
if err := rows.Scan(
&systemSettingRaw.Name,
&systemSettingRaw.Value,
&systemSettingRaw.Description,
); err != nil {
return nil, FormatError(err)
}
systemSettingRawList = append(systemSettingRawList, &systemSettingRaw)
}
if err := rows.Err(); err != nil {
return nil, FormatError(err)
}
return systemSettingRawList, nil
}

View File

@ -5,83 +5,29 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"strings" "strings"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
) )
type tagRaw struct { type Tag struct {
Name string Name string
CreatorID int CreatorID int
} }
func (raw *tagRaw) toTag() *api.Tag { type FindTag struct {
return &api.Tag{ CreatorID int
Name: raw.Name,
CreatorID: raw.CreatorID,
}
} }
func (s *Store) UpsertTag(ctx context.Context, upsert *api.TagUpsert) (*api.Tag, error) { type DeleteTag struct {
Name string
CreatorID int
}
func (s *Store) UpsertTagV1(ctx context.Context, upsert *Tag) (*Tag, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, FormatError(err)
} }
defer tx.Rollback() defer tx.Rollback()
tagRaw, err := upsertTag(ctx, tx, upsert)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
tag := tagRaw.toTag()
return tag, nil
}
func (s *Store) FindTagList(ctx context.Context, find *api.TagFind) ([]*api.Tag, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
tagRawList, err := findTagList(ctx, tx, find)
if err != nil {
return nil, err
}
list := []*api.Tag{}
for _, raw := range tagRawList {
list = append(list, raw.toTag())
}
return list, nil
}
func (s *Store) DeleteTag(ctx context.Context, delete *api.TagDelete) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return FormatError(err)
}
defer tx.Rollback()
if err := deleteTag(ctx, tx, delete); err != nil {
return FormatError(err)
}
if err := tx.Commit(); err != nil {
return FormatError(err)
}
return nil
}
func upsertTag(ctx context.Context, tx *sql.Tx, upsert *api.TagUpsert) (*tagRaw, error) {
query := ` query := `
INSERT INTO tag ( INSERT INTO tag (
name, creator_id name, creator_id
@ -90,22 +36,27 @@ func upsertTag(ctx context.Context, tx *sql.Tx, upsert *api.TagUpsert) (*tagRaw,
ON CONFLICT(name, creator_id) DO UPDATE ON CONFLICT(name, creator_id) DO UPDATE
SET SET
name = EXCLUDED.name name = EXCLUDED.name
RETURNING name, creator_id
` `
var tagRaw tagRaw if _, err := tx.ExecContext(ctx, query, upsert.Name, upsert.CreatorID); err != nil {
if err := tx.QueryRowContext(ctx, query, upsert.Name, upsert.CreatorID).Scan( return nil, err
&tagRaw.Name,
&tagRaw.CreatorID,
); err != nil {
return nil, FormatError(err)
} }
return &tagRaw, nil if err := tx.Commit(); err != nil {
return nil, err
}
tag := upsert
return tag, nil
} }
func findTagList(ctx context.Context, tx *sql.Tx, find *api.TagFind) ([]*tagRaw, error) { func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) {
where, args := []string{"creator_id = ?"}, []any{find.CreatorID} tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
where, args := []string{"creator_id = ?"}, []any{find.CreatorID}
query := ` query := `
SELECT SELECT
name, name,
@ -120,38 +71,48 @@ func findTagList(ctx context.Context, tx *sql.Tx, find *api.TagFind) ([]*tagRaw,
} }
defer rows.Close() defer rows.Close()
tagRawList := make([]*tagRaw, 0) list := []*Tag{}
for rows.Next() { for rows.Next() {
var tagRaw tagRaw tag := &Tag{}
if err := rows.Scan( if err := rows.Scan(
&tagRaw.Name, &tag.Name,
&tagRaw.CreatorID, &tag.CreatorID,
); err != nil { ); err != nil {
return nil, FormatError(err) return nil, FormatError(err)
} }
tagRawList = append(tagRawList, &tagRaw) list = append(list, tag)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, FormatError(err) return nil, FormatError(err)
} }
return tagRawList, nil return list, nil
} }
func deleteTag(ctx context.Context, tx *sql.Tx, delete *api.TagDelete) error { func (s *Store) DeleteTag(ctx context.Context, delete *DeleteTag) error {
where, args := []string{"name = ?", "creator_id = ?"}, []any{delete.Name, delete.CreatorID} tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return FormatError(err)
}
defer tx.Rollback()
stmt := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ") where, args := []string{"name = ?", "creator_id = ?"}, []any{delete.Name, delete.CreatorID}
result, err := tx.ExecContext(ctx, stmt, args...) query := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ")
result, err := tx.ExecContext(ctx, query, args...)
if err != nil { if err != nil {
return FormatError(err) return FormatError(err)
} }
rows, _ := result.RowsAffected() rows, _ := result.RowsAffected()
if rows == 0 { if rows == 0 {
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("tag not found")} return fmt.Errorf("tag not found")
}
if err := tx.Commit(); err != nil {
// Prevent linter warning.
return err
} }
return nil return nil

View File

@ -3,32 +3,29 @@ package store
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt" "errors"
"strings" "strings"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
) )
// Role is the type of a role. // Role is the type of a role.
type Role string type Role string
const ( const (
// Host is the HOST role. // RoleHost is the HOST role.
Host Role = "HOST" RoleHost Role = "HOST"
// Admin is the ADMIN role. // RoleAdmin is the ADMIN role.
Admin Role = "ADMIN" RoleAdmin Role = "ADMIN"
// NormalUser is the USER role. // RoleUser is the USER role.
NormalUser Role = "USER" RoleUser Role = "USER"
) )
func (e Role) String() string { func (e Role) String() string {
switch e { switch e {
case Host: case RoleHost:
return "HOST" return "HOST"
case Admin: case RoleAdmin:
return "ADMIN" return "ADMIN"
case NormalUser: case RoleUser:
return "USER" return "USER"
} }
return "USER" return "USER"
@ -81,7 +78,11 @@ type FindUser struct {
OpenID *string OpenID *string
} }
func (s *Store) CreateUserV1(ctx context.Context, create *User) (*User, error) { type DeleteUser struct {
ID int
}
func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -120,7 +121,7 @@ func (s *Store) CreateUserV1(ctx context.Context, create *User) (*User, error) {
return nil, err return nil, err
} }
user := create user := create
s.userV1Cache.Store(user.ID, user) s.userCache.Store(user.ID, user)
return user, nil return user, nil
} }
@ -185,7 +186,7 @@ func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, erro
return nil, err return nil, err
} }
s.userV1Cache.Store(user.ID, user) s.userCache.Store(user.ID, user)
return user, nil return user, nil
} }
@ -202,15 +203,15 @@ func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error)
} }
for _, user := range list { for _, user := range list {
s.userV1Cache.Store(user.ID, user) s.userCache.Store(user.ID, user)
} }
return list, nil return list, nil
} }
func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) { func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) {
if find.ID != nil { if find.ID != nil {
if user, ok := s.userV1Cache.Load(*find.ID); ok { if cache, ok := s.userCache.Load(*find.ID); ok {
return user.(*User), nil return cache.(*User), nil
} }
} }
@ -228,10 +229,43 @@ func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) {
return nil, nil return nil, nil
} }
user := list[0] user := list[0]
s.userV1Cache.Store(user.ID, user) s.userCache.Store(user.ID, user)
return user, nil return user, nil
} }
func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
result, err := tx.ExecContext(ctx, `
DELETE FROM user WHERE id = ?
`, delete.ID)
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return errors.New("user not found")
}
if err := s.vacuumImpl(ctx, tx); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
s.userCache.Delete(delete.ID)
return nil
}
func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error) { func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error) {
where, args := []string{"1 = 1"}, []any{} where, args := []string{"1 = 1"}, []any{}
@ -304,342 +338,3 @@ func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error)
return list, nil return list, nil
} }
// userRaw is the store model for an User.
// Fields have exactly the same meanings as User.
type userRaw struct {
ID int
// Standard fields
RowStatus api.RowStatus
CreatedTs int64
UpdatedTs int64
// Domain specific fields
Username string
Role api.Role
Email string
Nickname string
PasswordHash string
OpenID string
AvatarURL string
}
func (raw *userRaw) toUser() *api.User {
return &api.User{
ID: raw.ID,
RowStatus: raw.RowStatus,
CreatedTs: raw.CreatedTs,
UpdatedTs: raw.UpdatedTs,
Username: raw.Username,
Role: raw.Role,
Email: raw.Email,
Nickname: raw.Nickname,
PasswordHash: raw.PasswordHash,
OpenID: raw.OpenID,
AvatarURL: raw.AvatarURL,
}
}
func (s *Store) CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
userRaw, err := createUser(ctx, tx, create)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
s.userCache.Store(userRaw.ID, userRaw)
user := userRaw.toUser()
return user, nil
}
func (s *Store) PatchUser(ctx context.Context, patch *api.UserPatch) (*api.User, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
userRaw, err := patchUser(ctx, tx, patch)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
s.userCache.Store(userRaw.ID, userRaw)
user := userRaw.toUser()
return user, nil
}
func (s *Store) FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
userRawList, err := findUserList(ctx, tx, find)
if err != nil {
return nil, err
}
list := []*api.User{}
for _, raw := range userRawList {
list = append(list, raw.toUser())
}
return list, nil
}
func (s *Store) FindUser(ctx context.Context, find *api.UserFind) (*api.User, error) {
if find.ID != nil {
if user, ok := s.userCache.Load(*find.ID); ok {
return user.(*userRaw).toUser(), nil
}
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
list, err := findUserList(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found user with filter %+v", find)}
}
userRaw := list[0]
s.userCache.Store(userRaw.ID, userRaw)
user := userRaw.toUser()
return user, nil
}
func (s *Store) DeleteUser(ctx context.Context, delete *api.UserDelete) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return FormatError(err)
}
defer tx.Rollback()
if err := deleteUser(ctx, tx, delete); err != nil {
return err
}
if err := s.vacuumImpl(ctx, tx); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
s.userCache.Delete(delete.ID)
return nil
}
func createUser(ctx context.Context, tx *sql.Tx, create *api.UserCreate) (*userRaw, error) {
query := `
INSERT INTO user (
username,
role,
email,
nickname,
password_hash,
open_id
)
VALUES (?, ?, ?, ?, ?, ?)
RETURNING id, username, role, email, nickname, password_hash, open_id, avatar_url, created_ts, updated_ts, row_status
`
var userRaw userRaw
if err := tx.QueryRowContext(ctx, query,
create.Username,
create.Role,
create.Email,
create.Nickname,
create.PasswordHash,
create.OpenID,
).Scan(
&userRaw.ID,
&userRaw.Username,
&userRaw.Role,
&userRaw.Email,
&userRaw.Nickname,
&userRaw.PasswordHash,
&userRaw.OpenID,
&userRaw.AvatarURL,
&userRaw.CreatedTs,
&userRaw.UpdatedTs,
&userRaw.RowStatus,
); err != nil {
return nil, FormatError(err)
}
return &userRaw, nil
}
func patchUser(ctx context.Context, tx *sql.Tx, patch *api.UserPatch) (*userRaw, error) {
set, args := []string{}, []any{}
if v := patch.UpdatedTs; v != nil {
set, args = append(set, "updated_ts = ?"), append(args, *v)
}
if v := patch.RowStatus; v != nil {
set, args = append(set, "row_status = ?"), append(args, *v)
}
if v := patch.Username; v != nil {
set, args = append(set, "username = ?"), append(args, *v)
}
if v := patch.Email; v != nil {
set, args = append(set, "email = ?"), append(args, *v)
}
if v := patch.Nickname; v != nil {
set, args = append(set, "nickname = ?"), append(args, *v)
}
if v := patch.AvatarURL; v != nil {
set, args = append(set, "avatar_url = ?"), append(args, *v)
}
if v := patch.PasswordHash; v != nil {
set, args = append(set, "password_hash = ?"), append(args, *v)
}
if v := patch.OpenID; v != nil {
set, args = append(set, "open_id = ?"), append(args, *v)
}
args = append(args, patch.ID)
query := `
UPDATE user
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING id, username, role, email, nickname, password_hash, open_id, avatar_url, created_ts, updated_ts, row_status
`
var userRaw userRaw
if err := tx.QueryRowContext(ctx, query, args...).Scan(
&userRaw.ID,
&userRaw.Username,
&userRaw.Role,
&userRaw.Email,
&userRaw.Nickname,
&userRaw.PasswordHash,
&userRaw.OpenID,
&userRaw.AvatarURL,
&userRaw.CreatedTs,
&userRaw.UpdatedTs,
&userRaw.RowStatus,
); err != nil {
return nil, FormatError(err)
}
return &userRaw, nil
}
func findUserList(ctx context.Context, tx *sql.Tx, find *api.UserFind) ([]*userRaw, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := find.Username; v != nil {
where, args = append(where, "username = ?"), append(args, *v)
}
if v := find.Role; v != nil {
where, args = append(where, "role = ?"), append(args, *v)
}
if v := find.Email; v != nil {
where, args = append(where, "email = ?"), append(args, *v)
}
if v := find.Nickname; v != nil {
where, args = append(where, "nickname = ?"), append(args, *v)
}
if v := find.OpenID; v != nil {
where, args = append(where, "open_id = ?"), append(args, *v)
}
query := `
SELECT
id,
username,
role,
email,
nickname,
password_hash,
open_id,
avatar_url,
created_ts,
updated_ts,
row_status
FROM user
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY created_ts DESC, row_status DESC
`
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, FormatError(err)
}
defer rows.Close()
userRawList := make([]*userRaw, 0)
for rows.Next() {
var userRaw userRaw
if err := rows.Scan(
&userRaw.ID,
&userRaw.Username,
&userRaw.Role,
&userRaw.Email,
&userRaw.Nickname,
&userRaw.PasswordHash,
&userRaw.OpenID,
&userRaw.AvatarURL,
&userRaw.CreatedTs,
&userRaw.UpdatedTs,
&userRaw.RowStatus,
); err != nil {
return nil, FormatError(err)
}
userRawList = append(userRawList, &userRaw)
}
if err := rows.Err(); err != nil {
return nil, FormatError(err)
}
return userRawList, nil
}
func deleteUser(ctx context.Context, tx *sql.Tx, delete *api.UserDelete) error {
result, err := tx.ExecContext(ctx, `
DELETE FROM user WHERE id = ?
`, delete.ID)
if err != nil {
return FormatError(err)
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("user not found")}
}
return nil
}

View File

@ -20,7 +20,7 @@ type FindUserSetting struct {
func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error) { func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
@ -41,14 +41,14 @@ func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*Us
} }
userSetting := upsert userSetting := upsert
s.userSettingCache.Store(getUserSettingCacheKeyV1(userSetting.UserID, userSetting.Key), userSetting) s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting)
return userSetting, nil return userSetting, nil
} }
func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error) { func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
@ -58,21 +58,21 @@ func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]
} }
for _, userSetting := range userSettingList { for _, userSetting := range userSettingList {
s.userSettingCache.Store(getUserSettingCacheKeyV1(userSetting.UserID, userSetting.Key), userSetting) s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting)
} }
return userSettingList, nil return userSettingList, nil
} }
func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*UserSetting, error) { func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*UserSetting, error) {
if find.UserID != nil { if find.UserID != nil {
if cache, ok := s.userSettingCache.Load(getUserSettingCacheKeyV1(*find.UserID, find.Key)); ok { if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key)); ok {
return cache.(*UserSetting), nil return cache.(*UserSetting), nil
} }
} }
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
@ -84,8 +84,9 @@ func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*Use
if len(list) == 0 { if len(list) == 0 {
return nil, nil return nil, nil
} }
userSetting := list[0] userSetting := list[0]
s.userSettingCache.Store(getUserSettingCacheKeyV1(userSetting.UserID, userSetting.Key), userSetting) s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting)
return userSetting, nil return userSetting, nil
} }
@ -108,7 +109,7 @@ func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([
WHERE ` + strings.Join(where, " AND ") WHERE ` + strings.Join(where, " AND ")
rows, err := tx.QueryContext(ctx, query, args...) rows, err := tx.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer rows.Close() defer rows.Close()
@ -120,13 +121,13 @@ func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([
&userSetting.Key, &userSetting.Key,
&userSetting.Value, &userSetting.Value,
); err != nil { ); err != nil {
return nil, FormatError(err) return nil, err
} }
userSettingList = append(userSettingList, &userSetting) userSettingList = append(userSettingList, &userSetting)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, FormatError(err) return nil, err
} }
return userSettingList, nil return userSettingList, nil
@ -145,7 +146,7 @@ func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
)` )`
_, err := tx.ExecContext(ctx, stmt) _, err := tx.ExecContext(ctx, stmt)
if err != nil { if err != nil {
return FormatError(err) return err
} }
return nil return nil

View File

@ -8,7 +8,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/usememos/memos/api"
apiv1 "github.com/usememos/memos/api/v1" apiv1 "github.com/usememos/memos/api/v1"
) )
@ -27,7 +26,7 @@ func TestAuthServer(t *testing.T) {
require.Equal(t, signup.Username, user.Username) require.Equal(t, signup.Username, user.Username)
} }
func (s *TestingServer) postAuthSignup(signup *apiv1.SignUp) (*api.User, error) { func (s *TestingServer) postAuthSignup(signup *apiv1.SignUp) (*apiv1.User, error) {
rawData, err := json.Marshal(&signup) rawData, err := json.Marshal(&signup)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to marshal signup") return nil, errors.Wrap(err, "failed to marshal signup")
@ -44,7 +43,7 @@ func (s *TestingServer) postAuthSignup(signup *apiv1.SignUp) (*api.User, error)
return nil, errors.Wrap(err, "fail to read response body") return nil, errors.Wrap(err, "fail to read response body")
} }
user := &api.User{} user := &apiv1.User{}
if err = json.Unmarshal(buf.Bytes(), user); err != nil { if err = json.Unmarshal(buf.Bytes(), user); err != nil {
return nil, errors.Wrap(err, "fail to unmarshal post signup response") return nil, errors.Wrap(err, "fail to unmarshal post signup response")
} }

View File

@ -8,7 +8,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/usememos/memos/api"
apiv1 "github.com/usememos/memos/api/v1" apiv1 "github.com/usememos/memos/api/v1"
) )
@ -20,7 +19,7 @@ func TestSystemServer(t *testing.T) {
status, err := s.getSystemStatus() status, err := s.getSystemStatus()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, (*api.User)(nil), status.Host) require.Equal(t, (*apiv1.User)(nil), status.Host)
signup := &apiv1.SignUp{ signup := &apiv1.SignUp{
Username: "testuser", Username: "testuser",
@ -36,8 +35,8 @@ func TestSystemServer(t *testing.T) {
require.Equal(t, user.Username, status.Host.Username) require.Equal(t, user.Username, status.Host.Username)
} }
func (s *TestingServer) getSystemStatus() (*api.SystemStatus, error) { func (s *TestingServer) getSystemStatus() (*apiv1.SystemStatus, error) {
body, err := s.get("/api/status", nil) body, err := s.get("/api/v1/status", nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -48,12 +47,9 @@ func (s *TestingServer) getSystemStatus() (*api.SystemStatus, error) {
return nil, errors.Wrap(err, "fail to read response body") return nil, errors.Wrap(err, "fail to read response body")
} }
type SystemStatusResponse struct { systemStatus := &apiv1.SystemStatus{}
Data *api.SystemStatus `json:"data"` if err = json.Unmarshal(buf.Bytes(), systemStatus); err != nil {
}
res := new(SystemStatusResponse)
if err = json.Unmarshal(buf.Bytes(), res); err != nil {
return nil, errors.Wrap(err, "fail to unmarshal get system status response") return nil, errors.Wrap(err, "fail to unmarshal get system status response")
} }
return res.Data, nil return systemStatus, nil
} }

View File

@ -14,7 +14,7 @@ func TestIdentityProviderStore(t *testing.T) {
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
createdIDP, err := ts.CreateIdentityProvider(ctx, &store.IdentityProvider{ createdIDP, err := ts.CreateIdentityProvider(ctx, &store.IdentityProvider{
Name: "GitHub OAuth", Name: "GitHub OAuth",
Type: store.IdentityProviderOAuth2, Type: store.IdentityProviderOAuth2Type,
IdentifierFilter: "", IdentifierFilter: "",
Config: &store.IdentityProviderConfig{ Config: &store.IdentityProviderConfig{
OAuth2Config: &store.IdentityProviderOAuth2Config{ OAuth2Config: &store.IdentityProviderOAuth2Config{

View File

@ -6,30 +6,30 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/usememos/memos/api" apiv1 "github.com/usememos/memos/api/v1"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
func TestSystemSettingStore(t *testing.T) { func TestSystemSettingStore(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
_, err := ts.UpsertSystemSetting(ctx, &api.SystemSettingUpsert{ _, err := ts.UpsertSystemSetting(ctx, &store.SystemSetting{
Name: api.SystemSettingServerIDName, Name: apiv1.SystemSettingServerIDName.String(),
Value: "test_server_id", Value: "test_server_id",
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = ts.UpsertSystemSetting(ctx, &api.SystemSettingUpsert{ _, err = ts.UpsertSystemSetting(ctx, &store.SystemSetting{
Name: api.SystemSettingSecretSessionName, Name: apiv1.SystemSettingSecretSessionName.String(),
Value: "test_secret_session_name", Value: "test_secret_session_name",
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = ts.UpsertSystemSetting(ctx, &api.SystemSettingUpsert{ _, err = ts.UpsertSystemSetting(ctx, &store.SystemSetting{
Name: api.SystemSettingAllowSignUpName, Name: apiv1.SystemSettingAllowSignUpName.String(),
Value: "true", Value: "true",
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = ts.UpsertSystemSetting(ctx, &api.SystemSettingUpsert{ _, err = ts.UpsertSystemSetting(ctx, &store.SystemSetting{
Name: api.SystemSettingLocalStoragePathName, Name: apiv1.SystemSettingLocalStoragePathName.String(),
Value: "/tmp/memos", Value: "/tmp/memos",
}) })
require.NoError(t, err) require.NoError(t, err)

View File

@ -13,15 +13,21 @@ func TestUserSettingStore(t *testing.T) {
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts) user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err) require.NoError(t, err)
_, err = ts.UpsertUserSetting(ctx, &store.UserSetting{ testSetting, err := ts.UpsertUserSetting(ctx, &store.UserSetting{
UserID: user.ID, UserID: user.ID,
Key: "test_key", Key: "test_key",
Value: "test_value", Value: "test_value",
}) })
require.NoError(t, err) require.NoError(t, err)
localeSetting, err := ts.UpsertUserSetting(ctx, &store.UserSetting{
UserID: user.ID,
Key: "locale",
Value: "zh",
})
require.NoError(t, err)
list, err := ts.ListUserSettings(ctx, &store.FindUserSetting{}) list, err := ts.ListUserSettings(ctx, &store.FindUserSetting{})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(list)) require.Equal(t, 2, len(list))
require.Equal(t, "test_key", list[0].Key) require.Equal(t, testSetting, list[0])
require.Equal(t, "test_value", list[0].Value) require.Equal(t, localeSetting, list[1])
} }

View File

@ -5,7 +5,6 @@ import (
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/usememos/memos/api"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -18,7 +17,7 @@ func TestUserStore(t *testing.T) {
users, err := ts.ListUsers(ctx, &store.FindUser{}) users, err := ts.ListUsers(ctx, &store.FindUser{})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(users)) require.Equal(t, 1, len(users))
require.Equal(t, store.Host, users[0].Role) require.Equal(t, store.RoleHost, users[0].Role)
require.Equal(t, user, users[0]) require.Equal(t, user, users[0])
userPatchNickname := "test_nickname_2" userPatchNickname := "test_nickname_2"
userPatch := &store.UpdateUser{ userPatch := &store.UpdateUser{
@ -28,7 +27,7 @@ func TestUserStore(t *testing.T) {
user, err = ts.UpdateUser(ctx, userPatch) user, err = ts.UpdateUser(ctx, userPatch)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, userPatchNickname, user.Nickname) require.Equal(t, userPatchNickname, user.Nickname)
err = ts.DeleteUser(ctx, &api.UserDelete{ err = ts.DeleteUser(ctx, &store.DeleteUser{
ID: user.ID, ID: user.ID,
}) })
require.NoError(t, err) require.NoError(t, err)
@ -40,7 +39,7 @@ func TestUserStore(t *testing.T) {
func createTestingHostUser(ctx context.Context, ts *store.Store) (*store.User, error) { func createTestingHostUser(ctx context.Context, ts *store.Store) (*store.User, error) {
userCreate := &store.User{ userCreate := &store.User{
Username: "test", Username: "test",
Role: store.Host, Role: store.RoleHost,
Email: "test@test.com", Email: "test@test.com",
Nickname: "test_nickname", Nickname: "test_nickname",
OpenID: "test_open_id", OpenID: "test_open_id",
@ -50,6 +49,6 @@ func createTestingHostUser(ctx context.Context, ts *store.Store) (*store.User, e
return nil, err return nil, err
} }
userCreate.PasswordHash = string(passwordHash) userCreate.PasswordHash = string(passwordHash)
user, err := ts.CreateUserV1(ctx, userCreate) user, err := ts.CreateUser(ctx, userCreate)
return user, err return user, err
} }

View File

@ -31,7 +31,7 @@ const CreateTagDialog: React.FC<Props> = (props: Props) => {
useEffect(() => { useEffect(() => {
getTagSuggestionList().then(({ data }) => { getTagSuggestionList().then(({ data }) => {
setSuggestTagNameList(data.data.filter((tag) => validateTagName(tag))); setSuggestTagNameList(data.filter((tag) => validateTagName(tag)));
}); });
}, [tagNameList]); }, [tagNameList]);

View File

@ -29,7 +29,7 @@ const PreferencesSection = () => {
}, []); }, []);
const fetchUserList = async () => { const fetchUserList = async () => {
const { data } = (await api.getUserList()).data; const { data } = await api.getUserList();
setUserList(data); setUserList(data);
}; };

View File

@ -39,7 +39,7 @@ const SystemSection = () => {
}, []); }, []);
useEffect(() => { useEffect(() => {
api.getSystemSetting().then(({ data: { data: systemSettings } }) => { api.getSystemSetting().then(({ data: systemSettings }) => {
const telegramBotSetting = systemSettings.find((setting) => setting.name === "telegram-bot-token"); const telegramBotSetting = systemSettings.find((setting) => setting.name === "telegram-bot-token");
if (telegramBotSetting) { if (telegramBotSetting) {
setTelegramBotToken(telegramBotSetting.value); setTelegramBotToken(telegramBotSetting.value);

View File

@ -7,19 +7,19 @@ type ResponseObject<T> = {
}; };
export function getSystemStatus() { export function getSystemStatus() {
return axios.get<ResponseObject<SystemStatus>>("/api/status"); return axios.get<SystemStatus>("/api/v1/status");
} }
export function getSystemSetting() { export function getSystemSetting() {
return axios.get<ResponseObject<SystemSetting[]>>("/api/system/setting"); return axios.get<SystemSetting[]>("/api/v1/system/setting");
} }
export function upsertSystemSetting(systemSetting: SystemSetting) { export function upsertSystemSetting(systemSetting: SystemSetting) {
return axios.post<ResponseObject<SystemSetting>>("/api/system/setting", systemSetting); return axios.post<SystemSetting>("/api/v1/system/setting", systemSetting);
} }
export function vacuumDatabase() { export function vacuumDatabase() {
return axios.post("/api/system/vacuum"); return axios.post("/api/v1/system/vacuum");
} }
export function signin(username: string, password: string) { export function signin(username: string, password: string) {
@ -49,31 +49,31 @@ export function signout() {
} }
export function createUser(userCreate: UserCreate) { export function createUser(userCreate: UserCreate) {
return axios.post<ResponseObject<User>>("/api/user", userCreate); return axios.post<User>("/api/v1/user", userCreate);
} }
export function getMyselfUser() { export function getMyselfUser() {
return axios.get<ResponseObject<User>>("/api/user/me"); return axios.get<User>("/api/v1/user/me");
} }
export function getUserList() { export function getUserList() {
return axios.get<ResponseObject<User[]>>("/api/user"); return axios.get<User[]>("/api/v1/user");
} }
export function getUserById(id: number) { export function getUserById(id: number) {
return axios.get<ResponseObject<User>>(`/api/user/${id}`); return axios.get<User>(`/api/v1/user/${id}`);
} }
export function upsertUserSetting(upsert: UserSettingUpsert) { export function upsertUserSetting(upsert: UserSettingUpsert) {
return axios.post<ResponseObject<UserSetting>>(`/api/user/setting`, upsert); return axios.post<UserSetting>(`/api/v1/user/setting`, upsert);
} }
export function patchUser(userPatch: UserPatch) { export function patchUser(userPatch: UserPatch) {
return axios.patch<ResponseObject<User>>(`/api/user/${userPatch.id}`, userPatch); return axios.patch<User>(`/api/v1/user/${userPatch.id}`, userPatch);
} }
export function deleteUser(userDelete: UserDelete) { export function deleteUser(userDelete: UserDelete) {
return axios.delete(`/api/user/${userDelete.id}`); return axios.delete(`/api/v1/user/${userDelete.id}`);
} }
export function getAllMemos(memoFind?: MemoFind) { export function getAllMemos(memoFind?: MemoFind) {
@ -145,19 +145,19 @@ export function getShortcutList(shortcutFind?: ShortcutFind) {
if (shortcutFind?.creatorId) { if (shortcutFind?.creatorId) {
queryList.push(`creatorId=${shortcutFind.creatorId}`); queryList.push(`creatorId=${shortcutFind.creatorId}`);
} }
return axios.get<ResponseObject<Shortcut[]>>(`/api/shortcut?${queryList.join("&")}`); return axios.get<Shortcut[]>(`/api/v1/shortcut?${queryList.join("&")}`);
} }
export function createShortcut(shortcutCreate: ShortcutCreate) { export function createShortcut(shortcutCreate: ShortcutCreate) {
return axios.post<ResponseObject<Shortcut>>("/api/shortcut", shortcutCreate); return axios.post<Shortcut>("/api/v1/shortcut", shortcutCreate);
} }
export function patchShortcut(shortcutPatch: ShortcutPatch) { export function patchShortcut(shortcutPatch: ShortcutPatch) {
return axios.patch<ResponseObject<Shortcut>>(`/api/shortcut/${shortcutPatch.id}`, shortcutPatch); return axios.patch<Shortcut>(`/api/v1/shortcut/${shortcutPatch.id}`, shortcutPatch);
} }
export function deleteShortcutById(shortcutId: ShortcutId) { export function deleteShortcutById(shortcutId: ShortcutId) {
return axios.delete(`/api/shortcut/${shortcutId}`); return axios.delete(`/api/v1/shortcut/${shortcutId}`);
} }
export function getResourceList() { export function getResourceList() {
@ -210,21 +210,21 @@ export function getTagList(tagFind?: TagFind) {
if (tagFind?.creatorId) { if (tagFind?.creatorId) {
queryList.push(`creatorId=${tagFind.creatorId}`); queryList.push(`creatorId=${tagFind.creatorId}`);
} }
return axios.get<ResponseObject<string[]>>(`/api/tag?${queryList.join("&")}`); return axios.get<string[]>(`/api/v1/tag?${queryList.join("&")}`);
} }
export function getTagSuggestionList() { export function getTagSuggestionList() {
return axios.get<ResponseObject<string[]>>(`/api/tag/suggestion`); return axios.get<string[]>(`/api/v1/tag/suggestion`);
} }
export function upsertTag(tagName: string) { export function upsertTag(tagName: string) {
return axios.post<ResponseObject<string>>(`/api/tag`, { return axios.post<string>(`/api/v1/tag`, {
name: tagName, name: tagName,
}); });
} }
export function deleteTag(tagName: string) { export function deleteTag(tagName: string) {
return axios.post<ResponseObject<boolean>>(`/api/tag/delete`, { return axios.post(`/api/v1/tag/delete`, {
name: tagName, name: tagName,
}); });
} }

View File

@ -35,7 +35,7 @@ export const initialGlobalState = async () => {
defaultGlobalState.appearance = storageAppearance; defaultGlobalState.appearance = storageAppearance;
} }
const { data } = (await api.getSystemStatus()).data; const { data } = await api.getSystemStatus();
if (data) { if (data) {
const customizedProfile = data.customizedProfile; const customizedProfile = data.customizedProfile;
defaultGlobalState.systemStatus = { defaultGlobalState.systemStatus = {
@ -68,7 +68,7 @@ export const useGlobalStore = () => {
return state.systemStatus.profile.mode !== "prod"; return state.systemStatus.profile.mode !== "prod";
}, },
fetchSystemStatus: async () => { fetchSystemStatus: async () => {
const { data: systemStatus } = (await api.getSystemStatus()).data; const { data: systemStatus } = await api.getSystemStatus();
store.dispatch(setGlobalState({ systemStatus: systemStatus })); store.dispatch(setGlobalState({ systemStatus: systemStatus }));
return systemStatus; return systemStatus;
}, },

View File

@ -18,7 +18,7 @@ export const useShortcutStore = () => {
return store.getState().shortcut; return store.getState().shortcut;
}, },
getMyAllShortcuts: async () => { getMyAllShortcuts: async () => {
const { data } = (await api.getShortcutList()).data; const { data } = await api.getShortcutList();
const shortcuts = data.map((s) => convertResponseModelShortcut(s)); const shortcuts = data.map((s) => convertResponseModelShortcut(s));
store.dispatch(setShortcuts(shortcuts)); store.dispatch(setShortcuts(shortcuts));
}, },
@ -32,12 +32,12 @@ export const useShortcutStore = () => {
return null; return null;
}, },
createShortcut: async (shortcutCreate: ShortcutCreate) => { createShortcut: async (shortcutCreate: ShortcutCreate) => {
const { data } = (await api.createShortcut(shortcutCreate)).data; const { data } = await api.createShortcut(shortcutCreate);
const shortcut = convertResponseModelShortcut(data); const shortcut = convertResponseModelShortcut(data);
store.dispatch(createShortcut(shortcut)); store.dispatch(createShortcut(shortcut));
}, },
patchShortcut: async (shortcutPatch: ShortcutPatch) => { patchShortcut: async (shortcutPatch: ShortcutPatch) => {
const { data } = (await api.patchShortcut(shortcutPatch)).data; const { data } = await api.patchShortcut(shortcutPatch);
const shortcut = convertResponseModelShortcut(data); const shortcut = convertResponseModelShortcut(data);
store.dispatch(patchShortcut(shortcut)); store.dispatch(patchShortcut(shortcut));
}, },

View File

@ -16,7 +16,7 @@ export const useTagStore = () => {
if (userStore.isVisitorMode()) { if (userStore.isVisitorMode()) {
tagFind.creatorId = userStore.getUserIdFromPath(); tagFind.creatorId = userStore.getUserIdFromPath();
} }
const { data } = (await api.getTagList(tagFind)).data; const { data } = await api.getTagList(tagFind);
store.dispatch(setTags(data)); store.dispatch(setTags(data));
}, },
upsertTag: async (tagName: string) => { upsertTag: async (tagName: string) => {

View File

@ -59,7 +59,7 @@ export const initialUserState = async () => {
store.dispatch(setHost(convertResponseModelUser(systemStatus.host))); store.dispatch(setHost(convertResponseModelUser(systemStatus.host)));
} }
const { data } = (await api.getMyselfUser()).data; const { data } = await api.getMyselfUser();
if (data) { if (data) {
const user = convertResponseModelUser(data); const user = convertResponseModelUser(data);
store.dispatch(setUser(user)); store.dispatch(setUser(user));
@ -83,7 +83,7 @@ const getUserIdFromPath = () => {
}; };
const doSignIn = async () => { const doSignIn = async () => {
const { data: user } = (await api.getMyselfUser()).data; const { data: user } = await api.getMyselfUser();
if (user) { if (user) {
store.dispatch(setUser(convertResponseModelUser(user))); store.dispatch(setUser(convertResponseModelUser(user)));
} else { } else {
@ -120,7 +120,7 @@ export const useUserStore = () => {
} }
}, },
getUserById: async (userId: UserId) => { getUserById: async (userId: UserId) => {
const { data } = (await api.getUserById(userId)).data; const { data } = await api.getUserById(userId);
if (data) { if (data) {
const user = convertResponseModelUser(data); const user = convertResponseModelUser(data);
store.dispatch(setUserById(user)); store.dispatch(setUserById(user));
@ -141,7 +141,7 @@ export const useUserStore = () => {
store.dispatch(patchUser({ localSetting })); store.dispatch(patchUser({ localSetting }));
}, },
patchUser: async (userPatch: UserPatch): Promise<void> => { patchUser: async (userPatch: UserPatch): Promise<void> => {
const { data } = (await api.patchUser(userPatch)).data; const { data } = await api.patchUser(userPatch);
if (userPatch.id === store.getState().user.user?.id) { if (userPatch.id === store.getState().user.user?.id) {
const user = convertResponseModelUser(data); const user = convertResponseModelUser(data);
store.dispatch(patchUser(user)); store.dispatch(patchUser(user));

View File

@ -12,11 +12,6 @@ interface CustomizedProfile {
externalUrl: string; externalUrl: string;
} }
interface OpenAIConfig {
key: string;
host: string;
}
interface SystemStatus { interface SystemStatus {
host?: User; host?: User;
profile: Profile; profile: Profile;