chore: rename router package

This commit is contained in:
Steven
2024-05-01 10:28:32 +08:00
parent 8ae4bc95dc
commit 20dd3e17f7
32 changed files with 4 additions and 216 deletions

161
server/router/api/v1/acl.go Normal file
View File

@@ -0,0 +1,161 @@
package v1
import (
"context"
"net/http"
"strings"
"github.com/golang-jwt/jwt/v5"
"github.com/pkg/errors"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/usememos/memos/internal/util"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
// ContextKey is the key type of context value.
type ContextKey int
const (
// The key name used to store username in the context
// user id is extracted from the jwt token subject field.
usernameContextKey ContextKey = iota
)
// GRPCAuthInterceptor is the auth interceptor for gRPC server.
type GRPCAuthInterceptor struct {
Store *store.Store
secret string
}
// NewGRPCAuthInterceptor returns a new API auth interceptor.
func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor {
return &GRPCAuthInterceptor{
Store: store,
secret: secret,
}
}
// AuthenticationInterceptor is the unary interceptor for gRPC API.
func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
}
accessToken, err := getTokenFromMetadata(md)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, err.Error())
}
username, err := in.authenticate(ctx, accessToken)
if err != nil {
if isUnauthorizeAllowedMethod(serverInfo.FullMethod) {
return handler(ctx, request)
}
return nil, err
}
user, err := in.Store.GetUser(ctx, &store.FindUser{
Username: &username,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get user")
}
if user == nil {
return nil, errors.Errorf("user %q not exists", username)
}
if user.RowStatus == store.Archived {
return nil, errors.Errorf("user %q is archived", username)
}
if isOnlyForAdminAllowedMethod(serverInfo.FullMethod) && user.Role != store.RoleHost && user.Role != store.RoleAdmin {
return nil, errors.Errorf("user %q is not admin", username)
}
// Stores userID into context.
childCtx := context.WithValue(ctx, usernameContextKey, username)
return handler(childCtx, request)
}
func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessToken string) (string, error) {
if accessToken == "" {
return "", status.Errorf(codes.Unauthenticated, "access token not found")
}
claims := &ClaimsMessage{}
_, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(in.secret), nil
}
}
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
return "", status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
}
// We either have a valid access token or we will attempt to generate new access token.
userID, err := util.ConvertStringToInt32(claims.Subject)
if err != nil {
return "", errors.Wrap(err, "malformed ID in the token")
}
user, err := in.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return "", errors.Wrap(err, "failed to get user")
}
if user == nil {
return "", errors.Errorf("user %q not exists", userID)
}
if user.RowStatus == store.Archived {
return "", errors.Errorf("user %q is archived", userID)
}
accessTokens, err := in.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return "", errors.Wrapf(err, "failed to get user access tokens")
}
if !validateAccessToken(accessToken, accessTokens) {
return "", status.Errorf(codes.Unauthenticated, "invalid access token")
}
return user.Username, nil
}
func getTokenFromMetadata(md metadata.MD) (string, error) {
// Check the HTTP request header first.
authorizationHeaders := md.Get("Authorization")
if len(md.Get("Authorization")) > 0 {
authHeaderParts := strings.Fields(authorizationHeaders[0])
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
// Check the cookie header.
var accessToken string
for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) {
header := http.Header{}
header.Add("Cookie", t)
request := http.Request{Header: header}
if v, _ := request.Cookie(AccessTokenCookieName); v != nil {
accessToken = v.Value
}
}
return accessToken, nil
}
func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
for _, userAccessToken := range userAccessTokens {
if accessTokenString == userAccessToken.AccessToken {
return true
}
}
return false
}

View File

@@ -0,0 +1,38 @@
package v1
var authenticationAllowlistMethods = map[string]bool{
"/memos.api.v1.WorkspaceService/GetWorkspaceProfile": true,
"/memos.api.v1.WorkspaceSettingService/GetWorkspaceSetting": true,
"/memos.api.v1.WorkspaceSettingService/ListWorkspaceSettings": true,
"/memos.api.v1.IdentityProviderService/ListIdentityProviders": true,
"/memos.api.v1.IdentityProviderService/GetIdentityProvider": true,
"/memos.api.v1.AuthService/GetAuthStatus": true,
"/memos.api.v1.AuthService/SignIn": true,
"/memos.api.v1.AuthService/SignInWithSSO": true,
"/memos.api.v1.AuthService/SignOut": true,
"/memos.api.v1.AuthService/SignUp": true,
"/memos.api.v1.UserService/GetUser": true,
"/memos.api.v1.UserService/GetUserAvatar": true,
"/memos.api.v1.UserService/SearchUsers": true,
"/memos.api.v1.MemoService/ListMemos": true,
"/memos.api.v1.MemoService/GetMemo": true,
"/memos.api.v1.MemoService/SearchMemos": true,
"/memos.api.v1.MemoService/ListMemoResources": true,
"/memos.api.v1.MemoService/ListMemoRelations": true,
"/memos.api.v1.MemoService/ListMemoComments": true,
"/memos.api.v1.LinkService/GetLinkMetadata": true,
}
// isUnauthorizeAllowedMethod returns whether the method is exempted from authentication.
func isUnauthorizeAllowedMethod(fullMethodName string) bool {
return authenticationAllowlistMethods[fullMethodName]
}
var allowedMethodsOnlyForAdmin = map[string]bool{
"/memos.api.v1.UserService/CreateUser": true,
}
// isOnlyForAdminAllowedMethod returns true if the method is allowed to be called only by admin.
func isOnlyForAdminAllowedMethod(methodName string) bool {
return allowedMethodsOnlyForAdmin[methodName]
}

View File

@@ -0,0 +1,56 @@
package v1
import (
"context"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) GetActivity(ctx context.Context, request *v1pb.GetActivityRequest) (*v1pb.Activity, error) {
activity, err := s.Store.GetActivity(ctx, &store.FindActivity{
ID: &request.Id,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get activity: %v", err)
}
activityMessage, err := s.convertActivityFromStore(ctx, activity)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert activity from store: %v", err)
}
return activityMessage, nil
}
func (*APIV1Service) convertActivityFromStore(_ context.Context, activity *store.Activity) (*v1pb.Activity, error) {
return &v1pb.Activity{
Id: activity.ID,
CreatorId: activity.CreatorID,
Type: activity.Type.String(),
Level: activity.Level.String(),
CreateTime: timestamppb.New(time.Unix(activity.CreatedTs, 0)),
Payload: convertActivityPayloadFromStore(activity.Payload),
}, nil
}
func convertActivityPayloadFromStore(payload *storepb.ActivityPayload) *v1pb.ActivityPayload {
v2Payload := &v1pb.ActivityPayload{}
if payload.MemoComment != nil {
v2Payload.MemoComment = &v1pb.ActivityMemoCommentPayload{
MemoId: payload.MemoComment.MemoId,
RelatedMemoId: payload.MemoComment.RelatedMemoId,
}
}
if payload.VersionUpdate != nil {
v2Payload.VersionUpdate = &v1pb.ActivityVersionUpdatePayload{
Version: payload.VersionUpdate.Version,
}
}
return v2Payload
}

View File

@@ -0,0 +1,63 @@
package v1
import (
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
)
const (
// issuer is the issuer of the jwt token.
Issuer = "memos"
// Signing key section. For now, this is only used for signing, not for verifying since we only
// have 1 version. But it will be used to maintain backward compatibility if we change the signing mechanism.
KeyID = "v1"
// AccessTokenAudienceName is the audience name of the access token.
AccessTokenAudienceName = "user.access-token"
AccessTokenDuration = 7 * 24 * time.Hour
// CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user
// cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt.
CookieExpDuration = AccessTokenDuration - 1*time.Minute
// AccessTokenCookieName is the cookie name of access token.
AccessTokenCookieName = "memos.access-token"
)
type ClaimsMessage struct {
Name string `json:"name"`
jwt.RegisteredClaims
}
// GenerateAccessToken generates an access token.
func GenerateAccessToken(username string, userID int32, expirationTime time.Time, secret []byte) (string, error) {
return generateToken(username, userID, AccessTokenAudienceName, expirationTime, secret)
}
// generateToken generates a jwt token.
func generateToken(username string, userID int32, audience string, expirationTime time.Time, secret []byte) (string, error) {
registeredClaims := jwt.RegisteredClaims{
Issuer: Issuer,
Audience: jwt.ClaimStrings{audience},
IssuedAt: jwt.NewNumericDate(time.Now()),
Subject: fmt.Sprint(userID),
}
if !expirationTime.IsZero() {
registeredClaims.ExpiresAt = jwt.NewNumericDate(expirationTime)
}
// Declare the token with the HS256 algorithm used for signing, and the claims.
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &ClaimsMessage{
Name: username,
RegisteredClaims: registeredClaims,
})
token.Header["kid"] = KeyID
// Create the JWT string.
tokenString, err := token.SignedString(secret)
if err != nil {
return "", err
}
return tokenString, nil
}

View File

@@ -0,0 +1,264 @@
package v1
import (
"context"
"fmt"
"regexp"
"strings"
"time"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/idp"
"github.com/usememos/memos/plugin/idp/oauth2"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) GetAuthStatus(ctx context.Context, _ *v1pb.GetAuthStatusRequest) (*v1pb.User, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err)
}
if user == nil {
// Set the cookie header to expire access token.
if err := s.clearAccessTokenCookie(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to set grpc header: %v", err)
}
return nil, status.Errorf(codes.Unauthenticated, "user not found")
}
return convertUserFromStore(user), nil
}
func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) (*v1pb.User, error) {
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &request.Username,
})
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to find user by username %s", request.Username))
}
if user == nil {
return nil, status.Errorf(codes.InvalidArgument, fmt.Sprintf("user not found with username %s", request.Username))
} else if user.RowStatus == store.Archived {
return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("user has been archived with username %s", request.Username))
}
// Compare the stored hashed password, with the hashed version of the password that was received.
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(request.Password)); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "unmatched email and password")
}
expireTime := time.Now().Add(AccessTokenDuration)
if request.NeverExpire {
// Set the expire time to 100 years.
expireTime = time.Now().Add(100 * 365 * 24 * time.Hour)
}
if err := s.doSignIn(ctx, user, expireTime); err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
}
return convertUserFromStore(user), nil
}
func (s *APIV1Service) SignInWithSSO(ctx context.Context, request *v1pb.SignInWithSSORequest) (*v1pb.User, error) {
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &request.IdpId,
})
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get identity provider, err: %s", err))
}
if identityProvider == nil {
return nil, status.Errorf(codes.InvalidArgument, fmt.Sprintf("identity provider not found with id %d", request.IdpId))
}
var userInfo *idp.IdentityProviderUserInfo
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.GetOauth2Config())
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create oauth2 identity provider, err: %s", err))
}
token, err := oauth2IdentityProvider.ExchangeToken(ctx, request.RedirectUri, request.Code)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to exchange token, err: %s", err))
}
userInfo, err = oauth2IdentityProvider.UserInfo(token)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get user info, err: %s", err))
}
}
identifierFilter := identityProvider.IdentifierFilter
if identifierFilter != "" {
identifierFilterRegex, err := regexp.Compile(identifierFilter)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to compile identifier filter regex, err: %s", err))
}
if !identifierFilterRegex.MatchString(userInfo.Identifier) {
return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("identifier %s is not allowed", userInfo.Identifier))
}
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &userInfo.Identifier,
})
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to find user by username %s", userInfo.Identifier))
}
if user == nil {
userCreate := &store.User{
Username: userInfo.Identifier,
// The new signup user should be normal user by default.
Role: store.RoleUser,
Nickname: userInfo.DisplayName,
Email: userInfo.Email,
}
password, err := util.RandomString(20)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to generate random password, err: %s", err))
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to generate password hash, err: %s", err))
}
userCreate.PasswordHash = string(passwordHash)
user, err = s.Store.CreateUser(ctx, userCreate)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create user, err: %s", err))
}
}
if user.RowStatus == store.Archived {
return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("user has been archived with username %s", userInfo.Identifier))
}
if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
}
return convertUserFromStore(user), nil
}
func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error {
accessToken, err := GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret))
if err != nil {
return status.Errorf(codes.Internal, fmt.Sprintf("failed to generate tokens, err: %s", err))
}
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken, "user login"); err != nil {
return status.Errorf(codes.Internal, fmt.Sprintf("failed to upsert access token to store, err: %s", err))
}
cookie, err := s.buildAccessTokenCookie(ctx, accessToken, expireTime)
if err != nil {
return status.Errorf(codes.Internal, fmt.Sprintf("failed to build access token cookie, err: %s", err))
}
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
"Set-Cookie": cookie,
})); err != nil {
return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
}
return nil
}
func (s *APIV1Service) SignUp(ctx context.Context, request *v1pb.SignUpRequest) (*v1pb.User, error) {
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get workspace setting, err: %s", err))
}
if workspaceGeneralSetting.DisallowSignup || workspaceGeneralSetting.DisallowPasswordLogin {
return nil, status.Errorf(codes.PermissionDenied, "sign up is not allowed")
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.Password), bcrypt.DefaultCost)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to generate password hash, err: %s", err))
}
create := &store.User{
Username: request.Username,
Nickname: request.Username,
PasswordHash: string(passwordHash),
}
if !util.UIDMatcher.MatchString(strings.ToLower(create.Username)) {
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", create.Username)
}
hostUserType := store.RoleHost
existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{
Role: &hostUserType,
})
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to list users, err: %s", err))
}
if len(existedHostUsers) == 0 {
// Change the default role to host if there is no host user.
create.Role = store.RoleHost
} else {
create.Role = store.RoleUser
}
user, err := s.Store.CreateUser(ctx, create)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create user, err: %s", err))
}
if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
}
return convertUserFromStore(user), nil
}
func (s *APIV1Service) SignOut(ctx context.Context, _ *v1pb.SignOutRequest) (*emptypb.Empty, error) {
if err := s.clearAccessTokenCookie(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) clearAccessTokenCookie(ctx context.Context) error {
cookie, err := s.buildAccessTokenCookie(ctx, "", time.Time{})
if err != nil {
return errors.Wrap(err, "failed to build access token cookie")
}
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
"Set-Cookie": cookie,
})); err != nil {
return errors.Wrap(err, "failed to set grpc header")
}
return nil
}
func (*APIV1Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) {
attrs := []string{
fmt.Sprintf("%s=%s", AccessTokenCookieName, accessToken),
"Path=/",
"HttpOnly",
}
if expireTime.IsZero() {
attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT")
} else {
attrs = append(attrs, "Expires="+expireTime.Format(time.RFC1123))
}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", errors.New("failed to get metadata from context")
}
var origin string
for _, v := range md.Get("origin") {
origin = v
}
isHTTPS := strings.HasPrefix(origin, "https://")
if isHTTPS {
attrs = append(attrs, "SameSite=None")
attrs = append(attrs, "Secure")
} else {
attrs = append(attrs, "SameSite=Strict")
}
return strings.Join(attrs, "; "), nil
}

