chore: move api directory

This commit is contained in:
Steven
2024-02-29 01:16:43 +08:00
parent 1aa75847d6
commit 3e50bee7da
46 changed files with 13 additions and 13 deletions

View File

@@ -0,0 +1,63 @@
package auth
import (
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
)
const (
// issuer is the issuer of the jwt token.
Issuer = "memos"
// Signing key section. For now, this is only used for signing, not for verifying since we only
// have 1 version. But it will be used to maintain backward compatibility if we change the signing mechanism.
KeyID = "v1"
// AccessTokenAudienceName is the audience name of the access token.
AccessTokenAudienceName = "user.access-token"
AccessTokenDuration = 7 * 24 * time.Hour
// CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user
// cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt.
CookieExpDuration = AccessTokenDuration - 1*time.Minute
// AccessTokenCookieName is the cookie name of access token.
AccessTokenCookieName = "memos.access-token"
)
type ClaimsMessage struct {
Name string `json:"name"`
jwt.RegisteredClaims
}
// GenerateAccessToken generates an access token.
func GenerateAccessToken(username string, userID int32, expirationTime time.Time, secret []byte) (string, error) {
return generateToken(username, userID, AccessTokenAudienceName, expirationTime, secret)
}
// generateToken generates a jwt token.
func generateToken(username string, userID int32, audience string, expirationTime time.Time, secret []byte) (string, error) {
registeredClaims := jwt.RegisteredClaims{
Issuer: Issuer,
Audience: jwt.ClaimStrings{audience},
IssuedAt: jwt.NewNumericDate(time.Now()),
Subject: fmt.Sprint(userID),
}
if !expirationTime.IsZero() {
registeredClaims.ExpiresAt = jwt.NewNumericDate(expirationTime)
}
// Declare the token with the HS256 algorithm used for signing, and the claims.
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &ClaimsMessage{
Name: username,
RegisteredClaims: registeredClaims,
})
token.Header["kid"] = KeyID
// Create the JWT string.
tokenString, err := token.SignedString(secret)
if err != nil {
return "", err
}
return tokenString, nil
}

368
server/route/api/v1/auth.go Normal file
View File

@@ -0,0 +1,368 @@
package v1
import (
"context"
"encoding/json"
"fmt"
"net/http"
"regexp"
"strings"
"time"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/idp"
"github.com/usememos/memos/plugin/idp/oauth2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/route/api/auth"
"github.com/usememos/memos/store"
)
type SignIn struct {
Username string `json:"username"`
Password string `json:"password"`
Remember bool `json:"remember"`
}
type SSOSignIn struct {
IdentityProviderID int32 `json:"identityProviderId"`
Code string `json:"code"`
RedirectURI string `json:"redirectUri"`
}
type SignUp struct {
Username string `json:"username"`
Password string `json:"password"`
}
func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
g.POST("/auth/signin", s.SignIn)
g.POST("/auth/signin/sso", s.SignInSSO)
g.POST("/auth/signout", s.SignOut)
g.POST("/auth/signup", s.SignUp)
}
// SignIn godoc
//
// @Summary Sign-in to memos.
// @Tags auth
// @Accept json
// @Produce json
// @Param body body SignIn true "Sign-in object"
// @Success 200 {object} store.User "User information"
// @Failure 400 {object} nil "Malformatted signin request"
// @Failure 401 {object} nil "Password login is deactivated | Incorrect login credentials, please try again"
// @Failure 403 {object} nil "User has been archived with username %s"
// @Failure 500 {object} nil "Failed to find system setting | Failed to unmarshal system setting | Incorrect login credentials, please try again | Failed to generate tokens | Failed to create activity"
// @Router /api/v1/auth/signin [POST]
func (s *APIV1Service) SignIn(c echo.Context) error {
ctx := c.Request().Context()
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting").SetInternal(err)
}
if workspaceGeneralSetting.DisallowPasswordLogin {
return echo.NewHTTPError(http.StatusUnauthorized, "password login is deactivated").SetInternal(err)
}
signin := &SignIn{}
if err := json.NewDecoder(c.Request().Body).Decode(signin); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &signin.Username,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Incorrect login credentials, please try again")
}
if user == nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect login credentials, please try again")
} else if user.RowStatus == store.Archived {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", signin.Username))
}
// Compare the stored hashed password, with the hashed version of the password that was received.
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(signin.Password)); err != nil {
// If the two passwords don't match, return a 401 status.
return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect login credentials, please try again")
}
var expireAt time.Time
// Set cookie expiration to 100 years to make it persistent.
cookieExp := time.Now().AddDate(100, 0, 0)
if !signin.Remember {
expireAt = time.Now().Add(auth.AccessTokenDuration)
cookieExp = time.Now().Add(auth.CookieExpDuration)
}
accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, expireAt, []byte(s.Secret))
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err)
}
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to upsert access token, err: %s", err)).SetInternal(err)
}
setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
userMessage := convertUserFromStore(user)
return c.JSON(http.StatusOK, userMessage)
}
// SignInSSO godoc
//
// @Summary Sign-in to memos using SSO.
// @Tags auth
// @Accept json
// @Produce json
// @Param body body SSOSignIn true "SSO sign-in object"
// @Success 200 {object} store.User "User information"
// @Failure 400 {object} nil "Malformatted signin request"
// @Failure 401 {object} nil "Access denied, identifier does not match the filter."
// @Failure 403 {object} nil "User has been archived with username {username}"
// @Failure 404 {object} nil "Identity provider not found"
// @Failure 500 {object} nil "Failed to find identity provider | Failed to create identity provider instance | Failed to exchange token | Failed to get user info | Failed to compile identifier filter | Incorrect login credentials, please try again | Failed to generate random password | Failed to generate password hash | Failed to create user | Failed to generate tokens | Failed to create activity"
// @Router /api/v1/auth/signin/sso [POST]
func (s *APIV1Service) SignInSSO(c echo.Context) error {
ctx := c.Request().Context()
signin := &SSOSignIn{}
if err := json.NewDecoder(c.Request().Body).Decode(signin); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
}
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &signin.IdentityProviderID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider").SetInternal(err)
}
if identityProvider == nil {
return echo.NewHTTPError(http.StatusNotFound, "Identity provider not found")
}
var userInfo *idp.IdentityProviderUserInfo
if identityProvider.Type == store.IdentityProviderOAuth2Type {
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.OAuth2Config)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider instance").SetInternal(err)
}
token, err := oauth2IdentityProvider.ExchangeToken(ctx, signin.RedirectURI, signin.Code)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to exchange token").SetInternal(err)
}
userInfo, err = oauth2IdentityProvider.UserInfo(token)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get user info").SetInternal(err)
}
}
identifierFilter := identityProvider.IdentifierFilter
if identifierFilter != "" {
identifierFilterRegex, err := regexp.Compile(identifierFilter)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compile identifier filter").SetInternal(err)
}
if !identifierFilterRegex.MatchString(userInfo.Identifier) {
return echo.NewHTTPError(http.StatusUnauthorized, "Access denied, identifier does not match the filter.").SetInternal(err)
}
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &userInfo.Identifier,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Incorrect login credentials, please try again")
}
if user == nil {
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting").SetInternal(err)
}
if workspaceGeneralSetting.DisallowSignup {
return echo.NewHTTPError(http.StatusUnauthorized, "signup is disabled").SetInternal(err)
}
userCreate := &store.User{
Username: userInfo.Identifier,
// The new signup user should be normal user by default.
Role: store.RoleUser,
Nickname: userInfo.DisplayName,
Email: userInfo.Email,
}
password, err := util.RandomString(20)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate random password").SetInternal(err)
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err)
}
userCreate.PasswordHash = string(passwordHash)
user, err = s.Store.CreateUser(ctx, userCreate)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
}
}
if user.RowStatus == store.Archived {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", userInfo.Identifier))
}
accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, time.Now().Add(auth.AccessTokenDuration), []byte(s.Secret))
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err)
}
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to upsert access token, err: %s", err)).SetInternal(err)
}
cookieExp := time.Now().Add(auth.CookieExpDuration)
setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
userMessage := convertUserFromStore(user)
return c.JSON(http.StatusOK, userMessage)
}
// SignOut godoc
//
// @Summary Sign-out from memos.
// @Tags auth
// @Produce json
// @Success 200 {boolean} true "Sign-out success"
// @Router /api/v1/auth/signout [POST]
func (s *APIV1Service) SignOut(c echo.Context) error {
accessToken := findAccessToken(c)
userID, _ := getUserIDFromAccessToken(accessToken, s.Secret)
err := removeAccessTokenAndCookies(c, s.Store, userID, accessToken)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to remove access token, err: %s", err)).SetInternal(err)
}
return c.JSON(http.StatusOK, true)
}
// SignUp godoc
//
// @Summary Sign-up to memos.
// @Tags auth
// @Accept json
// @Produce json
// @Param body body SignUp true "Sign-up object"
// @Success 200 {object} store.User "User information"
// @Failure 400 {object} nil "Malformatted signup request | Failed to find users"
// @Failure 401 {object} nil "signup is disabled"
// @Failure 403 {object} nil "Forbidden"
// @Failure 404 {object} nil "Not found"
// @Failure 500 {object} nil "Failed to find system setting | Failed to unmarshal system setting allow signup | Failed to generate password hash | Failed to create user | Failed to generate tokens | Failed to create activity"
// @Router /api/v1/auth/signup [POST]
func (s *APIV1Service) SignUp(c echo.Context) error {
ctx := c.Request().Context()
signup := &SignUp{}
if err := json.NewDecoder(c.Request().Body).Decode(signup); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signup request").SetInternal(err)
}
hostUserType := store.RoleHost
existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{
Role: &hostUserType,
})
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Failed to find users").SetInternal(err)
}
if !util.ResourceNameMatcher.MatchString(strings.ToLower(signup.Username)) {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid username %s", signup.Username)).SetInternal(err)
}
userCreate := &store.User{
Username: signup.Username,
// The new signup user should be normal user by default.
Role: store.RoleUser,
Nickname: signup.Username,
}
if len(existedHostUsers) == 0 {
// Change the default role to host if there is no host user.
userCreate.Role = store.RoleHost
} else {
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting").SetInternal(err)
}
if workspaceGeneralSetting.DisallowSignup {
return echo.NewHTTPError(http.StatusUnauthorized, "signup is disabled").SetInternal(err)
}
if workspaceGeneralSetting.DisallowPasswordLogin {
return echo.NewHTTPError(http.StatusUnauthorized, "password login is deactivated").SetInternal(err)
}
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(signup.Password), bcrypt.DefaultCost)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err)
}
userCreate.PasswordHash = string(passwordHash)
user, err := s.Store.CreateUser(ctx, userCreate)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
}
accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, time.Now().Add(auth.AccessTokenDuration), []byte(s.Secret))
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err)
}
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to upsert access token, err: %s", err)).SetInternal(err)
}
cookieExp := time.Now().Add(auth.CookieExpDuration)
setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
userMessage := convertUserFromStore(user)
return c.JSON(http.StatusOK, userMessage)
}
func (s *APIV1Service) UpsertAccessTokenToStore(ctx context.Context, user *store.User, accessToken string) error {
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return errors.Wrap(err, "failed to get user access tokens")
}
userAccessToken := storepb.AccessTokensUserSetting_AccessToken{
AccessToken: accessToken,
Description: "Account sign in",
}
userAccessTokens = append(userAccessTokens, &userAccessToken)
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
Value: &storepb.UserSetting_AccessTokens{
AccessTokens: &storepb.AccessTokensUserSetting{
AccessTokens: userAccessTokens,
},
},
}); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to upsert user setting, err: %s", err)).SetInternal(err)
}
return nil
}
// removeAccessTokenAndCookies removes the jwt token from the cookies.
func removeAccessTokenAndCookies(c echo.Context, s *store.Store, userID int32, token string) error {
err := s.RemoveUserAccessToken(c.Request().Context(), userID, token)
if err != nil {
return err
}
cookieExp := time.Now().Add(-1 * time.Hour)
setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp)
return nil
}
// setTokenCookie sets the token to the cookie.
func setTokenCookie(c echo.Context, name, token string, expiration time.Time) {
cookie := new(http.Cookie)
cookie.Name = name
cookie.Value = token
cookie.Expires = expiration
cookie.Path = "/"
// Http-only helps mitigate the risk of client side script accessing the protected cookie.
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteStrictMode
c.SetCookie(cookie)
}

View File

@@ -0,0 +1,15 @@
package v1
// RowStatus is the status for a row.
type RowStatus string
const (
// Normal is the status for a normal row.
Normal RowStatus = "NORMAL"
// Archived is the status for an archived row.
Archived RowStatus = "ARCHIVED"
)
func (r RowStatus) String() string {
return string(r)
}

3393
server/route/api/v1/docs.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,49 @@
package v1
import (
"fmt"
"net/http"
"net/url"
"github.com/labstack/echo/v4"
getter "github.com/usememos/memos/plugin/http-getter"
)
func (*APIV1Service) registerGetterPublicRoutes(g *echo.Group) {
// GET /get/image?url={url} - Get image.
g.GET("/get/image", GetImage)
}
// GetImage godoc
//
// @Summary Get GetImage from URL
// @Tags image-url
// @Produce GetImage/*
// @Param url query string true "Image url"
// @Success 200 {object} nil "Image"
// @Failure 400 {object} nil "Missing GetImage url | Wrong url | Failed to get GetImage url: %s"
// @Failure 500 {object} nil "Failed to write GetImage blob"
// @Router /o/get/GetImage [GET]
func GetImage(c echo.Context) error {
urlStr := c.QueryParam("url")
if urlStr == "" {
return echo.NewHTTPError(http.StatusBadRequest, "Missing image url")
}
if _, err := url.Parse(urlStr); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Wrong url").SetInternal(err)
}
image, err := getter.GetImage(urlStr)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Failed to get image url: %s", urlStr)).SetInternal(err)
}
c.Response().Writer.WriteHeader(http.StatusOK)
c.Response().Writer.Header().Set("Content-Type", image.Mediatype)
c.Response().Writer.Header().Set(echo.HeaderCacheControl, "max-age=31536000, immutable")
if _, err := c.Response().Writer.Write(image.Blob); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to write image blob").SetInternal(err)
}
return nil
}

349
server/route/api/v1/idp.go Normal file
View File

@@ -0,0 +1,349 @@
package v1
import (
"encoding/json"
"fmt"
"net/http"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/store"
)
type IdentityProviderType string
const (
IdentityProviderOAuth2Type IdentityProviderType = "OAUTH2"
)
func (t IdentityProviderType) String() string {
return string(t)
}
type IdentityProviderConfig struct {
OAuth2Config *IdentityProviderOAuth2Config `json:"oauth2Config"`
}
type IdentityProviderOAuth2Config struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
AuthURL string `json:"authUrl"`
TokenURL string `json:"tokenUrl"`
UserInfoURL string `json:"userInfoUrl"`
Scopes []string `json:"scopes"`
FieldMapping *FieldMapping `json:"fieldMapping"`
}
type FieldMapping struct {
Identifier string `json:"identifier"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
}
type IdentityProvider struct {
ID int32 `json:"id"`
Name string `json:"name"`
Type IdentityProviderType `json:"type"`
IdentifierFilter string `json:"identifierFilter"`
Config *IdentityProviderConfig `json:"config"`
}
type CreateIdentityProviderRequest struct {
Name string `json:"name"`
Type IdentityProviderType `json:"type"`
IdentifierFilter string `json:"identifierFilter"`
Config *IdentityProviderConfig `json:"config"`
}
type UpdateIdentityProviderRequest struct {
ID int32 `json:"-"`
Type IdentityProviderType `json:"type"`
Name *string `json:"name"`
IdentifierFilter *string `json:"identifierFilter"`
Config *IdentityProviderConfig `json:"config"`
}
func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
g.GET("/idp", s.GetIdentityProviderList)
g.POST("/idp", s.CreateIdentityProvider)
g.GET("/idp/:idpId", s.GetIdentityProvider)
g.PATCH("/idp/:idpId", s.UpdateIdentityProvider)
g.DELETE("/idp/:idpId", s.DeleteIdentityProvider)
}
// GetIdentityProviderList godoc
//
// @Summary Get a list of identity providers
// @Description *clientSecret is only available for host user
// @Tags idp
// @Produce json
// @Success 200 {object} []IdentityProvider "List of available identity providers"
// @Failure 500 {object} nil "Failed to find identity provider list | Failed to find user"
// @Router /api/v1/idp [GET]
func (s *APIV1Service) GetIdentityProviderList(c echo.Context) error {
ctx := c.Request().Context()
list, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider list").SetInternal(err)
}
userID, ok := c.Get(userIDContextKey).(int32)
isHostUser := false
if ok {
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role == store.RoleHost {
isHostUser = true
}
}
identityProviderList := []*IdentityProvider{}
for _, item := range list {
identityProvider := convertIdentityProviderFromStore(item)
// data desensitize
if !isHostUser {
identityProvider.Config.OAuth2Config.ClientSecret = ""
}
identityProviderList = append(identityProviderList, identityProvider)
}
return c.JSON(http.StatusOK, identityProviderList)
}
// CreateIdentityProvider godoc
//
// @Summary Create Identity Provider
// @Tags idp
// @Accept json
// @Produce json
// @Param body body CreateIdentityProviderRequest true "Identity provider information"
// @Success 200 {object} store.IdentityProvider "Identity provider information"
// @Failure 401 {object} nil "Missing user in session | Unauthorized"
// @Failure 400 {object} nil "Malformatted post identity provider request"
// @Failure 500 {object} nil "Failed to find user | Failed to create identity provider"
// @Router /api/v1/idp [POST]
func (s *APIV1Service) CreateIdentityProvider(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
identityProviderCreate := &CreateIdentityProviderRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(identityProviderCreate); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post identity provider request").SetInternal(err)
}
identityProvider, err := s.Store.CreateIdentityProvider(ctx, &store.IdentityProvider{
Name: identityProviderCreate.Name,
Type: store.IdentityProviderType(identityProviderCreate.Type),
IdentifierFilter: identityProviderCreate.IdentifierFilter,
Config: convertIdentityProviderConfigToStore(identityProviderCreate.Config),
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider").SetInternal(err)
}
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProvider))
}
// GetIdentityProvider godoc
//
// @Summary Get an identity provider by ID
// @Tags idp
// @Accept json
// @Produce json
// @Param idpId path int true "Identity provider ID"
// @Success 200 {object} store.IdentityProvider "Requested identity provider"
// @Failure 400 {object} nil "ID is not a number: %s"
// @Failure 401 {object} nil "Missing user in session | Unauthorized"
// @Failure 404 {object} nil "Identity provider not found"
// @Failure 500 {object} nil "Failed to find identity provider list | Failed to find user"
// @Router /api/v1/idp/{idpId} [GET]
func (s *APIV1Service) GetIdentityProvider(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
identityProviderID, err := util.ConvertStringToInt32(c.Param("idpId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
}
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &identityProviderID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get identity provider").SetInternal(err)
}
if identityProvider == nil {
return echo.NewHTTPError(http.StatusNotFound, "Identity provider not found")
}
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProvider))
}
// DeleteIdentityProvider godoc
//
// @Summary Delete an identity provider by ID
// @Tags idp
// @Accept json
// @Produce json
// @Param idpId path int true "Identity Provider ID"
// @Success 200 {boolean} true "Identity Provider deleted"
// @Failure 400 {object} nil "ID is not a number: %s | Malformatted patch identity provider request"
// @Failure 401 {object} nil "Missing user in session | Unauthorized"
// @Failure 500 {object} nil "Failed to find user | Failed to patch identity provider"
// @Router /api/v1/idp/{idpId} [DELETE]
func (s *APIV1Service) DeleteIdentityProvider(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
identityProviderID, err := util.ConvertStringToInt32(c.Param("idpId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
}
if err = s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: identityProviderID}); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete identity provider").SetInternal(err)
}
return c.JSON(http.StatusOK, true)
}
// UpdateIdentityProvider godoc
//
// @Summary Update an identity provider by ID
// @Tags idp
// @Accept json
// @Produce json
// @Param idpId path int true "Identity Provider ID"
// @Param body body UpdateIdentityProviderRequest true "Patched identity provider information"
// @Success 200 {object} store.IdentityProvider "Patched identity provider"
// @Failure 400 {object} nil "ID is not a number: %s | Malformatted patch identity provider request"
// @Failure 401 {object} nil "Missing user in session | Unauthorized
// @Failure 500 {object} nil "Failed to find user | Failed to patch identity provider"
// @Router /api/v1/idp/{idpId} [PATCH]
func (s *APIV1Service) UpdateIdentityProvider(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
identityProviderID, err := util.ConvertStringToInt32(c.Param("idpId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
}
identityProviderPatch := &UpdateIdentityProviderRequest{
ID: identityProviderID,
}
if err := json.NewDecoder(c.Request().Body).Decode(identityProviderPatch); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch identity provider request").SetInternal(err)
}
identityProvider, err := s.Store.UpdateIdentityProvider(ctx, &store.UpdateIdentityProvider{
ID: identityProviderPatch.ID,
Type: store.IdentityProviderType(identityProviderPatch.Type),
Name: identityProviderPatch.Name,
IdentifierFilter: identityProviderPatch.IdentifierFilter,
Config: convertIdentityProviderConfigToStore(identityProviderPatch.Config),
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch identity provider").SetInternal(err)
}
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProvider))
}
func convertIdentityProviderFromStore(identityProvider *store.IdentityProvider) *IdentityProvider {
return &IdentityProvider{
ID: identityProvider.ID,
Name: identityProvider.Name,
Type: IdentityProviderType(identityProvider.Type),
IdentifierFilter: identityProvider.IdentifierFilter,
Config: convertIdentityProviderConfigFromStore(identityProvider.Config),
}
}
func convertIdentityProviderConfigFromStore(config *store.IdentityProviderConfig) *IdentityProviderConfig {
return &IdentityProviderConfig{
OAuth2Config: &IdentityProviderOAuth2Config{
ClientID: config.OAuth2Config.ClientID,
ClientSecret: config.OAuth2Config.ClientSecret,
AuthURL: config.OAuth2Config.AuthURL,
TokenURL: config.OAuth2Config.TokenURL,
UserInfoURL: config.OAuth2Config.UserInfoURL,
Scopes: config.OAuth2Config.Scopes,
FieldMapping: &FieldMapping{
Identifier: config.OAuth2Config.FieldMapping.Identifier,
DisplayName: config.OAuth2Config.FieldMapping.DisplayName,
Email: config.OAuth2Config.FieldMapping.Email,
},
},
}
}
func convertIdentityProviderConfigToStore(config *IdentityProviderConfig) *store.IdentityProviderConfig {
return &store.IdentityProviderConfig{
OAuth2Config: &store.IdentityProviderOAuth2Config{
ClientID: config.OAuth2Config.ClientID,
ClientSecret: config.OAuth2Config.ClientSecret,
AuthURL: config.OAuth2Config.AuthURL,
TokenURL: config.OAuth2Config.TokenURL,
UserInfoURL: config.OAuth2Config.UserInfoURL,
Scopes: config.OAuth2Config.Scopes,
FieldMapping: &store.FieldMapping{
Identifier: config.OAuth2Config.FieldMapping.Identifier,
DisplayName: config.OAuth2Config.FieldMapping.DisplayName,
Email: config.OAuth2Config.FieldMapping.Email,
},
},
}
}

156
server/route/api/v1/jwt.go Normal file
View File

@@ -0,0 +1,156 @@
package v1
import (
"fmt"
"net/http"
"strings"
"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"go.uber.org/zap"
"github.com/usememos/memos/internal/log"
"github.com/usememos/memos/internal/util"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/route/api/auth"
"github.com/usememos/memos/store"
)
const (
// The key name used to store user id in the context
// user id is extracted from the jwt token subject field.
userIDContextKey = "user-id"
)
func extractTokenFromHeader(c echo.Context) (string, error) {
authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
return "", nil
}
authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("Authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
func findAccessToken(c echo.Context) string {
// Check the HTTP request header first.
accessToken, _ := extractTokenFromHeader(c)
if accessToken == "" {
// Check the cookie.
cookie, _ := c.Cookie(auth.AccessTokenCookieName)
if cookie != nil {
accessToken = cookie.Value
}
}
return accessToken
}
// JWTMiddleware validates the access token.
func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
path := c.Request().URL.Path
method := c.Request().Method
if server.defaultAuthSkipper(c) {
return next(c)
}
// Skip validation for server status endpoints.
if util.HasPrefixes(path, "/api/v1/ping", "/api/v1/status") && method == http.MethodGet {
return next(c)
}
accessToken := findAccessToken(c)
if accessToken == "" {
// Allow the user to access the public endpoints.
if util.HasPrefixes(path, "/o") {
return next(c)
}
// When the request is not authenticated, we allow the user to access the memo endpoints for those public memos.
if util.HasPrefixes(path, "/api/v1/idp", "/api/v1/memo", "/api/v1/user") && path != "/api/v1/user" && method == http.MethodGet {
return next(c)
}
return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token")
}
userID, err := getUserIDFromAccessToken(accessToken, secret)
if err != nil {
err = removeAccessTokenAndCookies(c, server.Store, userID, accessToken)
if err != nil {
log.Error("fail to remove AccessToken and Cookies", zap.Error(err))
}
return echo.NewHTTPError(http.StatusUnauthorized, "Invalid or expired access token")
}
accessTokens, err := server.Store.GetUserAccessTokens(ctx, userID)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get user access tokens.").WithInternal(err)
}
if !validateAccessToken(accessToken, accessTokens) {
err = removeAccessTokenAndCookies(c, server.Store, userID, accessToken)
if err != nil {
log.Error("fail to remove AccessToken and Cookies", zap.Error(err))
}
return echo.NewHTTPError(http.StatusUnauthorized, "Invalid access token.")
}
// Even if there is no error, we still need to make sure the user still exists.
user, err := server.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to find user ID: %d", userID)).SetInternal(err)
}
if user == nil {
return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to find user ID: %d", userID))
}
// Stores userID into context.
c.Set(userIDContextKey, userID)
return next(c)
}
}
func getUserIDFromAccessToken(accessToken, secret string) (int32, error) {
claims := &auth.ClaimsMessage{}
_, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(secret), nil
}
}
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
return 0, errors.Wrap(err, "Invalid or expired access token")
}
// We either have a valid access token or we will attempt to generate new access token.
userID, err := util.ConvertStringToInt32(claims.Subject)
if err != nil {
return 0, errors.Wrap(err, "Malformed ID in the token")
}
return userID, nil
}
func (*APIV1Service) defaultAuthSkipper(c echo.Context) bool {
path := c.Path()
return util.HasPrefixes(path, "/api/v1/auth")
}
func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
for _, userAccessToken := range userAccessTokens {
if accessTokenString == userAccessToken.AccessToken {
return true
}
}
return false
}

1065
server/route/api/v1/memo.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,97 @@
package v1
import (
"encoding/json"
"fmt"
"net/http"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/store"
)
type MemoOrganizer struct {
MemoID int32 `json:"memoId"`
UserID int32 `json:"userId"`
Pinned bool `json:"pinned"`
}
type UpsertMemoOrganizerRequest struct {
Pinned bool `json:"pinned"`
}
func (s *APIV1Service) registerMemoOrganizerRoutes(g *echo.Group) {
g.POST("/memo/:memoId/organizer", s.CreateMemoOrganizer)
}
// CreateMemoOrganizer godoc
//
// @Summary Organize memo (pin/unpin)
// @Tags memo-organizer
// @Accept json
// @Produce json
// @Param memoId path int true "ID of memo to organize"
// @Param body body UpsertMemoOrganizerRequest true "Memo organizer object"
// @Success 200 {object} store.Memo "Memo information"
// @Failure 400 {object} nil "ID is not a number: %s | Malformatted post memo organizer request"
// @Failure 401 {object} nil "Missing user in session | Unauthorized"
// @Failure 404 {object} nil "Memo not found: %v"
// @Failure 500 {object} nil "Failed to find memo | Failed to upsert memo organizer | Failed to find memo by ID: %v | Failed to compose memo response"
// @Router /api/v1/memo/{memoId}/organizer [POST]
func (s *APIV1Service) CreateMemoOrganizer(c echo.Context) error {
ctx := c.Request().Context()
memoID, err := util.ConvertStringToInt32(c.Param("memoId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err)
}
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: &memoID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo").SetInternal(err)
}
if memo == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %v", memoID))
}
if memo.CreatorID != userID {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
request := &UpsertMemoOrganizerRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(request); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post memo organizer request").SetInternal(err)
}
upsert := &store.MemoOrganizer{
MemoID: memoID,
UserID: userID,
Pinned: request.Pinned,
}
_, err = s.Store.UpsertMemoOrganizer(ctx, upsert)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo organizer").SetInternal(err)
}
memo, err = s.Store.GetMemo(ctx, &store.FindMemo{
ID: &memoID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find memo by ID: %v", memoID)).SetInternal(err)
}
if memo == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %v", memoID))
}
memoResponse, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compose memo response").SetInternal(err)
}
return c.JSON(http.StatusOK, memoResponse)
}

View File

@@ -0,0 +1,156 @@
package v1
import (
"encoding/json"
"fmt"
"net/http"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/store"
)
type MemoRelationType string
const (
MemoRelationReference MemoRelationType = "REFERENCE"
MemoRelationComment MemoRelationType = "COMMENT"
)
func (t MemoRelationType) String() string {
return string(t)
}
type MemoRelation struct {
MemoID int32 `json:"memoId"`
RelatedMemoID int32 `json:"relatedMemoId"`
Type MemoRelationType `json:"type"`
}
type UpsertMemoRelationRequest struct {
RelatedMemoID int32 `json:"relatedMemoId"`
Type MemoRelationType `json:"type"`
}
func (s *APIV1Service) registerMemoRelationRoutes(g *echo.Group) {
g.GET("/memo/:memoId/relation", s.GetMemoRelationList)
g.POST("/memo/:memoId/relation", s.CreateMemoRelation)
g.DELETE("/memo/:memoId/relation/:relatedMemoId/type/:relationType", s.DeleteMemoRelation)
}
// GetMemoRelationList godoc
//
// @Summary Get a list of Memo Relations
// @Tags memo-relation
// @Accept json
// @Produce json
// @Param memoId path int true "ID of memo to find relations"
// @Success 200 {object} []store.MemoRelation "Memo relation information list"
// @Failure 400 {object} nil "ID is not a number: %s"
// @Failure 500 {object} nil "Failed to list memo relations"
// @Router /api/v1/memo/{memoId}/relation [GET]
func (s *APIV1Service) GetMemoRelationList(c echo.Context) error {
ctx := c.Request().Context()
memoID, err := util.ConvertStringToInt32(c.Param("memoId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err)
}
memoRelationList, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &memoID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to list memo relations").SetInternal(err)
}
return c.JSON(http.StatusOK, memoRelationList)
}
// CreateMemoRelation godoc
//
// @Summary Create Memo Relation
// @Description Create a relation between two memos
// @Tags memo-relation
// @Accept json
// @Produce json
// @Param memoId path int true "ID of memo to relate"
// @Param body body UpsertMemoRelationRequest true "Memo relation object"
// @Success 200 {object} store.MemoRelation "Memo relation information"
// @Failure 400 {object} nil "ID is not a number: %s | Malformatted post memo relation request"
// @Failure 500 {object} nil "Failed to upsert memo relation"
// @Router /api/v1/memo/{memoId}/relation [POST]
//
// NOTES:
// - Currently not secured
// - It's possible to create relations to memos that doesn't exist, which will trigger 404 errors when the frontend tries to load them.
// - It's possible to create multiple relations, though the interface only shows first.
func (s *APIV1Service) CreateMemoRelation(c echo.Context) error {
ctx := c.Request().Context()
memoID, err := util.ConvertStringToInt32(c.Param("memoId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err)
}
request := &UpsertMemoRelationRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(request); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post memo relation request").SetInternal(err)
}
memoRelation, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memoID,
RelatedMemoID: request.RelatedMemoID,
Type: store.MemoRelationType(request.Type),
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert memo relation").SetInternal(err)
}
return c.JSON(http.StatusOK, memoRelation)
}
// DeleteMemoRelation godoc
//
// @Summary Delete a Memo Relation
// @Description Removes a relation between two memos
// @Tags memo-relation
// @Accept json
// @Produce json
// @Param memoId path int true "ID of memo to find relations"
// @Param relatedMemoId path int true "ID of memo to remove relation to"
// @Param relationType path MemoRelationType true "Type of relation to remove"
// @Success 200 {boolean} true "Memo relation deleted"
// @Failure 400 {object} nil "Memo ID is not a number: %s | Related memo ID is not a number: %s"
// @Failure 500 {object} nil "Failed to delete memo relation"
// @Router /api/v1/memo/{memoId}/relation/{relatedMemoId}/type/{relationType} [DELETE]
//
// NOTES:
// - Currently not secured.
// - Will always return true, even if the relation doesn't exist.
func (s *APIV1Service) DeleteMemoRelation(c echo.Context) error {
ctx := c.Request().Context()
memoID, err := util.ConvertStringToInt32(c.Param("memoId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Memo ID is not a number: %s", c.Param("memoId"))).SetInternal(err)
}
relatedMemoID, err := util.ConvertStringToInt32(c.Param("relatedMemoId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Related memo ID is not a number: %s", c.Param("relatedMemoId"))).SetInternal(err)
}
relationType := store.MemoRelationType(c.Param("relationType"))
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
MemoID: &memoID,
RelatedMemoID: &relatedMemoID,
Type: &relationType,
}); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete memo relation").SetInternal(err)
}
return c.JSON(http.StatusOK, true)
}
func convertMemoRelationFromStore(memoRelation *store.MemoRelation) *MemoRelation {
return &MemoRelation{
MemoID: memoRelation.MemoID,
RelatedMemoID: memoRelation.RelatedMemoID,
Type: MemoRelationType(memoRelation.Type),
}
}

View File

@@ -0,0 +1,506 @@
package v1
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"github.com/labstack/echo/v4"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors"
"go.uber.org/zap"
"github.com/usememos/memos/internal/log"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/storage/s3"
"github.com/usememos/memos/store"
)
type Resource struct {
ID int32 `json:"id"`
Name string `json:"name"`
// Standard fields
CreatorID int32 `json:"creatorId"`
CreatedTs int64 `json:"createdTs"`
UpdatedTs int64 `json:"updatedTs"`
// Domain specific fields
Filename string `json:"filename"`
Blob []byte `json:"-"`
InternalPath string `json:"-"`
ExternalLink string `json:"externalLink"`
Type string `json:"type"`
Size int64 `json:"size"`
}
type CreateResourceRequest struct {
Filename string `json:"filename"`
ExternalLink string `json:"externalLink"`
Type string `json:"type"`
}
type FindResourceRequest struct {
ID *int32 `json:"id"`
CreatorID *int32 `json:"creatorId"`
Filename *string `json:"filename"`
}
type UpdateResourceRequest struct {
Filename *string `json:"filename"`
}
const (
// The upload memory buffer is 32 MiB.
// It should be kept low, so RAM usage doesn't get out of control.
// This is unrelated to maximum upload size limit, which is now set through system setting.
maxUploadBufferSizeBytes = 32 << 20
MebiByte = 1024 * 1024
)
var fileKeyPattern = regexp.MustCompile(`\{[a-z]{1,9}\}`)
func (s *APIV1Service) registerResourceRoutes(g *echo.Group) {
g.GET("/resource", s.GetResourceList)
g.POST("/resource", s.CreateResource)
g.POST("/resource/blob", s.UploadResource)
g.PATCH("/resource/:resourceId", s.UpdateResource)
g.DELETE("/resource/:resourceId", s.DeleteResource)
}
// GetResourceList godoc
//
// @Summary Get a list of resources
// @Tags resource
// @Produce json
// @Param limit query int false "Limit"
// @Param offset query int false "Offset"
// @Success 200 {object} []store.Resource "Resource list"
// @Failure 401 {object} nil "Missing user in session"
// @Failure 500 {object} nil "Failed to fetch resource list"
// @Router /api/v1/resource [GET]
func (s *APIV1Service) GetResourceList(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
find := &store.FindResource{
CreatorID: &userID,
}
if limit, err := strconv.Atoi(c.QueryParam("limit")); err == nil {
find.Limit = &limit
}
if offset, err := strconv.Atoi(c.QueryParam("offset")); err == nil {
find.Offset = &offset
}
list, err := s.Store.ListResources(ctx, find)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource list").SetInternal(err)
}
resourceMessageList := []*Resource{}
for _, resource := range list {
resourceMessageList = append(resourceMessageList, convertResourceFromStore(resource))
}
return c.JSON(http.StatusOK, resourceMessageList)
}
// CreateResource godoc
//
// @Summary Create resource
// @Tags resource
// @Accept json
// @Produce json
// @Param body body CreateResourceRequest true "Request object."
// @Success 200 {object} store.Resource "Created resource"
// @Failure 400 {object} nil "Malformatted post resource request | Invalid external link | Invalid external link scheme | Failed to request %s | Failed to read %s | Failed to read mime from %s"
// @Failure 401 {object} nil "Missing user in session"
// @Failure 500 {object} nil "Failed to save resource | Failed to create resource | Failed to create activity"
// @Router /api/v1/resource [POST]
func (s *APIV1Service) CreateResource(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
request := &CreateResourceRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(request); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post resource request").SetInternal(err)
}
create := &store.Resource{
ResourceName: shortuuid.New(),
CreatorID: userID,
Filename: request.Filename,
ExternalLink: request.ExternalLink,
Type: request.Type,
}
if request.ExternalLink != "" {
// Only allow those external links scheme with http/https
linkURL, err := url.Parse(request.ExternalLink)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid external link").SetInternal(err)
}
if linkURL.Scheme != "http" && linkURL.Scheme != "https" {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid external link scheme")
}
}
resource, err := s.Store.CreateResource(ctx, create)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create resource").SetInternal(err)
}
return c.JSON(http.StatusOK, convertResourceFromStore(resource))
}
// UploadResource godoc
//
// @Summary Upload resource
// @Tags resource
// @Accept multipart/form-data
// @Produce json
// @Param file formData file true "File to upload"
// @Success 200 {object} store.Resource "Created resource"
// @Failure 400 {object} nil "Upload file not found | File size exceeds allowed limit of %d MiB | Failed to parse upload data"
// @Failure 401 {object} nil "Missing user in session"
// @Failure 500 {object} nil "Failed to get uploading file | Failed to open file | Failed to save resource | Failed to create resource | Failed to create activity"
// @Router /api/v1/resource/blob [POST]
func (s *APIV1Service) UploadResource(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
maxUploadSetting, err := s.Store.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{Name: SystemSettingMaxUploadSizeMiBName.String()})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get max upload size").SetInternal(err)
}
var settingMaxUploadSizeBytes int
if maxUploadSetting != nil {
if settingMaxUploadSizeMiB, err := strconv.Atoi(maxUploadSetting.Value); err == nil {
settingMaxUploadSizeBytes = settingMaxUploadSizeMiB * MebiByte
} else {
log.Warn("Failed to parse max upload size", zap.Error(err))
settingMaxUploadSizeBytes = 0
}
} else {
// Default to 32 MiB.
settingMaxUploadSizeBytes = 32 * MebiByte
}
file, err := c.FormFile("file")
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get uploading file").SetInternal(err)
}
if file == nil {
return echo.NewHTTPError(http.StatusBadRequest, "Upload file not found").SetInternal(err)
}
if file.Size > int64(settingMaxUploadSizeBytes) {
message := fmt.Sprintf("File size exceeds allowed limit of %d MiB", settingMaxUploadSizeBytes/MebiByte)
return echo.NewHTTPError(http.StatusBadRequest, message).SetInternal(err)
}
if err := c.Request().ParseMultipartForm(maxUploadBufferSizeBytes); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Failed to parse upload data").SetInternal(err)
}
sourceFile, err := file.Open()
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to open file").SetInternal(err)
}
defer sourceFile.Close()
create := &store.Resource{
ResourceName: shortuuid.New(),
CreatorID: userID,
Filename: file.Filename,
Type: file.Header.Get("Content-Type"),
Size: file.Size,
}
err = SaveResourceBlob(ctx, s.Store, create, sourceFile)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to save resource").SetInternal(err)
}
resource, err := s.Store.CreateResource(ctx, create)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create resource").SetInternal(err)
}
return c.JSON(http.StatusOK, convertResourceFromStore(resource))
}
// DeleteResource godoc
//
// @Summary Delete a resource
// @Tags resource
// @Produce json
// @Param resourceId path int true "Resource ID"
// @Success 200 {boolean} true "Resource deleted"
// @Failure 400 {object} nil "ID is not a number: %s"
// @Failure 401 {object} nil "Missing user in session"
// @Failure 404 {object} nil "Resource not found: %d"
// @Failure 500 {object} nil "Failed to find resource | Failed to delete resource"
// @Router /api/v1/resource/{resourceId} [DELETE]
func (s *APIV1Service) DeleteResource(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
resourceID, err := util.ConvertStringToInt32(c.Param("resourceId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("resourceId"))).SetInternal(err)
}
resource, err := s.Store.GetResource(ctx, &store.FindResource{
ID: &resourceID,
CreatorID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find resource").SetInternal(err)
}
if resource == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Resource not found: %d", resourceID))
}
if err := s.Store.DeleteResource(ctx, &store.DeleteResource{
ID: resourceID,
}); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete resource").SetInternal(err)
}
return c.JSON(http.StatusOK, true)
}
// UpdateResource godoc
//
// @Summary Update a resource
// @Tags resource
// @Produce json
// @Param resourceId path int true "Resource ID"
// @Param patch body UpdateResourceRequest true "Patch resource request"
// @Success 200 {object} store.Resource "Updated resource"
// @Failure 400 {object} nil "ID is not a number: %s | Malformatted patch resource request"
// @Failure 401 {object} nil "Missing user in session | Unauthorized"
// @Failure 404 {object} nil "Resource not found: %d"
// @Failure 500 {object} nil "Failed to find resource | Failed to patch resource"
// @Router /api/v1/resource/{resourceId} [PATCH]
func (s *APIV1Service) UpdateResource(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
resourceID, err := util.ConvertStringToInt32(c.Param("resourceId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("resourceId"))).SetInternal(err)
}
resource, err := s.Store.GetResource(ctx, &store.FindResource{
ID: &resourceID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find resource").SetInternal(err)
}
if resource == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Resource not found: %d", resourceID))
}
if resource.CreatorID != userID {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
request := &UpdateResourceRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(request); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch resource request").SetInternal(err)
}
currentTs := time.Now().Unix()
update := &store.UpdateResource{
ID: resourceID,
UpdatedTs: &currentTs,
}
if request.Filename != nil && *request.Filename != "" {
update.Filename = request.Filename
}
resource, err = s.Store.UpdateResource(ctx, update)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch resource").SetInternal(err)
}
return c.JSON(http.StatusOK, convertResourceFromStore(resource))
}
func replacePathTemplate(path, filename string) string {
t := time.Now()
path = fileKeyPattern.ReplaceAllStringFunc(path, func(s string) string {
switch s {
case "{filename}":
return filename
case "{timestamp}":
return fmt.Sprintf("%d", t.Unix())
case "{year}":
return fmt.Sprintf("%d", t.Year())
case "{month}":
return fmt.Sprintf("%02d", t.Month())
case "{day}":
return fmt.Sprintf("%02d", t.Day())
case "{hour}":
return fmt.Sprintf("%02d", t.Hour())
case "{minute}":
return fmt.Sprintf("%02d", t.Minute())
case "{second}":
return fmt.Sprintf("%02d", t.Second())
case "{uuid}":
return util.GenUUID()
}
return s
})
return path
}
func convertResourceFromStore(resource *store.Resource) *Resource {
return &Resource{
ID: resource.ID,
Name: resource.ResourceName,
CreatorID: resource.CreatorID,
CreatedTs: resource.CreatedTs,
UpdatedTs: resource.UpdatedTs,
Filename: resource.Filename,
Blob: resource.Blob,
InternalPath: resource.InternalPath,
ExternalLink: resource.ExternalLink,
Type: resource.Type,
Size: resource.Size,
}
}
// SaveResourceBlob save the blob of resource based on the storage config
//
// Depend on the storage config, some fields of *store.ResourceCreate will be changed:
// 1. *DatabaseStorage*: `create.Blob`.
// 2. *LocalStorage*: `create.InternalPath`.
// 3. Others( external service): `create.ExternalLink`.
func SaveResourceBlob(ctx context.Context, s *store.Store, create *store.Resource, r io.Reader) error {
systemSettingStorageServiceID, err := s.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{Name: SystemSettingStorageServiceIDName.String()})
if err != nil {
return errors.Wrap(err, "Failed to find SystemSettingStorageServiceIDName")
}
storageServiceID := DefaultStorage
if systemSettingStorageServiceID != nil {
err = json.Unmarshal([]byte(systemSettingStorageServiceID.Value), &storageServiceID)
if err != nil {
return errors.Wrap(err, "Failed to unmarshal storage service id")
}
}
// `DatabaseStorage` means store blob into database
if storageServiceID == DatabaseStorage {
fileBytes, err := io.ReadAll(r)
if err != nil {
return errors.Wrap(err, "Failed to read file")
}
create.Blob = fileBytes
return nil
} else if storageServiceID == LocalStorage {
// `LocalStorage` means save blob into local disk
systemSettingLocalStoragePath, err := s.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{Name: SystemSettingLocalStoragePathName.String()})
if err != nil {
return errors.Wrap(err, "Failed to find SystemSettingLocalStoragePathName")
}
localStoragePath := "assets/{timestamp}_{filename}"
if systemSettingLocalStoragePath != nil && systemSettingLocalStoragePath.Value != "" {
err = json.Unmarshal([]byte(systemSettingLocalStoragePath.Value), &localStoragePath)
if err != nil {
return errors.Wrap(err, "Failed to unmarshal SystemSettingLocalStoragePathName")
}
}
internalPath := localStoragePath
if !strings.Contains(internalPath, "{filename}") {
internalPath = filepath.Join(internalPath, "{filename}")
}
internalPath = replacePathTemplate(internalPath, create.Filename)
internalPath = filepath.ToSlash(internalPath)
create.InternalPath = internalPath
osPath := filepath.FromSlash(internalPath)
if !filepath.IsAbs(osPath) {
osPath = filepath.Join(s.Profile.Data, osPath)
}
dir := filepath.Dir(osPath)
if err = os.MkdirAll(dir, os.ModePerm); err != nil {
return errors.Wrap(err, "Failed to create directory")
}
dst, err := os.Create(osPath)
if err != nil {
return errors.Wrap(err, "Failed to create file")
}
defer dst.Close()
_, err = io.Copy(dst, r)
if err != nil {
return errors.Wrap(err, "Failed to copy file")
}
return nil
}
// Others: store blob into external service, such as S3
storage, err := s.GetStorage(ctx, &store.FindStorage{ID: &storageServiceID})
if err != nil {
return errors.Wrap(err, "Failed to find StorageServiceID")
}
if storage == nil {
return errors.Errorf("Storage %d not found", storageServiceID)
}
storageMessage, err := ConvertStorageFromStore(storage)
if err != nil {
return errors.Wrap(err, "Failed to ConvertStorageFromStore")
}
if storageMessage.Type != StorageS3 {
return errors.Errorf("Unsupported storage type: %s", storageMessage.Type)
}
s3Config := storageMessage.Config.S3Config
s3Client, err := s3.NewClient(ctx, &s3.Config{
AccessKey: s3Config.AccessKey,
SecretKey: s3Config.SecretKey,
EndPoint: s3Config.EndPoint,
Region: s3Config.Region,
Bucket: s3Config.Bucket,
URLPrefix: s3Config.URLPrefix,
URLSuffix: s3Config.URLSuffix,
PreSign: s3Config.PreSign,
})
if err != nil {
return errors.Wrap(err, "Failed to create s3 client")
}
filePath := s3Config.Path
if !strings.Contains(filePath, "{filename}") {
filePath = filepath.Join(filePath, "{filename}")
}
filePath = replacePathTemplate(filePath, create.Filename)
link, err := s3Client.UploadFile(ctx, filePath, create.Type, r)
if err != nil {
return errors.Wrap(err, "Failed to upload via s3 client")
}
create.ExternalLink = link
return nil
}

View File

@@ -0,0 +1,316 @@
package v1
import (
"encoding/json"
"fmt"
"net/http"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/store"
)
const (
// LocalStorage means the storage service is local file system.
LocalStorage int32 = -1
// DatabaseStorage means the storage service is database.
DatabaseStorage int32 = 0
// Default storage service is database.
DefaultStorage int32 = DatabaseStorage
)
type StorageType string
const (
StorageS3 StorageType = "S3"
)
func (t StorageType) String() string {
return string(t)
}
type StorageConfig struct {
S3Config *StorageS3Config `json:"s3Config"`
}
type StorageS3Config struct {
EndPoint string `json:"endPoint"`
Path string `json:"path"`
Region string `json:"region"`
AccessKey string `json:"accessKey"`
SecretKey string `json:"secretKey"`
Bucket string `json:"bucket"`
URLPrefix string `json:"urlPrefix"`
URLSuffix string `json:"urlSuffix"`
PreSign bool `json:"presign"`
}
type Storage struct {
ID int32 `json:"id"`
Name string `json:"name"`
Type StorageType `json:"type"`
Config *StorageConfig `json:"config"`
}
type CreateStorageRequest struct {
Name string `json:"name"`
Type StorageType `json:"type"`
Config *StorageConfig `json:"config"`
}
type UpdateStorageRequest struct {
Type StorageType `json:"type"`
Name *string `json:"name"`
Config *StorageConfig `json:"config"`
}
func (s *APIV1Service) registerStorageRoutes(g *echo.Group) {
g.GET("/storage", s.GetStorageList)
g.POST("/storage", s.CreateStorage)
g.PATCH("/storage/:storageId", s.UpdateStorage)
g.DELETE("/storage/:storageId", s.DeleteStorage)
}
// GetStorageList godoc
//
// @Summary Get a list of storages
// @Tags storage
// @Produce json
// @Success 200 {object} []store.Storage "List of storages"
// @Failure 401 {object} nil "Missing user in session | Unauthorized"
// @Failure 500 {object} nil "Failed to find user | Failed to convert storage"
// @Router /api/v1/storage [GET]
func (s *APIV1Service) GetStorageList(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
// We should only show storage list to host user.
if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
list, err := s.Store.ListStorages(ctx, &store.FindStorage{})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage list").SetInternal(err)
}
storageList := []*Storage{}
for _, storage := range list {
storageMessage, err := ConvertStorageFromStore(storage)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to convert storage").SetInternal(err)
}
storageList = append(storageList, storageMessage)
}
return c.JSON(http.StatusOK, storageList)
}
// CreateStorage godoc
//
// @Summary Create storage
// @Tags storage
// @Accept json
// @Produce json
// @Param body body CreateStorageRequest true "Request object."
// @Success 200 {object} store.Storage "Created storage"
// @Failure 400 {object} nil "Malformatted post storage request"
// @Failure 401 {object} nil "Missing user in session"
// @Failure 500 {object} nil "Failed to find user | Failed to create storage | Failed to convert storage"
// @Router /api/v1/storage [POST]
func (s *APIV1Service) CreateStorage(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
create := &CreateStorageRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(create); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post storage request").SetInternal(err)
}
configString := ""
if create.Type == StorageS3 && create.Config.S3Config != nil {
configBytes, err := json.Marshal(create.Config.S3Config)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post storage request").SetInternal(err)
}
configString = string(configBytes)
}
storage, err := s.Store.CreateStorage(ctx, &store.Storage{
Name: create.Name,
Type: create.Type.String(),
Config: configString,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create storage").SetInternal(err)
}
storageMessage, err := ConvertStorageFromStore(storage)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to convert storage").SetInternal(err)
}
return c.JSON(http.StatusOK, storageMessage)
}
// DeleteStorage godoc
//
// @Summary Delete a storage
// @Tags storage
// @Produce json
// @Param storageId path int true "Storage ID"
// @Success 200 {boolean} true "Storage deleted"
// @Failure 400 {object} nil "ID is not a number: %s | Storage service %d is using"
// @Failure 401 {object} nil "Missing user in session | Unauthorized"
// @Failure 500 {object} nil "Failed to find user | Failed to find storage | Failed to unmarshal storage service id | Failed to delete storage"
// @Router /api/v1/storage/{storageId} [DELETE]
//
// NOTES:
// - error message "Storage service %d is using" probably should be "Storage service %d is in use".
func (s *APIV1Service) DeleteStorage(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
storageID, err := util.ConvertStringToInt32(c.Param("storageId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("storageId"))).SetInternal(err)
}
systemSetting, err := s.Store.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{Name: SystemSettingStorageServiceIDName.String()})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage").SetInternal(err)
}
if systemSetting != nil {
storageServiceID := DefaultStorage
err = json.Unmarshal([]byte(systemSetting.Value), &storageServiceID)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal storage service id").SetInternal(err)
}
if storageServiceID == storageID {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Storage service %d is using", storageID))
}
}
if err = s.Store.DeleteStorage(ctx, &store.DeleteStorage{ID: storageID}); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete storage").SetInternal(err)
}
return c.JSON(http.StatusOK, true)
}
// UpdateStorage godoc
//
// @Summary Update a storage
// @Tags storage
// @Produce json
// @Param storageId path int true "Storage ID"
// @Param patch body UpdateStorageRequest true "Patch request"
// @Success 200 {object} store.Storage "Updated resource"
// @Failure 400 {object} nil "ID is not a number: %s | Malformatted patch storage request | Malformatted post storage request"
// @Failure 401 {object} nil "Missing user in session | Unauthorized"
// @Failure 500 {object} nil "Failed to find user | Failed to patch storage | Failed to convert storage"
// @Router /api/v1/storage/{storageId} [PATCH]
func (s *APIV1Service) UpdateStorage(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
storageID, err := util.ConvertStringToInt32(c.Param("storageId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("storageId"))).SetInternal(err)
}
update := &UpdateStorageRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(update); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch storage request").SetInternal(err)
}
storageUpdate := &store.UpdateStorage{
ID: storageID,
}
if update.Name != nil {
storageUpdate.Name = update.Name
}
if update.Config != nil {
if update.Type == StorageS3 {
configBytes, err := json.Marshal(update.Config.S3Config)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post storage request").SetInternal(err)
}
configString := string(configBytes)
storageUpdate.Config = &configString
}
}
storage, err := s.Store.UpdateStorage(ctx, storageUpdate)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch storage").SetInternal(err)
}
storageMessage, err := ConvertStorageFromStore(storage)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to convert storage").SetInternal(err)
}
return c.JSON(http.StatusOK, storageMessage)
}
func ConvertStorageFromStore(storage *store.Storage) (*Storage, error) {
storageMessage := &Storage{
ID: storage.ID,
Name: storage.Name,
Type: StorageType(storage.Type),
Config: &StorageConfig{},
}
if storageMessage.Type == StorageS3 {
s3Config := &StorageS3Config{}
if err := json.Unmarshal([]byte(storage.Config), s3Config); err != nil {
return nil, err
}
storageMessage.Config = &StorageConfig{
S3Config: s3Config,
}
}
return storageMessage, nil
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,167 @@
package v1
import (
"encoding/json"
"net/http"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store"
)
type SystemStatus struct {
Host *User `json:"host"`
Profile profile.Profile `json:"profile"`
DBSize int64 `json:"dbSize"`
// System settings
// Disable password login.
DisablePasswordLogin bool `json:"disablePasswordLogin"`
// Disable public memos.
DisablePublicMemos bool `json:"disablePublicMemos"`
// Max upload size.
MaxUploadSizeMiB int `json:"maxUploadSizeMiB"`
// Customized server profile, including server name and external url.
CustomizedProfile CustomizedProfile `json:"customizedProfile"`
// Storage service ID.
StorageServiceID int32 `json:"storageServiceId"`
// Local storage path.
LocalStoragePath string `json:"localStoragePath"`
// Memo display with updated timestamp.
MemoDisplayWithUpdatedTs bool `json:"memoDisplayWithUpdatedTs"`
}
func (s *APIV1Service) registerSystemRoutes(g *echo.Group) {
g.GET("/ping", s.PingSystem)
g.GET("/status", s.GetSystemStatus)
g.POST("/system/vacuum", s.ExecVacuum)
}
// PingSystem godoc
//
// @Summary Ping the system
// @Tags system
// @Produce json
// @Success 200 {boolean} true "If succeed to ping the system"
// @Router /api/v1/ping [GET]
func (*APIV1Service) PingSystem(c echo.Context) error {
return c.JSON(http.StatusOK, true)
}
// GetSystemStatus godoc
//
// @Summary Get system GetSystemStatus
// @Tags system
// @Produce json
// @Success 200 {object} SystemStatus "System GetSystemStatus"
// @Failure 401 {object} nil "Missing user in session | Unauthorized"
// @Failure 500 {object} nil "Failed to find host user | Failed to find system setting list | Failed to unmarshal system setting customized profile value"
// @Router /api/v1/status [GET]
func (s *APIV1Service) GetSystemStatus(c echo.Context) error {
ctx := c.Request().Context()
systemStatus := SystemStatus{
Profile: profile.Profile{
Mode: s.Profile.Mode,
Version: s.Profile.Version,
},
MaxUploadSizeMiB: 32,
CustomizedProfile: CustomizedProfile{
Name: "Memos",
Locale: "en",
Appearance: "system",
},
StorageServiceID: DefaultStorage,
LocalStoragePath: "assets/{timestamp}_{filename}",
}
hostUserType := store.RoleHost
hostUser, err := s.Store.GetUser(ctx, &store.FindUser{
Role: &hostUserType,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find host user").SetInternal(err)
}
if hostUser != nil {
systemStatus.Host = &User{ID: hostUser.ID}
}
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find workspace general setting").SetInternal(err)
}
systemStatus.DisablePasswordLogin = workspaceGeneralSetting.DisallowPasswordLogin
systemSettingList, err := s.Store.ListWorkspaceSettings(ctx, &store.FindWorkspaceSetting{})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting list").SetInternal(err)
}
for _, systemSetting := range systemSettingList {
if systemSetting.Name == SystemSettingServerIDName.String() || systemSetting.Name == SystemSettingSecretSessionName.String() || systemSetting.Name == SystemSettingTelegramBotTokenName.String() {
continue
}
var baseValue any
err := json.Unmarshal([]byte(systemSetting.Value), &baseValue)
if err != nil {
// Skip invalid value.
continue
}
switch systemSetting.Name {
case SystemSettingDisablePublicMemosName.String():
systemStatus.DisablePublicMemos = baseValue.(bool)
case SystemSettingMaxUploadSizeMiBName.String():
systemStatus.MaxUploadSizeMiB = int(baseValue.(float64))
case SystemSettingCustomizedProfileName.String():
customizedProfile := CustomizedProfile{}
if err := json.Unmarshal([]byte(systemSetting.Value), &customizedProfile); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting customized profile value").SetInternal(err)
}
systemStatus.CustomizedProfile = customizedProfile
case SystemSettingStorageServiceIDName.String():
systemStatus.StorageServiceID = int32(baseValue.(float64))
case SystemSettingLocalStoragePathName.String():
systemStatus.LocalStoragePath = baseValue.(string)
case SystemSettingMemoDisplayWithUpdatedTsName.String():
systemStatus.MemoDisplayWithUpdatedTs = baseValue.(bool)
default:
// Skip unknown system setting.
}
}
return c.JSON(http.StatusOK, systemStatus)
}
// ExecVacuum godoc
//
// @Summary Vacuum the database
// @Tags system
// @Produce json
// @Success 200 {boolean} true "Database vacuumed"
// @Failure 401 {object} nil "Missing user in session | Unauthorized"
// @Failure 500 {object} nil "Failed to find user | Failed to ExecVacuum database"
// @Router /api/v1/system/vacuum [POST]
func (s *APIV1Service) ExecVacuum(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
if err := s.Store.Vacuum(ctx); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to vacuum database").SetInternal(err)
}
return c.JSON(http.StatusOK, true)
}

View File

@@ -0,0 +1,249 @@
package v1
import (
"encoding/json"
"net/http"
"path/filepath"
"strings"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/store"
)
type SystemSettingName string
const (
// SystemSettingServerIDName is the name of server id.
SystemSettingServerIDName SystemSettingName = "server-id"
// SystemSettingSecretSessionName is the name of secret session.
SystemSettingSecretSessionName SystemSettingName = "secret-session"
// SystemSettingDisablePublicMemosName is the name of disable public memos setting.
SystemSettingDisablePublicMemosName SystemSettingName = "disable-public-memos"
// SystemSettingMaxUploadSizeMiBName is the name of max upload size setting.
SystemSettingMaxUploadSizeMiBName SystemSettingName = "max-upload-size-mib"
// SystemSettingCustomizedProfileName is the name of customized server profile.
SystemSettingCustomizedProfileName SystemSettingName = "customized-profile"
// SystemSettingStorageServiceIDName is the name of storage service ID.
SystemSettingStorageServiceIDName SystemSettingName = "storage-service-id"
// SystemSettingLocalStoragePathName is the name of local storage path.
SystemSettingLocalStoragePathName SystemSettingName = "local-storage-path"
// SystemSettingTelegramBotTokenName is the name of Telegram Bot Token.
SystemSettingTelegramBotTokenName SystemSettingName = "telegram-bot-token"
// SystemSettingMemoDisplayWithUpdatedTsName is the name of memo display with updated ts.
SystemSettingMemoDisplayWithUpdatedTsName SystemSettingName = "memo-display-with-updated-ts"
)
const systemSettingUnmarshalError = `failed to unmarshal value from system setting "%v"`
// CustomizedProfile is the struct definition for SystemSettingCustomizedProfileName system setting item.
type CustomizedProfile struct {
// Name is the server name, default is `memos`
Name string `json:"name"`
// LogoURL is the url of logo image.
LogoURL string `json:"logoUrl"`
// Description is the server description.
Description string `json:"description"`
// Locale is the server default locale.
Locale string `json:"locale"`
// Appearance is the server default appearance.
Appearance string `json:"appearance"`
}
func (key SystemSettingName) String() string {
return string(key)
}
type SystemSetting struct {
Name SystemSettingName `json:"name"`
// Value is a JSON string with basic value.
Value string `json:"value"`
Description string `json:"description"`
}
type UpsertSystemSettingRequest struct {
Name SystemSettingName `json:"name"`
Value string `json:"value"`
Description string `json:"description"`
}
func (s *APIV1Service) registerSystemSettingRoutes(g *echo.Group) {
g.GET("/system/setting", s.GetSystemSettingList)
g.POST("/system/setting", s.CreateSystemSetting)
}
// GetSystemSettingList godoc
//
// @Summary Get a list of system settings
// @Tags system-setting
// @Produce json
// @Success 200 {object} []SystemSetting "System setting list"
// @Failure 401 {object} nil "Missing user in session | Unauthorized"
// @Failure 500 {object} nil "Failed to find user | Failed to find system setting list"
// @Router /api/v1/system/setting [GET]
func (s *APIV1Service) GetSystemSettingList(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
list, err := s.Store.ListWorkspaceSettings(ctx, &store.FindWorkspaceSetting{})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting list").SetInternal(err)
}
systemSettingList := make([]*SystemSetting, 0, len(list))
for _, systemSetting := range list {
systemSettingList = append(systemSettingList, convertSystemSettingFromStore(systemSetting))
}
return c.JSON(http.StatusOK, systemSettingList)
}
// CreateSystemSetting godoc
//
// @Summary Create system setting
// @Tags system-setting
// @Accept json
// @Produce json
// @Param body body UpsertSystemSettingRequest true "Request object."
// @Failure 400 {object} nil "Malformatted post system setting request | invalid system setting"
// @Failure 401 {object} nil "Missing user in session | Unauthorized"
// @Failure 403 {object} nil "Cannot disable passwords if no SSO identity provider is configured."
// @Failure 500 {object} nil "Failed to find user | Failed to upsert system setting"
// @Router /api/v1/system/setting [POST]
func (s *APIV1Service) CreateSystemSetting(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
systemSettingUpsert := &UpsertSystemSettingRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(systemSettingUpsert); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post system setting request").SetInternal(err)
}
if err := systemSettingUpsert.Validate(); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "invalid system setting").SetInternal(err)
}
systemSetting, err := s.Store.UpsertWorkspaceSetting(ctx, &store.WorkspaceSetting{
Name: systemSettingUpsert.Name.String(),
Value: systemSettingUpsert.Value,
Description: systemSettingUpsert.Description,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert system setting").SetInternal(err)
}
return c.JSON(http.StatusOK, convertSystemSettingFromStore(systemSetting))
}
func (upsert UpsertSystemSettingRequest) Validate() error {
switch settingName := upsert.Name; settingName {
case SystemSettingServerIDName:
return errors.Errorf("updating %v is not allowed", settingName)
case SystemSettingDisablePublicMemosName:
var value bool
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return errors.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingMaxUploadSizeMiBName:
var value int
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return errors.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingCustomizedProfileName:
customizedProfile := CustomizedProfile{
Name: "Memos",
LogoURL: "",
Description: "",
Locale: "en",
Appearance: "system",
}
if err := json.Unmarshal([]byte(upsert.Value), &customizedProfile); err != nil {
return errors.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingStorageServiceIDName:
// Note: 0 is the default value(database) for storage service ID.
value := 0
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return errors.Errorf(systemSettingUnmarshalError, settingName)
}
return nil
case SystemSettingLocalStoragePathName:
value := ""
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return errors.Errorf(systemSettingUnmarshalError, settingName)
}
trimmedValue := strings.TrimSpace(value)
switch {
case trimmedValue != value:
return errors.New("local storage path must not contain leading or trailing whitespace")
case trimmedValue == "":
return errors.New("local storage path can't be empty")
case strings.Contains(trimmedValue, "\\"):
return errors.New("local storage path must use forward slashes `/`")
case strings.Contains(trimmedValue, "../"):
return errors.New("local storage path is not allowed to contain `../`")
case strings.HasPrefix(trimmedValue, "./"):
return errors.New("local storage path is not allowed to start with `./`")
case filepath.IsAbs(trimmedValue) || trimmedValue[0] == '/':
return errors.New("local storage path must be a relative path")
case !strings.Contains(trimmedValue, "{filename}"):
return errors.New("local storage path must contain `{filename}`")
}
case SystemSettingTelegramBotTokenName:
if upsert.Value == "" {
return nil
}
// Bot Token with Reverse Proxy shoule like `http.../bot<token>`
if strings.HasPrefix(upsert.Value, "http") {
slashIndex := strings.LastIndexAny(upsert.Value, "/")
if strings.HasPrefix(upsert.Value[slashIndex:], "/bot") {
return nil
}
return errors.New("token start with `http` must end with `/bot<token>`")
}
fragments := strings.Split(upsert.Value, ":")
if len(fragments) != 2 {
return errors.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingMemoDisplayWithUpdatedTsName:
var value bool
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return errors.Errorf(systemSettingUnmarshalError, settingName)
}
default:
return errors.New("invalid system setting name")
}
return nil
}
func convertSystemSettingFromStore(systemSetting *store.WorkspaceSetting) *SystemSetting {
return &SystemSetting{
Name: SystemSettingName(systemSetting.Name),
Value: systemSetting.Value,
Description: systemSetting.Description,
}
}

218
server/route/api/v1/tag.go Normal file
View File

@@ -0,0 +1,218 @@
package v1
import (
"encoding/json"
"fmt"
"net/http"
"regexp"
"sort"
"github.com/labstack/echo/v4"
"golang.org/x/exp/slices"
"github.com/usememos/memos/store"
)
type Tag struct {
Name string
CreatorID int32
}
type UpsertTagRequest struct {
Name string `json:"name"`
}
type DeleteTagRequest struct {
Name string `json:"name"`
}
func (s *APIV1Service) registerTagRoutes(g *echo.Group) {
g.GET("/tag", s.GetTagList)
g.POST("/tag", s.CreateTag)
g.GET("/tag/suggestion", s.GetTagSuggestion)
g.POST("/tag/delete", s.DeleteTag)
}
// GetTagList godoc
//
// @Summary Get a list of tags
// @Tags tag
// @Produce json
// @Success 200 {object} []string "Tag list"
// @Failure 400 {object} nil "Missing user id to find tag"
// @Failure 500 {object} nil "Failed to find tag list"
// @Router /api/v1/tag [GET]
func (s *APIV1Service) GetTagList(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find tag")
}
list, err := s.Store.ListTags(ctx, &store.FindTag{
CreatorID: userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find tag list").SetInternal(err)
}
tagNameList := []string{}
for _, tag := range list {
tagNameList = append(tagNameList, tag.Name)
}
return c.JSON(http.StatusOK, tagNameList)
}
// CreateTag godoc
//
// @Summary Create a tag
// @Tags tag
// @Accept json
// @Produce json
// @Param body body UpsertTagRequest true "Request object."
// @Success 200 {object} string "Created tag name"
// @Failure 400 {object} nil "Malformatted post tag request | Tag name shouldn't be empty"
// @Failure 401 {object} nil "Missing user in session"
// @Failure 500 {object} nil "Failed to upsert tag | Failed to create activity"
// @Router /api/v1/tag [POST]
func (s *APIV1Service) CreateTag(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
tagUpsert := &UpsertTagRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(tagUpsert); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post tag request").SetInternal(err)
}
if tagUpsert.Name == "" {
return echo.NewHTTPError(http.StatusBadRequest, "Tag name shouldn't be empty")
}
tag, err := s.Store.UpsertTag(ctx, &store.Tag{
Name: tagUpsert.Name,
CreatorID: userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert tag").SetInternal(err)
}
tagMessage := convertTagFromStore(tag)
return c.JSON(http.StatusOK, tagMessage.Name)
}
// DeleteTag godoc
//
// @Summary Delete a tag
// @Tags tag
// @Accept json
// @Produce json
// @Param body body DeleteTagRequest true "Request object."
// @Success 200 {boolean} true "Tag deleted"
// @Failure 400 {object} nil "Malformatted post tag request | Tag name shouldn't be empty"
// @Failure 401 {object} nil "Missing user in session"
// @Failure 500 {object} nil "Failed to delete tag name: %v"
// @Router /api/v1/tag/delete [POST]
func (s *APIV1Service) DeleteTag(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
tagDelete := &DeleteTagRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(tagDelete); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post tag request").SetInternal(err)
}
if tagDelete.Name == "" {
return echo.NewHTTPError(http.StatusBadRequest, "Tag name shouldn't be empty")
}
err := s.Store.DeleteTag(ctx, &store.DeleteTag{
Name: tagDelete.Name,
CreatorID: userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to delete tag name: %v", tagDelete.Name)).SetInternal(err)
}
return c.JSON(http.StatusOK, true)
}
// GetTagSuggestion godoc
//
// @Summary Get a list of tags suggested from other memos contents
// @Tags tag
// @Produce json
// @Success 200 {object} []string "Tag list"
// @Failure 400 {object} nil "Missing user session"
// @Failure 500 {object} nil "Failed to find memo list | Failed to find tag list"
// @Router /api/v1/tag/suggestion [GET]
func (s *APIV1Service) GetTagSuggestion(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user session")
}
normalRowStatus := store.Normal
memoFind := &store.FindMemo{
CreatorID: &userID,
ContentSearch: []string{"#"},
RowStatus: &normalRowStatus,
}
memoMessageList, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err)
}
list, err := s.Store.ListTags(ctx, &store.FindTag{
CreatorID: userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find tag list").SetInternal(err)
}
tagNameList := []string{}
for _, tag := range list {
tagNameList = append(tagNameList, tag.Name)
}
tagMapSet := make(map[string]bool)
for _, memo := range memoMessageList {
for _, tag := range findTagListFromMemoContent(memo.Content) {
if !slices.Contains(tagNameList, tag) {
tagMapSet[tag] = true
}
}
}
tagList := []string{}
for tag := range tagMapSet {
tagList = append(tagList, tag)
}
sort.Strings(tagList)
return c.JSON(http.StatusOK, tagList)
}
func convertTagFromStore(tag *store.Tag) *Tag {
return &Tag{
Name: tag.Name,
CreatorID: tag.CreatorID,
}
}
var tagRegexp = regexp.MustCompile(`#([^\s#,]+)`)
func findTagListFromMemoContent(memoContent string) []string {
tagMapSet := make(map[string]bool)
matches := tagRegexp.FindAllStringSubmatch(memoContent, -1)
for _, v := range matches {
tagName := v[1]
tagMapSet[tagName] = true
}
tagList := []string{}
for tag := range tagMapSet {
tagList = append(tagList, tag)
}
sort.Strings(tagList)
return tagList
}

View File

@@ -0,0 +1,47 @@
package v1
import (
"testing"
)
func TestFindTagListFromMemoContent(t *testing.T) {
tests := []struct {
memoContent string
want []string
}{
{
memoContent: "#tag1 ",
want: []string{"tag1"},
},
{
memoContent: "#tag1 #tag2 ",
want: []string{"tag1", "tag2"},
},
{
memoContent: "#tag1 #tag2 \n#tag3 ",
want: []string{"tag1", "tag2", "tag3"},
},
{
memoContent: "#tag1 #tag2 \n#tag3 #tag4 ",
want: []string{"tag1", "tag2", "tag3", "tag4"},
},
{
memoContent: "#tag1 #tag2 \n#tag3 #tag4 ",
want: []string{"tag1", "tag2", "tag3", "tag4"},
},
{
memoContent: "#tag1 123123#tag2 \n#tag3 #tag4 ",
want: []string{"tag1", "tag2", "tag3", "tag4"},
},
{
memoContent: "#tag1 http://123123.com?123123#tag2 \n#tag3 #tag4 http://123123.com?123123#tag2) ",
want: []string{"tag1", "tag2", "tag2)", "tag3", "tag4"},
},
}
for _, test := range tests {
result := findTagListFromMemoContent(test.memoContent)
if len(result) != len(test.want) {
t.Errorf("Find tag list %s: got result %v, want %v.", test.memoContent, result, test.want)
}
}
}

499
server/route/api/v1/user.go Normal file
View File

@@ -0,0 +1,499 @@
package v1
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/store"
)
// Role is the type of a role.
type Role string
const (
// RoleHost is the HOST role.
RoleHost Role = "HOST"
// RoleAdmin is the ADMIN role.
RoleAdmin Role = "ADMIN"
// RoleUser is the USER role.
RoleUser Role = "USER"
)
func (role Role) String() string {
return string(role)
}
type User struct {
ID int32 `json:"id"`
// Standard fields
RowStatus RowStatus `json:"rowStatus"`
CreatedTs int64 `json:"createdTs"`
UpdatedTs int64 `json:"updatedTs"`
// Domain specific fields
Username string `json:"username"`
Role Role `json:"role"`
Email string `json:"email"`
Nickname string `json:"nickname"`
PasswordHash string `json:"-"`
AvatarURL string `json:"avatarUrl"`
}
type CreateUserRequest struct {
Username string `json:"username"`
Role Role `json:"role"`
Email string `json:"email"`
Nickname string `json:"nickname"`
Password string `json:"password"`
}
type UpdateUserRequest struct {
RowStatus *RowStatus `json:"rowStatus"`
Username *string `json:"username"`
Email *string `json:"email"`
Nickname *string `json:"nickname"`
Password *string `json:"password"`
AvatarURL *string `json:"avatarUrl"`
}
func (s *APIV1Service) registerUserRoutes(g *echo.Group) {
g.GET("/user", s.GetUserList)
g.POST("/user", s.CreateUser)
g.GET("/user/me", s.GetCurrentUser)
// NOTE: This should be moved to /api/v2/user/:username
g.GET("/user/name/:username", s.GetUserByUsername)
g.GET("/user/:id", s.GetUserByID)
g.PATCH("/user/:id", s.UpdateUser)
g.DELETE("/user/:id", s.DeleteUser)
}
// GetUserList godoc
//
// @Summary Get a list of users
// @Tags user
// @Produce json
// @Success 200 {object} []store.User "User list"
// @Failure 500 {object} nil "Failed to fetch user list"
// @Router /api/v1/user [GET]
func (s *APIV1Service) GetUserList(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
currentUser, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user by id").SetInternal(err)
}
if currentUser == nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
if currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized to list users")
}
list, err := s.Store.ListUsers(ctx, &store.FindUser{})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch user list").SetInternal(err)
}
userMessageList := make([]*User, 0, len(list))
for _, user := range list {
userMessage := convertUserFromStore(user)
// data desensitize
userMessage.Email = ""
userMessageList = append(userMessageList, userMessage)
}
return c.JSON(http.StatusOK, userMessageList)
}
// CreateUser godoc
//
// @Summary Create a user
// @Tags user
// @Accept json
// @Produce json
// @Param body body CreateUserRequest true "Request object"
// @Success 200 {object} store.User "Created user"
// @Failure 400 {object} nil "Malformatted post user request | Invalid user create format"
// @Failure 401 {object} nil "Missing auth session | Unauthorized to create user"
// @Failure 403 {object} nil "Could not create host user"
// @Failure 500 {object} nil "Failed to find user by id | Failed to generate password hash | Failed to create user | Failed to create activity"
// @Router /api/v1/user [POST]
func (s *APIV1Service) CreateUser(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
currentUser, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user by id").SetInternal(err)
}
if currentUser == nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
if currentUser.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized to create user")
}
userCreate := &CreateUserRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(userCreate); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post user request").SetInternal(err)
}
if err := userCreate.Validate(); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid user create format").SetInternal(err)
}
if !util.ResourceNameMatcher.MatchString(strings.ToLower(userCreate.Username)) {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid username %s", userCreate.Username)).SetInternal(err)
}
// Disallow host user to be created.
if userCreate.Role == RoleHost {
return echo.NewHTTPError(http.StatusForbidden, "Could not create host user")
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(userCreate.Password), bcrypt.DefaultCost)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err)
}
user, err := s.Store.CreateUser(ctx, &store.User{
Username: userCreate.Username,
Role: store.Role(userCreate.Role),
Email: userCreate.Email,
Nickname: userCreate.Nickname,
PasswordHash: string(passwordHash),
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
}
userMessage := convertUserFromStore(user)
return c.JSON(http.StatusOK, userMessage)
}
// GetCurrentUser godoc
//
// @Summary Get current user
// @Tags user
// @Produce json
// @Success 200 {object} store.User "Current user"
// @Failure 401 {object} nil "Missing auth session"
// @Failure 500 {object} nil "Failed to find user | Failed to find userSettingList"
// @Router /api/v1/user/me [GET]
func (s *APIV1Service) GetCurrentUser(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
userMessage := convertUserFromStore(user)
return c.JSON(http.StatusOK, userMessage)
}
// GetUserByUsername godoc
//
// @Summary Get user by username
// @Tags user
// @Produce json
// @Param username path string true "Username"
// @Success 200 {object} store.User "Requested user"
// @Failure 404 {object} nil "User not found"
// @Failure 500 {object} nil "Failed to find user"
// @Router /api/v1/user/name/{username} [GET]
func (s *APIV1Service) GetUserByUsername(c echo.Context) error {
ctx := c.Request().Context()
username := c.Param("username")
user, err := s.Store.GetUser(ctx, &store.FindUser{Username: &username})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil {
return echo.NewHTTPError(http.StatusNotFound, "User not found")
}
userMessage := convertUserFromStore(user)
// data desensitize
userMessage.Email = ""
return c.JSON(http.StatusOK, userMessage)
}
// GetUserByID godoc
//
// @Summary Get user by id
// @Tags user
// @Produce json
// @Param id path int true "User ID"
// @Success 200 {object} store.User "Requested user"
// @Failure 400 {object} nil "Malformatted user id"
// @Failure 404 {object} nil "User not found"
// @Failure 500 {object} nil "Failed to find user"
// @Router /api/v1/user/{id} [GET]
func (s *APIV1Service) GetUserByID(c echo.Context) error {
ctx := c.Request().Context()
id, err := util.ConvertStringToInt32(c.Param("id"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted user id").SetInternal(err)
}
user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &id})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil {
return echo.NewHTTPError(http.StatusNotFound, "User not found")
}
userMessage := convertUserFromStore(user)
userID, ok := c.Get(userIDContextKey).(int32)
if !ok || userID != user.ID {
// Data desensitize.
userMessage.Email = ""
}
return c.JSON(http.StatusOK, userMessage)
}
// DeleteUser godoc
//
// @Summary Delete a user
// @Tags user
// @Produce json
// @Param id path string true "User ID"
// @Success 200 {boolean} true "User deleted"
// @Failure 400 {object} nil "ID is not a number: %s | Current session user not found with ID: %d"
// @Failure 401 {object} nil "Missing user in session"
// @Failure 403 {object} nil "Unauthorized to delete user"
// @Failure 500 {object} nil "Failed to find user | Failed to delete user"
// @Router /api/v1/user/{id} [DELETE]
func (s *APIV1Service) DeleteUser(c echo.Context) error {
ctx := c.Request().Context()
currentUserID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
currentUser, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &currentUserID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if currentUser == nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Current session user not found with ID: %d", currentUserID)).SetInternal(err)
} else if currentUser.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusForbidden, "Unauthorized to delete user").SetInternal(err)
}
userID, err := util.ConvertStringToInt32(c.Param("id"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("id"))).SetInternal(err)
}
if currentUserID == userID {
return echo.NewHTTPError(http.StatusBadRequest, "Cannot delete current user")
}
findUser, err := s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if s.Profile.Mode == "demo" && findUser.Username == "memos-demo" {
return echo.NewHTTPError(http.StatusForbidden, "Unauthorized to delete this user in demo mode")
}
if err := s.Store.DeleteUser(ctx, &store.DeleteUser{
ID: userID,
}); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete user").SetInternal(err)
}
return c.JSON(http.StatusOK, true)
}
// UpdateUser godoc
//
// @Summary Update a user
// @Tags user
// @Produce json
// @Param id path string true "User ID"
// @Param patch body UpdateUserRequest true "Patch request"
// @Success 200 {object} store.User "Updated user"
// @Failure 400 {object} nil "ID is not a number: %s | Current session user not found with ID: %d | Malformatted patch user request | Invalid update user request"
// @Failure 401 {object} nil "Missing user in session"
// @Failure 403 {object} nil "Unauthorized to update user"
// @Failure 500 {object} nil "Failed to find user | Failed to generate password hash | Failed to patch user | Failed to find userSettingList"
// @Router /api/v1/user/{id} [PATCH]
func (s *APIV1Service) UpdateUser(c echo.Context) error {
ctx := c.Request().Context()
userID, err := util.ConvertStringToInt32(c.Param("id"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("id"))).SetInternal(err)
}
currentUserID, ok := c.Get(userIDContextKey).(int32)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
currentUser, err := s.Store.GetUser(ctx, &store.FindUser{ID: &currentUserID})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if currentUser == nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Current session user not found with ID: %d", currentUserID)).SetInternal(err)
} else if currentUser.Role != store.RoleHost && currentUserID != userID {
return echo.NewHTTPError(http.StatusForbidden, "Unauthorized to update user").SetInternal(err)
}
request := &UpdateUserRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(request); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch user request").SetInternal(err)
}
if err := request.Validate(); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid update user request").SetInternal(err)
}
if s.Profile.Mode == "demo" && *request.Username == "memos-demo" {
return echo.NewHTTPError(http.StatusForbidden, "Unauthorized to update user in demo mode")
}
currentTs := time.Now().Unix()
userUpdate := &store.UpdateUser{
ID: userID,
UpdatedTs: &currentTs,
}
if request.RowStatus != nil {
rowStatus := store.RowStatus(request.RowStatus.String())
userUpdate.RowStatus = &rowStatus
if rowStatus == store.Archived && currentUserID == userID {
return echo.NewHTTPError(http.StatusBadRequest, "Cannot archive current user")
}
}
if request.Username != nil {
if !util.ResourceNameMatcher.MatchString(strings.ToLower(*request.Username)) {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid username %s", *request.Username)).SetInternal(err)
}
userUpdate.Username = request.Username
}
if request.Email != nil {
userUpdate.Email = request.Email
}
if request.Nickname != nil {
userUpdate.Nickname = request.Nickname
}
if request.Password != nil {
passwordHash, err := bcrypt.GenerateFromPassword([]byte(*request.Password), bcrypt.DefaultCost)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err)
}
passwordHashStr := string(passwordHash)
userUpdate.PasswordHash = &passwordHashStr
}
if request.AvatarURL != nil {
userUpdate.AvatarURL = request.AvatarURL
}
user, err := s.Store.UpdateUser(ctx, userUpdate)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch user").SetInternal(err)
}
userMessage := convertUserFromStore(user)
return c.JSON(http.StatusOK, userMessage)
}
func (create CreateUserRequest) Validate() error {
if len(create.Username) < 3 {
return errors.New("username is too short, minimum length is 3")
}
if len(create.Username) > 32 {
return errors.New("username is too long, maximum length is 32")
}
if len(create.Password) < 3 {
return errors.New("password is too short, minimum length is 3")
}
if len(create.Password) > 512 {
return errors.New("password is too long, maximum length is 512")
}
if len(create.Nickname) > 64 {
return errors.New("nickname is too long, maximum length is 64")
}
if create.Email != "" {
if len(create.Email) > 256 {
return errors.New("email is too long, maximum length is 256")
}
if !util.ValidateEmail(create.Email) {
return errors.New("invalid email format")
}
}
return nil
}
func (update UpdateUserRequest) Validate() error {
if update.Username != nil && len(*update.Username) < 3 {
return errors.New("username is too short, minimum length is 3")
}
if update.Username != nil && len(*update.Username) > 32 {
return errors.New("username is too long, maximum length is 32")
}
if update.Password != nil && len(*update.Password) < 3 {
return errors.New("password is too short, minimum length is 3")
}
if update.Password != nil && len(*update.Password) > 512 {
return errors.New("password is too long, maximum length is 512")
}
if update.Nickname != nil && len(*update.Nickname) > 64 {
return errors.New("nickname is too long, maximum length is 64")
}
if update.AvatarURL != nil {
if len(*update.AvatarURL) > 2<<20 {
return errors.New("avatar is too large, maximum is 2MB")
}
}
if update.Email != nil && *update.Email != "" {
if len(*update.Email) > 256 {
return errors.New("email is too long, maximum length is 256")
}
if !util.ValidateEmail(*update.Email) {
return errors.New("invalid email format")
}
}
return nil
}
func convertUserFromStore(user *store.User) *User {
return &User{
ID: user.ID,
RowStatus: RowStatus(user.RowStatus),
CreatedTs: user.CreatedTs,
UpdatedTs: user.UpdatedTs,
Username: user.Username,
Role: Role(user.Role),
Email: user.Email,
Nickname: user.Nickname,
PasswordHash: user.PasswordHash,
AvatarURL: user.AvatarURL,
}
}