View File

@@ -0,0 +1,74 @@
package v1
import (
"context"
"encoding/base64"
"github.com/pkg/errors"
"google.golang.org/protobuf/proto"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func convertRowStatusFromStore(rowStatus store.RowStatus) v1pb.RowStatus {
switch rowStatus {
case store.Normal:
return v1pb.RowStatus_ACTIVE
case store.Archived:
return v1pb.RowStatus_ARCHIVED
default:
return v1pb.RowStatus_ROW_STATUS_UNSPECIFIED
}
}
func convertRowStatusToStore(rowStatus v1pb.RowStatus) store.RowStatus {
switch rowStatus {
case v1pb.RowStatus_ACTIVE:
return store.Normal
case v1pb.RowStatus_ARCHIVED:
return store.Archived
default:
return store.Normal
}
}
func getCurrentUser(ctx context.Context, s *store.Store) (*store.User, error) {
username, ok := ctx.Value(usernameContextKey).(string)
if !ok {
return nil, nil
}
user, err := s.GetUser(ctx, &store.FindUser{
Username: &username,
})
if err != nil {
return nil, err
}
return user, nil
}
func getPageToken(limit int, offset int) (string, error) {
return marshalPageToken(&v1pb.PageToken{
Limit: int32(limit),
Offset: int32(offset),
})
}
func marshalPageToken(pageToken *v1pb.PageToken) (string, error) {
b, err := proto.Marshal(pageToken)
if err != nil {
return "", errors.Wrapf(err, "failed to marshal page token")
}
return base64.StdEncoding.EncodeToString(b), nil
}
func unmarshalPageToken(s string, pageToken *v1pb.PageToken) error {
b, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return errors.Wrapf(err, "failed to decode page token")
}
if err := proto.Unmarshal(b, pageToken); err != nil {
return errors.Wrapf(err, "failed to unmarshal page token")
}
return nil
}

View File

@@ -0,0 +1,169 @@
package v1
import (
"context"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb.CreateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
identityProvider, err := s.Store.CreateIdentityProvider(ctx, convertIdentityProviderToStore(request.IdentityProvider))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create identity provider, error: %+v", err)
}
return convertIdentityProviderFromStore(identityProvider), nil
}
func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListIdentityProvidersRequest) (*v1pb.ListIdentityProvidersResponse, error) {
identityProviders, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list identity providers, error: %+v", err)
}
response := &v1pb.ListIdentityProvidersResponse{
IdentityProviders: []*v1pb.IdentityProvider{},
}
for _, identityProvider := range identityProviders {
response.IdentityProviders = append(response.IdentityProviders, convertIdentityProviderFromStore(identityProvider))
}
return response, nil
}
func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.GetIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
id, err := ExtractIdentityProviderIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &id,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err)
}
if identityProvider == nil {
return nil, status.Errorf(codes.NotFound, "identity provider not found")
}
return convertIdentityProviderFromStore(identityProvider), nil
}
func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb.UpdateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
}
id, err := ExtractIdentityProviderIDFromName(request.IdentityProvider.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
update := &store.UpdateIdentityProviderV1{
ID: id,
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[request.IdentityProvider.Type.String()]),
}
for _, field := range request.UpdateMask.Paths {
switch field {
case "title":
update.Name = &request.IdentityProvider.Title
case "config":
update.Config = convertIdentityProviderConfigToStore(request.IdentityProvider.Type, request.IdentityProvider.Config)
}
}
identityProvider, err := s.Store.UpdateIdentityProvider(ctx, update)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update identity provider, error: %+v", err)
}
return convertIdentityProviderFromStore(identityProvider), nil
}
func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb.DeleteIdentityProviderRequest) (*emptypb.Empty, error) {
id, err := ExtractIdentityProviderIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: id}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete identity provider, error: %+v", err)
}
return &emptypb.Empty{}, nil
}
func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider) *v1pb.IdentityProvider {
temp := &v1pb.IdentityProvider{
Name: fmt.Sprintf("%s%d", IdentityProviderNamePrefix, identityProvider.Id),
Title: identityProvider.Name,
IdentifierFilter: identityProvider.IdentifierFilter,
Type: v1pb.IdentityProvider_Type(v1pb.IdentityProvider_Type_value[identityProvider.Type.String()]),
}
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
oauth2Config := identityProvider.Config.GetOauth2Config()
temp.Config = &v1pb.IdentityProviderConfig{
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &v1pb.OAuth2Config{
ClientId: oauth2Config.ClientId,
ClientSecret: oauth2Config.ClientSecret,
AuthUrl: oauth2Config.AuthUrl,
TokenUrl: oauth2Config.TokenUrl,
UserInfoUrl: oauth2Config.UserInfoUrl,
Scopes: oauth2Config.Scopes,
FieldMapping: &v1pb.FieldMapping{
Identifier: oauth2Config.FieldMapping.Identifier,
DisplayName: oauth2Config.FieldMapping.DisplayName,
Email: oauth2Config.FieldMapping.Email,
},
},
},
}
}
return temp
}
func convertIdentityProviderToStore(identityProvider *v1pb.IdentityProvider) *storepb.IdentityProvider {
id, _ := ExtractIdentityProviderIDFromName(identityProvider.Name)
temp := &storepb.IdentityProvider{
Id: id,
Name: identityProvider.Title,
IdentifierFilter: identityProvider.IdentifierFilter,
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[identityProvider.Type.String()]),
Config: convertIdentityProviderConfigToStore(identityProvider.Type, identityProvider.Config),
}
return temp
}
func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProvider_Type, config *v1pb.IdentityProviderConfig) *storepb.IdentityProviderConfig {
if identityProviderType == v1pb.IdentityProvider_OAUTH2 {
oauth2Config := config.GetOauth2Config()
return &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: oauth2Config.ClientId,
ClientSecret: oauth2Config.ClientSecret,
AuthUrl: oauth2Config.AuthUrl,
TokenUrl: oauth2Config.TokenUrl,
UserInfoUrl: oauth2Config.UserInfoUrl,
Scopes: oauth2Config.Scopes,
FieldMapping: &storepb.FieldMapping{
Identifier: oauth2Config.FieldMapping.Identifier,
DisplayName: oauth2Config.FieldMapping.DisplayName,
Email: oauth2Config.FieldMapping.Email,
},
},
},
}
}
return nil
}

View File

@@ -0,0 +1,115 @@
package v1
import (
"context"
"fmt"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) ListInboxes(ctx context.Context, _ *v1pb.ListInboxesRequest) (*v1pb.ListInboxesResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
inboxes, err := s.Store.ListInboxes(ctx, &store.FindInbox{
ReceiverID: &user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list inbox: %v", err)
}
response := &v1pb.ListInboxesResponse{
Inboxes: []*v1pb.Inbox{},
}
for _, inbox := range inboxes {
response.Inboxes = append(response.Inboxes, convertInboxFromStore(inbox))
}
return response, nil
}
func (s *APIV1Service) UpdateInbox(ctx context.Context, request *v1pb.UpdateInboxRequest) (*v1pb.Inbox, error) {
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
inboxID, err := ExtractInboxIDFromName(request.Inbox.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid inbox name: %v", err)
}
update := &store.UpdateInbox{
ID: inboxID,
}
for _, field := range request.UpdateMask.Paths {
if field == "status" {
if request.Inbox.Status == v1pb.Inbox_STATUS_UNSPECIFIED {
return nil, status.Errorf(codes.InvalidArgument, "status is required")
}
update.Status = convertInboxStatusToStore(request.Inbox.Status)
}
}
inbox, err := s.Store.UpdateInbox(ctx, update)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update inbox: %v", err)
}
return convertInboxFromStore(inbox), nil
}
func (s *APIV1Service) DeleteInbox(ctx context.Context, request *v1pb.DeleteInboxRequest) (*emptypb.Empty, error) {
inboxID, err := ExtractInboxIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid inbox name: %v", err)
}
if err := s.Store.DeleteInbox(ctx, &store.DeleteInbox{
ID: inboxID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update inbox: %v", err)
}
return &emptypb.Empty{}, nil
}
func convertInboxFromStore(inbox *store.Inbox) *v1pb.Inbox {
return &v1pb.Inbox{
Name: fmt.Sprintf("%s%d", InboxNamePrefix, inbox.ID),
Sender: fmt.Sprintf("%s%d", UserNamePrefix, inbox.SenderID),
Receiver: fmt.Sprintf("%s%d", UserNamePrefix, inbox.ReceiverID),
Status: convertInboxStatusFromStore(inbox.Status),
CreateTime: timestamppb.New(time.Unix(inbox.CreatedTs, 0)),
Type: v1pb.Inbox_Type(inbox.Message.Type),
ActivityId: inbox.Message.ActivityId,
}
}
func convertInboxStatusFromStore(status store.InboxStatus) v1pb.Inbox_Status {
switch status {
case store.UNREAD:
return v1pb.Inbox_UNREAD
case store.ARCHIVED:
return v1pb.Inbox_ARCHIVED
default:
return v1pb.Inbox_STATUS_UNSPECIFIED
}
}
func convertInboxStatusToStore(status v1pb.Inbox_Status) store.InboxStatus {
switch status {
case v1pb.Inbox_UNREAD:
return store.UNREAD
case v1pb.Inbox_ARCHIVED:
return store.ARCHIVED
default:
return store.UNREAD
}
}

View File

@@ -0,0 +1,48 @@
package v1
import (
"context"
"log/slog"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type LoggerInterceptor struct {
}
func NewLoggerInterceptor() *LoggerInterceptor {
return &LoggerInterceptor{}
}
func (in *LoggerInterceptor) LoggerInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
resp, err := handler(ctx, request)
in.loggerInterceptorDo(ctx, serverInfo.FullMethod, err)
return resp, err
}
func (*LoggerInterceptor) loggerInterceptorDo(ctx context.Context, fullMethod string, err error) {
st := status.Convert(err)
var logLevel slog.Level
var logMsg string
switch st.Code() {
case codes.OK:
logLevel = slog.LevelInfo
logMsg = "OK"
case codes.Unauthenticated, codes.OutOfRange, codes.PermissionDenied, codes.NotFound:
logLevel = slog.LevelInfo
logMsg = "client error"
case codes.Internal, codes.Unknown, codes.DataLoss, codes.Unavailable, codes.DeadlineExceeded:
logLevel = slog.LevelError
logMsg = "server error"
default:
logLevel = slog.LevelError
logMsg = "unknown error"
}
logAttrs := []slog.Attr{slog.String("method", fullMethod)}
if err != nil {
logAttrs = append(logAttrs, slog.String("error", err.Error()))
}
slog.LogAttrs(ctx, logLevel, logMsg, logAttrs...)
}

View File

@@ -0,0 +1,235 @@
package v1
import (
"context"
"github.com/pkg/errors"
"github.com/yourselfhosted/gomark/ast"
"github.com/yourselfhosted/gomark/parser"
"github.com/yourselfhosted/gomark/parser/tokenizer"
"github.com/yourselfhosted/gomark/restore"
getter "github.com/usememos/memos/plugin/http-getter"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
)
func (*APIV1Service) ParseMarkdown(_ context.Context, request *v1pb.ParseMarkdownRequest) (*v1pb.ParseMarkdownResponse, error) {
rawNodes, err := parser.Parse(tokenizer.Tokenize(request.Markdown))
if err != nil {
return nil, errors.Wrap(err, "failed to parse memo content")
}
nodes := convertFromASTNodes(rawNodes)
return &v1pb.ParseMarkdownResponse{
Nodes: nodes,
}, nil
}
func (*APIV1Service) RestoreMarkdown(_ context.Context, request *v1pb.RestoreMarkdownRequest) (*v1pb.RestoreMarkdownResponse, error) {
markdown := restore.Restore(convertToASTNodes(request.Nodes))
return &v1pb.RestoreMarkdownResponse{
Markdown: markdown,
}, nil
}
func (*APIV1Service) GetLinkMetadata(_ context.Context, request *v1pb.GetLinkMetadataRequest) (*v1pb.LinkMetadata, error) {
htmlMeta, err := getter.GetHTMLMeta(request.Link)
if err != nil {
return nil, err
}
return &v1pb.LinkMetadata{
Title: htmlMeta.Title,
Description: htmlMeta.Description,
Image: htmlMeta.Image,
}, nil
}
func convertFromASTNode(rawNode ast.Node) *v1pb.Node {
node := &v1pb.Node{
Type: v1pb.NodeType(v1pb.NodeType_value[string(rawNode.Type())]),
}
switch n := rawNode.(type) {
case *ast.LineBreak:
node.Node = &v1pb.Node_LineBreakNode{}
case *ast.Paragraph:
children := convertFromASTNodes(n.Children)
node.Node = &v1pb.Node_ParagraphNode{ParagraphNode: &v1pb.ParagraphNode{Children: children}}
case *ast.CodeBlock:
node.Node = &v1pb.Node_CodeBlockNode{CodeBlockNode: &v1pb.CodeBlockNode{Language: n.Language, Content: n.Content}}
case *ast.Heading:
children := convertFromASTNodes(n.Children)
node.Node = &v1pb.Node_HeadingNode{HeadingNode: &v1pb.HeadingNode{Level: int32(n.Level), Children: children}}
case *ast.HorizontalRule:
node.Node = &v1pb.Node_HorizontalRuleNode{HorizontalRuleNode: &v1pb.HorizontalRuleNode{Symbol: n.Symbol}}
case *ast.Blockquote:
children := convertFromASTNodes(n.Children)
node.Node = &v1pb.Node_BlockquoteNode{BlockquoteNode: &v1pb.BlockquoteNode{Children: children}}
case *ast.OrderedList:
children := convertFromASTNodes(n.Children)
node.Node = &v1pb.Node_OrderedListNode{OrderedListNode: &v1pb.OrderedListNode{Number: n.Number, Indent: int32(n.Indent), Children: children}}
case *ast.UnorderedList:
children := convertFromASTNodes(n.Children)
node.Node = &v1pb.Node_UnorderedListNode{UnorderedListNode: &v1pb.UnorderedListNode{Symbol: n.Symbol, Indent: int32(n.Indent), Children: children}}
case *ast.TaskList:
children := convertFromASTNodes(n.Children)
node.Node = &v1pb.Node_TaskListNode{TaskListNode: &v1pb.TaskListNode{Symbol: n.Symbol, Indent: int32(n.Indent), Complete: n.Complete, Children: children}}
case *ast.MathBlock:
node.Node = &v1pb.Node_MathBlockNode{MathBlockNode: &v1pb.MathBlockNode{Content: n.Content}}
case *ast.Table:
node.Node = &v1pb.Node_TableNode{TableNode: convertTableFromASTNode(n)}
case *ast.EmbeddedContent:
node.Node = &v1pb.Node_EmbeddedContentNode{EmbeddedContentNode: &v1pb.EmbeddedContentNode{ResourceName: n.ResourceName, Params: n.Params}}
case *ast.Text:
node.Node = &v1pb.Node_TextNode{TextNode: &v1pb.TextNode{Content: n.Content}}
case *ast.Bold:
children := convertFromASTNodes(n.Children)
node.Node = &v1pb.Node_BoldNode{BoldNode: &v1pb.BoldNode{Symbol: n.Symbol, Children: children}}
case *ast.Italic:
node.Node = &v1pb.Node_ItalicNode{ItalicNode: &v1pb.ItalicNode{Symbol: n.Symbol, Content: n.Content}}
case *ast.BoldItalic:
node.Node = &v1pb.Node_BoldItalicNode{BoldItalicNode: &v1pb.BoldItalicNode{Symbol: n.Symbol, Content: n.Content}}
case *ast.Code:
node.Node = &v1pb.Node_CodeNode{CodeNode: &v1pb.CodeNode{Content: n.Content}}
case *ast.Image:
node.Node = &v1pb.Node_ImageNode{ImageNode: &v1pb.ImageNode{AltText: n.AltText, Url: n.URL}}
case *ast.Link:
node.Node = &v1pb.Node_LinkNode{LinkNode: &v1pb.LinkNode{Text: n.Text, Url: n.URL}}
case *ast.AutoLink:
node.Node = &v1pb.Node_AutoLinkNode{AutoLinkNode: &v1pb.AutoLinkNode{Url: n.URL, IsRawText: n.IsRawText}}
case *ast.Tag:
node.Node = &v1pb.Node_TagNode{TagNode: &v1pb.TagNode{Content: n.Content}}
case *ast.Strikethrough:
node.Node = &v1pb.Node_StrikethroughNode{StrikethroughNode: &v1pb.StrikethroughNode{Content: n.Content}}
case *ast.EscapingCharacter:
node.Node = &v1pb.Node_EscapingCharacterNode{EscapingCharacterNode: &v1pb.EscapingCharacterNode{Symbol: n.Symbol}}
case *ast.Math:
node.Node = &v1pb.Node_MathNode{MathNode: &v1pb.MathNode{Content: n.Content}}
case *ast.Highlight:
node.Node = &v1pb.Node_HighlightNode{HighlightNode: &v1pb.HighlightNode{Content: n.Content}}
case *ast.Subscript:
node.Node = &v1pb.Node_SubscriptNode{SubscriptNode: &v1pb.SubscriptNode{Content: n.Content}}
case *ast.Superscript:
node.Node = &v1pb.Node_SuperscriptNode{SuperscriptNode: &v1pb.SuperscriptNode{Content: n.Content}}
case *ast.ReferencedContent:
node.Node = &v1pb.Node_ReferencedContentNode{ReferencedContentNode: &v1pb.ReferencedContentNode{ResourceName: n.ResourceName, Params: n.Params}}
case *ast.Spoiler:
node.Node = &v1pb.Node_SpoilerNode{SpoilerNode: &v1pb.SpoilerNode{Content: n.Content}}
default:
node.Node = &v1pb.Node_TextNode{TextNode: &v1pb.TextNode{}}
}
return node
}
func convertFromASTNodes(rawNodes []ast.Node) []*v1pb.Node {
nodes := []*v1pb.Node{}
for _, rawNode := range rawNodes {
node := convertFromASTNode(rawNode)
nodes = append(nodes, node)
}
return nodes
}
func convertTableFromASTNode(node *ast.Table) *v1pb.TableNode {
table := &v1pb.TableNode{
Header: node.Header,
Delimiter: node.Delimiter,
}
for _, row := range node.Rows {
table.Rows = append(table.Rows, &v1pb.TableNode_Row{Cells: row})
}
return table
}
func convertToASTNode(node *v1pb.Node) ast.Node {
switch n := node.Node.(type) {
case *v1pb.Node_LineBreakNode:
return &ast.LineBreak{}
case *v1pb.Node_ParagraphNode:
children := convertToASTNodes(n.ParagraphNode.Children)
return &ast.Paragraph{Children: children}
case *v1pb.Node_CodeBlockNode:
return &ast.CodeBlock{Language: n.CodeBlockNode.Language, Content: n.CodeBlockNode.Content}
case *v1pb.Node_HeadingNode:
children := convertToASTNodes(n.HeadingNode.Children)
return &ast.Heading{Level: int(n.HeadingNode.Level), Children: children}
case *v1pb.Node_HorizontalRuleNode:
return &ast.HorizontalRule{Symbol: n.HorizontalRuleNode.Symbol}
case *v1pb.Node_BlockquoteNode:
children := convertToASTNodes(n.BlockquoteNode.Children)
return &ast.Blockquote{Children: children}
case *v1pb.Node_OrderedListNode:
children := convertToASTNodes(n.OrderedListNode.Children)
return &ast.OrderedList{Number: n.OrderedListNode.Number, Indent: int(n.OrderedListNode.Indent), Children: children}
case *v1pb.Node_UnorderedListNode:
children := convertToASTNodes(n.UnorderedListNode.Children)
return &ast.UnorderedList{Symbol: n.UnorderedListNode.Symbol, Indent: int(n.UnorderedListNode.Indent), Children: children}
case *v1pb.Node_TaskListNode:
children := convertToASTNodes(n.TaskListNode.Children)
return &ast.TaskList{Symbol: n.TaskListNode.Symbol, Indent: int(n.TaskListNode.Indent), Complete: n.TaskListNode.Complete, Children: children}
case *v1pb.Node_MathBlockNode:
return &ast.MathBlock{Content: n.MathBlockNode.Content}
case *v1pb.Node_TableNode:
return convertTableToASTNode(n.TableNode)
case *v1pb.Node_EmbeddedContentNode:
return &ast.EmbeddedContent{ResourceName: n.EmbeddedContentNode.ResourceName, Params: n.EmbeddedContentNode.Params}
case *v1pb.Node_TextNode:
return &ast.Text{Content: n.TextNode.Content}
case *v1pb.Node_BoldNode:
children := convertToASTNodes(n.BoldNode.Children)
return &ast.Bold{Symbol: n.BoldNode.Symbol, Children: children}
case *v1pb.Node_ItalicNode:
return &ast.Italic{Symbol: n.ItalicNode.Symbol, Content: n.ItalicNode.Content}
case *v1pb.Node_BoldItalicNode:
return &ast.BoldItalic{Symbol: n.BoldItalicNode.Symbol, Content: n.BoldItalicNode.Content}
case *v1pb.Node_CodeNode:
return &ast.Code{Content: n.CodeNode.Content}
case *v1pb.Node_ImageNode:
return &ast.Image{AltText: n.ImageNode.AltText, URL: n.ImageNode.Url}
case *v1pb.Node_LinkNode:
return &ast.Link{Text: n.LinkNode.Text, URL: n.LinkNode.Url}
case *v1pb.Node_AutoLinkNode:
return &ast.AutoLink{URL: n.AutoLinkNode.Url, IsRawText: n.AutoLinkNode.IsRawText}
case *v1pb.Node_TagNode:
return &ast.Tag{Content: n.TagNode.Content}
case *v1pb.Node_StrikethroughNode:
return &ast.Strikethrough{Content: n.StrikethroughNode.Content}
case *v1pb.Node_EscapingCharacterNode:
return &ast.EscapingCharacter{Symbol: n.EscapingCharacterNode.Symbol}
case *v1pb.Node_MathNode:
return &ast.Math{Content: n.MathNode.Content}
case *v1pb.Node_HighlightNode:
return &ast.Highlight{Content: n.HighlightNode.Content}
case *v1pb.Node_SubscriptNode:
return &ast.Subscript{Content: n.SubscriptNode.Content}
case *v1pb.Node_SuperscriptNode:
return &ast.Superscript{Content: n.SuperscriptNode.Content}
case *v1pb.Node_ReferencedContentNode:
return &ast.ReferencedContent{ResourceName: n.ReferencedContentNode.ResourceName, Params: n.ReferencedContentNode.Params}
case *v1pb.Node_SpoilerNode:
return &ast.Spoiler{Content: n.SpoilerNode.Content}
default:
return &ast.Text{}
}
}
func convertToASTNodes(nodes []*v1pb.Node) []ast.Node {
rawNodes := []ast.Node{}
for _, node := range nodes {
rawNode := convertToASTNode(node)
rawNodes = append(rawNodes, rawNode)
}
return rawNodes
}
func convertTableToASTNode(node *v1pb.TableNode) *ast.Table {
table := &ast.Table{
Header: node.Header,
Delimiter: node.Delimiter,
}
for _, row := range node.Rows {
table.Rows = append(table.Rows, row.Cells)
}
return table
}

View File

@@ -0,0 +1,114 @@
package v1
import (
"context"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMemoRelationsRequest) (*emptypb.Empty, error) {
id, err := ExtractMemoIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
referenceType := store.MemoRelationReference
// Delete all reference relations first.
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
MemoID: &id,
Type: &referenceType,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo relation")
}
for _, relation := range request.Relations {
// Ignore reflexive relations.
if request.Name == relation.RelatedMemo {
continue
}
// Ignore comment relations as there's no need to update a comment's relation.
// Inserting/Deleting a comment is handled elsewhere.
if relation.Type == v1pb.MemoRelation_COMMENT {
continue
}
relatedMemoID, err := ExtractMemoIDFromName(relation.RelatedMemo)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid related memo name: %v", err)
}
if _, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: id,
RelatedMemoID: relatedMemoID,
Type: convertMemoRelationTypeToStore(relation.Type),
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert memo relation")
}
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) ListMemoRelations(ctx context.Context, request *v1pb.ListMemoRelationsRequest) (*v1pb.ListMemoRelationsResponse, error) {
id, err := ExtractMemoIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
relationList := []*v1pb.MemoRelation{}
tempList, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoID: &id,
})
if err != nil {
return nil, err
}
for _, relation := range tempList {
relationList = append(relationList, convertMemoRelationFromStore(relation))
}
tempList, err = s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &id,
})
if err != nil {
return nil, err
}
for _, relation := range tempList {
relationList = append(relationList, convertMemoRelationFromStore(relation))
}
response := &v1pb.ListMemoRelationsResponse{
Relations: relationList,
}
return response, nil
}
func convertMemoRelationFromStore(memoRelation *store.MemoRelation) *v1pb.MemoRelation {
return &v1pb.MemoRelation{
Memo: fmt.Sprintf("%s%d", MemoNamePrefix, memoRelation.MemoID),
RelatedMemo: fmt.Sprintf("%s%d", MemoNamePrefix, memoRelation.RelatedMemoID),
Type: convertMemoRelationTypeFromStore(memoRelation.Type),
}
}
func convertMemoRelationTypeFromStore(relationType store.MemoRelationType) v1pb.MemoRelation_Type {
switch relationType {
case store.MemoRelationReference:
return v1pb.MemoRelation_REFERENCE
case store.MemoRelationComment:
return v1pb.MemoRelation_COMMENT
default:
return v1pb.MemoRelation_TYPE_UNSPECIFIED
}
}
func convertMemoRelationTypeToStore(relationType v1pb.MemoRelation_Type) store.MemoRelationType {
switch relationType {
case v1pb.MemoRelation_REFERENCE:
return store.MemoRelationReference
case v1pb.MemoRelation_COMMENT:
return store.MemoRelationComment
default:
return store.MemoRelationReference
}
}

View File