95
server/route/api/v1/v1.go Normal file
View File

@@ -0,0 +1,95 @@
package v1
import (
"net/http"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/usememos/memos/plugin/telegram"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/server/route/resource"
"github.com/usememos/memos/server/route/rss"
"github.com/usememos/memos/store"
)
type APIV1Service struct {
Secret string
Profile *profile.Profile
Store *store.Store
telegramBot *telegram.Bot
}
// @title memos API
// @version 1.0
// @description A privacy-first, lightweight note-taking service.
//
// @contact.name API Support
// @contact.url https://github.com/orgs/usememos/discussions
//
// @license.name MIT License
// @license.url https://github.com/usememos/memos/blob/main/LICENSE
//
// @BasePath /
//
// @externalDocs.url https://usememos.com/
// @externalDocs.description Find out more about Memos.
func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store, telegramBot *telegram.Bot) *APIV1Service {
return &APIV1Service{
Secret: secret,
Profile: profile,
Store: store,
telegramBot: telegramBot,
}
}
func (s *APIV1Service) Register(rootGroup *echo.Group) {
// Register API v1 routes.
apiV1Group := rootGroup.Group("/api/v1")
apiV1Group.Use(middleware.RateLimiterWithConfig(middleware.RateLimiterConfig{
Store: middleware.NewRateLimiterMemoryStoreWithConfig(
middleware.RateLimiterMemoryStoreConfig{Rate: 30, Burst: 100, ExpiresIn: 3 * time.Minute},
),
IdentifierExtractor: func(ctx echo.Context) (string, error) {
id := ctx.RealIP()
return id, nil
},
ErrorHandler: func(context echo.Context, err error) error {
return context.JSON(http.StatusForbidden, nil)
},
DenyHandler: func(context echo.Context, identifier string, err error) error {
return context.JSON(http.StatusTooManyRequests, nil)
},
}))
apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, s.Secret)
})
s.registerSystemRoutes(apiV1Group)
s.registerSystemSettingRoutes(apiV1Group)
s.registerAuthRoutes(apiV1Group)
s.registerIdentityProviderRoutes(apiV1Group)
s.registerUserRoutes(apiV1Group)
s.registerTagRoutes(apiV1Group)
s.registerStorageRoutes(apiV1Group)
s.registerResourceRoutes(apiV1Group)
s.registerMemoRoutes(apiV1Group)
s.registerMemoOrganizerRoutes(apiV1Group)
s.registerMemoRelationRoutes(apiV1Group)
// Register public routes.
publicGroup := rootGroup.Group("/o")
publicGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, s.Secret)
})
s.registerGetterPublicRoutes(publicGroup)
// Create and register resource public routes.
resource.NewResourceService(s.Profile, s.Store).RegisterRoutes(publicGroup)
// Create and register rss public routes.
rss.NewRSSService(s.Profile, s.Store).RegisterRoutes(rootGroup)
// programmatically set API version same as the server version
SwaggerInfo.Version = s.Profile.Version
}

162
server/route/api/v2/acl.go Normal file
View File

@@ -0,0 +1,162 @@
package v2
import (
"context"
"net/http"
"strings"
"github.com/golang-jwt/jwt/v5"
"github.com/pkg/errors"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/usememos/memos/internal/util"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/route/api/auth"
"github.com/usememos/memos/store"
)
// ContextKey is the key type of context value.
type ContextKey int
const (
// The key name used to store username in the context
// user id is extracted from the jwt token subject field.
usernameContextKey ContextKey = iota
)
// GRPCAuthInterceptor is the auth interceptor for gRPC server.
type GRPCAuthInterceptor struct {
Store *store.Store
secret string
}
// NewGRPCAuthInterceptor returns a new API auth interceptor.
func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor {
return &GRPCAuthInterceptor{
Store: store,
secret: secret,
}
}
// AuthenticationInterceptor is the unary interceptor for gRPC API.
func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
}
accessToken, err := getTokenFromMetadata(md)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, err.Error())
}
username, err := in.authenticate(ctx, accessToken)
if err != nil {
if isUnauthorizeAllowedMethod(serverInfo.FullMethod) {
return handler(ctx, request)
}
return nil, err
}
user, err := in.Store.GetUser(ctx, &store.FindUser{
Username: &username,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get user")
}
if user == nil {
return nil, errors.Errorf("user %q not exists", username)
}
if user.RowStatus == store.Archived {
return nil, errors.Errorf("user %q is archived", username)
}
if isOnlyForAdminAllowedMethod(serverInfo.FullMethod) && user.Role != store.RoleHost && user.Role != store.RoleAdmin {
return nil, errors.Errorf("user %q is not admin", username)
}
// Stores userID into context.
childCtx := context.WithValue(ctx, usernameContextKey, username)
return handler(childCtx, request)
}
func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessToken string) (string, error) {
if accessToken == "" {
return "", status.Errorf(codes.Unauthenticated, "access token not found")
}
claims := &auth.ClaimsMessage{}
_, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(in.secret), nil
}
}
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
return "", status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
}
// We either have a valid access token or we will attempt to generate new access token.
userID, err := util.ConvertStringToInt32(claims.Subject)
if err != nil {
return "", errors.Wrap(err, "malformed ID in the token")
}
user, err := in.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return "", errors.Wrap(err, "failed to get user")
}
if user == nil {
return "", errors.Errorf("user %q not exists", userID)
}
if user.RowStatus == store.Archived {
return "", errors.Errorf("user %q is archived", userID)
}
accessTokens, err := in.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return "", errors.Wrapf(err, "failed to get user access tokens")
}
if !validateAccessToken(accessToken, accessTokens) {
return "", status.Errorf(codes.Unauthenticated, "invalid access token")
}
return user.Username, nil
}
func getTokenFromMetadata(md metadata.MD) (string, error) {
// Check the HTTP request header first.
authorizationHeaders := md.Get("Authorization")
if len(md.Get("Authorization")) > 0 {
authHeaderParts := strings.Fields(authorizationHeaders[0])
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
// Check the cookie header.
var accessToken string
for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) {
header := http.Header{}
header.Add("Cookie", t)
request := http.Request{Header: header}
if v, _ := request.Cookie(auth.AccessTokenCookieName); v != nil {
accessToken = v.Value
}
}
return accessToken, nil
}
func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
for _, userAccessToken := range userAccessTokens {
if accessTokenString == userAccessToken.AccessToken {
return true
}
}
return false
}