@@ -0,0 +1,86 @@
package v1
import (
"context"
"slices"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) SetMemoResources(ctx context.Context, request *v1pb.SetMemoResourcesRequest) (*emptypb.Empty, error) {
memoID, err := ExtractMemoIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
resources, err := s.Store.ListResources(ctx, &store.FindResource{
MemoID: &memoID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list resources")
}
// Delete resources that are not in the request.
for _, resource := range resources {
found := false
for _, requestResource := range request.Resources {
if resource.UID == requestResource.Uid {
found = true
break
}
}
if !found {
if err = s.Store.DeleteResource(ctx, &store.DeleteResource{
ID: int32(resource.ID),
MemoID: &memoID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete resource")
}
}
}
slices.Reverse(request.Resources)
// Update resources' memo_id in the request.
for index, resource := range request.Resources {
id, err := ExtractResourceIDFromName(resource.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid resource name: %v", err)
}
updatedTs := time.Now().Unix() + int64(index)
if _, err := s.Store.UpdateResource(ctx, &store.UpdateResource{
ID: id,
MemoID: &memoID,
UpdatedTs: &updatedTs,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update resource: %v", err)
}
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) ListMemoResources(ctx context.Context, request *v1pb.ListMemoResourcesRequest) (*v1pb.ListMemoResourcesResponse, error) {
id, err := ExtractMemoIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
resources, err := s.Store.ListResources(ctx, &store.FindResource{
MemoID: &id,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list resources")
}
response := &v1pb.ListMemoResourcesResponse{
Resources: []*v1pb.Resource{},
}
for _, resource := range resources {
response.Resources = append(response.Resources, s.convertResourceFromStore(ctx, resource))
}
return response, nil
}

View File

@@ -0,0 +1,874 @@
package v1
import (
"archive/zip"
"bytes"
"context"
"fmt"
"log/slog"
"time"
"github.com/google/cel-go/cel"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors"
"github.com/yourselfhosted/gomark/parser"
"github.com/yourselfhosted/gomark/parser/tokenizer"
expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/webhook"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
const (
DefaultPageSize = 10
MaxContentLength = 8 * 1024
ChunkSize = 64 * 1024 // 64 KiB
)
func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoRequest) (*v1pb.Memo, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if len(request.Content) > MaxContentLength {
return nil, status.Errorf(codes.InvalidArgument, "content too long")
}
create := &store.Memo{
UID: shortuuid.New(),
CreatorID: user.ID,
Content: request.Content,
Visibility: convertVisibilityToStore(request.Visibility),
}
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace memo related setting")
}
if workspaceMemoRelatedSetting.DisallowPublicVisible && create.Visibility == store.Public {
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
}
memo, err := s.Store.CreateMemo(ctx, create)
if err != nil {
return nil, err
}
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
// Try to dispatch webhook when memo is created.
if err := s.DispatchMemoCreatedWebhook(ctx, memoMessage); err != nil {
slog.Warn("Failed to dispatch memo created webhook", err)
}
return memoMessage, nil
}
func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosRequest) (*v1pb.ListMemosResponse, error) {
memoFind := &store.FindMemo{
// Exclude comments by default.
ExcludeComments: true,
}
if err := s.buildMemoFindWithFilter(ctx, memoFind, request.Filter); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to build find memos with filter")
}
var limit, offset int
if request.PageToken != "" {
var pageToken v1pb.PageToken
if err := unmarshalPageToken(request.PageToken, &pageToken); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid page token: %v", err)
}
limit = int(pageToken.Limit)
offset = int(pageToken.Offset)
} else {
limit = int(request.PageSize)
}
if limit <= 0 {
limit = DefaultPageSize
}
limitPlusOne := limit + 1
memoFind.Limit = &limitPlusOne
memoFind.Offset = &offset
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
memoMessages := []*v1pb.Memo{}
nextPageToken := ""
if len(memos) == limitPlusOne {
memos = memos[:limit]
nextPageToken, err = getPageToken(limit, offset+limit)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get next page token, error: %v", err)
}
}
for _, memo := range memos {
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
memoMessages = append(memoMessages, memoMessage)
}
response := &v1pb.ListMemosResponse{
Memos: memoMessages,
NextPageToken: nextPageToken,
}
return response, nil
}
func (s *APIV1Service) SearchMemos(ctx context.Context, request *v1pb.SearchMemosRequest) (*v1pb.SearchMemosResponse, error) {
defaultSearchLimit := 10
memoFind := &store.FindMemo{
// Exclude comments by default.
ExcludeComments: true,
Limit: &defaultSearchLimit,
}
err := s.buildMemoFindWithFilter(ctx, memoFind, request.Filter)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to build find memos with filter")
}
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to search memos")
}
memoMessages := []*v1pb.Memo{}
for _, memo := range memos {
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
memoMessages = append(memoMessages, memoMessage)
}
response := &v1pb.SearchMemosResponse{
Memos: memoMessages,
}
return response, nil
}
func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest) (*v1pb.Memo, error) {
id, err := ExtractMemoIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: &id,
})
if err != nil {
return nil, err
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if memo.Visibility != store.Public {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if memo.Visibility == store.Private && memo.CreatorID != user.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
return memoMessage, nil
}
func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoRequest) (*v1pb.Memo, error) {
id, err := ExtractMemoIDFromName(request.Memo.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &id})
if err != nil {
return nil, err
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
user, _ := getCurrentUser(ctx, s.Store)
if memo.CreatorID != user.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
currentTs := time.Now().Unix()
update := &store.UpdateMemo{
ID: id,
UpdatedTs: &currentTs,
}
for _, path := range request.UpdateMask.Paths {
if path == "content" {
update.Content = &request.Memo.Content
} else if path == "uid" {
update.UID = &request.Memo.Name
if !util.UIDMatcher.MatchString(*update.UID) {
return nil, status.Errorf(codes.InvalidArgument, "invalid resource name")
}
} else if path == "visibility" {
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace memo related setting")
}
visibility := convertVisibilityToStore(request.Memo.Visibility)
if workspaceMemoRelatedSetting.DisallowPublicVisible && visibility == store.Public {
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
}
update.Visibility = &visibility
} else if path == "row_status" {
rowStatus := convertRowStatusToStore(request.Memo.RowStatus)
update.RowStatus = &rowStatus
} else if path == "created_ts" {
createdTs := request.Memo.CreateTime.AsTime().Unix()
update.CreatedTs = &createdTs
} else if path == "pinned" {
if _, err := s.Store.UpsertMemoOrganizer(ctx, &store.MemoOrganizer{
MemoID: id,
UserID: user.ID,
Pinned: request.Memo.Pinned,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert memo organizer")
}
}
}
if update.Content != nil && len(*update.Content) > MaxContentLength {
return nil, status.Errorf(codes.InvalidArgument, "content too long")
}
if err = s.Store.UpdateMemo(ctx, update); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update memo")
}
memo, err = s.Store.GetMemo(ctx, &store.FindMemo{
ID: &id,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get memo")
}
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
// Try to dispatch webhook when memo is updated.
if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil {
slog.Warn("Failed to dispatch memo updated webhook", err)
}
return memoMessage, nil
}
func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoRequest) (*emptypb.Empty, error) {
id, err := ExtractMemoIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: &id,
})
if err != nil {
return nil, err
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
user, _ := getCurrentUser(ctx, s.Store)
if memo.CreatorID != user.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if memoMessage, err := s.convertMemoFromStore(ctx, memo); err == nil {
// Try to dispatch webhook when memo is deleted.
if err := s.DispatchMemoDeletedWebhook(ctx, memoMessage); err != nil {
slog.Warn("Failed to dispatch memo deleted webhook", err)
}
}
if err = s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: id}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo")
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.CreateMemoCommentRequest) (*v1pb.Memo, error) {
id, err := ExtractMemoIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &id})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
// Create the comment memo first.
memo, err := s.CreateMemo(ctx, request.Comment)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create memo")
}
// Build the relation between the comment memo and the original memo.
memoID, err := ExtractMemoIDFromName(memo.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
_, err = s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memoID,
RelatedMemoID: relatedMemo.ID,
Type: store.MemoRelationComment,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create memo relation")
}
creatorID, err := ExtractUserIDFromName(memo.Creator)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo creator")
}
if memo.Visibility != v1pb.Visibility_PRIVATE && creatorID != relatedMemo.CreatorID {
activity, err := s.Store.CreateActivity(ctx, &store.Activity{
CreatorID: creatorID,
Type: store.ActivityTypeMemoComment,
Level: store.ActivityLevelInfo,
Payload: &storepb.ActivityPayload{
MemoComment: &storepb.ActivityMemoCommentPayload{
MemoId: memoID,
RelatedMemoId: relatedMemo.ID,
},
},
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create activity")
}
if _, err := s.Store.CreateInbox(ctx, &store.Inbox{
SenderID: creatorID,
ReceiverID: relatedMemo.CreatorID,
Status: store.UNREAD,
Message: &storepb.InboxMessage{
Type: storepb.InboxMessage_TYPE_MEMO_COMMENT,
ActivityId: &activity.ID,
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to create inbox")
}
}
return memo, nil
}
func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListMemoCommentsRequest) (*v1pb.ListMemoCommentsResponse, error) {
id, err := ExtractMemoIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memoRelationComment := store.MemoRelationComment
memoRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &id,
Type: &memoRelationComment,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memo relations")
}
var memos []*v1pb.Memo
for _, memoRelation := range memoRelations {
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: &memoRelation.MemoID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
if memo != nil {
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
memos = append(memos, memoMessage)
}
}
response := &v1pb.ListMemoCommentsResponse{
Memos: memos,
}
return response, nil
}
func (s *APIV1Service) GetUserMemosStats(ctx context.Context, request *v1pb.GetUserMemosStatsRequest) (*v1pb.GetUserMemosStatsResponse, error) {
userID, err := ExtractUserIDFromName(request.Name)
if err != nil {
return nil, errors.Wrap(err, "invalid user name")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
normalRowStatus := store.Normal
memoFind := &store.FindMemo{
CreatorID: &user.ID,
RowStatus: &normalRowStatus,
ExcludeComments: true,
ExcludeContent: true,
}
if err := s.buildMemoFindWithFilter(ctx, memoFind, request.Filter); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to build find memos with filter")
}
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
location, err := time.LoadLocation(request.Timezone)
if err != nil {
return nil, status.Errorf(codes.Internal, "invalid timezone location")
}
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace memo related setting")
}
stats := make(map[string]int32)
for _, memo := range memos {
displayTs := memo.CreatedTs
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
displayTs = memo.UpdatedTs
}
stats[time.Unix(displayTs, 0).In(location).Format("2006-01-02")]++
}
response := &v1pb.GetUserMemosStatsResponse{
Stats: stats,
}
return response, nil
}
func (s *APIV1Service) ExportMemos(ctx context.Context, request *v1pb.ExportMemosRequest) (*v1pb.ExportMemosResponse, error) {
normalRowStatus := store.Normal
memoFind := &store.FindMemo{
RowStatus: &normalRowStatus,
// Exclude comments by default.
ExcludeComments: true,
}
if err := s.buildMemoFindWithFilter(ctx, memoFind, request.Filter); err != nil {
return nil, status.Errorf(codes.Internal, "failed to build find memos with filter: %v", err)
}
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
buf := new(bytes.Buffer)
writer := zip.NewWriter(buf)
for _, memo := range memos {
memoMessage, err := s.convertMemoFromStore(ctx, memo)
if err != nil {
return nil, errors.Wrap(err, "failed to convert memo")
}
file, err := writer.Create(time.Unix(memo.CreatedTs, 0).Format(time.RFC3339) + ".md")
if err != nil {
return nil, status.Errorf(codes.Internal, "Failed to create memo file")
}
_, err = file.Write([]byte(memoMessage.Content))
if err != nil {
return nil, status.Errorf(codes.Internal, "Failed to write to memo file")
}
}
if err := writer.Close(); err != nil {
return nil, status.Errorf(codes.Internal, "Failed to close zip file writer")
}
return &v1pb.ExportMemosResponse{
Content: buf.Bytes(),
}, nil
}
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo) (*v1pb.Memo, error) {
displayTs := memo.CreatedTs
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get workspace memo related setting")
}
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
displayTs = memo.UpdatedTs
}
creator, err := s.Store.GetUser(ctx, &store.FindUser{ID: &memo.CreatorID})
if err != nil {
return nil, errors.Wrap(err, "failed to get creator")
}
name := fmt.Sprintf("%s%d", MemoNamePrefix, memo.ID)
listMemoRelationsResponse, err := s.ListMemoRelations(ctx, &v1pb.ListMemoRelationsRequest{Name: name})
if err != nil {
return nil, errors.Wrap(err, "failed to list memo relations")
}
listMemoResourcesResponse, err := s.ListMemoResources(ctx, &v1pb.ListMemoResourcesRequest{Name: name})
if err != nil {
return nil, errors.Wrap(err, "failed to list memo resources")
}
listMemoReactionsResponse, err := s.ListMemoReactions(ctx, &v1pb.ListMemoReactionsRequest{Name: name})
if err != nil {
return nil, errors.Wrap(err, "failed to list memo reactions")
}
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
if err != nil {
return nil, errors.Wrap(err, "failed to parse content")
}
return &v1pb.Memo{
Name: name,
Uid: memo.UID,
RowStatus: convertRowStatusFromStore(memo.RowStatus),
Creator: fmt.Sprintf("%s%d", UserNamePrefix, creator.ID),
CreateTime: timestamppb.New(time.Unix(memo.CreatedTs, 0)),
UpdateTime: timestamppb.New(time.Unix(memo.UpdatedTs, 0)),
DisplayTime: timestamppb.New(time.Unix(displayTs, 0)),
Content: memo.Content,
Nodes: convertFromASTNodes(nodes),
Visibility: convertVisibilityFromStore(memo.Visibility),
Pinned: memo.Pinned,
ParentId: memo.ParentID,
Relations: listMemoRelationsResponse.Relations,
Resources: listMemoResourcesResponse.Resources,
Reactions: listMemoReactionsResponse.Reactions,
}, nil
}
func convertVisibilityFromStore(visibility store.Visibility) v1pb.Visibility {
switch visibility {
case store.Private:
return v1pb.Visibility_PRIVATE
case store.Protected:
return v1pb.Visibility_PROTECTED
case store.Public:
return v1pb.Visibility_PUBLIC
default:
return v1pb.Visibility_VISIBILITY_UNSPECIFIED
}
}
func convertVisibilityToStore(visibility v1pb.Visibility) store.Visibility {
switch visibility {
case v1pb.Visibility_PRIVATE:
return store.Private
case v1pb.Visibility_PROTECTED:
return store.Protected
case v1pb.Visibility_PUBLIC:
return store.Public
default:
return store.Private
}
}
func (s *APIV1Service) buildMemoFindWithFilter(ctx context.Context, find *store.FindMemo, filter string) error {
user, _ := getCurrentUser(ctx, s.Store)
if find == nil {
find = &store.FindMemo{}
}
if filter != "" {
filter, err := parseSearchMemosFilter(filter)
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
if len(filter.ContentSearch) > 0 {
find.ContentSearch = filter.ContentSearch
}
if len(filter.Visibilities) > 0 {
find.VisibilityList = filter.Visibilities
}
if filter.OrderByPinned {
find.OrderByPinned = filter.OrderByPinned
}
if filter.DisplayTimeAfter != nil {
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return status.Errorf(codes.Internal, "failed to get workspace memo related setting")
}
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
find.UpdatedTsAfter = filter.DisplayTimeAfter
} else {
find.CreatedTsAfter = filter.DisplayTimeAfter
}
}
if filter.DisplayTimeBefore != nil {
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return status.Errorf(codes.Internal, "failed to get workspace memo related setting")
}
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
find.UpdatedTsBefore = filter.DisplayTimeBefore
} else {
find.CreatedTsBefore = filter.DisplayTimeBefore
}
}
if filter.Creator != nil {
userID, err := ExtractUserIDFromName(*filter.Creator)
if err != nil {
return errors.Wrap(err, "invalid user name")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return status.Errorf(codes.NotFound, "user not found")
}
find.CreatorID = &user.ID
}
if filter.UID != nil {
find.UID = filter.UID
}
if filter.RowStatus != nil {
find.RowStatus = filter.RowStatus
}
if filter.Random {
find.Random = filter.Random
}
if filter.Limit != nil {
find.Limit = filter.Limit
}
if filter.IncludeComments {
find.ExcludeComments = false
}
}
// If the user is not authenticated, only public memos are visible.
if user == nil {
if filter == "" {
// If no filter is provided, return an error.
return status.Errorf(codes.InvalidArgument, "filter is required")
}
find.VisibilityList = []store.Visibility{store.Public}
} else if find.CreatorID != nil && *find.CreatorID != user.ID {
find.VisibilityList = []store.Visibility{store.Public, store.Protected}
}
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
if err != nil {
return status.Errorf(codes.Internal, "failed to get workspace memo related setting")
}
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
find.OrderByUpdatedTs = true
}
return nil
}
// SearchMemosFilterCELAttributes are the CEL attributes.
var SearchMemosFilterCELAttributes = []cel.EnvOption{
cel.Variable("content_search", cel.ListType(cel.StringType)),
cel.Variable("visibilities", cel.ListType(cel.StringType)),
cel.Variable("order_by_pinned", cel.BoolType),
cel.Variable("display_time_before", cel.IntType),
cel.Variable("display_time_after", cel.IntType),
cel.Variable("creator", cel.StringType),
cel.Variable("uid", cel.StringType),
cel.Variable("row_status", cel.StringType),
cel.Variable("random", cel.BoolType),
cel.Variable("limit", cel.IntType),
cel.Variable("include_comments", cel.BoolType),
}
type SearchMemosFilter struct {
ContentSearch []string
Visibilities []store.Visibility
OrderByPinned bool
DisplayTimeBefore *int64
DisplayTimeAfter *int64
Creator *string
UID *string
RowStatus *store.RowStatus
Random bool
Limit *int
IncludeComments bool
}
func parseSearchMemosFilter(expression string) (*SearchMemosFilter, error) {
e, err := cel.NewEnv(SearchMemosFilterCELAttributes...)
if err != nil {
return nil, err
}
ast, issues := e.Compile(expression)
if issues != nil {
return nil, errors.Errorf("found issue %v", issues)
}
filter := &SearchMemosFilter{}
expr, err := cel.AstToParsedExpr(ast)
if err != nil {
return nil, err
}
callExpr := expr.GetExpr().GetCallExpr()
findSearchMemosField(callExpr, filter)
return filter, nil
}
func findSearchMemosField(callExpr *expr.Expr_Call, filter *SearchMemosFilter) {
if len(callExpr.Args) == 2 {
idExpr := callExpr.Args[0].GetIdentExpr()
if idExpr != nil {
if idExpr.Name == "content_search" {
contentSearch := []string{}
for _, expr := range callExpr.Args[1].GetListExpr().GetElements() {
value := expr.GetConstExpr().GetStringValue()
contentSearch = append(contentSearch, value)
}
filter.ContentSearch = contentSearch
} else if idExpr.Name == "visibilities" {
visibilities := []store.Visibility{}
for _, expr := range callExpr.Args[1].GetListExpr().GetElements() {
value := expr.GetConstExpr().GetStringValue()
visibilities = append(visibilities, store.Visibility(value))
}
filter.Visibilities = visibilities
} else if idExpr.Name == "order_by_pinned" {
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
filter.OrderByPinned = value
} else if idExpr.Name == "display_time_before" {
displayTimeBefore := callExpr.Args[1].GetConstExpr().GetInt64Value()
filter.DisplayTimeBefore = &displayTimeBefore
} else if idExpr.Name == "display_time_after" {
displayTimeAfter := callExpr.Args[1].GetConstExpr().GetInt64Value()
filter.DisplayTimeAfter = &displayTimeAfter
} else if idExpr.Name == "creator" {
creator := callExpr.Args[1].GetConstExpr().GetStringValue()
filter.Creator = &creator
} else if idExpr.Name == "uid" {
uid := callExpr.Args[1].GetConstExpr().GetStringValue()
filter.UID = &uid
} else if idExpr.Name == "row_status" {
rowStatus := store.RowStatus(callExpr.Args[1].GetConstExpr().GetStringValue())
filter.RowStatus = &rowStatus
} else if idExpr.Name == "random" {
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
filter.Random = value
} else if idExpr.Name == "limit" {
limit := int(callExpr.Args[1].GetConstExpr().GetInt64Value())
filter.Limit = &limit
} else if idExpr.Name == "include_comments" {
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
filter.IncludeComments = value
}
return
}
}
for _, arg := range callExpr.Args {
callExpr := arg.GetCallExpr()
if callExpr != nil {
findSearchMemosField(callExpr, filter)
}
}
}
// DispatchMemoCreatedWebhook dispatches webhook when memo is created.
func (s *APIV1Service) DispatchMemoCreatedWebhook(ctx context.Context, memo *v1pb.Memo) error {
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.created")
}
// DispatchMemoUpdatedWebhook dispatches webhook when memo is updated.
func (s *APIV1Service) DispatchMemoUpdatedWebhook(ctx context.Context, memo *v1pb.Memo) error {
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.updated")
}
// DispatchMemoDeletedWebhook dispatches webhook when memo is deleted.
func (s *APIV1Service) DispatchMemoDeletedWebhook(ctx context.Context, memo *v1pb.Memo) error {
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.deleted")
}
func (s *APIV1Service) dispatchMemoRelatedWebhook(ctx context.Context, memo *v1pb.Memo, activityType string) error {
creatorID, err := ExtractUserIDFromName(memo.Creator)
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid memo creator")
}
webhooks, err := s.Store.ListWebhooks(ctx, &store.FindWebhook{
CreatorID: &creatorID,
})
if err != nil {
return err
}
for _, hook := range webhooks {
payload, err := convertMemoToWebhookPayload(memo)
if err != nil {
return errors.Wrap(err, "failed to convert memo to webhook payload")
}
payload.ActivityType = activityType
payload.URL = hook.URL
if err := webhook.Post(*payload); err != nil {
return errors.Wrap(err, "failed to post webhook")
}
}
return nil
}
func convertMemoToWebhookPayload(memo *v1pb.Memo) (*webhook.WebhookPayload, error) {
creatorID, err := ExtractUserIDFromName(memo.Creator)
if err != nil {
return nil, errors.Wrap(err, "invalid memo creator")
}
id, err := ExtractMemoIDFromName(memo.Name)
if err != nil {
return nil, errors.Wrap(err, "invalid memo name")
}
return &webhook.WebhookPayload{
CreatorID: creatorID,
CreatedTs: time.Now().Unix(),
Memo: &webhook.Memo{
ID: id,
CreatorID: creatorID,
CreatedTs: memo.CreateTime.Seconds,
UpdatedTs: memo.UpdateTime.Seconds,
Content: memo.Content,
Visibility: memo.Visibility.String(),
Pinned: memo.Pinned,
ResourceList: func() []*webhook.Resource {
resources := []*webhook.Resource{}
for _, resource := range memo.Resources {
resources = append(resources, &webhook.Resource{
UID: resource.Uid,
Filename: resource.Filename,
ExternalLink: resource.ExternalLink,
Type: resource.Type,
Size: resource.Size,
})
}
return resources
}(),
},
}, nil
}