View File

@@ -0,0 +1,37 @@
package v2
import "strings"
var authenticationAllowlistMethods = map[string]bool{
"/memos.api.v2.WorkspaceService/GetWorkspaceProfile": true,
"/memos.api.v2.WorkspaceSettingService/GetWorkspaceSetting": true,
"/memos.api.v2.AuthService/GetAuthStatus": true,
"/memos.api.v2.AuthService/SignIn": true,
"/memos.api.v2.AuthService/SignInWithSSO": true,
"/memos.api.v2.AuthService/SignOut": true,
"/memos.api.v2.AuthService/SignUp": true,
"/memos.api.v2.UserService/GetUser": true,
"/memos.api.v2.MemoService/ListMemos": true,
"/memos.api.v2.MemoService/GetMemo": true,
"/memos.api.v2.MemoService/GetMemoByName": true,
"/memos.api.v2.MemoService/ListMemoResources": true,
"/memos.api.v2.MemoService/ListMemoRelations": true,
"/memos.api.v2.MemoService/ListMemoComments": true,
}
// isUnauthorizeAllowedMethod returns whether the method is exempted from authentication.
func isUnauthorizeAllowedMethod(fullMethodName string) bool {
if strings.HasPrefix(fullMethodName, "/grpc.reflection") {
return true
}
return authenticationAllowlistMethods[fullMethodName]
}
var allowedMethodsOnlyForAdmin = map[string]bool{
"/memos.api.v2.UserService/CreateUser": true,
}
// isOnlyForAdminAllowedMethod returns true if the method is allowed to be called only by admin.
func isOnlyForAdminAllowedMethod(methodName string) bool {
return allowedMethodsOnlyForAdmin[methodName]
}

View File

@@ -0,0 +1,58 @@
package v2
import (
"context"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV2Service) GetActivity(ctx context.Context, request *apiv2pb.GetActivityRequest) (*apiv2pb.GetActivityResponse, error) {
activity, err := s.Store.GetActivity(ctx, &store.FindActivity{
ID: &request.Id,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get activity: %v", err)
}
activityMessage, err := s.convertActivityFromStore(ctx, activity)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert activity from store: %v", err)
}
return &apiv2pb.GetActivityResponse{
Activity: activityMessage,
}, nil
}
func (*APIV2Service) convertActivityFromStore(_ context.Context, activity *store.Activity) (*apiv2pb.Activity, error) {
return &apiv2pb.Activity{
Id: activity.ID,
CreatorId: activity.CreatorID,
Type: activity.Type.String(),
Level: activity.Level.String(),
CreateTime: timestamppb.New(time.Unix(activity.CreatedTs, 0)),
Payload: convertActivityPayloadFromStore(activity.Payload),
}, nil
}
func convertActivityPayloadFromStore(payload *storepb.ActivityPayload) *apiv2pb.ActivityPayload {
v2Payload := &apiv2pb.ActivityPayload{}
if payload.MemoComment != nil {
v2Payload.MemoComment = &apiv2pb.ActivityMemoCommentPayload{
MemoId: payload.MemoComment.MemoId,
RelatedMemoId: payload.MemoComment.RelatedMemoId,
}
}
if payload.VersionUpdate != nil {
v2Payload.VersionUpdate = &apiv2pb.ActivityVersionUpdatePayload{
Version: payload.VersionUpdate.Version,
}
}
return v2Payload
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,262 @@
package v2
import (
"context"
"fmt"
"regexp"
"strings"
"time"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/usememos/memos/server/route/api/auth"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/idp"
"github.com/usememos/memos/plugin/idp/oauth2"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
"github.com/usememos/memos/store"
)
func (s *APIV2Service) GetAuthStatus(ctx context.Context, _ *apiv2pb.GetAuthStatusRequest) (*apiv2pb.GetAuthStatusResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err)
}
if user == nil {
// Set the cookie header to expire access token.
if err := s.clearAccessTokenCookie(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to set grpc header")
}
return nil, status.Errorf(codes.Unauthenticated, "user not found")
}
return &apiv2pb.GetAuthStatusResponse{
User: convertUserFromStore(user),
}, nil
}
func (s *APIV2Service) SignIn(ctx context.Context, request *apiv2pb.SignInRequest) (*apiv2pb.SignInResponse, error) {
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &request.Username,
})
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to find user by username %s", request.Username))
}
if user == nil {
return nil, status.Errorf(codes.InvalidArgument, fmt.Sprintf("user not found with username %s", request.Username))
} else if user.RowStatus == store.Archived {
return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("user has been archived with username %s", request.Username))
}
// Compare the stored hashed password, with the hashed version of the password that was received.
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(request.Password)); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "unmatched email and password")
}
expireTime := time.Now().Add(auth.AccessTokenDuration)
if request.NeverExpire {
// Set the expire time to 100 years.
expireTime = time.Now().Add(100 * 365 * 24 * time.Hour)
}
if err := s.doSignIn(ctx, user, expireTime); err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
}
return &apiv2pb.SignInResponse{
User: convertUserFromStore(user),
}, nil
}
func (s *APIV2Service) SignInWithSSO(ctx context.Context, request *apiv2pb.SignInWithSSORequest) (*apiv2pb.SignInWithSSOResponse, error) {
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &request.IdpId,
})
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get identity provider, err: %s", err))
}
if identityProvider == nil {
return nil, status.Errorf(codes.InvalidArgument, fmt.Sprintf("identity provider not found with id %d", request.IdpId))
}
var userInfo *idp.IdentityProviderUserInfo
if identityProvider.Type == store.IdentityProviderOAuth2Type {
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.OAuth2Config)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create oauth2 identity provider, err: %s", err))
}
token, err := oauth2IdentityProvider.ExchangeToken(ctx, request.RedirectUri, request.Code)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to exchange token, err: %s", err))
}
userInfo, err = oauth2IdentityProvider.UserInfo(token)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get user info, err: %s", err))
}
}
identifierFilter := identityProvider.IdentifierFilter
if identifierFilter != "" {
identifierFilterRegex, err := regexp.Compile(identifierFilter)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to compile identifier filter regex, err: %s", err))
}
if !identifierFilterRegex.MatchString(userInfo.Identifier) {
return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("identifier %s is not allowed", userInfo.Identifier))
}
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &userInfo.Identifier,
})
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to find user by username %s", userInfo.Identifier))
}
if user == nil {
userCreate := &store.User{
Username: userInfo.Identifier,
// The new signup user should be normal user by default.
Role: store.RoleUser,
Nickname: userInfo.DisplayName,
Email: userInfo.Email,
}
password, err := util.RandomString(20)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to generate random password, err: %s", err))
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to generate password hash, err: %s", err))
}
userCreate.PasswordHash = string(passwordHash)
user, err = s.Store.CreateUser(ctx, userCreate)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create user, err: %s", err))
}
}
if user.RowStatus == store.Archived {
return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("user has been archived with username %s", userInfo.Identifier))
}
if err := s.doSignIn(ctx, user, time.Now().Add(auth.AccessTokenDuration)); err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
}
return &apiv2pb.SignInWithSSOResponse{
User: convertUserFromStore(user),
}, nil
}
func (s *APIV2Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error {
accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret))
if err != nil {
return status.Errorf(codes.Internal, fmt.Sprintf("failed to generate tokens, err: %s", err))
}
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken, "user login"); err != nil {
return status.Errorf(codes.Internal, fmt.Sprintf("failed to upsert access token to store, err: %s", err))
}
cookie, err := s.buildAccessTokenCookie(ctx, accessToken, expireTime)
if err != nil {
return status.Errorf(codes.Internal, fmt.Sprintf("failed to build access token cookie, err: %s", err))
}
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
"Set-Cookie": cookie,
})); err != nil {
return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
}
return nil
}
func (s *APIV2Service) SignUp(ctx context.Context, request *apiv2pb.SignUpRequest) (*apiv2pb.SignUpResponse, error) {
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get workspace setting, err: %s", err))
}
if workspaceGeneralSetting.DisallowSignup || workspaceGeneralSetting.DisallowPasswordLogin {
return nil, status.Errorf(codes.PermissionDenied, "sign up is not allowed")
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.Password), bcrypt.DefaultCost)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to generate password hash, err: %s", err))
}
create := &store.User{
Username: request.Username,
Nickname: request.Username,
PasswordHash: string(passwordHash),
}
hostUserType := store.RoleHost
existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{
Role: &hostUserType,
})
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to list users, err: %s", err))
}
if len(existedHostUsers) == 0 {
// Change the default role to host if there is no host user.
create.Role = store.RoleHost
} else {
create.Role = store.RoleUser
}
user, err := s.Store.CreateUser(ctx, create)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create user, err: %s", err))
}
if err := s.doSignIn(ctx, user, time.Now().Add(auth.AccessTokenDuration)); err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
}
return &apiv2pb.SignUpResponse{
User: convertUserFromStore(user),
}, nil
}
func (s *APIV2Service) SignOut(ctx context.Context, _ *apiv2pb.SignOutRequest) (*apiv2pb.SignOutResponse, error) {
if err := s.clearAccessTokenCookie(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
}
return &apiv2pb.SignOutResponse{}, nil
}
func (s *APIV2Service) clearAccessTokenCookie(ctx context.Context) error {
cookie, err := s.buildAccessTokenCookie(ctx, "", time.Time{})
if err != nil {
return errors.Wrap(err, "failed to build access token cookie")
}
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
"Set-Cookie": cookie,
})); err != nil {
return errors.Wrap(err, "failed to set grpc header")
}
return nil
}
func (s *APIV2Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) {
attrs := []string{
fmt.Sprintf("%s=%s", auth.AccessTokenCookieName, accessToken),
"Path=/",
"HttpOnly",
}
if expireTime.IsZero() {
attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT")
} else {
attrs = append(attrs, "Expires="+expireTime.Format(time.RFC1123))
}
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return "", errors.Wrap(err, "failed to get workspace setting")
}
if strings.HasPrefix(workspaceGeneralSetting.InstanceUrl, "https://") {
attrs = append(attrs, "SameSite=None")
attrs = append(attrs, "Secure")
} else {
attrs = append(attrs, "SameSite=Strict")
}
return strings.Join(attrs, "; "), nil
}

View File

@@ -0,0 +1,74 @@
package v2
import (
"context"
"encoding/base64"
"github.com/pkg/errors"
"google.golang.org/protobuf/proto"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
"github.com/usememos/memos/store"
)
func convertRowStatusFromStore(rowStatus store.RowStatus) apiv2pb.RowStatus {
switch rowStatus {
case store.Normal:
return apiv2pb.RowStatus_ACTIVE
case store.Archived:
return apiv2pb.RowStatus_ARCHIVED
default:
return apiv2pb.RowStatus_ROW_STATUS_UNSPECIFIED
}
}
func convertRowStatusToStore(rowStatus apiv2pb.RowStatus) store.RowStatus {
switch rowStatus {
case apiv2pb.RowStatus_ACTIVE:
return store.Normal
case apiv2pb.RowStatus_ARCHIVED:
return store.Archived
default:
return store.Normal
}
}
func getCurrentUser(ctx context.Context, s *store.Store) (*store.User, error) {
username, ok := ctx.Value(usernameContextKey).(string)
if !ok {
return nil, nil
}
user, err := s.GetUser(ctx, &store.FindUser{
Username: &username,
})
if err != nil {
return nil, err
}
return user, nil
}
func getPageToken(limit int, offset int) (string, error) {
return marshalPageToken(&apiv2pb.PageToken{
Limit: int32(limit),
Offset: int32(offset),
})
}
func marshalPageToken(pageToken *apiv2pb.PageToken) (string, error) {
b, err := proto.Marshal(pageToken)
if err != nil {
return "", errors.Wrapf(err, "failed to marshal page token")
}
return base64.StdEncoding.EncodeToString(b), nil
}
func unmarshalPageToken(s string, pageToken *apiv2pb.PageToken) error {
b, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return errors.Wrapf(err, "failed to decode page token")
}
if err := proto.Unmarshal(b, pageToken); err != nil {
return errors.Wrapf(err, "failed to unmarshal page token")
}
return nil
}

View File

@@ -0,0 +1,138 @@
package v2
import (
"context"
"fmt"
"time"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
"github.com/usememos/memos/store"
)
func (s *APIV2Service) ListInboxes(ctx context.Context, _ *apiv2pb.ListInboxesRequest) (*apiv2pb.ListInboxesResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
inboxes, err := s.Store.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list inbox: %v", err)
}
response := &apiv2pb.ListInboxesResponse{
Inboxes: []*apiv2pb.Inbox{},
}
for _, inbox := range inboxes {
inboxMessage, err := s.convertInboxFromStore(ctx, inbox)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert inbox from store: %v", err)
}
response.Inboxes = append(response.Inboxes, inboxMessage)
}
return response, nil
}
func (s *APIV2Service) UpdateInbox(ctx context.Context, request *apiv2pb.UpdateInboxRequest) (*apiv2pb.UpdateInboxResponse, error) {
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
inboxID, err := ExtractInboxIDFromName(request.Inbox.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid inbox name: %v", err)
}
update := &store.UpdateInbox{
ID: inboxID,
}
for _, field := range request.UpdateMask.Paths {
if field == "status" {
if request.Inbox.Status == apiv2pb.Inbox_STATUS_UNSPECIFIED {
return nil, status.Errorf(codes.InvalidArgument, "status is required")
}
update.Status = convertInboxStatusToStore(request.Inbox.Status)
}
}
inbox, err := s.Store.UpdateInbox(ctx, update)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update inbox: %v", err)
}
inboxMessage, err := s.convertInboxFromStore(ctx, inbox)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert inbox from store: %v", err)
}
return &apiv2pb.UpdateInboxResponse{
Inbox: inboxMessage,
}, nil
}
func (s *APIV2Service) DeleteInbox(ctx context.Context, request *apiv2pb.DeleteInboxRequest) (*apiv2pb.DeleteInboxResponse, error) {
inboxID, err := ExtractInboxIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid inbox name: %v", err)
}
if err := s.Store.DeleteInbox(ctx, &store.DeleteInbox{
ID: inboxID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update inbox: %v", err)
}
return &apiv2pb.DeleteInboxResponse{}, nil
}
func (s *APIV2Service) convertInboxFromStore(ctx context.Context, inbox *store.Inbox) (*apiv2pb.Inbox, error) {
sender, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &inbox.SenderID,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get sender")
}
receiver, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &inbox.ReceiverID,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get receiver")
}
return &apiv2pb.Inbox{
Name: fmt.Sprintf("inboxes/%d", inbox.ID),
Sender: fmt.Sprintf("users/%s", sender.Username),
Receiver: fmt.Sprintf("users/%s", receiver.Username),
Status: convertInboxStatusFromStore(inbox.Status),
CreateTime: timestamppb.New(time.Unix(inbox.CreatedTs, 0)),
Type: apiv2pb.Inbox_Type(inbox.Message.Type),
ActivityId: inbox.Message.ActivityId,
}, nil
}
func convertInboxStatusFromStore(status store.InboxStatus) apiv2pb.Inbox_Status {
switch status {
case store.UNREAD:
return apiv2pb.Inbox_UNREAD
case store.ARCHIVED:
return apiv2pb.Inbox_ARCHIVED
default:
return apiv2pb.Inbox_STATUS_UNSPECIFIED
}
}
func convertInboxStatusToStore(status apiv2pb.Inbox_Status) store.InboxStatus {
switch status {
case apiv2pb.Inbox_UNREAD:
return store.UNREAD
case apiv2pb.Inbox_ARCHIVED:
return store.ARCHIVED
default:
return store.UNREAD
}
}

View File