View File

@@ -0,0 +1,81 @@
package v1
import (
"context"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) ListMemoReactions(ctx context.Context, request *v1pb.ListMemoReactionsRequest) (*v1pb.ListMemoReactionsResponse, error) {
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
ContentID: &request.Name,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
response := &v1pb.ListMemoReactionsResponse{
Reactions: []*v1pb.Reaction{},
}
for _, reaction := range reactions {
reactionMessage, err := s.convertReactionFromStore(ctx, reaction)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert reaction")
}
response.Reactions = append(response.Reactions, reactionMessage)
}
return response, nil
}
func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.UpsertMemoReactionRequest) (*v1pb.Reaction, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
reaction, err := s.Store.UpsertReaction(ctx, &store.Reaction{
CreatorID: user.ID,
ContentID: request.Reaction.ContentId,
ReactionType: storepb.ReactionType(request.Reaction.ReactionType),
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert reaction")
}
reactionMessage, err := s.convertReactionFromStore(ctx, reaction)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert reaction")
}
return reactionMessage, nil
}
func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.DeleteMemoReactionRequest) (*emptypb.Empty, error) {
if err := s.Store.DeleteReaction(ctx, &store.DeleteReaction{
ID: request.ReactionId,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete reaction")
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) convertReactionFromStore(ctx context.Context, reaction *store.Reaction) (*v1pb.Reaction, error) {
creator, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &reaction.CreatorID,
})
if err != nil {
return nil, err
}
return &v1pb.Reaction{
Id: reaction.ID,
Creator: fmt.Sprintf("%s%d", UserNamePrefix, creator.ID),
ContentId: reaction.ContentID,
ReactionType: v1pb.Reaction_Type(reaction.ReactionType),
}, nil
}

View File

@@ -0,0 +1,125 @@
package v1
import (
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/util"
)
const (
WorkspaceSettingNamePrefix = "settings/"
UserNamePrefix = "users/"
MemoNamePrefix = "memos/"
ResourceNamePrefix = "resources/"
InboxNamePrefix = "inboxes/"
StorageNamePrefix = "storages/"
IdentityProviderNamePrefix = "identityProviders/"
)
// GetNameParentTokens returns the tokens from a resource name.
func GetNameParentTokens(name string, tokenPrefixes ...string) ([]string, error) {
parts := strings.Split(name, "/")
if len(parts) != 2*len(tokenPrefixes) {
return nil, errors.Errorf("invalid request %q", name)
}
var tokens []string
for i, tokenPrefix := range tokenPrefixes {
if fmt.Sprintf("%s/", parts[2*i]) != tokenPrefix {
return nil, errors.Errorf("invalid prefix %q in request %q", tokenPrefix, name)
}
if parts[2*i+1] == "" {
return nil, errors.Errorf("invalid request %q with empty prefix %q", name, tokenPrefix)
}
tokens = append(tokens, parts[2*i+1])
}
return tokens, nil
}
func ExtractWorkspaceSettingKeyFromName(name string) (string, error) {
tokens, err := GetNameParentTokens(name, WorkspaceSettingNamePrefix)
if err != nil {
return "", err
}
return tokens[0], nil
}
// ExtractUserIDFromName returns the uid from a resource name.
func ExtractUserIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, UserNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid user ID %q", tokens[0])
}
return id, nil
}
// ExtractMemoIDFromName returns the memo ID from a resource name.
func ExtractMemoIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, MemoNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid memo ID %q", tokens[0])
}
return id, nil
}
// ExtractResourceIDFromName returns the resource ID from a resource name.
func ExtractResourceIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, ResourceNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid resource ID %q", tokens[0])
}
return id, nil
}
// ExtractInboxIDFromName returns the inbox ID from a resource name.
func ExtractInboxIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, InboxNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid inbox ID %q", tokens[0])
}
return id, nil
}
// ExtractStorageIDFromName returns the storage ID from a resource name.
func ExtractStorageIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, StorageNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid storage ID %q", tokens[0])
}
return id, nil
}
func ExtractIdentityProviderIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, IdentityProviderNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid identity provider ID %q", tokens[0])
}
return id, nil
}

View File

@@ -0,0 +1,466 @@
package v1
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/google/cel-go/cel"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors"
expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"google.golang.org/genproto/googleapis/api/httpbody"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/storage/s3"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
const (
// The upload memory buffer is 32 MiB.
// It should be kept low, so RAM usage doesn't get out of control.
// This is unrelated to maximum upload size limit, which is now set through system setting.
MaxUploadBufferSizeBytes = 32 << 20
MebiByte = 1024 * 1024
)
func (s *APIV1Service) CreateResource(ctx context.Context, request *v1pb.CreateResourceRequest) (*v1pb.Resource, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
create := &store.Resource{
UID: shortuuid.New(),
CreatorID: user.ID,
Filename: request.Resource.Filename,
Type: request.Resource.Type,
}
if request.Resource.ExternalLink != "" {
// Only allow those external links scheme with http/https
linkURL, err := url.Parse(request.Resource.ExternalLink)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid external link: %v", err)
}
if linkURL.Scheme != "http" && linkURL.Scheme != "https" {
return nil, status.Errorf(codes.InvalidArgument, "invalid external link scheme: %v", linkURL.Scheme)
}
create.ExternalLink = request.Resource.ExternalLink
} else {
workspaceStorageSetting, err := s.Store.GetWorkspaceStorageSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace storage setting: %v", err)
}
size := binary.Size(request.Resource.Content)
uploadSizeLimit := int(workspaceStorageSetting.UploadSizeLimitMb) * MebiByte
if uploadSizeLimit == 0 {
uploadSizeLimit = MaxUploadBufferSizeBytes
}
if size > uploadSizeLimit {
return nil, status.Errorf(codes.InvalidArgument, "file size exceeds the limit")
}
create.Size = int64(size)
create.Blob = request.Resource.Content
if err := SaveResourceBlob(ctx, s.Store, create); err != nil {
return nil, status.Errorf(codes.Internal, "failed to save resource blob: %v", err)
}
}
if request.Resource.Memo != nil {
memoID, err := ExtractMemoIDFromName(*request.Resource.Memo)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo id: %v", err)
}
create.MemoID = &memoID
}
resource, err := s.Store.CreateResource(ctx, create)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create resource: %v", err)
}
return s.convertResourceFromStore(ctx, resource), nil
}
func (s *APIV1Service) ListResources(ctx context.Context, _ *v1pb.ListResourcesRequest) (*v1pb.ListResourcesResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
resources, err := s.Store.ListResources(ctx, &store.FindResource{
CreatorID: &user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list resources: %v", err)
}
response := &v1pb.ListResourcesResponse{}
for _, resource := range resources {
response.Resources = append(response.Resources, s.convertResourceFromStore(ctx, resource))
}
return response, nil
}
func (s *APIV1Service) SearchResources(ctx context.Context, request *v1pb.SearchResourcesRequest) (*v1pb.SearchResourcesResponse, error) {
if request.Filter == "" {
return nil, status.Errorf(codes.InvalidArgument, "filter is empty")
}
filter, err := parseSearchResourcesFilter(request.Filter)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse filter: %v", err)
}
resourceFind := &store.FindResource{}
if filter.UID != nil {
resourceFind.UID = filter.UID
}
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
resourceFind.CreatorID = &user.ID
resources, err := s.Store.ListResources(ctx, resourceFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to search resources: %v", err)
}
response := &v1pb.SearchResourcesResponse{}
for _, resource := range resources {
response.Resources = append(response.Resources, s.convertResourceFromStore(ctx, resource))
}
return response, nil
}
func (s *APIV1Service) GetResource(ctx context.Context, request *v1pb.GetResourceRequest) (*v1pb.Resource, error) {
id, err := ExtractResourceIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid resource id: %v", err)
}
resource, err := s.Store.GetResource(ctx, &store.FindResource{
ID: &id,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get resource: %v", err)
}
if resource == nil {
return nil, status.Errorf(codes.NotFound, "resource not found")
}
return s.convertResourceFromStore(ctx, resource), nil
}
func (s *APIV1Service) GetResourceBinary(ctx context.Context, request *v1pb.GetResourceBinaryRequest) (*httpbody.HttpBody, error) {
resourceFind := &store.FindResource{
GetBlob: true,
}
if request.Name != "" {
id, err := ExtractResourceIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid resource id: %v", err)
}
resourceFind.ID = &id
}
if request.Uid != "" {
resourceFind.UID = &request.Uid
}
resource, err := s.Store.GetResource(ctx, resourceFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get resource: %v", err)
}
if resource == nil {
return nil, status.Errorf(codes.NotFound, "resource not found")
}
// Check the related memo visibility.
if resource.MemoID != nil {
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
ID: resource.MemoID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to find memo by ID: %v", resource.MemoID)
}
if memo != nil && memo.Visibility != store.Public {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if memo.Visibility == store.Private && user.ID != resource.CreatorID {
return nil, status.Errorf(codes.Unauthenticated, "unauthorized access")
}
}
}
blob := resource.Blob
if resource.InternalPath != "" {
resourcePath := filepath.FromSlash(resource.InternalPath)
if !filepath.IsAbs(resourcePath) {
resourcePath = filepath.Join(s.Profile.Data, resourcePath)
}
file, err := os.Open(resourcePath)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to open the file: %v", err)
}
defer file.Close()
blob, err = io.ReadAll(file)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to read the file: %v", err)
}
}
httpBody := &httpbody.HttpBody{
ContentType: resource.Type,
Data: blob,
}
return httpBody, nil
}
func (s *APIV1Service) UpdateResource(ctx context.Context, request *v1pb.UpdateResourceRequest) (*v1pb.Resource, error) {
id, err := ExtractResourceIDFromName(request.Resource.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid resource id: %v", err)
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
}
currentTs := time.Now().Unix()
update := &store.UpdateResource{
ID: id,
UpdatedTs: &currentTs,
}
for _, field := range request.UpdateMask.Paths {
if field == "filename" {
update.Filename = &request.Resource.Filename
} else if field == "memo" {
if request.Resource.Memo == nil {
return nil, status.Errorf(codes.InvalidArgument, "memo is required")
}
memoID, err := ExtractMemoIDFromName(*request.Resource.Memo)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo id: %v", err)
}
update.MemoID = &memoID
}
}
resource, err := s.Store.UpdateResource(ctx, update)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update resource: %v", err)
}
return s.convertResourceFromStore(ctx, resource), nil
}
func (s *APIV1Service) DeleteResource(ctx context.Context, request *v1pb.DeleteResourceRequest) (*emptypb.Empty, error) {
id, err := ExtractResourceIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid resource id: %v", err)
}
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
resource, err := s.Store.GetResource(ctx, &store.FindResource{
ID: &id,
CreatorID: &user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to find resource: %v", err)
}
if resource == nil {
return nil, status.Errorf(codes.NotFound, "resource not found")
}
// Delete the resource from the database.
if err := s.Store.DeleteResource(ctx, &store.DeleteResource{
ID: resource.ID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete resource: %v", err)
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) convertResourceFromStore(ctx context.Context, resource *store.Resource) *v1pb.Resource {
resourceMessage := &v1pb.Resource{
Name: fmt.Sprintf("%s%d", ResourceNamePrefix, resource.ID),
Uid: resource.UID,
CreateTime: timestamppb.New(time.Unix(resource.CreatedTs, 0)),
Filename: resource.Filename,
ExternalLink: resource.ExternalLink,
Type: resource.Type,
Size: resource.Size,
}
if resource.MemoID != nil {
memo, _ := s.Store.GetMemo(ctx, &store.FindMemo{
ID: resource.MemoID,
})
if memo != nil {
memoName := fmt.Sprintf("%s%d", MemoNamePrefix, memo.ID)
resourceMessage.Memo = &memoName
}
}
return resourceMessage
}
// SaveResourceBlob save the blob of resource based on the storage config.
func SaveResourceBlob(ctx context.Context, s *store.Store, create *store.Resource) error {
workspaceStorageSetting, err := s.GetWorkspaceStorageSetting(ctx)
if err != nil {
return errors.Wrap(err, "Failed to find workspace storage setting")
}
if workspaceStorageSetting.StorageType == storepb.WorkspaceStorageSetting_STORAGE_TYPE_LOCAL {
filepathTemplate := "assets/{timestamp}_{filename}"
if workspaceStorageSetting.FilepathTemplate != "" {
filepathTemplate = workspaceStorageSetting.FilepathTemplate
}
internalPath := filepathTemplate
if !strings.Contains(internalPath, "{filename}") {
internalPath = filepath.Join(internalPath, "{filename}")
}
internalPath = replacePathTemplate(internalPath, create.Filename)
internalPath = filepath.ToSlash(internalPath)
// Ensure the directory exists.
osPath := filepath.FromSlash(internalPath)
if !filepath.IsAbs(osPath) {
osPath = filepath.Join(s.Profile.Data, osPath)
}
dir := filepath.Dir(osPath)
if err = os.MkdirAll(dir, os.ModePerm); err != nil {
return errors.Wrap(err, "Failed to create directory")
}
dst, err := os.Create(osPath)
if err != nil {
return errors.Wrap(err, "Failed to create file")
}
defer dst.Close()
// Write the blob to the file.
if err := os.WriteFile(osPath, create.Blob, 0644); err != nil {
return errors.Wrap(err, "Failed to write file")
}
create.InternalPath = internalPath
create.Blob = nil
} else if workspaceStorageSetting.StorageType == storepb.WorkspaceStorageSetting_STORAGE_TYPE_S3 {
s3Config := workspaceStorageSetting.S3Config
if s3Config == nil {
return errors.Errorf("No actived external storage found")
}
s3Client, err := s3.NewClient(ctx, &s3.Config{
AccessKeyID: s3Config.AccessKeyId,
AcesssKeySecret: s3Config.AccessKeySecret,
Endpoint: s3Config.Endpoint,
Region: s3Config.Region,
Bucket: s3Config.Bucket,
})
if err != nil {
return errors.Wrap(err, "Failed to create s3 client")
}
filepathTemplate := workspaceStorageSetting.FilepathTemplate
if !strings.Contains(filepathTemplate, "{filename}") {
filepathTemplate = filepath.Join(filepathTemplate, "{filename}")
}
filepathTemplate = replacePathTemplate(filepathTemplate, create.Filename)
r := bytes.NewReader(create.Blob)
link, err := s3Client.UploadFile(ctx, filepathTemplate, create.Type, r)
if err != nil {
return errors.Wrap(err, "Failed to upload via s3 client")
}
create.ExternalLink = link
create.Blob = nil
}
return nil
}
var fileKeyPattern = regexp.MustCompile(`\{[a-z]{1,9}\}`)
func replacePathTemplate(path, filename string) string {
t := time.Now()
path = fileKeyPattern.ReplaceAllStringFunc(path, func(s string) string {
switch s {
case "{filename}":
return filename
case "{timestamp}":
return fmt.Sprintf("%d", t.Unix())
case "{year}":
return fmt.Sprintf("%d", t.Year())
case "{month}":
return fmt.Sprintf("%02d", t.Month())
case "{day}":
return fmt.Sprintf("%02d", t.Day())
case "{hour}":
return fmt.Sprintf("%02d", t.Hour())
case "{minute}":
return fmt.Sprintf("%02d", t.Minute())
case "{second}":
return fmt.Sprintf("%02d", t.Second())
case "{uuid}":
return util.GenUUID()
}
return s
})
return path
}
// SearchResourcesFilterCELAttributes are the CEL attributes for SearchResourcesFilter.
var SearchResourcesFilterCELAttributes = []cel.EnvOption{
cel.Variable("uid", cel.StringType),
}
type SearchResourcesFilter struct {
UID *string
}
func parseSearchResourcesFilter(expression string) (*SearchResourcesFilter, error) {
e, err := cel.NewEnv(SearchResourcesFilterCELAttributes...)
if err != nil {
return nil, err
}
ast, issues := e.Compile(expression)
if issues != nil {
return nil, errors.Errorf("found issue %v", issues)
}
filter := &SearchResourcesFilter{}
expr, err := cel.AstToParsedExpr(ast)
if err != nil {
return nil, err
}
callExpr := expr.GetExpr().GetCallExpr()
findSearchResourcesField(callExpr, filter)
return filter, nil
}
func findSearchResourcesField(callExpr *expr.Expr_Call, filter *SearchResourcesFilter) {
if len(callExpr.Args) == 2 {
idExpr := callExpr.Args[0].GetIdentExpr()
if idExpr != nil {
if idExpr.Name == "uid" {
uid := callExpr.Args[1].GetConstExpr().GetStringValue()
filter.UID = &uid
}
return
}
}
for _, arg := range callExpr.Args {
callExpr := arg.GetCallExpr()
if callExpr != nil {
findSearchResourcesField(callExpr, filter)
}
}
}

View File

@@ -0,0 +1,260 @@
package v1
import (
"context"
"fmt"
"slices"
"sort"
"github.com/pkg/errors"
"github.com/yourselfhosted/gomark/ast"
"github.com/yourselfhosted/gomark/parser"
"github.com/yourselfhosted/gomark/parser/tokenizer"
"github.com/yourselfhosted/gomark/restore"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) UpsertTag(ctx context.Context, request *v1pb.UpsertTagRequest) (*v1pb.Tag, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
tag, err := s.Store.UpsertTag(ctx, &store.Tag{
Name: request.Name,
CreatorID: user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert tag: %v", err)
}
tagMessage, err := s.convertTagFromStore(ctx, tag)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert tag: %v", err)
}
return tagMessage, nil
}
func (s *APIV1Service) BatchUpsertTag(ctx context.Context, request *v1pb.BatchUpsertTagRequest) (*emptypb.Empty, error) {
for _, r := range request.Requests {
if _, err := s.UpsertTag(ctx, r); err != nil {
return nil, status.Errorf(codes.Internal, "failed to batch upsert tags: %v", err)
}
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) ListTags(ctx context.Context, _ *v1pb.ListTagsRequest) (*v1pb.ListTagsResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user")
}
if user == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
tagFind := &store.FindTag{
CreatorID: user.ID,
}
tags, err := s.Store.ListTags(ctx, tagFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list tags: %v", err)
}
response := &v1pb.ListTagsResponse{}
for _, tag := range tags {
t, err := s.convertTagFromStore(ctx, tag)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert tag: %v", err)
}
response.Tags = append(response.Tags, t)
}
return response, nil
}
func (s *APIV1Service) RenameTag(ctx context.Context, request *v1pb.RenameTagRequest) (*emptypb.Empty, error) {
userID, err := ExtractUserIDFromName(request.User)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
// Find all related memos.
memos, err := s.Store.ListMemos(ctx, &store.FindMemo{
CreatorID: &user.ID,
ContentSearch: []string{fmt.Sprintf("#%s", request.OldName)},
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
// Replace tag name in memo content.
for _, memo := range memos {
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to parse memo: %v", err)
}
TraverseASTNodes(nodes, func(node ast.Node) {
if tag, ok := node.(*ast.Tag); ok && tag.Content == request.OldName {
tag.Content = request.NewName
}
})
content := restore.Restore(nodes)
if err := s.Store.UpdateMemo(ctx, &store.UpdateMemo{
ID: memo.ID,
Content: &content,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update memo: %v", err)
}
}
// Delete old tag and create new tag.
if err := s.Store.DeleteTag(ctx, &store.DeleteTag{
CreatorID: user.ID,
Name: request.OldName,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete tag: %v", err)
}
if _, err := s.Store.UpsertTag(ctx, &store.Tag{
CreatorID: user.ID,
Name: request.NewName,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert tag: %v", err)
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) DeleteTag(ctx context.Context, request *v1pb.DeleteTagRequest) (*emptypb.Empty, error) {
userID, err := ExtractUserIDFromName(request.Tag.Creator)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
if err := s.Store.DeleteTag(ctx, &store.DeleteTag{
Name: request.Tag.Name,
CreatorID: user.ID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete tag: %v", err)
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) GetTagSuggestions(ctx context.Context, request *v1pb.GetTagSuggestionsRequest) (*v1pb.GetTagSuggestionsResponse, error) {
userID, err := ExtractUserIDFromName(request.User)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
normalRowStatus := store.Normal
memoFind := &store.FindMemo{
CreatorID: &user.ID,
ContentSearch: []string{"#"},
RowStatus: &normalRowStatus,
}
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
}
tagList, err := s.Store.ListTags(ctx, &store.FindTag{
CreatorID: user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list tags: %v", err)
}
tagNameList := []string{}
for _, tag := range tagList {
tagNameList = append(tagNameList, tag.Name)
}
tagMapSet := make(map[string]bool)
for _, memo := range memos {
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
if err != nil {
return nil, errors.Wrap(err, "failed to parse memo content")
}
// Dynamically upsert tags from memo content.
TraverseASTNodes(nodes, func(node ast.Node) {
if tagNode, ok := node.(*ast.Tag); ok {
tag := tagNode.Content
if !slices.Contains(tagNameList, tag) {
tagMapSet[tag] = true
}
}
})
}
suggestions := []string{}
for tag := range tagMapSet {
suggestions = append(suggestions, tag)
}
sort.Strings(suggestions)
return &v1pb.GetTagSuggestionsResponse{
Tags: suggestions,
}, nil
}
func (s *APIV1Service) convertTagFromStore(ctx context.Context, tag *store.Tag) (*v1pb.Tag, error) {
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &tag.CreatorID,
})
if err != nil {
return nil, errors.Wrap(err, "failed to get user")
}
return &v1pb.Tag{
Name: tag.Name,
Creator: fmt.Sprintf("%s%d", UserNamePrefix, user.ID),
}, nil
}
func TraverseASTNodes(nodes []ast.Node, fn func(ast.Node)) {
for _, node := range nodes {
fn(node)
switch n := node.(type) {
case *ast.Paragraph:
TraverseASTNodes(n.Children, fn)
case *ast.Heading:
TraverseASTNodes(n.Children, fn)
case *ast.Blockquote:
TraverseASTNodes(n.Children, fn)
case *ast.OrderedList:
TraverseASTNodes(n.Children, fn)
case *ast.UnorderedList:
TraverseASTNodes(n.Children, fn)
case *ast.TaskList:
TraverseASTNodes(n.Children, fn)
case *ast.Bold:
TraverseASTNodes(n.Children, fn)
}
}
}

View File

@@ -0,0 +1,620 @@
package v1
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"regexp"
"slices"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/cel-go/cel"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"google.golang.org/genproto/googleapis/api/httpbody"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/usememos/memos/internal/util"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) ListUsers(ctx context.Context, _ *v1pb.ListUsersRequest) (*v1pb.ListUsersResponse, error) {
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
users, err := s.Store.ListUsers(ctx, &store.FindUser{})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list users: %v", err)
}
response := &v1pb.ListUsersResponse{
Users: []*v1pb.User{},
}
for _, user := range users {
response.Users = append(response.Users, convertUserFromStore(user))
}
return response, nil
}
func (s *APIV1Service) SearchUsers(ctx context.Context, request *v1pb.SearchUsersRequest) (*v1pb.SearchUsersResponse, error) {
if request.Filter == "" {
return nil, status.Errorf(codes.InvalidArgument, "filter is empty")
}
filter, err := parseSearchUsersFilter(request.Filter)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse filter: %v", err)
}
userFind := &store.FindUser{}
if filter.Username != nil {
userFind.Username = filter.Username
}
if filter.Random {
userFind.Random = true
}
if filter.Limit != nil {
userFind.Limit = filter.Limit
}
users, err := s.Store.ListUsers(ctx, userFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to search users: %v", err)
}
response := &v1pb.SearchUsersResponse{
Users: []*v1pb.User{},
}
for _, user := range users {
response.Users = append(response.Users, convertUserFromStore(user))
}
return response, nil
}
func (s *APIV1Service) GetUser(ctx context.Context, request *v1pb.GetUserRequest) (*v1pb.User, error) {
userID, err := ExtractUserIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
return convertUserFromStore(user), nil
}
func (s *APIV1Service) GetUserAvatarBinary(ctx context.Context, request *v1pb.GetUserAvatarBinaryRequest) (*httpbody.HttpBody, error) {
userID, err := ExtractUserIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
if user.AvatarURL == "" {
return nil, status.Errorf(codes.NotFound, "avatar not found")
}
imageType, base64Data, err := extractImageInfo(user.AvatarURL)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to extract image info: %v", err)
}
imageData, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to decode string: %v", err)
}
httpBody := &httpbody.HttpBody{
ContentType: imageType,
Data: imageData,
}
return httpBody, nil
}
func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserRequest) (*v1pb.User, error) {
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if !util.UIDMatcher.MatchString(strings.ToLower(request.User.Username)) {
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", request.User.Username)
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost)
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to generate password hash").SetInternal(err)
}
user, err := s.Store.CreateUser(ctx, &store.User{
Username: request.User.Username,
Role: convertUserRoleToStore(request.User.Role),
Email: request.User.Email,
Nickname: request.User.Nickname,
PasswordHash: string(passwordHash),
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create user: %v", err)
}
return convertUserFromStore(user), nil
}
func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserRequest) (*v1pb.User, error) {
userID, err := ExtractUserIDFromName(request.User.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is empty")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
currentTs := time.Now().Unix()
update := &store.UpdateUser{
ID: user.ID,
UpdatedTs: &currentTs,
}
for _, field := range request.UpdateMask.Paths {
if field == "username" {
if !util.UIDMatcher.MatchString(strings.ToLower(request.User.Username)) {
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", request.User.Username)
}
update.Username = &request.User.Username
} else if field == "nickname" {
update.Nickname = &request.User.Nickname
} else if field == "email" {
update.Email = &request.User.Email
} else if field == "avatar_url" {
update.AvatarURL = &request.User.AvatarUrl
} else if field == "description" {
update.Description = &request.User.Description
} else if field == "role" {
role := convertUserRoleToStore(request.User.Role)
update.Role = &role
} else if field == "password" {
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost)
if err != nil {
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to generate password hash").SetInternal(err)
}
passwordHashStr := string(passwordHash)
update.PasswordHash = &passwordHashStr
} else if field == "row_status" {
rowStatus := convertRowStatusToStore(request.User.RowStatus)
update.RowStatus = &rowStatus
} else {
return nil, status.Errorf(codes.InvalidArgument, "invalid update path: %s", field)
}
}
updatedUser, err := s.Store.UpdateUser(ctx, update)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update user: %v", err)
}
return convertUserFromStore(updatedUser), nil
}
func (s *APIV1Service) DeleteUser(ctx context.Context, request *v1pb.DeleteUserRequest) (*emptypb.Empty, error) {
userID, err := ExtractUserIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
}
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
if err := s.Store.DeleteUser(ctx, &store.DeleteUser{
ID: user.ID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete user: %v", err)
}
return &emptypb.Empty{}, nil
}
func getDefaultUserSetting() *v1pb.UserSetting {
return &v1pb.UserSetting{
Locale: "en",
Appearance: "system",
MemoVisibility: "PRIVATE",
}
}
func (s *APIV1Service) GetUserSetting(ctx context.Context, _ *v1pb.GetUserSettingRequest) (*v1pb.UserSetting, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
userSettings, err := s.Store.ListUserSettings(ctx, &store.FindUserSetting{
UserID: &user.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list user settings: %v", err)
}
userSettingMessage := getDefaultUserSetting()
for _, setting := range userSettings {
if setting.Key == storepb.UserSettingKey_USER_SETTING_LOCALE {
userSettingMessage.Locale = setting.GetLocale()
} else if setting.Key == storepb.UserSettingKey_USER_SETTING_APPEARANCE {
userSettingMessage.Appearance = setting.GetAppearance()
} else if setting.Key == storepb.UserSettingKey_USER_SETTING_MEMO_VISIBILITY {
userSettingMessage.MemoVisibility = setting.GetMemoVisibility()
}
}
return userSettingMessage, nil
}
func (s *APIV1Service) UpdateUserSetting(ctx context.Context, request *v1pb.UpdateUserSettingRequest) (*v1pb.UserSetting, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update mask is empty")
}
for _, field := range request.UpdateMask.Paths {
if field == "locale" {
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_LOCALE,
Value: &storepb.UserSetting_Locale{
Locale: request.Setting.Locale,
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
} else if field == "appearance" {
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_APPEARANCE,
Value: &storepb.UserSetting_Appearance{
Appearance: request.Setting.Appearance,
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
} else if field == "memo_visibility" {
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_MEMO_VISIBILITY,
Value: &storepb.UserSetting_MemoVisibility{
MemoVisibility: request.Setting.MemoVisibility,
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
} else {
return nil, status.Errorf(codes.InvalidArgument, "invalid update path: %s", field)
}
}
return s.GetUserSetting(ctx, &v1pb.GetUserSettingRequest{})
}
func (s *APIV1Service) ListUserAccessTokens(ctx context.Context, _ *v1pb.ListUserAccessTokensRequest) (*v1pb.ListUserAccessTokensResponse, error) {
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, currentUser.ID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
}
accessTokens := []*v1pb.UserAccessToken{}
for _, userAccessToken := range userAccessTokens {
claims := &ClaimsMessage{}
_, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(s.Secret), nil
}
}
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
// If the access token is invalid or expired, just ignore it.
continue
}
userAccessToken := &v1pb.UserAccessToken{
AccessToken: userAccessToken.AccessToken,
Description: userAccessToken.Description,
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
}
if claims.ExpiresAt != nil {
userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
}
accessTokens = append(accessTokens, userAccessToken)
}
// Sort by issued time in descending order.
slices.SortFunc(accessTokens, func(i, j *v1pb.UserAccessToken) int {
return int(i.IssuedAt.Seconds - j.IssuedAt.Seconds)
})
response := &v1pb.ListUserAccessTokensResponse{
AccessTokens: accessTokens,
}
return response, nil
}
func (s *APIV1Service) CreateUserAccessToken(ctx context.Context, request *v1pb.CreateUserAccessTokenRequest) (*v1pb.UserAccessToken, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
expiresAt := time.Time{}
if request.ExpiresAt != nil {
expiresAt = request.ExpiresAt.AsTime()
}
accessToken, err := GenerateAccessToken(user.Username, user.ID, expiresAt, []byte(s.Secret))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
}
claims := &ClaimsMessage{}
_, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(s.Secret), nil
}
}
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to parse access token: %v", err)
}
// Upsert the access token to user setting store.
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken, request.Description); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert access token to store: %v", err)
}
userAccessToken := &v1pb.UserAccessToken{
AccessToken: accessToken,
Description: request.Description,
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
}
if claims.ExpiresAt != nil {
userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
}
return userAccessToken, nil
}
func (s *APIV1Service) DeleteUserAccessToken(ctx context.Context, request *v1pb.DeleteUserAccessTokenRequest) (*emptypb.Empty, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
}
updatedUserAccessTokens := []*storepb.AccessTokensUserSetting_AccessToken{}
for _, userAccessToken := range userAccessTokens {
if userAccessToken.AccessToken == request.AccessToken {
continue
}
updatedUserAccessTokens = append(updatedUserAccessTokens, userAccessToken)
}
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
Value: &storepb.UserSetting_AccessTokens{
AccessTokens: &storepb.AccessTokensUserSetting{
AccessTokens: updatedUserAccessTokens,
},
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) UpsertAccessTokenToStore(ctx context.Context, user *store.User, accessToken, description string) error {
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return errors.Wrap(err, "failed to get user access tokens")
}
userAccessToken := storepb.AccessTokensUserSetting_AccessToken{
AccessToken: accessToken,
Description: description,
}
userAccessTokens = append(userAccessTokens, &userAccessToken)
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
Value: &storepb.UserSetting_AccessTokens{
AccessTokens: &storepb.AccessTokensUserSetting{
AccessTokens: userAccessTokens,
},
},
}); err != nil {
return errors.Wrap(err, "failed to upsert user setting")
}
return nil
}
func convertUserFromStore(user *store.User) *v1pb.User {
userpb := &v1pb.User{
Name: fmt.Sprintf("%s%d", UserNamePrefix, user.ID),
Id: user.ID,
RowStatus: convertRowStatusFromStore(user.RowStatus),
CreateTime: timestamppb.New(time.Unix(user.CreatedTs, 0)),
UpdateTime: timestamppb.New(time.Unix(user.UpdatedTs, 0)),
Role: convertUserRoleFromStore(user.Role),
Username: user.Username,
Email: user.Email,
Nickname: user.Nickname,
AvatarUrl: user.AvatarURL,
Description: user.Description,
}
// Use the avatar URL instead of raw base64 image data to reduce the response size.
if user.AvatarURL != "" {
userpb.AvatarUrl = fmt.Sprintf("/o/%s/avatar", userpb.Name)
}
return userpb
}
func convertUserRoleFromStore(role store.Role) v1pb.User_Role {
switch role {
case store.RoleHost:
return v1pb.User_HOST
case store.RoleAdmin:
return v1pb.User_ADMIN
case store.RoleUser:
return v1pb.User_USER
default:
return v1pb.User_ROLE_UNSPECIFIED
}
}
func convertUserRoleToStore(role v1pb.User_Role) store.Role {
switch role {
case v1pb.User_HOST:
return store.RoleHost
case v1pb.User_ADMIN:
return store.RoleAdmin
case v1pb.User_USER:
return store.RoleUser
default:
return store.RoleUser
}
}
// SearchUsersFilterCELAttributes are the CEL attributes for SearchUsersFilter.
var SearchUsersFilterCELAttributes = []cel.EnvOption{
cel.Variable("username", cel.StringType),
cel.Variable("random", cel.BoolType),
cel.Variable("limit", cel.IntType),
}
type SearchUsersFilter struct {
Username *string
Random bool
Limit *int
}
func parseSearchUsersFilter(expression string) (*SearchUsersFilter, error) {
e, err := cel.NewEnv(SearchUsersFilterCELAttributes...)
if err != nil {
return nil, err
}
ast, issues := e.Compile(expression)
if issues != nil {
return nil, errors.Errorf("found issue %v", issues)
}
filter := &SearchUsersFilter{}
expr, err := cel.AstToParsedExpr(ast)
if err != nil {
return nil, err
}
callExpr := expr.GetExpr().GetCallExpr()
findSearchUsersField(callExpr, filter)
return filter, nil
}
func findSearchUsersField(callExpr *expr.Expr_Call, filter *SearchUsersFilter) {
if len(callExpr.Args) == 2 {
idExpr := callExpr.Args[0].GetIdentExpr()
if idExpr != nil {
if idExpr.Name == "username" {
username := callExpr.Args[1].GetConstExpr().GetStringValue()
filter.Username = &username
} else if idExpr.Name == "random" {
random := callExpr.Args[1].GetConstExpr().GetBoolValue()
filter.Random = random
} else if idExpr.Name == "limit" {
limit := int(callExpr.Args[1].GetConstExpr().GetInt64Value())
filter.Limit = &limit
}
return
}
}
for _, arg := range callExpr.Args {
callExpr := arg.GetCallExpr()
if callExpr != nil {
findSearchUsersField(callExpr, filter)
}
}
}
func extractImageInfo(dataURI string) (string, string, error) {
dataURIRegex := regexp.MustCompile(`^data:(?P<type>.+);base64,(?P<base64>.+)`)
matches := dataURIRegex.FindStringSubmatch(dataURI)
if len(matches) != 3 {
return "", "", errors.New("Invalid data URI format")
}
imageType := matches[1]
base64Data := matches[2]
return imageType, base64Data, nil
}