@@ -0,0 +1,100 @@
package v2
import (
"context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
"github.com/usememos/memos/store"
)
func (s *APIV2Service) SetMemoRelations(ctx context.Context, request *apiv2pb.SetMemoRelationsRequest) (*apiv2pb.SetMemoRelationsResponse, error) {
referenceType := store.MemoRelationReference
// Delete all reference relations first.
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
MemoID: &request.Id,
Type: &referenceType,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo relation")
}
for _, relation := range request.Relations {
// Ignore reflexive relations.
if request.Id == relation.RelatedMemoId {
continue
}
// Ignore comment relations as there's no need to update a comment's relation.
// Inserting/Deleting a comment is handled elsewhere.
if relation.Type == apiv2pb.MemoRelation_COMMENT {
continue
}
if _, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: request.Id,
RelatedMemoID: relation.RelatedMemoId,
Type: convertMemoRelationTypeToStore(relation.Type),
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert memo relation")
}
}
return &apiv2pb.SetMemoRelationsResponse{}, nil
}
func (s *APIV2Service) ListMemoRelations(ctx context.Context, request *apiv2pb.ListMemoRelationsRequest) (*apiv2pb.ListMemoRelationsResponse, error) {
relationList := []*apiv2pb.MemoRelation{}
tempList, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &request.Id,
})
if err != nil {
return nil, err
}
for _, relation := range tempList {
relationList = append(relationList, convertMemoRelationFromStore(relation))
}
tempList, err = s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &request.Id,
})
if err != nil {
return nil, err
}
for _, relation := range tempList {
relationList = append(relationList, convertMemoRelationFromStore(relation))
}
response := &apiv2pb.ListMemoRelationsResponse{
Relations: relationList,
}
return response, nil
}
func convertMemoRelationFromStore(memoRelation *store.MemoRelation) *apiv2pb.MemoRelation {
return &apiv2pb.MemoRelation{
MemoId: memoRelation.MemoID,
RelatedMemoId: memoRelation.RelatedMemoID,
Type: convertMemoRelationTypeFromStore(memoRelation.Type),
}
}
func convertMemoRelationTypeFromStore(relationType store.MemoRelationType) apiv2pb.MemoRelation_Type {
switch relationType {
case store.MemoRelationReference:
return apiv2pb.MemoRelation_REFERENCE
case store.MemoRelationComment:
return apiv2pb.MemoRelation_COMMENT
default:
return apiv2pb.MemoRelation_TYPE_UNSPECIFIED
}
}
func convertMemoRelationTypeToStore(relationType apiv2pb.MemoRelation_Type) store.MemoRelationType {
switch relationType {
case apiv2pb.MemoRelation_REFERENCE:
return store.MemoRelationReference
case apiv2pb.MemoRelation_COMMENT:
return store.MemoRelationComment
default:
return store.MemoRelationReference
}
}

View File

@@ -0,0 +1,73 @@
package v2
import (
"context"
"slices"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
"github.com/usememos/memos/store"
)
func (s *APIV2Service) SetMemoResources(ctx context.Context, request *apiv2pb.SetMemoResourcesRequest) (*apiv2pb.SetMemoResourcesResponse, error) {
resources, err := s.Store.ListResources(ctx, &store.FindResource{
MemoID: &request.Id,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list resources")
}
// Delete resources that are not in the request.
for _, resource := range resources {
found := false
for _, requestResource := range request.Resources {
if resource.ID == int32(requestResource.Id) {
found = true
break
}
}
if !found {
if err = s.Store.DeleteResource(ctx, &store.DeleteResource{
ID: int32(resource.ID),
MemoID: &request.Id,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete resource")
}
}
}
slices.Reverse(request.Resources)
// Update resources' memo_id in the request.
for index, resource := range request.Resources {
updatedTs := time.Now().Unix() + int64(index)
if _, err := s.Store.UpdateResource(ctx, &store.UpdateResource{
ID: resource.Id,
MemoID: &request.Id,
UpdatedTs: &updatedTs,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update resource: %v", err)
}
}
return &apiv2pb.SetMemoResourcesResponse{}, nil
}
func (s *APIV2Service) ListMemoResources(ctx context.Context, request *apiv2pb.ListMemoResourcesRequest) (*apiv2pb.ListMemoResourcesResponse, error) {
resources, err := s.Store.ListResources(ctx, &store.FindResource{
MemoID: &request.Id,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list resources")
}
response := &apiv2pb.ListMemoResourcesResponse{
Resources: []*apiv2pb.Resource{},
}
for _, resource := range resources {
response.Resources = append(response.Resources, s.convertResourceFromStore(ctx, resource))
}
return response, nil
}

View File

@@ -0,0 +1,856 @@
package v2
import (
"archive/zip"
"bytes"
"context"
"encoding/json"
"fmt"
"time"
"github.com/google/cel-go/cel"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors"
"go.uber.org/zap"
expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/internal/log"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/webhook"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
apiv1 "github.com/usememos/memos/server/route/api/v1"
"github.com/usememos/memos/store"
)
const (
DefaultPageSize = 10
MaxContentLength = 8 * 1024
ChunkSize = 64 * 1024 // 64 KiB
)
func (s *APIV2Service) CreateMemo(ctx context.Context, request *apiv2pb.CreateMemoRequest) (*apiv2pb.CreateMemoResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if len(request.Content) > MaxContentLength {
return nil, status.Errorf(codes.InvalidArgument, "content too long")
}
create := &store.Memo{
ResourceName: shortuuid.New(),
CreatorID: user.ID,
Content: request.Content,
Visibility: convertVisibilityToStore(request.Visibility),
}
// Find disable public memos system setting.
disablePublicMemosSystem, err := s.getDisablePublicMemosSystemSettingValue(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get system setting")
}
if disablePublicMemosSystem && create.Visibility == store.Public {
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
}
memo, err := s.Store.CreateMemo(ctx, create)
if err != nil {
return nil, err
}
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
// Try to dispatch webhook when memo is created.
if err := s.DispatchMemoCreatedWebhook(ctx, memoMessage); err != nil {
log.Warn("Failed to dispatch memo created webhook", zap.Error(err))
}
response := &apiv2pb.CreateMemoResponse{
Memo: memoMessage,
}
return response, nil
}
func (s *APIV2Service) ListMemos(ctx context.Context, request *apiv2pb.ListMemosRequest) (*apiv2pb.ListMemosResponse, error) {
memoFind := &store.FindMemo{
// Exclude comments by default.
ExcludeComments: true,
}
if err := s.buildMemoFindWithFilter(ctx, memoFind, request.Filter); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to build find memos with filter")
}
var limit, offset int
if request.PageToken != "" {
var pageToken apiv2pb.PageToken
if err := unmarshalPageToken(request.PageToken, &pageToken); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid page token: %v", err)
}
limit = int(pageToken.Limit)
offset = int(pageToken.Offset)
} else {
limit = int(request.PageSize)
}
if limit <= 0 {
limit = DefaultPageSize
}
limitPlusOne := limit + 1
memoFind.Limit = &limitPlusOne
memoFind.Offset = &offset
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos")
}
memoMessages := []*apiv2pb.Memo{}
nextPageToken := ""
if len(memos) == limitPlusOne {
memos = memos[:limit]
nextPageToken, err = getPageToken(limit, offset+limit)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get next page token, error: %v", err)
}
}
for _, memo := range memos {
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
memoMessages = append(memoMessages, memoMessage)
}
response := &apiv2pb.ListMemosResponse{
Memos: memoMessages,
NextPageToken: nextPageToken,
}
return response, nil
}
func (s *APIV2Service) GetMemo(ctx context.Context, request *apiv2pb.GetMemoRequest) (*apiv2pb.GetMemoResponse, error) {
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: &request.Id,
})
if err != nil {
return nil, err
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if memo.Visibility != store.Public {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
response := &apiv2pb.GetMemoResponse{
Memo: memoMessage,
}
return response, nil
}
func (s *APIV2Service) GetMemoByName(ctx context.Context, request *apiv2pb.GetMemoByNameRequest) (*apiv2pb.GetMemoByNameResponse, error) {
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ResourceName: &request.Name,
})
if err != nil {
return nil, err
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if memo.Visibility != store.Public {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
response := &apiv2pb.GetMemoByNameResponse{
Memo: memoMessage,
}
return response, nil
}
func (s *APIV2Service) UpdateMemo(ctx context.Context, request *apiv2pb.UpdateMemoRequest) (*apiv2pb.UpdateMemoResponse, error) {
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: &request.Memo.Id,
})
if err != nil {
return nil, err
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
user, _ := getCurrentUser(ctx, s.Store)
if memo.CreatorID != user.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
currentTs := time.Now().Unix()
update := &store.UpdateMemo{
ID: request.Memo.Id,
UpdatedTs: &currentTs,
}
for _, path := range request.UpdateMask.Paths {
if path == "content" {
update.Content = &request.Memo.Content
} else if path == "resource_name" {
update.ResourceName = &request.Memo.Name
if !util.ResourceNameMatcher.MatchString(*update.ResourceName) {
return nil, status.Errorf(codes.InvalidArgument, "invalid resource name")
}
} else if path == "visibility" {
visibility := convertVisibilityToStore(request.Memo.Visibility)
// Find disable public memos system setting.
disablePublicMemosSystem, err := s.getDisablePublicMemosSystemSettingValue(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get system setting")
}
if disablePublicMemosSystem && visibility == store.Public {
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
}
update.Visibility = &visibility
} else if path == "row_status" {
rowStatus := convertRowStatusToStore(request.Memo.RowStatus)
update.RowStatus = &rowStatus
} else if path == "created_ts" {
createdTs := request.Memo.CreateTime.AsTime().Unix()
update.CreatedTs = &createdTs
} else if path == "pinned" {
if _, err := s.Store.UpsertMemoOrganizer(ctx, &store.MemoOrganizer{
MemoID: request.Memo.Id,
UserID: user.ID,
Pinned: request.Memo.Pinned,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert memo organizer")
}
}
}
if update.Content != nil && len(*update.Content) > MaxContentLength {
return nil, status.Errorf(codes.InvalidArgument, "content too long")
}
if err = s.Store.UpdateMemo(ctx, update); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update memo")
}
memo, err = s.Store.GetMemo(ctx, &store.FindMemo{
ID: &request.Memo.Id,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get memo")
}
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
// Try to dispatch webhook when memo is updated.
if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil {
log.Warn("Failed to dispatch memo updated webhook", zap.Error(err))
}
return &apiv2pb.UpdateMemoResponse{
Memo: memoMessage,
}, nil
}
func (s *APIV2Service) DeleteMemo(ctx context.Context, request *apiv2pb.DeleteMemoRequest) (*apiv2pb.DeleteMemoResponse, error) {
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: &request.Id,
})
if err != nil {
return nil, err
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
user, _ := getCurrentUser(ctx, s.Store)
if memo.CreatorID != user.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if memoMessage, err := s.convertMemoFromStore(ctx, memo); err == nil {
// Try to dispatch webhook when memo is deleted.
if err := s.DispatchMemoDeletedWebhook(ctx, memoMessage); err != nil {
log.Warn("Failed to dispatch memo deleted webhook", zap.Error(err))
}
}
if err = s.Store.DeleteMemo(ctx, &store.DeleteMemo{
ID: request.Id,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo")
}
return &apiv2pb.DeleteMemoResponse{}, nil
}
func (s *APIV2Service) CreateMemoComment(ctx context.Context, request *apiv2pb.CreateMemoCommentRequest) (*apiv2pb.CreateMemoCommentResponse, error) {
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &request.Id})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
// Create the comment memo first.
createMemoResponse, err := s.CreateMemo(ctx, request.Create)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create memo")
}
// Build the relation between the comment memo and the original memo.
memo := createMemoResponse.Memo
_, err = s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memo.Id,
RelatedMemoID: request.Id,
Type: store.MemoRelationComment,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create memo relation")
}
if memo.Visibility != apiv2pb.Visibility_PRIVATE && memo.CreatorId != relatedMemo.CreatorID {
activity, err := s.Store.CreateActivity(ctx, &store.Activity{
CreatorID: memo.CreatorId,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{
MemoComment: &storepb.ActivityMemoCommentPayload{
MemoId: memo.Id,
RelatedMemoId: request.Id,
},
},
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create activity")
}
if _, err := s.Store.CreateInbox(ctx, &store.Inbox{
SenderID: memo.CreatorId,
ReceiverID: relatedMemo.CreatorID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_TYPE_MEMO_COMMENT,
ActivityId: &activity.ID,
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to create inbox")
}
}
response := &apiv2pb.CreateMemoCommentResponse{
Memo: memo,
}
return response, nil
}
func (s *APIV2Service) ListMemoComments(ctx context.Context, request *apiv2pb.ListMemoCommentsRequest) (*apiv2pb.ListMemoCommentsResponse, error) {
memoRelationComment := store.MemoRelationComment
memoRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &request.Id,
Type: &memoRelationComment,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memo relations")
}
var memos []*apiv2pb.Memo
for _, memoRelation := range memoRelations {
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: &memoRelation.MemoID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
if memo != nil {
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
memos = append(memos, memoMessage)
}
}
response := &apiv2pb.ListMemoCommentsResponse{
Memos: memos,
}
return response, nil
}
func (s *APIV2Service) GetUserMemosStats(ctx context.Context, request *apiv2pb.GetUserMemosStatsRequest) (*apiv2pb.GetUserMemosStatsResponse, error) {
username, err := ExtractUsernameFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid username")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &username,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
normalRowStatus := store.Normal
memoFind := &store.FindMemo{
CreatorID: &user.ID,
RowStatus: &normalRowStatus,
ExcludeComments: true,
ExcludeContent: true,
}
if err := s.buildMemoFindWithFilter(ctx, memoFind, request.Filter); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to build find memos with filter")
}
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos")
}
location, err := time.LoadLocation(request.Timezone)
if err != nil {
return nil, status.Errorf(codes.Internal, "invalid timezone location")
}
displayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo display with updated ts setting value")
}
stats := make(map[string]int32)
for _, memo := range memos {
displayTs := memo.CreatedTs
if displayWithUpdatedTs {
displayTs = memo.UpdatedTs
}
stats[time.Unix(displayTs, 0).In(location).Format("2006-01-02")]++
}
response := &apiv2pb.GetUserMemosStatsResponse{
Stats: stats,
}
return response, nil
}
func (s *APIV2Service) ExportMemos(ctx context.Context, request *apiv2pb.ExportMemosRequest) (*apiv2pb.ExportMemosResponse, error) {
normalRowStatus := store.Normal
memoFind := &store.FindMemo{
RowStatus: &normalRowStatus,
// Exclude comments by default.
ExcludeComments: true,
}
if err := s.buildMemoFindWithFilter(ctx, memoFind, request.Filter); err != nil {
return nil, status.Errorf(codes.Internal, "failed to build find memos with filter")
}
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos")
}
buf := new(bytes.Buffer)
writer := zip.NewWriter(buf)
for _, memo := range memos {
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
file, err := writer.Create(time.Unix(memo.CreatedTs, 0).Format(time.RFC3339) + ".md")
if err != nil {
return nil, status.Errorf(codes.Internal, "Failed to create memo file")
}
_, err = file.Write([]byte(memoMessage.Content))
if err != nil {
return nil, status.Errorf(codes.Internal, "Failed to write to memo file")
}
}
if err := writer.Close(); err != nil {
return nil, status.Errorf(codes.Internal, "Failed to close zip file writer")
}
return &apiv2pb.ExportMemosResponse{
Content: buf.Bytes(),
}, nil
}
func (s *APIV2Service) convertMemoFromStore(ctx context.Context, memo *store.Memo) (*apiv2pb.Memo, error) {
displayTs := memo.CreatedTs
if displayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx); err == nil && displayWithUpdatedTs {
displayTs = memo.UpdatedTs
}
creator, err := s.Store.GetUser(ctx, &store.FindUser{ID: &memo.CreatorID})
if err != nil {
return nil, errors.Wrap(err, "failed to get creator")
}
listMemoRelationsResponse, err := s.ListMemoRelations(ctx, &apiv2pb.ListMemoRelationsRequest{Id: memo.ID})
if err != nil {
return nil, errors.Wrap(err, "failed to list memo relations")
}
listMemoResourcesResponse, err := s.ListMemoResources(ctx, &apiv2pb.ListMemoResourcesRequest{Id: memo.ID})
if err != nil {
return nil, errors.Wrap(err, "failed to list memo resources")
}
listMemoReactionsResponse, err := s.ListMemoReactions(ctx, &apiv2pb.ListMemoReactionsRequest{Id: memo.ID})
if err != nil {
return nil, errors.Wrap(err, "failed to list memo reactions")
}
return &apiv2pb.Memo{
Id: int32(memo.ID),
Name: memo.ResourceName,
RowStatus: convertRowStatusFromStore(memo.RowStatus),
Creator: fmt.Sprintf("%s%s", UserNamePrefix, creator.Username),
CreatorId: int32(memo.CreatorID),
CreateTime: timestamppb.New(time.Unix(memo.CreatedTs, 0)),
UpdateTime: timestamppb.New(time.Unix(memo.UpdatedTs, 0)),
DisplayTime: timestamppb.New(time.Unix(displayTs, 0)),
Content: memo.Content,
Visibility: convertVisibilityFromStore(memo.Visibility),
Pinned: memo.Pinned,
ParentId: memo.ParentID,
Relations: listMemoRelationsResponse.Relations,
Resources: listMemoResourcesResponse.Resources,
Reactions: listMemoReactionsResponse.Reactions,
}, nil
}
func (s *APIV2Service) getMemoDisplayWithUpdatedTsSettingValue(ctx context.Context) (bool, error) {
memoDisplayWithUpdatedTsSetting, err := s.Store.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{
Name: apiv1.SystemSettingMemoDisplayWithUpdatedTsName.String(),
})
if err != nil {
return false, errors.Wrap(err, "failed to find system setting")
}
if memoDisplayWithUpdatedTsSetting == nil {
return false, nil
}
memoDisplayWithUpdatedTs := false
if err := json.Unmarshal([]byte(memoDisplayWithUpdatedTsSetting.Value), &memoDisplayWithUpdatedTs); err != nil {
return false, errors.Wrap(err, "failed to unmarshal system setting value")
}
return memoDisplayWithUpdatedTs, nil
}
func (s *APIV2Service) getDisablePublicMemosSystemSettingValue(ctx context.Context) (bool, error) {
disablePublicMemosSystemSetting, err := s.Store.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{
Name: apiv1.SystemSettingDisablePublicMemosName.String(),
})
if err != nil {
return false, errors.Wrap(err, "failed to find system setting")
}
if disablePublicMemosSystemSetting == nil {
return false, nil
}
disablePublicMemos := false
if err := json.Unmarshal([]byte(disablePublicMemosSystemSetting.Value), &disablePublicMemos); err != nil {
return false, errors.Wrap(err, "failed to unmarshal system setting value")
}
return disablePublicMemos, nil
}
func convertVisibilityFromStore(visibility store.Visibility) apiv2pb.Visibility {
switch visibility {
case store.Private:
return apiv2pb.Visibility_PRIVATE
case store.Protected:
return apiv2pb.Visibility_PROTECTED
case store.Public:
return apiv2pb.Visibility_PUBLIC
default:
return apiv2pb.Visibility_VISIBILITY_UNSPECIFIED
}
}
func convertVisibilityToStore(visibility apiv2pb.Visibility) store.Visibility {
switch visibility {
case apiv2pb.Visibility_PRIVATE:
return store.Private
case apiv2pb.Visibility_PROTECTED:
return store.Protected
case apiv2pb.Visibility_PUBLIC:
return store.Public
default:
return store.Private
}
}
func (s *APIV2Service) buildMemoFindWithFilter(ctx context.Context, find *store.FindMemo, filter string) error {
user, _ := getCurrentUser(ctx, s.Store)
if find == nil {
find = &store.FindMemo{}
}
if filter != "" {
filter, err := parseListMemosFilter(filter)
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
if len(filter.ContentSearch) > 0 {
find.ContentSearch = filter.ContentSearch
}
if len(filter.Visibilities) > 0 {
find.VisibilityList = filter.Visibilities
}
if filter.OrderByPinned {
find.OrderByPinned = filter.OrderByPinned
}
if filter.DisplayTimeAfter != nil {
displayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx)
if err != nil {
return status.Errorf(codes.Internal, "failed to get memo display with updated ts setting value")
}
if displayWithUpdatedTs {
find.UpdatedTsAfter = filter.DisplayTimeAfter
} else {
find.CreatedTsAfter = filter.DisplayTimeAfter
}
}
if filter.DisplayTimeBefore != nil {
displayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx)
if err != nil {
return status.Errorf(codes.Internal, "failed to get memo display with updated ts setting value")
}
if displayWithUpdatedTs {
find.UpdatedTsBefore = filter.DisplayTimeBefore
} else {
find.CreatedTsBefore = filter.DisplayTimeBefore
}
}
if filter.Creator != nil {
username, err := ExtractUsernameFromName(*filter.Creator)
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid creator name")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &username,
})
if err != nil {
return status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return status.Errorf(codes.NotFound, "user not found")
}
find.CreatorID = &user.ID
}
if filter.RowStatus != nil {
find.RowStatus = filter.RowStatus
}
}
// If the user is not authenticated, only public memos are visible.
if user == nil {
if filter == "" {
// If no filter is provided, return an error.
return status.Errorf(codes.InvalidArgument, "filter is required")
}
find.VisibilityList = []store.Visibility{store.Public}
} else if find.CreatorID != nil && *find.CreatorID != user.ID {
find.VisibilityList = []store.Visibility{store.Public, store.Protected}
}
displayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx)
if err != nil {
return status.Errorf(codes.Internal, "failed to get memo display with updated ts setting value")
}
if displayWithUpdatedTs {
find.OrderByUpdatedTs = true
}
return nil
}
// ListMemosFilterCELAttributes are the CEL attributes for ListMemosFilter.
var ListMemosFilterCELAttributes = []cel.EnvOption{
cel.Variable("content_search", cel.ListType(cel.StringType)),
cel.Variable("visibilities", cel.ListType(cel.StringType)),
cel.Variable("order_by_pinned", cel.BoolType),
cel.Variable("display_time_before", cel.IntType),
cel.Variable("display_time_after", cel.IntType),
cel.Variable("creator", cel.StringType),
cel.Variable("row_status", cel.StringType),
}
type ListMemosFilter struct {
ContentSearch []string
Visibilities []store.Visibility
OrderByPinned bool
DisplayTimeBefore *int64
DisplayTimeAfter *int64
Creator *string
RowStatus *store.RowStatus
}
func parseListMemosFilter(expression string) (*ListMemosFilter, error) {
e, err := cel.NewEnv(ListMemosFilterCELAttributes...)
if err != nil {
return nil, err
}
ast, issues := e.Compile(expression)
if issues != nil {
return nil, errors.Errorf("found issue %v", issues)
}
filter := &ListMemosFilter{}
expr, err := cel.AstToParsedExpr(ast)
if err != nil {
return nil, err
}
callExpr := expr.GetExpr().GetCallExpr()
findField(callExpr, filter)
return filter, nil
}
func findField(callExpr *expr.Expr_Call, filter *ListMemosFilter) {
if len(callExpr.Args) == 2 {
idExpr := callExpr.Args[0].GetIdentExpr()
if idExpr != nil {
if idExpr.Name == "content_search" {
contentSearch := []string{}
for _, expr := range callExpr.Args[1].GetListExpr().GetElements() {
value := expr.GetConstExpr().GetStringValue()
contentSearch = append(contentSearch, value)
}
filter.ContentSearch = contentSearch
} else if idExpr.Name == "visibilities" {
visibilities := []store.Visibility{}
for _, expr := range callExpr.Args[1].GetListExpr().GetElements() {
value := expr.GetConstExpr().GetStringValue()
visibilities = append(visibilities, store.Visibility(value))
}
filter.Visibilities = visibilities
} else if idExpr.Name == "order_by_pinned" {
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
filter.OrderByPinned = value
} else if idExpr.Name == "display_time_before" {
displayTimeBefore := callExpr.Args[1].GetConstExpr().GetInt64Value()
filter.DisplayTimeBefore = &displayTimeBefore
} else if idExpr.Name == "display_time_after" {
displayTimeAfter := callExpr.Args[1].GetConstExpr().GetInt64Value()
filter.DisplayTimeAfter = &displayTimeAfter
} else if idExpr.Name == "creator" {
creator := callExpr.Args[1].GetConstExpr().GetStringValue()
filter.Creator = &creator
} else if idExpr.Name == "row_status" {
rowStatus := store.RowStatus(callExpr.Args[1].GetConstExpr().GetStringValue())
filter.RowStatus = &rowStatus
}
return
}
}
for _, arg := range callExpr.Args {
callExpr := arg.GetCallExpr()
if callExpr != nil {
findField(callExpr, filter)
}
}
}
// DispatchMemoCreatedWebhook dispatches webhook when memo is created.
func (s *APIV2Service) DispatchMemoCreatedWebhook(ctx context.Context, memo *apiv2pb.Memo) error {
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.created")
}
// DispatchMemoUpdatedWebhook dispatches webhook when memo is updated.
func (s *APIV2Service) DispatchMemoUpdatedWebhook(ctx context.Context, memo *apiv2pb.Memo) error {
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.updated")
}
// DispatchMemoDeletedWebhook dispatches webhook when memo is deleted.
func (s *APIV2Service) DispatchMemoDeletedWebhook(ctx context.Context, memo *apiv2pb.Memo) error {
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.deleted")
}
func (s *APIV2Service) dispatchMemoRelatedWebhook(ctx context.Context, memo *apiv2pb.Memo, activityType string) error {
webhooks, err := s.Store.ListWebhooks(ctx, &store.FindWebhook{
CreatorID: &memo.CreatorId,
})
if err != nil {
return err
}
for _, hook := range webhooks {
payload := convertMemoToWebhookPayload(memo)
payload.ActivityType = activityType
payload.URL = hook.Url
err := webhook.Post(*payload)
if err != nil {
return errors.Wrap(err, "failed to post webhook")
}
}
return nil
}
func convertMemoToWebhookPayload(memo *apiv2pb.Memo) *webhook.WebhookPayload {
return &webhook.WebhookPayload{
CreatorID: memo.CreatorId,
CreatedTs: time.Now().Unix(),
Memo: &webhook.Memo{
ID: memo.Id,
CreatorID: memo.CreatorId,
CreatedTs: memo.CreateTime.Seconds,
UpdatedTs: memo.UpdateTime.Seconds,
Content: memo.Content,
Visibility: memo.Visibility.String(),
Pinned: memo.Pinned,
ResourceList: func() []*webhook.Resource {
resources := []*webhook.Resource{}
for _, resource := range memo.Resources {
resources = append(resources, &webhook.Resource{
ID: resource.Id,
Filename: resource.Filename,
ExternalLink: resource.ExternalLink,
Type: resource.Type,
Size: resource.Size,
})
}
return resources
}(),
RelationList: func() []*webhook.MemoRelation {
relations := []*webhook.MemoRelation{}
for _, relation := range memo.Relations {
relations = append(relations, &webhook.MemoRelation{
MemoID: relation.MemoId,
RelatedMemoID: relation.RelatedMemoId,
Type: relation.Type.String(),
})
}
return relations
}(),
},
}
}