127
server/router/api/v1/v1.go Normal file
View File

@@ -0,0 +1,127 @@
package v1
import (
"context"
"fmt"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/improbable-eng/grpc-web/go/grpcweb"
"github.com/labstack/echo/v4"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/reflection"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store"
)
type APIV1Service struct {
v1pb.UnimplementedWorkspaceServiceServer
v1pb.UnimplementedWorkspaceSettingServiceServer
v1pb.UnimplementedAuthServiceServer
v1pb.UnimplementedUserServiceServer
v1pb.UnimplementedMemoServiceServer
v1pb.UnimplementedResourceServiceServer
v1pb.UnimplementedTagServiceServer
v1pb.UnimplementedInboxServiceServer
v1pb.UnimplementedActivityServiceServer
v1pb.UnimplementedWebhookServiceServer
v1pb.UnimplementedMarkdownServiceServer
v1pb.UnimplementedIdentityProviderServiceServer
Secret string
Profile *profile.Profile
Store *store.Store
grpcServer *grpc.Server
}
func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store, grpcServer *grpc.Server) *APIV1Service {
grpc.EnableTracing = true
apiv1Service := &APIV1Service{
Secret: secret,
Profile: profile,
Store: store,
grpcServer: grpcServer,
}
v1pb.RegisterWorkspaceServiceServer(grpcServer, apiv1Service)
v1pb.RegisterWorkspaceSettingServiceServer(grpcServer, apiv1Service)
v1pb.RegisterAuthServiceServer(grpcServer, apiv1Service)
v1pb.RegisterUserServiceServer(grpcServer, apiv1Service)
v1pb.RegisterMemoServiceServer(grpcServer, apiv1Service)
v1pb.RegisterTagServiceServer(grpcServer, apiv1Service)
v1pb.RegisterResourceServiceServer(grpcServer, apiv1Service)
v1pb.RegisterInboxServiceServer(grpcServer, apiv1Service)
v1pb.RegisterActivityServiceServer(grpcServer, apiv1Service)
v1pb.RegisterWebhookServiceServer(grpcServer, apiv1Service)
v1pb.RegisterMarkdownServiceServer(grpcServer, apiv1Service)
v1pb.RegisterIdentityProviderServiceServer(grpcServer, apiv1Service)
reflection.Register(grpcServer)
return apiv1Service
}
// RegisterGateway registers the gRPC-Gateway with the given Echo instance.
func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Echo) error {
// Create a client connection to the gRPC Server we just started.
// This is where the gRPC-Gateway proxies the requests.
conn, err := grpc.DialContext(
ctx,
fmt.Sprintf(":%d", s.Profile.Port),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
return err
}
gwMux := runtime.NewServeMux()
if err := v1pb.RegisterWorkspaceServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterWorkspaceSettingServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterAuthServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterUserServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterMemoServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterTagServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterResourceServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterInboxServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterActivityServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterWebhookServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterMarkdownServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
if err := v1pb.RegisterIdentityProviderServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
echoServer.Any("*", echo.WrapHandler(gwMux))
// GRPC web proxy.
options := []grpcweb.Option{
grpcweb.WithCorsForRegisteredEndpointsOnly(false),
grpcweb.WithOriginFunc(func(origin string) bool {
return true
}),
}
wrappedGrpc := grpcweb.WrapServer(s.grpcServer, options...)
echoServer.Any("/memos.api.v1.*", echo.WrapHandler(wrappedGrpc))
return nil
}