View File

@@ -0,0 +1,83 @@
package v2
import (
"context"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV2Service) ListMemoReactions(ctx context.Context, request *apiv2pb.ListMemoReactionsRequest) (*apiv2pb.ListMemoReactionsResponse, error) {
contentID := fmt.Sprintf("memos/%d", request.Id)
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
ContentID: &contentID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
response := &apiv2pb.ListMemoReactionsResponse{
Reactions: []*apiv2pb.Reaction{},
}
for _, reaction := range reactions {
reactionMessage, err := s.convertReactionFromStore(ctx, reaction)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert reaction")
}
response.Reactions = append(response.Reactions, reactionMessage)
}
return response, nil
}
func (s *APIV2Service) UpsertMemoReaction(ctx context.Context, request *apiv2pb.UpsertMemoReactionRequest) (*apiv2pb.UpsertMemoReactionResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
reaction, err := s.Store.UpsertReaction(ctx, &storepb.Reaction{
CreatorId: user.ID,
ContentId: request.Reaction.ContentId,
ReactionType: storepb.Reaction_Type(request.Reaction.ReactionType),
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert reaction")
}
reactionMessage, err := s.convertReactionFromStore(ctx, reaction)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert reaction")
}
return &apiv2pb.UpsertMemoReactionResponse{
Reaction: reactionMessage,
}, nil
}
func (s *APIV2Service) DeleteMemoReaction(ctx context.Context, request *apiv2pb.DeleteMemoReactionRequest) (*apiv2pb.DeleteMemoReactionResponse, error) {
if err := s.Store.DeleteReaction(ctx, &store.DeleteReaction{
ID: request.Id,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete reaction")
}
return &apiv2pb.DeleteMemoReactionResponse{}, nil
}
func (s *APIV2Service) convertReactionFromStore(ctx context.Context, reaction *storepb.Reaction) (*apiv2pb.Reaction, error) {
creator, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &reaction.CreatorId,
})
if err != nil {
return nil, err
}
return &apiv2pb.Reaction{
Id: reaction.Id,
Creator: fmt.Sprintf("%s%s", UserNamePrefix, creator.Username),
ContentId: reaction.ContentId,
ReactionType: apiv2pb.Reaction_Type(reaction.ReactionType),
}, nil
}

View File

@@ -0,0 +1,66 @@
package v2
import (
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/util"
)
const (
WorkspaceSettingNamePrefix = "settings/"
UserNamePrefix = "users/"
InboxNamePrefix = "inboxes/"
)
// GetNameParentTokens returns the tokens from a resource name.
func GetNameParentTokens(name string, tokenPrefixes ...string) ([]string, error) {
parts := strings.Split(name, "/")
if len(parts) != 2*len(tokenPrefixes) {
return nil, errors.Errorf("invalid request %q", name)
}
var tokens []string
for i, tokenPrefix := range tokenPrefixes {
if fmt.Sprintf("%s/", parts[2*i]) != tokenPrefix {
return nil, errors.Errorf("invalid prefix %q in request %q", tokenPrefix, name)
}
if parts[2*i+1] == "" {
return nil, errors.Errorf("invalid request %q with empty prefix %q", name, tokenPrefix)
}
tokens = append(tokens, parts[2*i+1])
}
return tokens, nil
}
func ExtractWorkspaceSettingKeyFromName(name string) (string, error) {
tokens, err := GetNameParentTokens(name, WorkspaceSettingNamePrefix)
if err != nil {
return "", err
}
return tokens[0], nil
}
// ExtractUsernameFromName returns the username from a resource name.
func ExtractUsernameFromName(name string) (string, error) {
tokens, err := GetNameParentTokens(name, UserNamePrefix)
if err != nil {
return "", err
}
return tokens[0], nil
}
// ExtractInboxIDFromName returns the inbox ID from a resource name.
func ExtractInboxIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, InboxNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid inbox ID %q", tokens[0])
}
return id, nil
}

View File

@@ -0,0 +1,176 @@
package v2
import (
"context"
"net/url"
"time"
"github.com/lithammer/shortuuid/v4"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
"github.com/usememos/memos/store"
)
func (s *APIV2Service) CreateResource(ctx context.Context, request *apiv2pb.CreateResourceRequest) (*apiv2pb.CreateResourceResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if request.ExternalLink != "" {
// Only allow those external links scheme with http/https
linkURL, err := url.Parse(request.ExternalLink)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid external link: %v", err)
}
if linkURL.Scheme != "http" && linkURL.Scheme != "https" {
return nil, status.Errorf(codes.InvalidArgument, "invalid external link scheme: %v", linkURL.Scheme)
}
}
create := &store.Resource{
ResourceName: shortuuid.New(),
CreatorID: user.ID,
Filename: request.Filename,
ExternalLink: request.ExternalLink,
Type: request.Type,
}
if request.MemoId != nil {
create.MemoID = request.MemoId
}
resource, err := s.Store.CreateResource(ctx, create)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create resource: %v", err)
}
return &apiv2pb.CreateResourceResponse{
Resource: s.convertResourceFromStore(ctx, resource),
}, nil
}
func (s *APIV2Service) ListResources(ctx context.Context, _ *apiv2pb.ListResourcesRequest) (*apiv2pb.ListResourcesResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
resources, err := s.Store.ListResources(ctx, &store.FindResource{
CreatorID: &user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list resources: %v", err)
}
response := &apiv2pb.ListResourcesResponse{}
for _, resource := range resources {
response.Resources = append(response.Resources, s.convertResourceFromStore(ctx, resource))
}
return response, nil
}
func (s *APIV2Service) GetResource(ctx context.Context, request *apiv2pb.GetResourceRequest) (*apiv2pb.GetResourceResponse, error) {
resource, err := s.Store.GetResource(ctx, &store.FindResource{
ID: &request.Id,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get resource: %v", err)
}
if resource == nil {
return nil, status.Errorf(codes.NotFound, "resource not found")
}
return &apiv2pb.GetResourceResponse{
Resource: s.convertResourceFromStore(ctx, resource),
}, nil
}
func (s *APIV2Service) GetResourceByName(ctx context.Context, request *apiv2pb.GetResourceByNameRequest) (*apiv2pb.GetResourceByNameResponse, error) {
resource, err := s.Store.GetResource(ctx, &store.FindResource{
ResourceName: &request.Name,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get resource: %v", err)
}
if resource == nil {
return nil, status.Errorf(codes.NotFound, "resource not found")
}
return &apiv2pb.GetResourceByNameResponse{
Resource: s.convertResourceFromStore(ctx, resource),
}, nil
}
func (s *APIV2Service) UpdateResource(ctx context.Context, request *apiv2pb.UpdateResourceRequest) (*apiv2pb.UpdateResourceResponse, error) {
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
currentTs := time.Now().Unix()
update := &store.UpdateResource{
ID: request.Resource.Id,
UpdatedTs: &currentTs,
}
for _, field := range request.UpdateMask.Paths {
if field == "filename" {
update.Filename = &request.Resource.Filename
} else if field == "memo_id" {
update.MemoID = request.Resource.MemoId
}
}
resource, err := s.Store.UpdateResource(ctx, update)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update resource: %v", err)
}
return &apiv2pb.UpdateResourceResponse{
Resource: s.convertResourceFromStore(ctx, resource),
}, nil
}
func (s *APIV2Service) DeleteResource(ctx context.Context, request *apiv2pb.DeleteResourceRequest) (*apiv2pb.DeleteResourceResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
resource, err := s.Store.GetResource(ctx, &store.FindResource{
ID: &request.Id,
CreatorID: &user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to find resource: %v", err)
}
if resource == nil {
return nil, status.Errorf(codes.NotFound, "resource not found")
}
// Delete the resource from the database.
if err := s.Store.DeleteResource(ctx, &store.DeleteResource{
ID: resource.ID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete resource: %v", err)
}
return &apiv2pb.DeleteResourceResponse{}, nil
}
func (s *APIV2Service) convertResourceFromStore(ctx context.Context, resource *store.Resource) *apiv2pb.Resource {
var memoID *int32
if resource.MemoID != nil {
memo, _ := s.Store.GetMemo(ctx, &store.FindMemo{
ID: resource.MemoID,
})
if memo != nil {
memoID = &memo.ID
}
}
return &apiv2pb.Resource{
Id: resource.ID,
Name: resource.ResourceName,
CreateTime: timestamppb.New(time.Unix(resource.CreatedTs, 0)),
Filename: resource.Filename,
ExternalLink: resource.ExternalLink,
Type: resource.Type,
Size: resource.Size,
MemoId: memoID,
}
}

View File

@@ -0,0 +1,271 @@
package v2
import (
"context"
"fmt"
"slices"
"sort"
"github.com/pkg/errors"
"github.com/yourselfhosted/gomark/ast"
"github.com/yourselfhosted/gomark/parser"
"github.com/yourselfhosted/gomark/parser/tokenizer"
"github.com/yourselfhosted/gomark/restore"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
"github.com/usememos/memos/store"
)
func (s *APIV2Service) UpsertTag(ctx context.Context, request *apiv2pb.UpsertTagRequest) (*apiv2pb.UpsertTagResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
tag, err := s.Store.UpsertTag(ctx, &store.Tag{
Name: request.Name,
CreatorID: user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert tag: %v", err)
}
t, err := s.convertTagFromStore(ctx, tag)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert tag: %v", err)
}
return &apiv2pb.UpsertTagResponse{
Tag: t,
}, nil
}
func (s *APIV2Service) BatchUpsertTag(ctx context.Context, request *apiv2pb.BatchUpsertTagRequest) (*apiv2pb.BatchUpsertTagResponse, error) {
for _, r := range request.Requests {
if _, err := s.UpsertTag(ctx, r); err != nil {
return nil, status.Errorf(codes.Internal, "failed to batch upsert tags: %v", err)
}
}
return &apiv2pb.BatchUpsertTagResponse{}, nil
}
func (s *APIV2Service) ListTags(ctx context.Context, request *apiv2pb.ListTagsRequest) (*apiv2pb.ListTagsResponse, error) {
username, err := ExtractUsernameFromName(request.User)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %v", err)
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &username,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
tags, err := s.Store.ListTags(ctx, &store.FindTag{
CreatorID: user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list tags: %v", err)
}
response := &apiv2pb.ListTagsResponse{}
for _, tag := range tags {
t, err := s.convertTagFromStore(ctx, tag)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert tag: %v", err)
}
response.Tags = append(response.Tags, t)
}
return response, nil
}
func (s *APIV2Service) RenameTag(ctx context.Context, request *apiv2pb.RenameTagRequest) (*apiv2pb.RenameTagResponse, error) {
username, err := ExtractUsernameFromName(request.User)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %v", err)
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &username,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
// Find all related memos.
memos, err := s.Store.ListMemos(ctx, &store.FindMemo{
CreatorID: &user.ID,
ContentSearch: []string{fmt.Sprintf("#%s", request.OldName)},
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
// Replace tag name in memo content.
for _, memo := range memos {
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to parse memo: %v", err)
}
traverseASTNodes(nodes, func(node ast.Node) {
if tag, ok := node.(*ast.Tag); ok && tag.Content == request.OldName {
tag.Content = request.NewName
}
})
content := restore.Restore(nodes)
if err := s.Store.UpdateMemo(ctx, &store.UpdateMemo{
ID: memo.ID,
Content: &content,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update memo: %v", err)
}
}
// Delete old tag and create new tag.
if err := s.Store.DeleteTag(ctx, &store.DeleteTag{
CreatorID: user.ID,
Name: request.OldName,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete tag: %v", err)
}
tag, err := s.Store.UpsertTag(ctx, &store.Tag{
CreatorID: user.ID,
Name: request.NewName,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert tag: %v", err)
}
tagMessage, err := s.convertTagFromStore(ctx, tag)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert tag: %v", err)
}
return &apiv2pb.RenameTagResponse{Tag: tagMessage}, nil
}
func (s *APIV2Service) DeleteTag(ctx context.Context, request *apiv2pb.DeleteTagRequest) (*apiv2pb.DeleteTagResponse, error) {
username, err := ExtractUsernameFromName(request.Tag.Creator)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %v", err)
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &username,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
if err := s.Store.DeleteTag(ctx, &store.DeleteTag{
Name: request.Tag.Name,
CreatorID: user.ID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete tag: %v", err)
}
return &apiv2pb.DeleteTagResponse{}, nil
}
func (s *APIV2Service) GetTagSuggestions(ctx context.Context, request *apiv2pb.GetTagSuggestionsRequest) (*apiv2pb.GetTagSuggestionsResponse, error) {
username, err := ExtractUsernameFromName(request.User)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %v", err)
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &username,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
normalRowStatus := store.Normal
memoFind := &store.FindMemo{
CreatorID: &user.ID,
ContentSearch: []string{"#"},
RowStatus: &normalRowStatus,
}
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
tagList, err := s.Store.ListTags(ctx, &store.FindTag{
CreatorID: user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list tags: %v", err)
}
tagNameList := []string{}
for _, tag := range tagList {
tagNameList = append(tagNameList, tag.Name)
}
tagMapSet := make(map[string]bool)
for _, memo := range memos {
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
if err != nil {
return nil, errors.Wrap(err, "failed to parse memo content")
}
// Dynamically upsert tags from memo content.
traverseASTNodes(nodes, func(node ast.Node) {
if tagNode, ok := node.(*ast.Tag); ok {
tag := tagNode.Content
if !slices.Contains(tagNameList, tag) {
tagMapSet[tag] = true
}
}
})
}
suggestions := []string{}
for tag := range tagMapSet {
suggestions = append(suggestions, tag)
}
sort.Strings(suggestions)
return &apiv2pb.GetTagSuggestionsResponse{
Tags: suggestions,
}, nil
}
func (s *APIV2Service) convertTagFromStore(ctx context.Context, tag *store.Tag) (*apiv2pb.Tag, error) {
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &tag.CreatorID,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get user")
}
return &apiv2pb.Tag{
Name: tag.Name,
Creator: fmt.Sprintf("%s%s", UserNamePrefix, user.Username),
}, nil
}
func traverseASTNodes(nodes []ast.Node, fn func(ast.Node)) {
for _, node := range nodes {
fn(node)
switch n := node.(type) {
case *ast.Paragraph:
traverseASTNodes(n.Children, fn)
case *ast.Heading:
traverseASTNodes(n.Children, fn)
case *ast.Blockquote:
traverseASTNodes(n.Children, fn)
case *ast.OrderedList:
traverseASTNodes(n.Children, fn)
case *ast.UnorderedList:
traverseASTNodes(n.Children, fn)
case *ast.TaskList:
traverseASTNodes(n.Children, fn)
case *ast.Bold:
traverseASTNodes(n.Children, fn)
}
}
}

View File

@@ -0,0 +1,552 @@
package v2
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"golang.org/x/exp/slices"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/internal/util"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/route/api/auth"
"github.com/usememos/memos/store"
)
func (s *APIV2Service) ListUsers(ctx context.Context, _ *apiv2pb.ListUsersRequest) (*apiv2pb.ListUsersResponse, error) {
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
users, err := s.Store.ListUsers(ctx, &store.FindUser{})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list users: %v", err)
}
response := &apiv2pb.ListUsersResponse{
Users: []*apiv2pb.User{},
}
for _, user := range users {
response.Users = append(response.Users, convertUserFromStore(user))
}
return response, nil
}
func (s *APIV2Service) GetUser(ctx context.Context, request *apiv2pb.GetUserRequest) (*apiv2pb.GetUserResponse, error) {
username, err := ExtractUsernameFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "name is required")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &username,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
userMessage := convertUserFromStore(user)
response := &apiv2pb.GetUserResponse{
User: userMessage,
}
return response, nil
}
func (s *APIV2Service) CreateUser(ctx context.Context, request *apiv2pb.CreateUserRequest) (*apiv2pb.CreateUserResponse, error) {
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
username, err := ExtractUsernameFromName(request.User.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "name is required")
}
if !util.ResourceNameMatcher.MatchString(strings.ToLower(username)) {
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", username)
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost)
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to generate password hash").SetInternal(err)
}
user, err := s.Store.CreateUser(ctx, &store.User{
Username: username,
Role: convertUserRoleToStore(request.User.Role),
Email: request.User.Email,
Nickname: request.User.Nickname,
PasswordHash: string(passwordHash),
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create user: %v", err)
}
response := &apiv2pb.CreateUserResponse{
User: convertUserFromStore(user),
}
return response, nil
}
func (s *APIV2Service) UpdateUser(ctx context.Context, request *apiv2pb.UpdateUserRequest) (*apiv2pb.UpdateUserResponse, error) {
username, err := ExtractUsernameFromName(request.User.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "name is required")
}
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.Username != username && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is empty")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{Username: &username})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
if s.Profile.Mode == "demo" && user.Username == "memos-demo" {
return nil, status.Errorf(codes.PermissionDenied, "unauthorized to update user in demo mode")
}
currentTs := time.Now().Unix()
update := &store.UpdateUser{
ID: user.ID,
UpdatedTs: &currentTs,
}
for _, field := range request.UpdateMask.Paths {
if field == "username" {
if !util.ResourceNameMatcher.MatchString(strings.ToLower(request.User.Username)) {
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", request.User.Username)
}
update.Username = &request.User.Username
} else if field == "nickname" {
update.Nickname = &request.User.Nickname
} else if field == "email" {
update.Email = &request.User.Email
} else if field == "avatar_url" {
update.AvatarURL = &request.User.AvatarUrl
} else if field == "role" {
role := convertUserRoleToStore(request.User.Role)
update.Role = &role
} else if field == "password" {
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost)
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to generate password hash").SetInternal(err)
}
passwordHashStr := string(passwordHash)
update.PasswordHash = &passwordHashStr
} else if field == "row_status" {
rowStatus := convertRowStatusToStore(request.User.RowStatus)
update.RowStatus = &rowStatus
} else {
return nil, status.Errorf(codes.InvalidArgument, "invalid update path: %s", field)
}
}
updatedUser, err := s.Store.UpdateUser(ctx, update)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update user: %v", err)
}
response := &apiv2pb.UpdateUserResponse{
User: convertUserFromStore(updatedUser),
}
return response, nil
}
func (s *APIV2Service) DeleteUser(ctx context.Context, request *apiv2pb.DeleteUserRequest) (*apiv2pb.DeleteUserResponse, error) {
username, err := ExtractUsernameFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "name is required")
}
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.Username != username && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{Username: &username})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
if s.Profile.Mode == "demo" && user.Username == "memos-demo" {
return nil, status.Errorf(codes.PermissionDenied, "unauthorized to delete this user in demo mode")
}
if err := s.Store.DeleteUser(ctx, &store.DeleteUser{
ID: user.ID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete user: %v", err)
}
return &apiv2pb.DeleteUserResponse{}, nil
}
func getDefaultUserSetting() *apiv2pb.UserSetting {
return &apiv2pb.UserSetting{
Locale: "en",
Appearance: "system",
MemoVisibility: "PRIVATE",
CompactView: false,
}
}
func (s *APIV2Service) GetUserSetting(ctx context.Context, _ *apiv2pb.GetUserSettingRequest) (*apiv2pb.GetUserSettingResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
userSettings, err := s.Store.ListUserSettings(ctx, &store.FindUserSetting{
UserID: &user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list user settings: %v", err)
}
userSettingMessage := getDefaultUserSetting()
for _, setting := range userSettings {
if setting.Key == storepb.UserSettingKey_USER_SETTING_LOCALE {
userSettingMessage.Locale = setting.GetLocale()
} else if setting.Key == storepb.UserSettingKey_USER_SETTING_APPEARANCE {
userSettingMessage.Appearance = setting.GetAppearance()
} else if setting.Key == storepb.UserSettingKey_USER_SETTING_MEMO_VISIBILITY {
userSettingMessage.MemoVisibility = setting.GetMemoVisibility()
} else if setting.Key == storepb.UserSettingKey_USER_SETTING_TELEGRAM_USER_ID {
userSettingMessage.TelegramUserId = setting.GetTelegramUserId()
} else if setting.Key == storepb.UserSettingKey_USER_SETTING_COMPACT_VIEW {
userSettingMessage.CompactView = setting.GetCompactView()
}
}
return &apiv2pb.GetUserSettingResponse{
Setting: userSettingMessage,
}, nil
}
func (s *APIV2Service) UpdateUserSetting(ctx context.Context, request *apiv2pb.UpdateUserSettingRequest) (*apiv2pb.UpdateUserSettingResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is empty")
}
for _, field := range request.UpdateMask.Paths {
if field == "locale" {
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_LOCALE,
Value: &storepb.UserSetting_Locale{
Locale: request.Setting.Locale,
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
} else if field == "appearance" {
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_APPEARANCE,
Value: &storepb.UserSetting_Appearance{
Appearance: request.Setting.Appearance,
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
} else if field == "memo_visibility" {
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_MEMO_VISIBILITY,
Value: &storepb.UserSetting_MemoVisibility{
MemoVisibility: request.Setting.MemoVisibility,
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
} else if field == "telegram_user_id" {
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_TELEGRAM_USER_ID,
Value: &storepb.UserSetting_TelegramUserId{
TelegramUserId: request.Setting.TelegramUserId,
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
} else if field == "compact_view" {
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_COMPACT_VIEW,
Value: &storepb.UserSetting_CompactView{
CompactView: request.Setting.CompactView,
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
} else {
return nil, status.Errorf(codes.InvalidArgument, "invalid update path: %s", field)
}
}
userSettingResponse, err := s.GetUserSetting(ctx, &apiv2pb.GetUserSettingRequest{})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user setting: %v", err)
}
return &apiv2pb.UpdateUserSettingResponse{
Setting: userSettingResponse.Setting,
}, nil
}
func (s *APIV2Service) ListUserAccessTokens(ctx context.Context, request *apiv2pb.ListUserAccessTokensRequest) (*apiv2pb.ListUserAccessTokensResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userID := user.ID
username, err := ExtractUsernameFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "name is required")
}
// List access token for other users need to be verified.
if user.Username != username {
// Normal users can only list their access tokens.
if user.Role == store.RoleUser {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
// The request user must be exist.
requestUser, err := s.Store.GetUser(ctx, &store.FindUser{Username: &username})
if requestUser == nil || err != nil {
return nil, status.Errorf(codes.NotFound, "fail to find user %s", username)
}
userID = requestUser.ID
}
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, userID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
}
accessTokens := []*apiv2pb.UserAccessToken{}
for _, userAccessToken := range userAccessTokens {
claims := &auth.ClaimsMessage{}
_, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(s.Secret), nil
}
}
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
// If the access token is invalid or expired, just ignore it.
continue
}
userAccessToken := &apiv2pb.UserAccessToken{
AccessToken: userAccessToken.AccessToken,
Description: userAccessToken.Description,
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
}
if claims.ExpiresAt != nil {
userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
}
accessTokens = append(accessTokens, userAccessToken)
}
// Sort by issued time in descending order.
slices.SortFunc(accessTokens, func(i, j *apiv2pb.UserAccessToken) int {
return int(i.IssuedAt.Seconds - j.IssuedAt.Seconds)
})
response := &apiv2pb.ListUserAccessTokensResponse{
AccessTokens: accessTokens,
}
return response, nil
}
func (s *APIV2Service) CreateUserAccessToken(ctx context.Context, request *apiv2pb.CreateUserAccessTokenRequest) (*apiv2pb.CreateUserAccessTokenResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
expiresAt := time.Time{}
if request.ExpiresAt != nil {
expiresAt = request.ExpiresAt.AsTime()
}
accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, expiresAt, []byte(s.Secret))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
}
claims := &auth.ClaimsMessage{}
_, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(s.Secret), nil
}
}
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to parse access token: %v", err)
}
// Upsert the access token to user setting store.
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken, request.Description); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert access token to store: %v", err)
}
userAccessToken := &apiv2pb.UserAccessToken{
AccessToken: accessToken,
Description: request.Description,
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
}
if claims.ExpiresAt != nil {
userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
}
response := &apiv2pb.CreateUserAccessTokenResponse{
AccessToken: userAccessToken,
}
return response, nil
}
func (s *APIV2Service) DeleteUserAccessToken(ctx context.Context, request *apiv2pb.DeleteUserAccessTokenRequest) (*apiv2pb.DeleteUserAccessTokenResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
}
updatedUserAccessTokens := []*storepb.AccessTokensUserSetting_AccessToken{}
for _, userAccessToken := range userAccessTokens {
if userAccessToken.AccessToken == request.AccessToken {
continue
}
updatedUserAccessTokens = append(updatedUserAccessTokens, userAccessToken)
}
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
Value: &storepb.UserSetting_AccessTokens{
AccessTokens: &storepb.AccessTokensUserSetting{
AccessTokens: updatedUserAccessTokens,
},
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
return &apiv2pb.DeleteUserAccessTokenResponse{}, nil
}
func (s *APIV2Service) UpsertAccessTokenToStore(ctx context.Context, user *store.User, accessToken, description string) error {
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return errors.Wrap(err, "failed to get user access tokens")
}
userAccessToken := storepb.AccessTokensUserSetting_AccessToken{
AccessToken: accessToken,
Description: description,
}
userAccessTokens = append(userAccessTokens, &userAccessToken)
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
Value: &storepb.UserSetting_AccessTokens{
AccessTokens: &storepb.AccessTokensUserSetting{
AccessTokens: userAccessTokens,
},
},
}); err != nil {
return errors.Wrap(err, "failed to upsert user setting")
}
return nil
}
func convertUserFromStore(user *store.User) *apiv2pb.User {
return &apiv2pb.User{
Name: fmt.Sprintf("%s%s", UserNamePrefix, user.Username),
Id: user.ID,
RowStatus: convertRowStatusFromStore(user.RowStatus),
CreateTime: timestamppb.New(time.Unix(user.CreatedTs, 0)),
UpdateTime: timestamppb.New(time.Unix(user.UpdatedTs, 0)),
Role: convertUserRoleFromStore(user.Role),
Username: user.Username,
Email: user.Email,
Nickname: user.Nickname,
AvatarUrl: user.AvatarURL,
}
}
func convertUserRoleFromStore(role store.Role) apiv2pb.User_Role {
switch role {
case store.RoleHost:
return apiv2pb.User_HOST
case store.RoleAdmin:
return apiv2pb.User_ADMIN
case store.RoleUser:
return apiv2pb.User_USER
default:
return apiv2pb.User_ROLE_UNSPECIFIED
}
}
func convertUserRoleToStore(role apiv2pb.User_Role) store.Role {
switch role {
case apiv2pb.User_HOST:
return store.RoleHost
case apiv2pb.User_ADMIN:
return store.RoleAdmin
case apiv2pb.User_USER:
return store.RoleUser
default:
return store.RoleUser
}
}

146
server/route/api/v2/v2.go Normal file
View File

@@ -0,0 +1,146 @@
package v2
import (
"context"
"fmt"
"net"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/improbable-eng/grpc-web/go/grpcweb"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/reflection"
"github.com/usememos/memos/internal/log"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store"
)
type APIV2Service struct {
apiv2pb.UnimplementedWorkspaceServiceServer
apiv2pb.UnimplementedWorkspaceSettingServiceServer
apiv2pb.UnimplementedAuthServiceServer
apiv2pb.UnimplementedUserServiceServer
apiv2pb.UnimplementedMemoServiceServer
apiv2pb.UnimplementedResourceServiceServer
apiv2pb.UnimplementedTagServiceServer
apiv2pb.UnimplementedInboxServiceServer
apiv2pb.UnimplementedActivityServiceServer
apiv2pb.UnimplementedWebhookServiceServer
Secret string
Profile *profile.Profile
Store *store.Store
grpcServer *grpc.Server
grpcServerPort int
}
func NewAPIV2Service(secret string, profile *profile.Profile, store *store.Store, grpcServerPort int) *APIV2Service {
grpc.EnableTracing = true
authProvider := NewGRPCAuthInterceptor(store, secret)
grpcServer := grpc.NewServer(
grpc.ChainUnaryInterceptor(
authProvider.AuthenticationInterceptor,
),
)
apiv2Service := &APIV2Service{
Secret: secret,
Profile: profile,
Store: store,
grpcServer: grpcServer,
grpcServerPort: grpcServerPort,
}
apiv2pb.RegisterWorkspaceServiceServer(grpcServer, apiv2Service)
apiv2pb.RegisterWorkspaceSettingServiceServer(grpcServer, apiv2Service)
apiv2pb.RegisterAuthServiceServer(grpcServer, apiv2Service)
apiv2pb.RegisterUserServiceServer(grpcServer, apiv2Service)
apiv2pb.RegisterMemoServiceServer(grpcServer, apiv2Service)
apiv2pb.RegisterTagServiceServer(grpcServer, apiv2Service)
apiv2pb.RegisterResourceServiceServer(grpcServer, apiv2Service)
apiv2pb.RegisterInboxServiceServer(grpcServer, apiv2Service)
apiv2pb.RegisterActivityServiceServer(grpcServer, apiv2Service)
apiv2pb.RegisterWebhookServiceServer(grpcServer, apiv2Service)
reflection.Register(grpcServer)
return apiv2Service
}
func (s *APIV2Service) GetGRPCServer() *grpc.Server {
return s.grpcServer
}
// RegisterGateway registers the gRPC-Gateway with the given Echo instance.
func (s *APIV2Service) RegisterGateway(ctx context.Context, e *echo.Echo) error {
// Create a client connection to the gRPC Server we just started.
// This is where the gRPC-Gateway proxies the requests.
conn, err := grpc.DialContext(
ctx,
fmt.Sprintf(":%d", s.grpcServerPort),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
return err
}
gwMux := runtime.NewServeMux()
if err := apiv2pb.RegisterWorkspaceServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := apiv2pb.RegisterWorkspaceSettingServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := apiv2pb.RegisterAuthServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := apiv2pb.RegisterUserServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := apiv2pb.RegisterMemoServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := apiv2pb.RegisterTagServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := apiv2pb.RegisterResourceServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := apiv2pb.RegisterInboxServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := apiv2pb.RegisterActivityServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := apiv2pb.RegisterWebhookServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
e.Any("/api/v2/*", echo.WrapHandler(gwMux))
// GRPC web proxy.
options := []grpcweb.Option{
grpcweb.WithCorsForRegisteredEndpointsOnly(false),
grpcweb.WithOriginFunc(func(origin string) bool {
return true
}),
}
wrappedGrpc := grpcweb.WrapServer(s.grpcServer, options...)
e.Any("/memos.api.v2.*", echo.WrapHandler(wrappedGrpc))
// Start gRPC server.
listen, err := net.Listen("tcp", fmt.Sprintf("%s:%d", s.Profile.Addr, s.grpcServerPort))
if err != nil {
return errors.Wrap(err, "failed to start gRPC server")
}
go func() {
if err := s.grpcServer.Serve(listen); err != nil {
log.Error("grpc server listen error", zap.Error(err))
}
}()
return nil
}

View File

@@ -0,0 +1,120 @@
package v2
import (
"context"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV2Service) CreateWebhook(ctx context.Context, request *apiv2pb.CreateWebhookRequest) (*apiv2pb.CreateWebhookResponse, error) {
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
webhook, err := s.Store.CreateWebhook(ctx, &storepb.Webhook{
CreatorId: currentUser.ID,
Name: request.Name,
Url: request.Url,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create webhook, error: %+v", err)
}
return &apiv2pb.CreateWebhookResponse{
Webhook: convertWebhookFromStore(webhook),
}, nil
}
func (s *APIV2Service) ListWebhooks(ctx context.Context, request *apiv2pb.ListWebhooksRequest) (*apiv2pb.ListWebhooksResponse, error) {
webhooks, err := s.Store.ListWebhooks(ctx, &store.FindWebhook{
CreatorID: &request.CreatorId,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list webhooks, error: %+v", err)
}
response := &apiv2pb.ListWebhooksResponse{
Webhooks: []*apiv2pb.Webhook{},
}
for _, webhook := range webhooks {
response.Webhooks = append(response.Webhooks, convertWebhookFromStore(webhook))
}
return response, nil
}
func (s *APIV2Service) GetWebhook(ctx context.Context, request *apiv2pb.GetWebhookRequest) (*apiv2pb.GetWebhookResponse, error) {
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
webhook, err := s.Store.GetWebhooks(ctx, &store.FindWebhook{
ID: &request.Id,
CreatorID: &currentUser.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get webhook, error: %+v", err)
}
if webhook == nil {
return nil, status.Errorf(codes.NotFound, "webhook not found")
}
return &apiv2pb.GetWebhookResponse{
Webhook: convertWebhookFromStore(webhook),
}, nil
}
func (s *APIV2Service) UpdateWebhook(ctx context.Context, request *apiv2pb.UpdateWebhookRequest) (*apiv2pb.UpdateWebhookResponse, error) {
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
}
update := &store.UpdateWebhook{}
for _, field := range request.UpdateMask.Paths {
switch field {
case "row_status":
rowStatus := storepb.RowStatus(storepb.RowStatus_value[request.Webhook.RowStatus.String()])
update.RowStatus = &rowStatus
case "name":
update.Name = &request.Webhook.Name
case "url":
update.URL = &request.Webhook.Url
}
}
webhook, err := s.Store.UpdateWebhook(ctx, update)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update webhook, error: %+v", err)
}
return &apiv2pb.UpdateWebhookResponse{
Webhook: convertWebhookFromStore(webhook),
}, nil
}
func (s *APIV2Service) DeleteWebhook(ctx context.Context, request *apiv2pb.DeleteWebhookRequest) (*apiv2pb.DeleteWebhookResponse, error) {
err := s.Store.DeleteWebhook(ctx, &store.DeleteWebhook{
ID: request.Id,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete webhook, error: %+v", err)
}
return &apiv2pb.DeleteWebhookResponse{}, nil
}
func convertWebhookFromStore(webhook *storepb.Webhook) *apiv2pb.Webhook {
return &apiv2pb.Webhook{
Id: webhook.Id,
CreatedTime: timestamppb.New(time.Unix(webhook.CreatedTs, 0)),
UpdatedTime: timestamppb.New(time.Unix(webhook.UpdatedTs, 0)),
RowStatus: apiv2pb.RowStatus(webhook.RowStatus),
CreatorId: webhook.CreatorId,
Name: webhook.Name,
Url: webhook.Url,
}
}

View File

@@ -0,0 +1,17 @@
package v2
import (
"context"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
)
func (s *APIV2Service) GetWorkspaceProfile(_ context.Context, _ *apiv2pb.GetWorkspaceProfileRequest) (*apiv2pb.GetWorkspaceProfileResponse, error) {
workspaceProfile := &apiv2pb.WorkspaceProfile{
Version: s.Profile.Version,
Mode: s.Profile.Mode,
}
return &apiv2pb.GetWorkspaceProfileResponse{
WorkspaceProfile: workspaceProfile,
}, nil
}

View File

@@ -0,0 +1,95 @@
package v2
import (
"context"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV2Service) GetWorkspaceSetting(ctx context.Context, request *apiv2pb.GetWorkspaceSettingRequest) (*apiv2pb.GetWorkspaceSettingResponse, error) {
settingKeyString, err := ExtractWorkspaceSettingKeyFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid workspace setting name: %v", err)
}
settingKey := storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[settingKeyString])
workspaceSetting, err := s.Store.GetWorkspaceSettingV1(ctx, &store.FindWorkspaceSettingV1{
Key: settingKey,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace setting: %v", err)
}
if workspaceSetting == nil {
return nil, status.Errorf(codes.NotFound, "workspace setting not found")
}
return &apiv2pb.GetWorkspaceSettingResponse{
Setting: convertWorkspaceSettingFromStore(workspaceSetting),
}, nil
}
func (s *APIV2Service) SetWorkspaceSetting(ctx context.Context, request *apiv2pb.SetWorkspaceSettingRequest) (*apiv2pb.SetWorkspaceSettingResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if _, err := s.Store.UpsertWorkspaceSettingV1(ctx, convertWorkspaceSettingToStore(request.Setting)); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert workspace setting: %v", err)
}
return &apiv2pb.SetWorkspaceSettingResponse{}, nil
}
func convertWorkspaceSettingFromStore(setting *storepb.WorkspaceSetting) *apiv2pb.WorkspaceSetting {
return &apiv2pb.WorkspaceSetting{
Name: fmt.Sprintf("%s%s", WorkspaceSettingNamePrefix, setting.Key.String()),
Value: &apiv2pb.WorkspaceSetting_GeneralSetting{
GeneralSetting: convertWorkspaceGeneralSettingFromStore(setting.GetGeneral()),
},
}
}
func convertWorkspaceSettingToStore(setting *apiv2pb.WorkspaceSetting) *storepb.WorkspaceSetting {
settingKeyString, _ := ExtractWorkspaceSettingKeyFromName(setting.Name)
return &storepb.WorkspaceSetting{
Key: storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[settingKeyString]),
Value: &storepb.WorkspaceSetting_General{
General: convertWorkspaceGeneralSettingToStore(setting.GetGeneralSetting()),
},
}
}
func convertWorkspaceGeneralSettingFromStore(setting *storepb.WorkspaceGeneralSetting) *apiv2pb.WorkspaceGeneralSetting {
if setting == nil {
return nil
}
return &apiv2pb.WorkspaceGeneralSetting{
InstanceUrl: setting.InstanceUrl,
DisallowSignup: setting.DisallowSignup,
DisallowPasswordLogin: setting.DisallowPasswordLogin,
AdditionalScript: setting.AdditionalScript,
AdditionalStyle: setting.AdditionalStyle,
}
}
func convertWorkspaceGeneralSettingToStore(setting *apiv2pb.WorkspaceGeneralSetting) *storepb.WorkspaceGeneralSetting {
if setting == nil {
return nil
}
return &storepb.WorkspaceGeneralSetting{
InstanceUrl: setting.InstanceUrl,
DisallowSignup: setting.DisallowSignup,
DisallowPasswordLogin: setting.DisallowPasswordLogin,
AdditionalScript: setting.AdditionalScript,
AdditionalStyle: setting.AdditionalStyle,
}
}