View File

@@ -0,0 +1,114 @@
package v1
import (
"context"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) CreateWebhook(ctx context.Context, request *v1pb.CreateWebhookRequest) (*v1pb.Webhook, error) {
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
webhook, err := s.Store.CreateWebhook(ctx, &store.Webhook{
CreatorID: currentUser.ID,
Name: request.Name,
URL: request.Url,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create webhook, error: %+v", err)
}
return convertWebhookFromStore(webhook), nil
}
func (s *APIV1Service) ListWebhooks(ctx context.Context, request *v1pb.ListWebhooksRequest) (*v1pb.ListWebhooksResponse, error) {
webhooks, err := s.Store.ListWebhooks(ctx, &store.FindWebhook{
CreatorID: &request.CreatorId,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list webhooks, error: %+v", err)
}
response := &v1pb.ListWebhooksResponse{
Webhooks: []*v1pb.Webhook{},
}
for _, webhook := range webhooks {
response.Webhooks = append(response.Webhooks, convertWebhookFromStore(webhook))
}
return response, nil
}
func (s *APIV1Service) GetWebhook(ctx context.Context, request *v1pb.GetWebhookRequest) (*v1pb.Webhook, error) {
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
webhook, err := s.Store.GetWebhook(ctx, &store.FindWebhook{
ID: &request.Id,
CreatorID: &currentUser.ID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get webhook, error: %+v", err)
}
if webhook == nil {
return nil, status.Errorf(codes.NotFound, "webhook not found")
}
return convertWebhookFromStore(webhook), nil
}
func (s *APIV1Service) UpdateWebhook(ctx context.Context, request *v1pb.UpdateWebhookRequest) (*v1pb.Webhook, error) {
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
}
update := &store.UpdateWebhook{}
for _, field := range request.UpdateMask.Paths {
switch field {
case "row_status":
rowStatus := store.RowStatus(request.Webhook.RowStatus.String())
update.RowStatus = &rowStatus
case "name":
update.Name = &request.Webhook.Name
case "url":
update.URL = &request.Webhook.Url
}
}
webhook, err := s.Store.UpdateWebhook(ctx, update)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update webhook, error: %+v", err)
}
return convertWebhookFromStore(webhook), nil
}
func (s *APIV1Service) DeleteWebhook(ctx context.Context, request *v1pb.DeleteWebhookRequest) (*emptypb.Empty, error) {
err := s.Store.DeleteWebhook(ctx, &store.DeleteWebhook{
ID: request.Id,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete webhook, error: %+v", err)
}
return &emptypb.Empty{}, nil
}
func convertWebhookFromStore(webhook *store.Webhook) *v1pb.Webhook {
return &v1pb.Webhook{
Id: webhook.ID,
CreatedTime: timestamppb.New(time.Unix(webhook.CreatedTs, 0)),
UpdatedTime: timestamppb.New(time.Unix(webhook.UpdatedTs, 0)),
RowStatus: convertRowStatusFromStore(webhook.RowStatus),
CreatorId: webhook.CreatorID,
Name: webhook.Name,
Url: webhook.URL,
}
}

View File

@@ -0,0 +1,49 @@
package v1
import (
"context"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/store"
)
var ownerCache *v1pb.User
func (s *APIV1Service) GetWorkspaceProfile(ctx context.Context, _ *v1pb.GetWorkspaceProfileRequest) (*v1pb.WorkspaceProfile, error) {
workspaceProfile := &v1pb.WorkspaceProfile{
Version: s.Profile.Version,
Mode: s.Profile.Mode,
}
owner, err := s.GetInstanceOwner(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance owner: %v", err)
}
if owner != nil {
workspaceProfile.Owner = owner.Name
}
return workspaceProfile, nil
}
func (s *APIV1Service) GetInstanceOwner(ctx context.Context) (*v1pb.User, error) {
if ownerCache != nil {
return ownerCache, nil
}
hostUserType := store.RoleHost
user, err := s.Store.GetUser(ctx, &store.FindUser{
Role: &hostUserType,
})
if err != nil {
return nil, errors.Wrapf(err, "failed to find owner")
}
if user == nil {
return nil, nil
}
ownerCache = convertUserFromStore(user)
return ownerCache, nil
}

View File

@@ -0,0 +1,225 @@
package v1
import (
"context"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV1Service) ListWorkspaceSettings(ctx context.Context, _ *v1pb.ListWorkspaceSettingsRequest) (*v1pb.ListWorkspaceSettingsResponse, error) {
workspaceSettings, err := s.Store.ListWorkspaceSettings(ctx, &store.FindWorkspaceSetting{})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace setting: %v", err)
}
response := &v1pb.ListWorkspaceSettingsResponse{
Settings: []*v1pb.WorkspaceSetting{},
}
for _, workspaceSetting := range workspaceSettings {
if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_BASIC {
continue
}
response.Settings = append(response.Settings, convertWorkspaceSettingFromStore(workspaceSetting))
}
return response, nil
}
func (s *APIV1Service) GetWorkspaceSetting(ctx context.Context, request *v1pb.GetWorkspaceSettingRequest) (*v1pb.WorkspaceSetting, error) {
settingKeyString, err := ExtractWorkspaceSettingKeyFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid workspace setting name: %v", err)
}
settingKey := storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[settingKeyString])
workspaceSetting, err := s.Store.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{
Name: settingKey.String(),
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace setting: %v", err)
}
if workspaceSetting == nil {
return nil, status.Errorf(codes.NotFound, "workspace setting not found")
}
return convertWorkspaceSettingFromStore(workspaceSetting), nil
}
func (s *APIV1Service) SetWorkspaceSetting(ctx context.Context, request *v1pb.SetWorkspaceSettingRequest) (*v1pb.WorkspaceSetting, error) {
if s.Profile.Mode == "demo" {
return nil, status.Errorf(codes.InvalidArgument, "setting workspace setting is not allowed in demo mode")
}
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
workspaceSetting, err := s.Store.UpsertWorkspaceSetting(ctx, convertWorkspaceSettingToStore(request.Setting))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert workspace setting: %v", err)
}
return convertWorkspaceSettingFromStore(workspaceSetting), nil
}
func convertWorkspaceSettingFromStore(setting *storepb.WorkspaceSetting) *v1pb.WorkspaceSetting {
workspaceSetting := &v1pb.WorkspaceSetting{
Name: fmt.Sprintf("%s%s", WorkspaceSettingNamePrefix, setting.Key.String()),
}
switch setting.Value.(type) {
case *storepb.WorkspaceSetting_GeneralSetting:
workspaceSetting.Value = &v1pb.WorkspaceSetting_GeneralSetting{
GeneralSetting: convertWorkspaceGeneralSettingFromStore(setting.GetGeneralSetting()),
}
case *storepb.WorkspaceSetting_StorageSetting:
workspaceSetting.Value = &v1pb.WorkspaceSetting_StorageSetting{
StorageSetting: convertWorkspaceStorageSettingFromStore(setting.GetStorageSetting()),
}
case *storepb.WorkspaceSetting_MemoRelatedSetting:
workspaceSetting.Value = &v1pb.WorkspaceSetting_MemoRelatedSetting{
MemoRelatedSetting: convertWorkspaceMemoRelatedSettingFromStore(setting.GetMemoRelatedSetting()),
}
}
return workspaceSetting
}
func convertWorkspaceSettingToStore(setting *v1pb.WorkspaceSetting) *storepb.WorkspaceSetting {
settingKeyString, _ := ExtractWorkspaceSettingKeyFromName(setting.Name)
workspaceSetting := &storepb.WorkspaceSetting{
Key: storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[settingKeyString]),
Value: &storepb.WorkspaceSetting_GeneralSetting{
GeneralSetting: convertWorkspaceGeneralSettingToStore(setting.GetGeneralSetting()),
},
}
switch workspaceSetting.Key {
case storepb.WorkspaceSettingKey_WORKSPACE_SETTING_GENERAL:
workspaceSetting.Value = &storepb.WorkspaceSetting_GeneralSetting{
GeneralSetting: convertWorkspaceGeneralSettingToStore(setting.GetGeneralSetting()),
}
case storepb.WorkspaceSettingKey_WORKSPACE_SETTING_STORAGE:
workspaceSetting.Value = &storepb.WorkspaceSetting_StorageSetting{
StorageSetting: convertWorkspaceStorageSettingToStore(setting.GetStorageSetting()),
}
case storepb.WorkspaceSettingKey_WORKSPACE_SETTING_MEMO_RELATED:
workspaceSetting.Value = &storepb.WorkspaceSetting_MemoRelatedSetting{
MemoRelatedSetting: convertWorkspaceMemoRelatedSettingToStore(setting.GetMemoRelatedSetting()),
}
}
return workspaceSetting
}
func convertWorkspaceGeneralSettingFromStore(setting *storepb.WorkspaceGeneralSetting) *v1pb.WorkspaceGeneralSetting {
if setting == nil {
return nil
}
generalSetting := &v1pb.WorkspaceGeneralSetting{
InstanceUrl: setting.InstanceUrl,
DisallowSignup: setting.DisallowSignup,
DisallowPasswordLogin: setting.DisallowPasswordLogin,
AdditionalScript: setting.AdditionalScript,
AdditionalStyle: setting.AdditionalStyle,
}
if setting.CustomProfile != nil {
generalSetting.CustomProfile = &v1pb.WorkspaceCustomProfile{
Title: setting.CustomProfile.Title,
Description: setting.CustomProfile.Description,
LogoUrl: setting.CustomProfile.LogoUrl,
Locale: setting.CustomProfile.Locale,
Appearance: setting.CustomProfile.Appearance,
}
}
return generalSetting
}
func convertWorkspaceGeneralSettingToStore(setting *v1pb.WorkspaceGeneralSetting) *storepb.WorkspaceGeneralSetting {
if setting == nil {
return nil
}
generalSetting := &storepb.WorkspaceGeneralSetting{
InstanceUrl: setting.InstanceUrl,
DisallowSignup: setting.DisallowSignup,
DisallowPasswordLogin: setting.DisallowPasswordLogin,
AdditionalScript: setting.AdditionalScript,
AdditionalStyle: setting.AdditionalStyle,
}
if setting.CustomProfile != nil {
generalSetting.CustomProfile = &storepb.WorkspaceCustomProfile{
Title: setting.CustomProfile.Title,
Description: setting.CustomProfile.Description,
LogoUrl: setting.CustomProfile.LogoUrl,
Locale: setting.CustomProfile.Locale,
Appearance: setting.CustomProfile.Appearance,
}
}
return generalSetting
}
func convertWorkspaceStorageSettingFromStore(settingpb *storepb.WorkspaceStorageSetting) *v1pb.WorkspaceStorageSetting {
if settingpb == nil {
return nil
}
setting := &v1pb.WorkspaceStorageSetting{
StorageType: v1pb.WorkspaceStorageSetting_StorageType(settingpb.StorageType),
FilepathTemplate: settingpb.FilepathTemplate,
UploadSizeLimitMb: settingpb.UploadSizeLimitMb,
}
if settingpb.S3Config != nil {
setting.S3Config = &v1pb.WorkspaceStorageSetting_S3Config{
AccessKeyId: settingpb.S3Config.AccessKeyId,
AccessKeySecret: settingpb.S3Config.AccessKeySecret,
Endpoint: settingpb.S3Config.Endpoint,
Region: settingpb.S3Config.Region,
Bucket: settingpb.S3Config.Bucket,
}
}
return setting
}
func convertWorkspaceStorageSettingToStore(setting *v1pb.WorkspaceStorageSetting) *storepb.WorkspaceStorageSetting {
if setting == nil {
return nil
}
settingpb := &storepb.WorkspaceStorageSetting{
StorageType: storepb.WorkspaceStorageSetting_StorageType(setting.StorageType),
FilepathTemplate: setting.FilepathTemplate,
UploadSizeLimitMb: setting.UploadSizeLimitMb,
}
if setting.S3Config != nil {
settingpb.S3Config = &storepb.WorkspaceStorageSetting_S3Config{
AccessKeyId: setting.S3Config.AccessKeyId,
AccessKeySecret: setting.S3Config.AccessKeySecret,
Endpoint: setting.S3Config.Endpoint,
Region: setting.S3Config.Region,
Bucket: setting.S3Config.Bucket,
}
}
return settingpb
}
func convertWorkspaceMemoRelatedSettingFromStore(setting *storepb.WorkspaceMemoRelatedSetting) *v1pb.WorkspaceMemoRelatedSetting {
if setting == nil {
return nil
}
return &v1pb.WorkspaceMemoRelatedSetting{
DisallowPublicVisible: setting.DisallowPublicVisible,
DisplayWithUpdateTime: setting.DisplayWithUpdateTime,
}
}
func convertWorkspaceMemoRelatedSettingToStore(setting *v1pb.WorkspaceMemoRelatedSetting) *storepb.WorkspaceMemoRelatedSetting {
if setting == nil {
return nil
}
return &storepb.WorkspaceMemoRelatedSetting{
DisallowPublicVisible: setting.DisallowPublicVisible,
DisplayWithUpdateTime: setting.DisplayWithUpdateTime,
}
}