mirror of
https://github.com/usememos/memos.git
synced 2025-06-05 22:09:59 +02:00
chore: rename router package
This commit is contained in:
161
server/router/api/v1/acl.go
Normal file
161
server/router/api/v1/acl.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// ContextKey is the key type of context value.
|
||||
type ContextKey int
|
||||
|
||||
const (
|
||||
// The key name used to store username in the context
|
||||
// user id is extracted from the jwt token subject field.
|
||||
usernameContextKey ContextKey = iota
|
||||
)
|
||||
|
||||
// GRPCAuthInterceptor is the auth interceptor for gRPC server.
|
||||
type GRPCAuthInterceptor struct {
|
||||
Store *store.Store
|
||||
secret string
|
||||
}
|
||||
|
||||
// NewGRPCAuthInterceptor returns a new API auth interceptor.
|
||||
func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor {
|
||||
return &GRPCAuthInterceptor{
|
||||
Store: store,
|
||||
secret: secret,
|
||||
}
|
||||
}
|
||||
|
||||
// AuthenticationInterceptor is the unary interceptor for gRPC API.
|
||||
func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
|
||||
}
|
||||
accessToken, err := getTokenFromMetadata(md)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, err.Error())
|
||||
}
|
||||
|
||||
username, err := in.authenticate(ctx, accessToken)
|
||||
if err != nil {
|
||||
if isUnauthorizeAllowedMethod(serverInfo.FullMethod) {
|
||||
return handler(ctx, request)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
user, err := in.Store.GetUser(ctx, &store.FindUser{
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, errors.Errorf("user %q not exists", username)
|
||||
}
|
||||
if user.RowStatus == store.Archived {
|
||||
return nil, errors.Errorf("user %q is archived", username)
|
||||
}
|
||||
if isOnlyForAdminAllowedMethod(serverInfo.FullMethod) && user.Role != store.RoleHost && user.Role != store.RoleAdmin {
|
||||
return nil, errors.Errorf("user %q is not admin", username)
|
||||
}
|
||||
|
||||
// Stores userID into context.
|
||||
childCtx := context.WithValue(ctx, usernameContextKey, username)
|
||||
return handler(childCtx, request)
|
||||
}
|
||||
|
||||
func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessToken string) (string, error) {
|
||||
if accessToken == "" {
|
||||
return "", status.Errorf(codes.Unauthenticated, "access token not found")
|
||||
}
|
||||
claims := &ClaimsMessage{}
|
||||
_, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
|
||||
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
|
||||
}
|
||||
if kid, ok := t.Header["kid"].(string); ok {
|
||||
if kid == "v1" {
|
||||
return []byte(in.secret), nil
|
||||
}
|
||||
}
|
||||
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"])
|
||||
})
|
||||
if err != nil {
|
||||
return "", status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
|
||||
}
|
||||
|
||||
// We either have a valid access token or we will attempt to generate new access token.
|
||||
userID, err := util.ConvertStringToInt32(claims.Subject)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "malformed ID in the token")
|
||||
}
|
||||
user, err := in.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return "", errors.Errorf("user %q not exists", userID)
|
||||
}
|
||||
if user.RowStatus == store.Archived {
|
||||
return "", errors.Errorf("user %q is archived", userID)
|
||||
}
|
||||
|
||||
accessTokens, err := in.Store.GetUserAccessTokens(ctx, user.ID)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "failed to get user access tokens")
|
||||
}
|
||||
if !validateAccessToken(accessToken, accessTokens) {
|
||||
return "", status.Errorf(codes.Unauthenticated, "invalid access token")
|
||||
}
|
||||
|
||||
return user.Username, nil
|
||||
}
|
||||
|
||||
func getTokenFromMetadata(md metadata.MD) (string, error) {
|
||||
// Check the HTTP request header first.
|
||||
authorizationHeaders := md.Get("Authorization")
|
||||
if len(md.Get("Authorization")) > 0 {
|
||||
authHeaderParts := strings.Fields(authorizationHeaders[0])
|
||||
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
|
||||
return "", errors.New("authorization header format must be Bearer {token}")
|
||||
}
|
||||
return authHeaderParts[1], nil
|
||||
}
|
||||
// Check the cookie header.
|
||||
var accessToken string
|
||||
for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) {
|
||||
header := http.Header{}
|
||||
header.Add("Cookie", t)
|
||||
request := http.Request{Header: header}
|
||||
if v, _ := request.Cookie(AccessTokenCookieName); v != nil {
|
||||
accessToken = v.Value
|
||||
}
|
||||
}
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
|
||||
for _, userAccessToken := range userAccessTokens {
|
||||
if accessTokenString == userAccessToken.AccessToken {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
38
server/router/api/v1/acl_config.go
Normal file
38
server/router/api/v1/acl_config.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package v1
|
||||
|
||||
var authenticationAllowlistMethods = map[string]bool{
|
||||
"/memos.api.v1.WorkspaceService/GetWorkspaceProfile": true,
|
||||
"/memos.api.v1.WorkspaceSettingService/GetWorkspaceSetting": true,
|
||||
"/memos.api.v1.WorkspaceSettingService/ListWorkspaceSettings": true,
|
||||
"/memos.api.v1.IdentityProviderService/ListIdentityProviders": true,
|
||||
"/memos.api.v1.IdentityProviderService/GetIdentityProvider": true,
|
||||
"/memos.api.v1.AuthService/GetAuthStatus": true,
|
||||
"/memos.api.v1.AuthService/SignIn": true,
|
||||
"/memos.api.v1.AuthService/SignInWithSSO": true,
|
||||
"/memos.api.v1.AuthService/SignOut": true,
|
||||
"/memos.api.v1.AuthService/SignUp": true,
|
||||
"/memos.api.v1.UserService/GetUser": true,
|
||||
"/memos.api.v1.UserService/GetUserAvatar": true,
|
||||
"/memos.api.v1.UserService/SearchUsers": true,
|
||||
"/memos.api.v1.MemoService/ListMemos": true,
|
||||
"/memos.api.v1.MemoService/GetMemo": true,
|
||||
"/memos.api.v1.MemoService/SearchMemos": true,
|
||||
"/memos.api.v1.MemoService/ListMemoResources": true,
|
||||
"/memos.api.v1.MemoService/ListMemoRelations": true,
|
||||
"/memos.api.v1.MemoService/ListMemoComments": true,
|
||||
"/memos.api.v1.LinkService/GetLinkMetadata": true,
|
||||
}
|
||||
|
||||
// isUnauthorizeAllowedMethod returns whether the method is exempted from authentication.
|
||||
func isUnauthorizeAllowedMethod(fullMethodName string) bool {
|
||||
return authenticationAllowlistMethods[fullMethodName]
|
||||
}
|
||||
|
||||
var allowedMethodsOnlyForAdmin = map[string]bool{
|
||||
"/memos.api.v1.UserService/CreateUser": true,
|
||||
}
|
||||
|
||||
// isOnlyForAdminAllowedMethod returns true if the method is allowed to be called only by admin.
|
||||
func isOnlyForAdminAllowedMethod(methodName string) bool {
|
||||
return allowedMethodsOnlyForAdmin[methodName]
|
||||
}
|
56
server/router/api/v1/activity_service.go
Normal file
56
server/router/api/v1/activity_service.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) GetActivity(ctx context.Context, request *v1pb.GetActivityRequest) (*v1pb.Activity, error) {
|
||||
activity, err := s.Store.GetActivity(ctx, &store.FindActivity{
|
||||
ID: &request.Id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get activity: %v", err)
|
||||
}
|
||||
|
||||
activityMessage, err := s.convertActivityFromStore(ctx, activity)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert activity from store: %v", err)
|
||||
}
|
||||
return activityMessage, nil
|
||||
}
|
||||
|
||||
func (*APIV1Service) convertActivityFromStore(_ context.Context, activity *store.Activity) (*v1pb.Activity, error) {
|
||||
return &v1pb.Activity{
|
||||
Id: activity.ID,
|
||||
CreatorId: activity.CreatorID,
|
||||
Type: activity.Type.String(),
|
||||
Level: activity.Level.String(),
|
||||
CreateTime: timestamppb.New(time.Unix(activity.CreatedTs, 0)),
|
||||
Payload: convertActivityPayloadFromStore(activity.Payload),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convertActivityPayloadFromStore(payload *storepb.ActivityPayload) *v1pb.ActivityPayload {
|
||||
v2Payload := &v1pb.ActivityPayload{}
|
||||
if payload.MemoComment != nil {
|
||||
v2Payload.MemoComment = &v1pb.ActivityMemoCommentPayload{
|
||||
MemoId: payload.MemoComment.MemoId,
|
||||
RelatedMemoId: payload.MemoComment.RelatedMemoId,
|
||||
}
|
||||
}
|
||||
if payload.VersionUpdate != nil {
|
||||
v2Payload.VersionUpdate = &v1pb.ActivityVersionUpdatePayload{
|
||||
Version: payload.VersionUpdate.Version,
|
||||
}
|
||||
}
|
||||
return v2Payload
|
||||
}
|
63
server/router/api/v1/auth.go
Normal file
63
server/router/api/v1/auth.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
const (
|
||||
// issuer is the issuer of the jwt token.
|
||||
Issuer = "memos"
|
||||
// Signing key section. For now, this is only used for signing, not for verifying since we only
|
||||
// have 1 version. But it will be used to maintain backward compatibility if we change the signing mechanism.
|
||||
KeyID = "v1"
|
||||
// AccessTokenAudienceName is the audience name of the access token.
|
||||
AccessTokenAudienceName = "user.access-token"
|
||||
AccessTokenDuration = 7 * 24 * time.Hour
|
||||
|
||||
// CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user
|
||||
// cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt.
|
||||
CookieExpDuration = AccessTokenDuration - 1*time.Minute
|
||||
// AccessTokenCookieName is the cookie name of access token.
|
||||
AccessTokenCookieName = "memos.access-token"
|
||||
)
|
||||
|
||||
type ClaimsMessage struct {
|
||||
Name string `json:"name"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateAccessToken generates an access token.
|
||||
func GenerateAccessToken(username string, userID int32, expirationTime time.Time, secret []byte) (string, error) {
|
||||
return generateToken(username, userID, AccessTokenAudienceName, expirationTime, secret)
|
||||
}
|
||||
|
||||
// generateToken generates a jwt token.
|
||||
func generateToken(username string, userID int32, audience string, expirationTime time.Time, secret []byte) (string, error) {
|
||||
registeredClaims := jwt.RegisteredClaims{
|
||||
Issuer: Issuer,
|
||||
Audience: jwt.ClaimStrings{audience},
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Subject: fmt.Sprint(userID),
|
||||
}
|
||||
if !expirationTime.IsZero() {
|
||||
registeredClaims.ExpiresAt = jwt.NewNumericDate(expirationTime)
|
||||
}
|
||||
|
||||
// Declare the token with the HS256 algorithm used for signing, and the claims.
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &ClaimsMessage{
|
||||
Name: username,
|
||||
RegisteredClaims: registeredClaims,
|
||||
})
|
||||
token.Header["kid"] = KeyID
|
||||
|
||||
// Create the JWT string.
|
||||
tokenString, err := token.SignedString(secret)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
264
server/router/api/v1/auth_service.go
Normal file
264
server/router/api/v1/auth_service.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
"github.com/usememos/memos/plugin/idp"
|
||||
"github.com/usememos/memos/plugin/idp/oauth2"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) GetAuthStatus(ctx context.Context, _ *v1pb.GetAuthStatusRequest) (*v1pb.User, error) {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
// Set the cookie header to expire access token.
|
||||
if err := s.clearAccessTokenCookie(ctx); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to set grpc header: %v", err)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unauthenticated, "user not found")
|
||||
}
|
||||
return convertUserFromStore(user), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) (*v1pb.User, error) {
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
Username: &request.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to find user by username %s", request.Username))
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, fmt.Sprintf("user not found with username %s", request.Username))
|
||||
} else if user.RowStatus == store.Archived {
|
||||
return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("user has been archived with username %s", request.Username))
|
||||
}
|
||||
|
||||
// Compare the stored hashed password, with the hashed version of the password that was received.
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(request.Password)); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "unmatched email and password")
|
||||
}
|
||||
|
||||
expireTime := time.Now().Add(AccessTokenDuration)
|
||||
if request.NeverExpire {
|
||||
// Set the expire time to 100 years.
|
||||
expireTime = time.Now().Add(100 * 365 * 24 * time.Hour)
|
||||
}
|
||||
if err := s.doSignIn(ctx, user, expireTime); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
|
||||
}
|
||||
return convertUserFromStore(user), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) SignInWithSSO(ctx context.Context, request *v1pb.SignInWithSSORequest) (*v1pb.User, error) {
|
||||
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
||||
ID: &request.IdpId,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get identity provider, err: %s", err))
|
||||
}
|
||||
if identityProvider == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, fmt.Sprintf("identity provider not found with id %d", request.IdpId))
|
||||
}
|
||||
|
||||
var userInfo *idp.IdentityProviderUserInfo
|
||||
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
|
||||
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.GetOauth2Config())
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create oauth2 identity provider, err: %s", err))
|
||||
}
|
||||
token, err := oauth2IdentityProvider.ExchangeToken(ctx, request.RedirectUri, request.Code)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to exchange token, err: %s", err))
|
||||
}
|
||||
userInfo, err = oauth2IdentityProvider.UserInfo(token)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get user info, err: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
identifierFilter := identityProvider.IdentifierFilter
|
||||
if identifierFilter != "" {
|
||||
identifierFilterRegex, err := regexp.Compile(identifierFilter)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to compile identifier filter regex, err: %s", err))
|
||||
}
|
||||
if !identifierFilterRegex.MatchString(userInfo.Identifier) {
|
||||
return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("identifier %s is not allowed", userInfo.Identifier))
|
||||
}
|
||||
}
|
||||
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
Username: &userInfo.Identifier,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to find user by username %s", userInfo.Identifier))
|
||||
}
|
||||
if user == nil {
|
||||
userCreate := &store.User{
|
||||
Username: userInfo.Identifier,
|
||||
// The new signup user should be normal user by default.
|
||||
Role: store.RoleUser,
|
||||
Nickname: userInfo.DisplayName,
|
||||
Email: userInfo.Email,
|
||||
}
|
||||
password, err := util.RandomString(20)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to generate random password, err: %s", err))
|
||||
}
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to generate password hash, err: %s", err))
|
||||
}
|
||||
userCreate.PasswordHash = string(passwordHash)
|
||||
user, err = s.Store.CreateUser(ctx, userCreate)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create user, err: %s", err))
|
||||
}
|
||||
}
|
||||
if user.RowStatus == store.Archived {
|
||||
return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("user has been archived with username %s", userInfo.Identifier))
|
||||
}
|
||||
|
||||
if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
|
||||
}
|
||||
return convertUserFromStore(user), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error {
|
||||
accessToken, err := GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret))
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, fmt.Sprintf("failed to generate tokens, err: %s", err))
|
||||
}
|
||||
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken, "user login"); err != nil {
|
||||
return status.Errorf(codes.Internal, fmt.Sprintf("failed to upsert access token to store, err: %s", err))
|
||||
}
|
||||
|
||||
cookie, err := s.buildAccessTokenCookie(ctx, accessToken, expireTime)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, fmt.Sprintf("failed to build access token cookie, err: %s", err))
|
||||
}
|
||||
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
|
||||
"Set-Cookie": cookie,
|
||||
})); err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) SignUp(ctx context.Context, request *v1pb.SignUpRequest) (*v1pb.User, error) {
|
||||
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get workspace setting, err: %s", err))
|
||||
}
|
||||
if workspaceGeneralSetting.DisallowSignup || workspaceGeneralSetting.DisallowPasswordLogin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "sign up is not allowed")
|
||||
}
|
||||
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to generate password hash, err: %s", err))
|
||||
}
|
||||
|
||||
create := &store.User{
|
||||
Username: request.Username,
|
||||
Nickname: request.Username,
|
||||
PasswordHash: string(passwordHash),
|
||||
}
|
||||
if !util.UIDMatcher.MatchString(strings.ToLower(create.Username)) {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", create.Username)
|
||||
}
|
||||
|
||||
hostUserType := store.RoleHost
|
||||
existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{
|
||||
Role: &hostUserType,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to list users, err: %s", err))
|
||||
}
|
||||
if len(existedHostUsers) == 0 {
|
||||
// Change the default role to host if there is no host user.
|
||||
create.Role = store.RoleHost
|
||||
} else {
|
||||
create.Role = store.RoleUser
|
||||
}
|
||||
|
||||
user, err := s.Store.CreateUser(ctx, create)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create user, err: %s", err))
|
||||
}
|
||||
|
||||
if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
|
||||
}
|
||||
return convertUserFromStore(user), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) SignOut(ctx context.Context, _ *v1pb.SignOutRequest) (*emptypb.Empty, error) {
|
||||
if err := s.clearAccessTokenCookie(ctx); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) clearAccessTokenCookie(ctx context.Context) error {
|
||||
cookie, err := s.buildAccessTokenCookie(ctx, "", time.Time{})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to build access token cookie")
|
||||
}
|
||||
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
|
||||
"Set-Cookie": cookie,
|
||||
})); err != nil {
|
||||
return errors.Wrap(err, "failed to set grpc header")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*APIV1Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) {
|
||||
attrs := []string{
|
||||
fmt.Sprintf("%s=%s", AccessTokenCookieName, accessToken),
|
||||
"Path=/",
|
||||
"HttpOnly",
|
||||
}
|
||||
if expireTime.IsZero() {
|
||||
attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT")
|
||||
} else {
|
||||
attrs = append(attrs, "Expires="+expireTime.Format(time.RFC1123))
|
||||
}
|
||||
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return "", errors.New("failed to get metadata from context")
|
||||
}
|
||||
var origin string
|
||||
for _, v := range md.Get("origin") {
|
||||
origin = v
|
||||
}
|
||||
isHTTPS := strings.HasPrefix(origin, "https://")
|
||||
if isHTTPS {
|
||||
attrs = append(attrs, "SameSite=None")
|
||||
attrs = append(attrs, "Secure")
|
||||
} else {
|
||||
attrs = append(attrs, "SameSite=Strict")
|
||||
}
|
||||
return strings.Join(attrs, "; "), nil
|
||||
}
|
74
server/router/api/v1/common.go
Normal file
74
server/router/api/v1/common.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func convertRowStatusFromStore(rowStatus store.RowStatus) v1pb.RowStatus {
|
||||
switch rowStatus {
|
||||
case store.Normal:
|
||||
return v1pb.RowStatus_ACTIVE
|
||||
case store.Archived:
|
||||
return v1pb.RowStatus_ARCHIVED
|
||||
default:
|
||||
return v1pb.RowStatus_ROW_STATUS_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func convertRowStatusToStore(rowStatus v1pb.RowStatus) store.RowStatus {
|
||||
switch rowStatus {
|
||||
case v1pb.RowStatus_ACTIVE:
|
||||
return store.Normal
|
||||
case v1pb.RowStatus_ARCHIVED:
|
||||
return store.Archived
|
||||
default:
|
||||
return store.Normal
|
||||
}
|
||||
}
|
||||
|
||||
func getCurrentUser(ctx context.Context, s *store.Store) (*store.User, error) {
|
||||
username, ok := ctx.Value(usernameContextKey).(string)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
user, err := s.GetUser(ctx, &store.FindUser{
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func getPageToken(limit int, offset int) (string, error) {
|
||||
return marshalPageToken(&v1pb.PageToken{
|
||||
Limit: int32(limit),
|
||||
Offset: int32(offset),
|
||||
})
|
||||
}
|
||||
|
||||
func marshalPageToken(pageToken *v1pb.PageToken) (string, error) {
|
||||
b, err := proto.Marshal(pageToken)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "failed to marshal page token")
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func unmarshalPageToken(s string, pageToken *v1pb.PageToken) error {
|
||||
b, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to decode page token")
|
||||
}
|
||||
if err := proto.Unmarshal(b, pageToken); err != nil {
|
||||
return errors.Wrapf(err, "failed to unmarshal page token")
|
||||
}
|
||||
return nil
|
||||
}
|
169
server/router/api/v1/idp_service.go
Normal file
169
server/router/api/v1/idp_service.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb.CreateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
|
||||
currentUser, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser.Role != store.RoleHost {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
identityProvider, err := s.Store.CreateIdentityProvider(ctx, convertIdentityProviderToStore(request.IdentityProvider))
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create identity provider, error: %+v", err)
|
||||
}
|
||||
return convertIdentityProviderFromStore(identityProvider), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListIdentityProvidersRequest) (*v1pb.ListIdentityProvidersResponse, error) {
|
||||
identityProviders, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list identity providers, error: %+v", err)
|
||||
}
|
||||
|
||||
response := &v1pb.ListIdentityProvidersResponse{
|
||||
IdentityProviders: []*v1pb.IdentityProvider{},
|
||||
}
|
||||
for _, identityProvider := range identityProviders {
|
||||
response.IdentityProviders = append(response.IdentityProviders, convertIdentityProviderFromStore(identityProvider))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.GetIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
|
||||
id, err := ExtractIdentityProviderIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
|
||||
}
|
||||
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
||||
ID: &id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err)
|
||||
}
|
||||
if identityProvider == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "identity provider not found")
|
||||
}
|
||||
return convertIdentityProviderFromStore(identityProvider), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb.UpdateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
|
||||
}
|
||||
|
||||
id, err := ExtractIdentityProviderIDFromName(request.IdentityProvider.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
|
||||
}
|
||||
update := &store.UpdateIdentityProviderV1{
|
||||
ID: id,
|
||||
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[request.IdentityProvider.Type.String()]),
|
||||
}
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
switch field {
|
||||
case "title":
|
||||
update.Name = &request.IdentityProvider.Title
|
||||
case "config":
|
||||
update.Config = convertIdentityProviderConfigToStore(request.IdentityProvider.Type, request.IdentityProvider.Config)
|
||||
}
|
||||
}
|
||||
|
||||
identityProvider, err := s.Store.UpdateIdentityProvider(ctx, update)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update identity provider, error: %+v", err)
|
||||
}
|
||||
return convertIdentityProviderFromStore(identityProvider), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb.DeleteIdentityProviderRequest) (*emptypb.Empty, error) {
|
||||
id, err := ExtractIdentityProviderIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
|
||||
}
|
||||
if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: id}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete identity provider, error: %+v", err)
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider) *v1pb.IdentityProvider {
|
||||
temp := &v1pb.IdentityProvider{
|
||||
Name: fmt.Sprintf("%s%d", IdentityProviderNamePrefix, identityProvider.Id),
|
||||
Title: identityProvider.Name,
|
||||
IdentifierFilter: identityProvider.IdentifierFilter,
|
||||
Type: v1pb.IdentityProvider_Type(v1pb.IdentityProvider_Type_value[identityProvider.Type.String()]),
|
||||
}
|
||||
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
|
||||
oauth2Config := identityProvider.Config.GetOauth2Config()
|
||||
temp.Config = &v1pb.IdentityProviderConfig{
|
||||
Config: &v1pb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &v1pb.OAuth2Config{
|
||||
ClientId: oauth2Config.ClientId,
|
||||
ClientSecret: oauth2Config.ClientSecret,
|
||||
AuthUrl: oauth2Config.AuthUrl,
|
||||
TokenUrl: oauth2Config.TokenUrl,
|
||||
UserInfoUrl: oauth2Config.UserInfoUrl,
|
||||
Scopes: oauth2Config.Scopes,
|
||||
FieldMapping: &v1pb.FieldMapping{
|
||||
Identifier: oauth2Config.FieldMapping.Identifier,
|
||||
DisplayName: oauth2Config.FieldMapping.DisplayName,
|
||||
Email: oauth2Config.FieldMapping.Email,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return temp
|
||||
}
|
||||
|
||||
func convertIdentityProviderToStore(identityProvider *v1pb.IdentityProvider) *storepb.IdentityProvider {
|
||||
id, _ := ExtractIdentityProviderIDFromName(identityProvider.Name)
|
||||
|
||||
temp := &storepb.IdentityProvider{
|
||||
Id: id,
|
||||
Name: identityProvider.Title,
|
||||
IdentifierFilter: identityProvider.IdentifierFilter,
|
||||
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[identityProvider.Type.String()]),
|
||||
Config: convertIdentityProviderConfigToStore(identityProvider.Type, identityProvider.Config),
|
||||
}
|
||||
return temp
|
||||
}
|
||||
|
||||
func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProvider_Type, config *v1pb.IdentityProviderConfig) *storepb.IdentityProviderConfig {
|
||||
if identityProviderType == v1pb.IdentityProvider_OAUTH2 {
|
||||
oauth2Config := config.GetOauth2Config()
|
||||
return &storepb.IdentityProviderConfig{
|
||||
Config: &storepb.IdentityProviderConfig_Oauth2Config{
|
||||
Oauth2Config: &storepb.OAuth2Config{
|
||||
ClientId: oauth2Config.ClientId,
|
||||
ClientSecret: oauth2Config.ClientSecret,
|
||||
AuthUrl: oauth2Config.AuthUrl,
|
||||
TokenUrl: oauth2Config.TokenUrl,
|
||||
UserInfoUrl: oauth2Config.UserInfoUrl,
|
||||
Scopes: oauth2Config.Scopes,
|
||||
FieldMapping: &storepb.FieldMapping{
|
||||
Identifier: oauth2Config.FieldMapping.Identifier,
|
||||
DisplayName: oauth2Config.FieldMapping.DisplayName,
|
||||
Email: oauth2Config.FieldMapping.Email,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
115
server/router/api/v1/inbox_service.go
Normal file
115
server/router/api/v1/inbox_service.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) ListInboxes(ctx context.Context, _ *v1pb.ListInboxesRequest) (*v1pb.ListInboxesResponse, error) {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
|
||||
inboxes, err := s.Store.ListInboxes(ctx, &store.FindInbox{
|
||||
ReceiverID: &user.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list inbox: %v", err)
|
||||
}
|
||||
|
||||
response := &v1pb.ListInboxesResponse{
|
||||
Inboxes: []*v1pb.Inbox{},
|
||||
}
|
||||
for _, inbox := range inboxes {
|
||||
response.Inboxes = append(response.Inboxes, convertInboxFromStore(inbox))
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateInbox(ctx context.Context, request *v1pb.UpdateInboxRequest) (*v1pb.Inbox, error) {
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
|
||||
}
|
||||
|
||||
inboxID, err := ExtractInboxIDFromName(request.Inbox.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid inbox name: %v", err)
|
||||
}
|
||||
update := &store.UpdateInbox{
|
||||
ID: inboxID,
|
||||
}
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
if field == "status" {
|
||||
if request.Inbox.Status == v1pb.Inbox_STATUS_UNSPECIFIED {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "status is required")
|
||||
}
|
||||
update.Status = convertInboxStatusToStore(request.Inbox.Status)
|
||||
}
|
||||
}
|
||||
|
||||
inbox, err := s.Store.UpdateInbox(ctx, update)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update inbox: %v", err)
|
||||
}
|
||||
|
||||
return convertInboxFromStore(inbox), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteInbox(ctx context.Context, request *v1pb.DeleteInboxRequest) (*emptypb.Empty, error) {
|
||||
inboxID, err := ExtractInboxIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid inbox name: %v", err)
|
||||
}
|
||||
|
||||
if err := s.Store.DeleteInbox(ctx, &store.DeleteInbox{
|
||||
ID: inboxID,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update inbox: %v", err)
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func convertInboxFromStore(inbox *store.Inbox) *v1pb.Inbox {
|
||||
return &v1pb.Inbox{
|
||||
Name: fmt.Sprintf("%s%d", InboxNamePrefix, inbox.ID),
|
||||
Sender: fmt.Sprintf("%s%d", UserNamePrefix, inbox.SenderID),
|
||||
Receiver: fmt.Sprintf("%s%d", UserNamePrefix, inbox.ReceiverID),
|
||||
Status: convertInboxStatusFromStore(inbox.Status),
|
||||
CreateTime: timestamppb.New(time.Unix(inbox.CreatedTs, 0)),
|
||||
Type: v1pb.Inbox_Type(inbox.Message.Type),
|
||||
ActivityId: inbox.Message.ActivityId,
|
||||
}
|
||||
}
|
||||
|
||||
func convertInboxStatusFromStore(status store.InboxStatus) v1pb.Inbox_Status {
|
||||
switch status {
|
||||
case store.UNREAD:
|
||||
return v1pb.Inbox_UNREAD
|
||||
case store.ARCHIVED:
|
||||
return v1pb.Inbox_ARCHIVED
|
||||
default:
|
||||
return v1pb.Inbox_STATUS_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func convertInboxStatusToStore(status v1pb.Inbox_Status) store.InboxStatus {
|
||||
switch status {
|
||||
case v1pb.Inbox_UNREAD:
|
||||
return store.UNREAD
|
||||
case v1pb.Inbox_ARCHIVED:
|
||||
return store.ARCHIVED
|
||||
default:
|
||||
return store.UNREAD
|
||||
}
|
||||
}
|
48
server/router/api/v1/logger_interceptor.go
Normal file
48
server/router/api/v1/logger_interceptor.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type LoggerInterceptor struct {
|
||||
}
|
||||
|
||||
func NewLoggerInterceptor() *LoggerInterceptor {
|
||||
return &LoggerInterceptor{}
|
||||
}
|
||||
|
||||
func (in *LoggerInterceptor) LoggerInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
resp, err := handler(ctx, request)
|
||||
in.loggerInterceptorDo(ctx, serverInfo.FullMethod, err)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (*LoggerInterceptor) loggerInterceptorDo(ctx context.Context, fullMethod string, err error) {
|
||||
st := status.Convert(err)
|
||||
var logLevel slog.Level
|
||||
var logMsg string
|
||||
switch st.Code() {
|
||||
case codes.OK:
|
||||
logLevel = slog.LevelInfo
|
||||
logMsg = "OK"
|
||||
case codes.Unauthenticated, codes.OutOfRange, codes.PermissionDenied, codes.NotFound:
|
||||
logLevel = slog.LevelInfo
|
||||
logMsg = "client error"
|
||||
case codes.Internal, codes.Unknown, codes.DataLoss, codes.Unavailable, codes.DeadlineExceeded:
|
||||
logLevel = slog.LevelError
|
||||
logMsg = "server error"
|
||||
default:
|
||||
logLevel = slog.LevelError
|
||||
logMsg = "unknown error"
|
||||
}
|
||||
logAttrs := []slog.Attr{slog.String("method", fullMethod)}
|
||||
if err != nil {
|
||||
logAttrs = append(logAttrs, slog.String("error", err.Error()))
|
||||
}
|
||||
slog.LogAttrs(ctx, logLevel, logMsg, logAttrs...)
|
||||
}
|
235
server/router/api/v1/markdown_service.go
Normal file
235
server/router/api/v1/markdown_service.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/yourselfhosted/gomark/ast"
|
||||
"github.com/yourselfhosted/gomark/parser"
|
||||
"github.com/yourselfhosted/gomark/parser/tokenizer"
|
||||
"github.com/yourselfhosted/gomark/restore"
|
||||
|
||||
getter "github.com/usememos/memos/plugin/http-getter"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
)
|
||||
|
||||
func (*APIV1Service) ParseMarkdown(_ context.Context, request *v1pb.ParseMarkdownRequest) (*v1pb.ParseMarkdownResponse, error) {
|
||||
rawNodes, err := parser.Parse(tokenizer.Tokenize(request.Markdown))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse memo content")
|
||||
}
|
||||
|
||||
nodes := convertFromASTNodes(rawNodes)
|
||||
return &v1pb.ParseMarkdownResponse{
|
||||
Nodes: nodes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*APIV1Service) RestoreMarkdown(_ context.Context, request *v1pb.RestoreMarkdownRequest) (*v1pb.RestoreMarkdownResponse, error) {
|
||||
markdown := restore.Restore(convertToASTNodes(request.Nodes))
|
||||
return &v1pb.RestoreMarkdownResponse{
|
||||
Markdown: markdown,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*APIV1Service) GetLinkMetadata(_ context.Context, request *v1pb.GetLinkMetadataRequest) (*v1pb.LinkMetadata, error) {
|
||||
htmlMeta, err := getter.GetHTMLMeta(request.Link)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &v1pb.LinkMetadata{
|
||||
Title: htmlMeta.Title,
|
||||
Description: htmlMeta.Description,
|
||||
Image: htmlMeta.Image,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convertFromASTNode(rawNode ast.Node) *v1pb.Node {
|
||||
node := &v1pb.Node{
|
||||
Type: v1pb.NodeType(v1pb.NodeType_value[string(rawNode.Type())]),
|
||||
}
|
||||
|
||||
switch n := rawNode.(type) {
|
||||
case *ast.LineBreak:
|
||||
node.Node = &v1pb.Node_LineBreakNode{}
|
||||
case *ast.Paragraph:
|
||||
children := convertFromASTNodes(n.Children)
|
||||
node.Node = &v1pb.Node_ParagraphNode{ParagraphNode: &v1pb.ParagraphNode{Children: children}}
|
||||
case *ast.CodeBlock:
|
||||
node.Node = &v1pb.Node_CodeBlockNode{CodeBlockNode: &v1pb.CodeBlockNode{Language: n.Language, Content: n.Content}}
|
||||
case *ast.Heading:
|
||||
children := convertFromASTNodes(n.Children)
|
||||
node.Node = &v1pb.Node_HeadingNode{HeadingNode: &v1pb.HeadingNode{Level: int32(n.Level), Children: children}}
|
||||
case *ast.HorizontalRule:
|
||||
node.Node = &v1pb.Node_HorizontalRuleNode{HorizontalRuleNode: &v1pb.HorizontalRuleNode{Symbol: n.Symbol}}
|
||||
case *ast.Blockquote:
|
||||
children := convertFromASTNodes(n.Children)
|
||||
node.Node = &v1pb.Node_BlockquoteNode{BlockquoteNode: &v1pb.BlockquoteNode{Children: children}}
|
||||
case *ast.OrderedList:
|
||||
children := convertFromASTNodes(n.Children)
|
||||
node.Node = &v1pb.Node_OrderedListNode{OrderedListNode: &v1pb.OrderedListNode{Number: n.Number, Indent: int32(n.Indent), Children: children}}
|
||||
case *ast.UnorderedList:
|
||||
children := convertFromASTNodes(n.Children)
|
||||
node.Node = &v1pb.Node_UnorderedListNode{UnorderedListNode: &v1pb.UnorderedListNode{Symbol: n.Symbol, Indent: int32(n.Indent), Children: children}}
|
||||
case *ast.TaskList:
|
||||
children := convertFromASTNodes(n.Children)
|
||||
node.Node = &v1pb.Node_TaskListNode{TaskListNode: &v1pb.TaskListNode{Symbol: n.Symbol, Indent: int32(n.Indent), Complete: n.Complete, Children: children}}
|
||||
case *ast.MathBlock:
|
||||
node.Node = &v1pb.Node_MathBlockNode{MathBlockNode: &v1pb.MathBlockNode{Content: n.Content}}
|
||||
case *ast.Table:
|
||||
node.Node = &v1pb.Node_TableNode{TableNode: convertTableFromASTNode(n)}
|
||||
case *ast.EmbeddedContent:
|
||||
node.Node = &v1pb.Node_EmbeddedContentNode{EmbeddedContentNode: &v1pb.EmbeddedContentNode{ResourceName: n.ResourceName, Params: n.Params}}
|
||||
case *ast.Text:
|
||||
node.Node = &v1pb.Node_TextNode{TextNode: &v1pb.TextNode{Content: n.Content}}
|
||||
case *ast.Bold:
|
||||
children := convertFromASTNodes(n.Children)
|
||||
node.Node = &v1pb.Node_BoldNode{BoldNode: &v1pb.BoldNode{Symbol: n.Symbol, Children: children}}
|
||||
case *ast.Italic:
|
||||
node.Node = &v1pb.Node_ItalicNode{ItalicNode: &v1pb.ItalicNode{Symbol: n.Symbol, Content: n.Content}}
|
||||
case *ast.BoldItalic:
|
||||
node.Node = &v1pb.Node_BoldItalicNode{BoldItalicNode: &v1pb.BoldItalicNode{Symbol: n.Symbol, Content: n.Content}}
|
||||
case *ast.Code:
|
||||
node.Node = &v1pb.Node_CodeNode{CodeNode: &v1pb.CodeNode{Content: n.Content}}
|
||||
case *ast.Image:
|
||||
node.Node = &v1pb.Node_ImageNode{ImageNode: &v1pb.ImageNode{AltText: n.AltText, Url: n.URL}}
|
||||
case *ast.Link:
|
||||
node.Node = &v1pb.Node_LinkNode{LinkNode: &v1pb.LinkNode{Text: n.Text, Url: n.URL}}
|
||||
case *ast.AutoLink:
|
||||
node.Node = &v1pb.Node_AutoLinkNode{AutoLinkNode: &v1pb.AutoLinkNode{Url: n.URL, IsRawText: n.IsRawText}}
|
||||
case *ast.Tag:
|
||||
node.Node = &v1pb.Node_TagNode{TagNode: &v1pb.TagNode{Content: n.Content}}
|
||||
case *ast.Strikethrough:
|
||||
node.Node = &v1pb.Node_StrikethroughNode{StrikethroughNode: &v1pb.StrikethroughNode{Content: n.Content}}
|
||||
case *ast.EscapingCharacter:
|
||||
node.Node = &v1pb.Node_EscapingCharacterNode{EscapingCharacterNode: &v1pb.EscapingCharacterNode{Symbol: n.Symbol}}
|
||||
case *ast.Math:
|
||||
node.Node = &v1pb.Node_MathNode{MathNode: &v1pb.MathNode{Content: n.Content}}
|
||||
case *ast.Highlight:
|
||||
node.Node = &v1pb.Node_HighlightNode{HighlightNode: &v1pb.HighlightNode{Content: n.Content}}
|
||||
case *ast.Subscript:
|
||||
node.Node = &v1pb.Node_SubscriptNode{SubscriptNode: &v1pb.SubscriptNode{Content: n.Content}}
|
||||
case *ast.Superscript:
|
||||
node.Node = &v1pb.Node_SuperscriptNode{SuperscriptNode: &v1pb.SuperscriptNode{Content: n.Content}}
|
||||
case *ast.ReferencedContent:
|
||||
node.Node = &v1pb.Node_ReferencedContentNode{ReferencedContentNode: &v1pb.ReferencedContentNode{ResourceName: n.ResourceName, Params: n.Params}}
|
||||
case *ast.Spoiler:
|
||||
node.Node = &v1pb.Node_SpoilerNode{SpoilerNode: &v1pb.SpoilerNode{Content: n.Content}}
|
||||
default:
|
||||
node.Node = &v1pb.Node_TextNode{TextNode: &v1pb.TextNode{}}
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
func convertFromASTNodes(rawNodes []ast.Node) []*v1pb.Node {
|
||||
nodes := []*v1pb.Node{}
|
||||
for _, rawNode := range rawNodes {
|
||||
node := convertFromASTNode(rawNode)
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
func convertTableFromASTNode(node *ast.Table) *v1pb.TableNode {
|
||||
table := &v1pb.TableNode{
|
||||
Header: node.Header,
|
||||
Delimiter: node.Delimiter,
|
||||
}
|
||||
for _, row := range node.Rows {
|
||||
table.Rows = append(table.Rows, &v1pb.TableNode_Row{Cells: row})
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func convertToASTNode(node *v1pb.Node) ast.Node {
|
||||
switch n := node.Node.(type) {
|
||||
case *v1pb.Node_LineBreakNode:
|
||||
return &ast.LineBreak{}
|
||||
case *v1pb.Node_ParagraphNode:
|
||||
children := convertToASTNodes(n.ParagraphNode.Children)
|
||||
return &ast.Paragraph{Children: children}
|
||||
case *v1pb.Node_CodeBlockNode:
|
||||
return &ast.CodeBlock{Language: n.CodeBlockNode.Language, Content: n.CodeBlockNode.Content}
|
||||
case *v1pb.Node_HeadingNode:
|
||||
children := convertToASTNodes(n.HeadingNode.Children)
|
||||
return &ast.Heading{Level: int(n.HeadingNode.Level), Children: children}
|
||||
case *v1pb.Node_HorizontalRuleNode:
|
||||
return &ast.HorizontalRule{Symbol: n.HorizontalRuleNode.Symbol}
|
||||
case *v1pb.Node_BlockquoteNode:
|
||||
children := convertToASTNodes(n.BlockquoteNode.Children)
|
||||
return &ast.Blockquote{Children: children}
|
||||
case *v1pb.Node_OrderedListNode:
|
||||
children := convertToASTNodes(n.OrderedListNode.Children)
|
||||
return &ast.OrderedList{Number: n.OrderedListNode.Number, Indent: int(n.OrderedListNode.Indent), Children: children}
|
||||
case *v1pb.Node_UnorderedListNode:
|
||||
children := convertToASTNodes(n.UnorderedListNode.Children)
|
||||
return &ast.UnorderedList{Symbol: n.UnorderedListNode.Symbol, Indent: int(n.UnorderedListNode.Indent), Children: children}
|
||||
case *v1pb.Node_TaskListNode:
|
||||
children := convertToASTNodes(n.TaskListNode.Children)
|
||||
return &ast.TaskList{Symbol: n.TaskListNode.Symbol, Indent: int(n.TaskListNode.Indent), Complete: n.TaskListNode.Complete, Children: children}
|
||||
case *v1pb.Node_MathBlockNode:
|
||||
return &ast.MathBlock{Content: n.MathBlockNode.Content}
|
||||
case *v1pb.Node_TableNode:
|
||||
return convertTableToASTNode(n.TableNode)
|
||||
case *v1pb.Node_EmbeddedContentNode:
|
||||
return &ast.EmbeddedContent{ResourceName: n.EmbeddedContentNode.ResourceName, Params: n.EmbeddedContentNode.Params}
|
||||
case *v1pb.Node_TextNode:
|
||||
return &ast.Text{Content: n.TextNode.Content}
|
||||
case *v1pb.Node_BoldNode:
|
||||
children := convertToASTNodes(n.BoldNode.Children)
|
||||
return &ast.Bold{Symbol: n.BoldNode.Symbol, Children: children}
|
||||
case *v1pb.Node_ItalicNode:
|
||||
return &ast.Italic{Symbol: n.ItalicNode.Symbol, Content: n.ItalicNode.Content}
|
||||
case *v1pb.Node_BoldItalicNode:
|
||||
return &ast.BoldItalic{Symbol: n.BoldItalicNode.Symbol, Content: n.BoldItalicNode.Content}
|
||||
case *v1pb.Node_CodeNode:
|
||||
return &ast.Code{Content: n.CodeNode.Content}
|
||||
case *v1pb.Node_ImageNode:
|
||||
return &ast.Image{AltText: n.ImageNode.AltText, URL: n.ImageNode.Url}
|
||||
case *v1pb.Node_LinkNode:
|
||||
return &ast.Link{Text: n.LinkNode.Text, URL: n.LinkNode.Url}
|
||||
case *v1pb.Node_AutoLinkNode:
|
||||
return &ast.AutoLink{URL: n.AutoLinkNode.Url, IsRawText: n.AutoLinkNode.IsRawText}
|
||||
case *v1pb.Node_TagNode:
|
||||
return &ast.Tag{Content: n.TagNode.Content}
|
||||
case *v1pb.Node_StrikethroughNode:
|
||||
return &ast.Strikethrough{Content: n.StrikethroughNode.Content}
|
||||
case *v1pb.Node_EscapingCharacterNode:
|
||||
return &ast.EscapingCharacter{Symbol: n.EscapingCharacterNode.Symbol}
|
||||
case *v1pb.Node_MathNode:
|
||||
return &ast.Math{Content: n.MathNode.Content}
|
||||
case *v1pb.Node_HighlightNode:
|
||||
return &ast.Highlight{Content: n.HighlightNode.Content}
|
||||
case *v1pb.Node_SubscriptNode:
|
||||
return &ast.Subscript{Content: n.SubscriptNode.Content}
|
||||
case *v1pb.Node_SuperscriptNode:
|
||||
return &ast.Superscript{Content: n.SuperscriptNode.Content}
|
||||
case *v1pb.Node_ReferencedContentNode:
|
||||
return &ast.ReferencedContent{ResourceName: n.ReferencedContentNode.ResourceName, Params: n.ReferencedContentNode.Params}
|
||||
case *v1pb.Node_SpoilerNode:
|
||||
return &ast.Spoiler{Content: n.SpoilerNode.Content}
|
||||
default:
|
||||
return &ast.Text{}
|
||||
}
|
||||
}
|
||||
|
||||
func convertToASTNodes(nodes []*v1pb.Node) []ast.Node {
|
||||
rawNodes := []ast.Node{}
|
||||
for _, node := range nodes {
|
||||
rawNode := convertToASTNode(node)
|
||||
rawNodes = append(rawNodes, rawNode)
|
||||
}
|
||||
return rawNodes
|
||||
}
|
||||
|
||||
func convertTableToASTNode(node *v1pb.TableNode) *ast.Table {
|
||||
table := &ast.Table{
|
||||
Header: node.Header,
|
||||
Delimiter: node.Delimiter,
|
||||
}
|
||||
for _, row := range node.Rows {
|
||||
table.Rows = append(table.Rows, row.Cells)
|
||||
}
|
||||
return table
|
||||
}
|
114
server/router/api/v1/memo_relation_service.go
Normal file
114
server/router/api/v1/memo_relation_service.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMemoRelationsRequest) (*emptypb.Empty, error) {
|
||||
id, err := ExtractMemoIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
referenceType := store.MemoRelationReference
|
||||
// Delete all reference relations first.
|
||||
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
|
||||
MemoID: &id,
|
||||
Type: &referenceType,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo relation")
|
||||
}
|
||||
|
||||
for _, relation := range request.Relations {
|
||||
// Ignore reflexive relations.
|
||||
if request.Name == relation.RelatedMemo {
|
||||
continue
|
||||
}
|
||||
// Ignore comment relations as there's no need to update a comment's relation.
|
||||
// Inserting/Deleting a comment is handled elsewhere.
|
||||
if relation.Type == v1pb.MemoRelation_COMMENT {
|
||||
continue
|
||||
}
|
||||
relatedMemoID, err := ExtractMemoIDFromName(relation.RelatedMemo)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid related memo name: %v", err)
|
||||
}
|
||||
if _, err := s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: id,
|
||||
RelatedMemoID: relatedMemoID,
|
||||
Type: convertMemoRelationTypeToStore(relation.Type),
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert memo relation")
|
||||
}
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListMemoRelations(ctx context.Context, request *v1pb.ListMemoRelationsRequest) (*v1pb.ListMemoRelationsResponse, error) {
|
||||
id, err := ExtractMemoIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
relationList := []*v1pb.MemoRelation{}
|
||||
tempList, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
MemoID: &id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, relation := range tempList {
|
||||
relationList = append(relationList, convertMemoRelationFromStore(relation))
|
||||
}
|
||||
tempList, err = s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
RelatedMemoID: &id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, relation := range tempList {
|
||||
relationList = append(relationList, convertMemoRelationFromStore(relation))
|
||||
}
|
||||
|
||||
response := &v1pb.ListMemoRelationsResponse{
|
||||
Relations: relationList,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func convertMemoRelationFromStore(memoRelation *store.MemoRelation) *v1pb.MemoRelation {
|
||||
return &v1pb.MemoRelation{
|
||||
Memo: fmt.Sprintf("%s%d", MemoNamePrefix, memoRelation.MemoID),
|
||||
RelatedMemo: fmt.Sprintf("%s%d", MemoNamePrefix, memoRelation.RelatedMemoID),
|
||||
Type: convertMemoRelationTypeFromStore(memoRelation.Type),
|
||||
}
|
||||
}
|
||||
|
||||
func convertMemoRelationTypeFromStore(relationType store.MemoRelationType) v1pb.MemoRelation_Type {
|
||||
switch relationType {
|
||||
case store.MemoRelationReference:
|
||||
return v1pb.MemoRelation_REFERENCE
|
||||
case store.MemoRelationComment:
|
||||
return v1pb.MemoRelation_COMMENT
|
||||
default:
|
||||
return v1pb.MemoRelation_TYPE_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func convertMemoRelationTypeToStore(relationType v1pb.MemoRelation_Type) store.MemoRelationType {
|
||||
switch relationType {
|
||||
case v1pb.MemoRelation_REFERENCE:
|
||||
return store.MemoRelationReference
|
||||
case v1pb.MemoRelation_COMMENT:
|
||||
return store.MemoRelationComment
|
||||
default:
|
||||
return store.MemoRelationReference
|
||||
}
|
||||
}
|
86
server/router/api/v1/memo_resource_service.go
Normal file
86
server/router/api/v1/memo_resource_service.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) SetMemoResources(ctx context.Context, request *v1pb.SetMemoResourcesRequest) (*emptypb.Empty, error) {
|
||||
memoID, err := ExtractMemoIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
resources, err := s.Store.ListResources(ctx, &store.FindResource{
|
||||
MemoID: &memoID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list resources")
|
||||
}
|
||||
|
||||
// Delete resources that are not in the request.
|
||||
for _, resource := range resources {
|
||||
found := false
|
||||
for _, requestResource := range request.Resources {
|
||||
if resource.UID == requestResource.Uid {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if err = s.Store.DeleteResource(ctx, &store.DeleteResource{
|
||||
ID: int32(resource.ID),
|
||||
MemoID: &memoID,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete resource")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slices.Reverse(request.Resources)
|
||||
// Update resources' memo_id in the request.
|
||||
for index, resource := range request.Resources {
|
||||
id, err := ExtractResourceIDFromName(resource.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid resource name: %v", err)
|
||||
}
|
||||
updatedTs := time.Now().Unix() + int64(index)
|
||||
if _, err := s.Store.UpdateResource(ctx, &store.UpdateResource{
|
||||
ID: id,
|
||||
MemoID: &memoID,
|
||||
UpdatedTs: &updatedTs,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update resource: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListMemoResources(ctx context.Context, request *v1pb.ListMemoResourcesRequest) (*v1pb.ListMemoResourcesResponse, error) {
|
||||
id, err := ExtractMemoIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
resources, err := s.Store.ListResources(ctx, &store.FindResource{
|
||||
MemoID: &id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list resources")
|
||||
}
|
||||
|
||||
response := &v1pb.ListMemoResourcesResponse{
|
||||
Resources: []*v1pb.Resource{},
|
||||
}
|
||||
for _, resource := range resources {
|
||||
response.Resources = append(response.Resources, s.convertResourceFromStore(ctx, resource))
|
||||
}
|
||||
return response, nil
|
||||
}
|
874
server/router/api/v1/memo_service.go
Normal file
874
server/router/api/v1/memo_service.go
Normal file
@@ -0,0 +1,874 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/yourselfhosted/gomark/parser"
|
||||
"github.com/yourselfhosted/gomark/parser/tokenizer"
|
||||
expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
"github.com/usememos/memos/plugin/webhook"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultPageSize = 10
|
||||
MaxContentLength = 8 * 1024
|
||||
ChunkSize = 64 * 1024 // 64 KiB
|
||||
)
|
||||
|
||||
func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoRequest) (*v1pb.Memo, error) {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if len(request.Content) > MaxContentLength {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "content too long")
|
||||
}
|
||||
|
||||
create := &store.Memo{
|
||||
UID: shortuuid.New(),
|
||||
CreatorID: user.ID,
|
||||
Content: request.Content,
|
||||
Visibility: convertVisibilityToStore(request.Visibility),
|
||||
}
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
||||
}
|
||||
if workspaceMemoRelatedSetting.DisallowPublicVisible && create.Visibility == store.Public {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
|
||||
}
|
||||
|
||||
memo, err := s.Store.CreateMemo(ctx, create)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
// Try to dispatch webhook when memo is created.
|
||||
if err := s.DispatchMemoCreatedWebhook(ctx, memoMessage); err != nil {
|
||||
slog.Warn("Failed to dispatch memo created webhook", err)
|
||||
}
|
||||
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosRequest) (*v1pb.ListMemosResponse, error) {
|
||||
memoFind := &store.FindMemo{
|
||||
// Exclude comments by default.
|
||||
ExcludeComments: true,
|
||||
}
|
||||
if err := s.buildMemoFindWithFilter(ctx, memoFind, request.Filter); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to build find memos with filter")
|
||||
}
|
||||
|
||||
var limit, offset int
|
||||
if request.PageToken != "" {
|
||||
var pageToken v1pb.PageToken
|
||||
if err := unmarshalPageToken(request.PageToken, &pageToken); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid page token: %v", err)
|
||||
}
|
||||
limit = int(pageToken.Limit)
|
||||
offset = int(pageToken.Offset)
|
||||
} else {
|
||||
limit = int(request.PageSize)
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = DefaultPageSize
|
||||
}
|
||||
limitPlusOne := limit + 1
|
||||
memoFind.Limit = &limitPlusOne
|
||||
memoFind.Offset = &offset
|
||||
memos, err := s.Store.ListMemos(ctx, memoFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
|
||||
}
|
||||
|
||||
memoMessages := []*v1pb.Memo{}
|
||||
nextPageToken := ""
|
||||
if len(memos) == limitPlusOne {
|
||||
memos = memos[:limit]
|
||||
nextPageToken, err = getPageToken(limit, offset+limit)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get next page token, error: %v", err)
|
||||
}
|
||||
}
|
||||
for _, memo := range memos {
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
memoMessages = append(memoMessages, memoMessage)
|
||||
}
|
||||
|
||||
response := &v1pb.ListMemosResponse{
|
||||
Memos: memoMessages,
|
||||
NextPageToken: nextPageToken,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) SearchMemos(ctx context.Context, request *v1pb.SearchMemosRequest) (*v1pb.SearchMemosResponse, error) {
|
||||
defaultSearchLimit := 10
|
||||
memoFind := &store.FindMemo{
|
||||
// Exclude comments by default.
|
||||
ExcludeComments: true,
|
||||
Limit: &defaultSearchLimit,
|
||||
}
|
||||
err := s.buildMemoFindWithFilter(ctx, memoFind, request.Filter)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to build find memos with filter")
|
||||
}
|
||||
|
||||
memos, err := s.Store.ListMemos(ctx, memoFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to search memos")
|
||||
}
|
||||
|
||||
memoMessages := []*v1pb.Memo{}
|
||||
for _, memo := range memos {
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
memoMessages = append(memoMessages, memoMessage)
|
||||
}
|
||||
|
||||
response := &v1pb.SearchMemosResponse{
|
||||
Memos: memoMessages,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest) (*v1pb.Memo, error) {
|
||||
id, err := ExtractMemoIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: &id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
if memo.Visibility != store.Public {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if memo.Visibility == store.Private && memo.CreatorID != user.ID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoRequest) (*v1pb.Memo, error) {
|
||||
id, err := ExtractMemoIDFromName(request.Memo.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
|
||||
}
|
||||
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
|
||||
user, _ := getCurrentUser(ctx, s.Store)
|
||||
if memo.CreatorID != user.ID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
currentTs := time.Now().Unix()
|
||||
update := &store.UpdateMemo{
|
||||
ID: id,
|
||||
UpdatedTs: ¤tTs,
|
||||
}
|
||||
for _, path := range request.UpdateMask.Paths {
|
||||
if path == "content" {
|
||||
update.Content = &request.Memo.Content
|
||||
} else if path == "uid" {
|
||||
update.UID = &request.Memo.Name
|
||||
if !util.UIDMatcher.MatchString(*update.UID) {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid resource name")
|
||||
}
|
||||
} else if path == "visibility" {
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
||||
}
|
||||
visibility := convertVisibilityToStore(request.Memo.Visibility)
|
||||
if workspaceMemoRelatedSetting.DisallowPublicVisible && visibility == store.Public {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "disable public memos system setting is enabled")
|
||||
}
|
||||
update.Visibility = &visibility
|
||||
} else if path == "row_status" {
|
||||
rowStatus := convertRowStatusToStore(request.Memo.RowStatus)
|
||||
update.RowStatus = &rowStatus
|
||||
} else if path == "created_ts" {
|
||||
createdTs := request.Memo.CreateTime.AsTime().Unix()
|
||||
update.CreatedTs = &createdTs
|
||||
} else if path == "pinned" {
|
||||
if _, err := s.Store.UpsertMemoOrganizer(ctx, &store.MemoOrganizer{
|
||||
MemoID: id,
|
||||
UserID: user.ID,
|
||||
Pinned: request.Memo.Pinned,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert memo organizer")
|
||||
}
|
||||
}
|
||||
}
|
||||
if update.Content != nil && len(*update.Content) > MaxContentLength {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "content too long")
|
||||
}
|
||||
|
||||
if err = s.Store.UpdateMemo(ctx, update); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update memo")
|
||||
}
|
||||
|
||||
memo, err = s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: &id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get memo")
|
||||
}
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
// Try to dispatch webhook when memo is updated.
|
||||
if err := s.DispatchMemoUpdatedWebhook(ctx, memoMessage); err != nil {
|
||||
slog.Warn("Failed to dispatch memo updated webhook", err)
|
||||
}
|
||||
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteMemo(ctx context.Context, request *v1pb.DeleteMemoRequest) (*emptypb.Empty, error) {
|
||||
id, err := ExtractMemoIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: &id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "memo not found")
|
||||
}
|
||||
|
||||
user, _ := getCurrentUser(ctx, s.Store)
|
||||
if memo.CreatorID != user.ID {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
if memoMessage, err := s.convertMemoFromStore(ctx, memo); err == nil {
|
||||
// Try to dispatch webhook when memo is deleted.
|
||||
if err := s.DispatchMemoDeletedWebhook(ctx, memoMessage); err != nil {
|
||||
slog.Warn("Failed to dispatch memo deleted webhook", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err = s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: id}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete memo")
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.CreateMemoCommentRequest) (*v1pb.Memo, error) {
|
||||
id, err := ExtractMemoIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
relatedMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{ID: &id})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
|
||||
// Create the comment memo first.
|
||||
memo, err := s.CreateMemo(ctx, request.Comment)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create memo")
|
||||
}
|
||||
|
||||
// Build the relation between the comment memo and the original memo.
|
||||
memoID, err := ExtractMemoIDFromName(memo.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
_, err = s.Store.UpsertMemoRelation(ctx, &store.MemoRelation{
|
||||
MemoID: memoID,
|
||||
RelatedMemoID: relatedMemo.ID,
|
||||
Type: store.MemoRelationComment,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create memo relation")
|
||||
}
|
||||
creatorID, err := ExtractUserIDFromName(memo.Creator)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo creator")
|
||||
}
|
||||
if memo.Visibility != v1pb.Visibility_PRIVATE && creatorID != relatedMemo.CreatorID {
|
||||
activity, err := s.Store.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: creatorID,
|
||||
Type: store.ActivityTypeMemoComment,
|
||||
Level: store.ActivityLevelInfo,
|
||||
Payload: &storepb.ActivityPayload{
|
||||
MemoComment: &storepb.ActivityMemoCommentPayload{
|
||||
MemoId: memoID,
|
||||
RelatedMemoId: relatedMemo.ID,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create activity")
|
||||
}
|
||||
if _, err := s.Store.CreateInbox(ctx, &store.Inbox{
|
||||
SenderID: creatorID,
|
||||
ReceiverID: relatedMemo.CreatorID,
|
||||
Status: store.UNREAD,
|
||||
Message: &storepb.InboxMessage{
|
||||
Type: storepb.InboxMessage_TYPE_MEMO_COMMENT,
|
||||
ActivityId: &activity.ID,
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create inbox")
|
||||
}
|
||||
}
|
||||
|
||||
return memo, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListMemoCommentsRequest) (*v1pb.ListMemoCommentsResponse, error) {
|
||||
id, err := ExtractMemoIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
|
||||
}
|
||||
memoRelationComment := store.MemoRelationComment
|
||||
memoRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
|
||||
RelatedMemoID: &id,
|
||||
Type: &memoRelationComment,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memo relations")
|
||||
}
|
||||
|
||||
var memos []*v1pb.Memo
|
||||
for _, memoRelation := range memoRelations {
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: &memoRelation.MemoID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get memo")
|
||||
}
|
||||
if memo != nil {
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
memos = append(memos, memoMessage)
|
||||
}
|
||||
}
|
||||
|
||||
response := &v1pb.ListMemoCommentsResponse{
|
||||
Memos: memos,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetUserMemosStats(ctx context.Context, request *v1pb.GetUserMemosStatsRequest) (*v1pb.GetUserMemosStatsResponse, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid user name")
|
||||
}
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
|
||||
normalRowStatus := store.Normal
|
||||
memoFind := &store.FindMemo{
|
||||
CreatorID: &user.ID,
|
||||
RowStatus: &normalRowStatus,
|
||||
ExcludeComments: true,
|
||||
ExcludeContent: true,
|
||||
}
|
||||
if err := s.buildMemoFindWithFilter(ctx, memoFind, request.Filter); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to build find memos with filter")
|
||||
}
|
||||
|
||||
memos, err := s.Store.ListMemos(ctx, memoFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
|
||||
}
|
||||
|
||||
location, err := time.LoadLocation(request.Timezone)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "invalid timezone location")
|
||||
}
|
||||
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
||||
}
|
||||
stats := make(map[string]int32)
|
||||
for _, memo := range memos {
|
||||
displayTs := memo.CreatedTs
|
||||
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
displayTs = memo.UpdatedTs
|
||||
}
|
||||
stats[time.Unix(displayTs, 0).In(location).Format("2006-01-02")]++
|
||||
}
|
||||
|
||||
response := &v1pb.GetUserMemosStatsResponse{
|
||||
Stats: stats,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ExportMemos(ctx context.Context, request *v1pb.ExportMemosRequest) (*v1pb.ExportMemosResponse, error) {
|
||||
normalRowStatus := store.Normal
|
||||
memoFind := &store.FindMemo{
|
||||
RowStatus: &normalRowStatus,
|
||||
// Exclude comments by default.
|
||||
ExcludeComments: true,
|
||||
}
|
||||
if err := s.buildMemoFindWithFilter(ctx, memoFind, request.Filter); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to build find memos with filter: %v", err)
|
||||
}
|
||||
|
||||
memos, err := s.Store.ListMemos(ctx, memoFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
writer := zip.NewWriter(buf)
|
||||
for _, memo := range memos {
|
||||
memoMessage, err := s.convertMemoFromStore(ctx, memo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert memo")
|
||||
}
|
||||
file, err := writer.Create(time.Unix(memo.CreatedTs, 0).Format(time.RFC3339) + ".md")
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "Failed to create memo file")
|
||||
}
|
||||
_, err = file.Write([]byte(memoMessage.Content))
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "Failed to write to memo file")
|
||||
}
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "Failed to close zip file writer")
|
||||
}
|
||||
|
||||
return &v1pb.ExportMemosResponse{
|
||||
Content: buf.Bytes(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo) (*v1pb.Memo, error) {
|
||||
displayTs := memo.CreatedTs
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get workspace memo related setting")
|
||||
}
|
||||
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
displayTs = memo.UpdatedTs
|
||||
}
|
||||
|
||||
creator, err := s.Store.GetUser(ctx, &store.FindUser{ID: &memo.CreatorID})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get creator")
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("%s%d", MemoNamePrefix, memo.ID)
|
||||
listMemoRelationsResponse, err := s.ListMemoRelations(ctx, &v1pb.ListMemoRelationsRequest{Name: name})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to list memo relations")
|
||||
}
|
||||
|
||||
listMemoResourcesResponse, err := s.ListMemoResources(ctx, &v1pb.ListMemoResourcesRequest{Name: name})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to list memo resources")
|
||||
}
|
||||
|
||||
listMemoReactionsResponse, err := s.ListMemoReactions(ctx, &v1pb.ListMemoReactionsRequest{Name: name})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to list memo reactions")
|
||||
}
|
||||
|
||||
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse content")
|
||||
}
|
||||
|
||||
return &v1pb.Memo{
|
||||
Name: name,
|
||||
Uid: memo.UID,
|
||||
RowStatus: convertRowStatusFromStore(memo.RowStatus),
|
||||
Creator: fmt.Sprintf("%s%d", UserNamePrefix, creator.ID),
|
||||
CreateTime: timestamppb.New(time.Unix(memo.CreatedTs, 0)),
|
||||
UpdateTime: timestamppb.New(time.Unix(memo.UpdatedTs, 0)),
|
||||
DisplayTime: timestamppb.New(time.Unix(displayTs, 0)),
|
||||
Content: memo.Content,
|
||||
Nodes: convertFromASTNodes(nodes),
|
||||
Visibility: convertVisibilityFromStore(memo.Visibility),
|
||||
Pinned: memo.Pinned,
|
||||
ParentId: memo.ParentID,
|
||||
Relations: listMemoRelationsResponse.Relations,
|
||||
Resources: listMemoResourcesResponse.Resources,
|
||||
Reactions: listMemoReactionsResponse.Reactions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convertVisibilityFromStore(visibility store.Visibility) v1pb.Visibility {
|
||||
switch visibility {
|
||||
case store.Private:
|
||||
return v1pb.Visibility_PRIVATE
|
||||
case store.Protected:
|
||||
return v1pb.Visibility_PROTECTED
|
||||
case store.Public:
|
||||
return v1pb.Visibility_PUBLIC
|
||||
default:
|
||||
return v1pb.Visibility_VISIBILITY_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func convertVisibilityToStore(visibility v1pb.Visibility) store.Visibility {
|
||||
switch visibility {
|
||||
case v1pb.Visibility_PRIVATE:
|
||||
return store.Private
|
||||
case v1pb.Visibility_PROTECTED:
|
||||
return store.Protected
|
||||
case v1pb.Visibility_PUBLIC:
|
||||
return store.Public
|
||||
default:
|
||||
return store.Private
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIV1Service) buildMemoFindWithFilter(ctx context.Context, find *store.FindMemo, filter string) error {
|
||||
user, _ := getCurrentUser(ctx, s.Store)
|
||||
if find == nil {
|
||||
find = &store.FindMemo{}
|
||||
}
|
||||
if filter != "" {
|
||||
filter, err := parseSearchMemosFilter(filter)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
||||
}
|
||||
if len(filter.ContentSearch) > 0 {
|
||||
find.ContentSearch = filter.ContentSearch
|
||||
}
|
||||
if len(filter.Visibilities) > 0 {
|
||||
find.VisibilityList = filter.Visibilities
|
||||
}
|
||||
if filter.OrderByPinned {
|
||||
find.OrderByPinned = filter.OrderByPinned
|
||||
}
|
||||
if filter.DisplayTimeAfter != nil {
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
||||
}
|
||||
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
find.UpdatedTsAfter = filter.DisplayTimeAfter
|
||||
} else {
|
||||
find.CreatedTsAfter = filter.DisplayTimeAfter
|
||||
}
|
||||
}
|
||||
if filter.DisplayTimeBefore != nil {
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
||||
}
|
||||
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
find.UpdatedTsBefore = filter.DisplayTimeBefore
|
||||
} else {
|
||||
find.CreatedTsBefore = filter.DisplayTimeBefore
|
||||
}
|
||||
}
|
||||
if filter.Creator != nil {
|
||||
userID, err := ExtractUserIDFromName(*filter.Creator)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "invalid user name")
|
||||
}
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
find.CreatorID = &user.ID
|
||||
}
|
||||
if filter.UID != nil {
|
||||
find.UID = filter.UID
|
||||
}
|
||||
if filter.RowStatus != nil {
|
||||
find.RowStatus = filter.RowStatus
|
||||
}
|
||||
if filter.Random {
|
||||
find.Random = filter.Random
|
||||
}
|
||||
if filter.Limit != nil {
|
||||
find.Limit = filter.Limit
|
||||
}
|
||||
if filter.IncludeComments {
|
||||
find.ExcludeComments = false
|
||||
}
|
||||
}
|
||||
|
||||
// If the user is not authenticated, only public memos are visible.
|
||||
if user == nil {
|
||||
if filter == "" {
|
||||
// If no filter is provided, return an error.
|
||||
return status.Errorf(codes.InvalidArgument, "filter is required")
|
||||
}
|
||||
|
||||
find.VisibilityList = []store.Visibility{store.Public}
|
||||
} else if find.CreatorID != nil && *find.CreatorID != user.ID {
|
||||
find.VisibilityList = []store.Visibility{store.Public, store.Protected}
|
||||
}
|
||||
|
||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
||||
}
|
||||
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||
find.OrderByUpdatedTs = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SearchMemosFilterCELAttributes are the CEL attributes.
|
||||
var SearchMemosFilterCELAttributes = []cel.EnvOption{
|
||||
cel.Variable("content_search", cel.ListType(cel.StringType)),
|
||||
cel.Variable("visibilities", cel.ListType(cel.StringType)),
|
||||
cel.Variable("order_by_pinned", cel.BoolType),
|
||||
cel.Variable("display_time_before", cel.IntType),
|
||||
cel.Variable("display_time_after", cel.IntType),
|
||||
cel.Variable("creator", cel.StringType),
|
||||
cel.Variable("uid", cel.StringType),
|
||||
cel.Variable("row_status", cel.StringType),
|
||||
cel.Variable("random", cel.BoolType),
|
||||
cel.Variable("limit", cel.IntType),
|
||||
cel.Variable("include_comments", cel.BoolType),
|
||||
}
|
||||
|
||||
type SearchMemosFilter struct {
|
||||
ContentSearch []string
|
||||
Visibilities []store.Visibility
|
||||
OrderByPinned bool
|
||||
DisplayTimeBefore *int64
|
||||
DisplayTimeAfter *int64
|
||||
Creator *string
|
||||
UID *string
|
||||
RowStatus *store.RowStatus
|
||||
Random bool
|
||||
Limit *int
|
||||
IncludeComments bool
|
||||
}
|
||||
|
||||
func parseSearchMemosFilter(expression string) (*SearchMemosFilter, error) {
|
||||
e, err := cel.NewEnv(SearchMemosFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ast, issues := e.Compile(expression)
|
||||
if issues != nil {
|
||||
return nil, errors.Errorf("found issue %v", issues)
|
||||
}
|
||||
filter := &SearchMemosFilter{}
|
||||
expr, err := cel.AstToParsedExpr(ast)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
callExpr := expr.GetExpr().GetCallExpr()
|
||||
findSearchMemosField(callExpr, filter)
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
func findSearchMemosField(callExpr *expr.Expr_Call, filter *SearchMemosFilter) {
|
||||
if len(callExpr.Args) == 2 {
|
||||
idExpr := callExpr.Args[0].GetIdentExpr()
|
||||
if idExpr != nil {
|
||||
if idExpr.Name == "content_search" {
|
||||
contentSearch := []string{}
|
||||
for _, expr := range callExpr.Args[1].GetListExpr().GetElements() {
|
||||
value := expr.GetConstExpr().GetStringValue()
|
||||
contentSearch = append(contentSearch, value)
|
||||
}
|
||||
filter.ContentSearch = contentSearch
|
||||
} else if idExpr.Name == "visibilities" {
|
||||
visibilities := []store.Visibility{}
|
||||
for _, expr := range callExpr.Args[1].GetListExpr().GetElements() {
|
||||
value := expr.GetConstExpr().GetStringValue()
|
||||
visibilities = append(visibilities, store.Visibility(value))
|
||||
}
|
||||
filter.Visibilities = visibilities
|
||||
} else if idExpr.Name == "order_by_pinned" {
|
||||
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
|
||||
filter.OrderByPinned = value
|
||||
} else if idExpr.Name == "display_time_before" {
|
||||
displayTimeBefore := callExpr.Args[1].GetConstExpr().GetInt64Value()
|
||||
filter.DisplayTimeBefore = &displayTimeBefore
|
||||
} else if idExpr.Name == "display_time_after" {
|
||||
displayTimeAfter := callExpr.Args[1].GetConstExpr().GetInt64Value()
|
||||
filter.DisplayTimeAfter = &displayTimeAfter
|
||||
} else if idExpr.Name == "creator" {
|
||||
creator := callExpr.Args[1].GetConstExpr().GetStringValue()
|
||||
filter.Creator = &creator
|
||||
} else if idExpr.Name == "uid" {
|
||||
uid := callExpr.Args[1].GetConstExpr().GetStringValue()
|
||||
filter.UID = &uid
|
||||
} else if idExpr.Name == "row_status" {
|
||||
rowStatus := store.RowStatus(callExpr.Args[1].GetConstExpr().GetStringValue())
|
||||
filter.RowStatus = &rowStatus
|
||||
} else if idExpr.Name == "random" {
|
||||
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
|
||||
filter.Random = value
|
||||
} else if idExpr.Name == "limit" {
|
||||
limit := int(callExpr.Args[1].GetConstExpr().GetInt64Value())
|
||||
filter.Limit = &limit
|
||||
} else if idExpr.Name == "include_comments" {
|
||||
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
|
||||
filter.IncludeComments = value
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
for _, arg := range callExpr.Args {
|
||||
callExpr := arg.GetCallExpr()
|
||||
if callExpr != nil {
|
||||
findSearchMemosField(callExpr, filter)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DispatchMemoCreatedWebhook dispatches webhook when memo is created.
|
||||
func (s *APIV1Service) DispatchMemoCreatedWebhook(ctx context.Context, memo *v1pb.Memo) error {
|
||||
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.created")
|
||||
}
|
||||
|
||||
// DispatchMemoUpdatedWebhook dispatches webhook when memo is updated.
|
||||
func (s *APIV1Service) DispatchMemoUpdatedWebhook(ctx context.Context, memo *v1pb.Memo) error {
|
||||
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.updated")
|
||||
}
|
||||
|
||||
// DispatchMemoDeletedWebhook dispatches webhook when memo is deleted.
|
||||
func (s *APIV1Service) DispatchMemoDeletedWebhook(ctx context.Context, memo *v1pb.Memo) error {
|
||||
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.deleted")
|
||||
}
|
||||
|
||||
func (s *APIV1Service) dispatchMemoRelatedWebhook(ctx context.Context, memo *v1pb.Memo, activityType string) error {
|
||||
creatorID, err := ExtractUserIDFromName(memo.Creator)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.InvalidArgument, "invalid memo creator")
|
||||
}
|
||||
webhooks, err := s.Store.ListWebhooks(ctx, &store.FindWebhook{
|
||||
CreatorID: &creatorID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, hook := range webhooks {
|
||||
payload, err := convertMemoToWebhookPayload(memo)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to convert memo to webhook payload")
|
||||
}
|
||||
payload.ActivityType = activityType
|
||||
payload.URL = hook.URL
|
||||
if err := webhook.Post(*payload); err != nil {
|
||||
return errors.Wrap(err, "failed to post webhook")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertMemoToWebhookPayload(memo *v1pb.Memo) (*webhook.WebhookPayload, error) {
|
||||
creatorID, err := ExtractUserIDFromName(memo.Creator)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid memo creator")
|
||||
}
|
||||
id, err := ExtractMemoIDFromName(memo.Name)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid memo name")
|
||||
}
|
||||
return &webhook.WebhookPayload{
|
||||
CreatorID: creatorID,
|
||||
CreatedTs: time.Now().Unix(),
|
||||
Memo: &webhook.Memo{
|
||||
ID: id,
|
||||
CreatorID: creatorID,
|
||||
CreatedTs: memo.CreateTime.Seconds,
|
||||
UpdatedTs: memo.UpdateTime.Seconds,
|
||||
Content: memo.Content,
|
||||
Visibility: memo.Visibility.String(),
|
||||
Pinned: memo.Pinned,
|
||||
ResourceList: func() []*webhook.Resource {
|
||||
resources := []*webhook.Resource{}
|
||||
for _, resource := range memo.Resources {
|
||||
resources = append(resources, &webhook.Resource{
|
||||
UID: resource.Uid,
|
||||
Filename: resource.Filename,
|
||||
ExternalLink: resource.ExternalLink,
|
||||
Type: resource.Type,
|
||||
Size: resource.Size,
|
||||
})
|
||||
}
|
||||
return resources
|
||||
}(),
|
||||
},
|
||||
}, nil
|
||||
}
|
81
server/router/api/v1/reaction_service.go
Normal file
81
server/router/api/v1/reaction_service.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) ListMemoReactions(ctx context.Context, request *v1pb.ListMemoReactionsRequest) (*v1pb.ListMemoReactionsResponse, error) {
|
||||
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
|
||||
ContentID: &request.Name,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list reactions")
|
||||
}
|
||||
|
||||
response := &v1pb.ListMemoReactionsResponse{
|
||||
Reactions: []*v1pb.Reaction{},
|
||||
}
|
||||
for _, reaction := range reactions {
|
||||
reactionMessage, err := s.convertReactionFromStore(ctx, reaction)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert reaction")
|
||||
}
|
||||
response.Reactions = append(response.Reactions, reactionMessage)
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.UpsertMemoReactionRequest) (*v1pb.Reaction, error) {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user")
|
||||
}
|
||||
reaction, err := s.Store.UpsertReaction(ctx, &store.Reaction{
|
||||
CreatorID: user.ID,
|
||||
ContentID: request.Reaction.ContentId,
|
||||
ReactionType: storepb.ReactionType(request.Reaction.ReactionType),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert reaction")
|
||||
}
|
||||
|
||||
reactionMessage, err := s.convertReactionFromStore(ctx, reaction)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert reaction")
|
||||
}
|
||||
return reactionMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.DeleteMemoReactionRequest) (*emptypb.Empty, error) {
|
||||
if err := s.Store.DeleteReaction(ctx, &store.DeleteReaction{
|
||||
ID: request.ReactionId,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete reaction")
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) convertReactionFromStore(ctx context.Context, reaction *store.Reaction) (*v1pb.Reaction, error) {
|
||||
creator, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &reaction.CreatorID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v1pb.Reaction{
|
||||
Id: reaction.ID,
|
||||
Creator: fmt.Sprintf("%s%d", UserNamePrefix, creator.ID),
|
||||
ContentId: reaction.ContentID,
|
||||
ReactionType: v1pb.Reaction_Type(reaction.ReactionType),
|
||||
}, nil
|
||||
}
|
125
server/router/api/v1/resource_name.go
Normal file
125
server/router/api/v1/resource_name.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
WorkspaceSettingNamePrefix = "settings/"
|
||||
UserNamePrefix = "users/"
|
||||
MemoNamePrefix = "memos/"
|
||||
ResourceNamePrefix = "resources/"
|
||||
InboxNamePrefix = "inboxes/"
|
||||
StorageNamePrefix = "storages/"
|
||||
IdentityProviderNamePrefix = "identityProviders/"
|
||||
)
|
||||
|
||||
// GetNameParentTokens returns the tokens from a resource name.
|
||||
func GetNameParentTokens(name string, tokenPrefixes ...string) ([]string, error) {
|
||||
parts := strings.Split(name, "/")
|
||||
if len(parts) != 2*len(tokenPrefixes) {
|
||||
return nil, errors.Errorf("invalid request %q", name)
|
||||
}
|
||||
|
||||
var tokens []string
|
||||
for i, tokenPrefix := range tokenPrefixes {
|
||||
if fmt.Sprintf("%s/", parts[2*i]) != tokenPrefix {
|
||||
return nil, errors.Errorf("invalid prefix %q in request %q", tokenPrefix, name)
|
||||
}
|
||||
if parts[2*i+1] == "" {
|
||||
return nil, errors.Errorf("invalid request %q with empty prefix %q", name, tokenPrefix)
|
||||
}
|
||||
tokens = append(tokens, parts[2*i+1])
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func ExtractWorkspaceSettingKeyFromName(name string) (string, error) {
|
||||
tokens, err := GetNameParentTokens(name, WorkspaceSettingNamePrefix)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return tokens[0], nil
|
||||
}
|
||||
|
||||
// ExtractUserIDFromName returns the uid from a resource name.
|
||||
func ExtractUserIDFromName(name string) (int32, error) {
|
||||
tokens, err := GetNameParentTokens(name, UserNamePrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := util.ConvertStringToInt32(tokens[0])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid user ID %q", tokens[0])
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// ExtractMemoIDFromName returns the memo ID from a resource name.
|
||||
func ExtractMemoIDFromName(name string) (int32, error) {
|
||||
tokens, err := GetNameParentTokens(name, MemoNamePrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := util.ConvertStringToInt32(tokens[0])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid memo ID %q", tokens[0])
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// ExtractResourceIDFromName returns the resource ID from a resource name.
|
||||
func ExtractResourceIDFromName(name string) (int32, error) {
|
||||
tokens, err := GetNameParentTokens(name, ResourceNamePrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := util.ConvertStringToInt32(tokens[0])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid resource ID %q", tokens[0])
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// ExtractInboxIDFromName returns the inbox ID from a resource name.
|
||||
func ExtractInboxIDFromName(name string) (int32, error) {
|
||||
tokens, err := GetNameParentTokens(name, InboxNamePrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := util.ConvertStringToInt32(tokens[0])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid inbox ID %q", tokens[0])
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// ExtractStorageIDFromName returns the storage ID from a resource name.
|
||||
func ExtractStorageIDFromName(name string) (int32, error) {
|
||||
tokens, err := GetNameParentTokens(name, StorageNamePrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := util.ConvertStringToInt32(tokens[0])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid storage ID %q", tokens[0])
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func ExtractIdentityProviderIDFromName(name string) (int32, error) {
|
||||
tokens, err := GetNameParentTokens(name, IdentityProviderNamePrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := util.ConvertStringToInt32(tokens[0])
|
||||
if err != nil {
|
||||
return 0, errors.Errorf("invalid identity provider ID %q", tokens[0])
|
||||
}
|
||||
return id, nil
|
||||
}
|
466
server/router/api/v1/resource_service.go
Normal file
466
server/router/api/v1/resource_service.go
Normal file
@@ -0,0 +1,466 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/lithammer/shortuuid/v4"
|
||||
"github.com/pkg/errors"
|
||||
expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
"google.golang.org/genproto/googleapis/api/httpbody"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
"github.com/usememos/memos/plugin/storage/s3"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
const (
|
||||
// The upload memory buffer is 32 MiB.
|
||||
// It should be kept low, so RAM usage doesn't get out of control.
|
||||
// This is unrelated to maximum upload size limit, which is now set through system setting.
|
||||
MaxUploadBufferSizeBytes = 32 << 20
|
||||
MebiByte = 1024 * 1024
|
||||
)
|
||||
|
||||
func (s *APIV1Service) CreateResource(ctx context.Context, request *v1pb.CreateResourceRequest) (*v1pb.Resource, error) {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
create := &store.Resource{
|
||||
UID: shortuuid.New(),
|
||||
CreatorID: user.ID,
|
||||
Filename: request.Resource.Filename,
|
||||
Type: request.Resource.Type,
|
||||
}
|
||||
if request.Resource.ExternalLink != "" {
|
||||
// Only allow those external links scheme with http/https
|
||||
linkURL, err := url.Parse(request.Resource.ExternalLink)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid external link: %v", err)
|
||||
}
|
||||
if linkURL.Scheme != "http" && linkURL.Scheme != "https" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid external link scheme: %v", linkURL.Scheme)
|
||||
}
|
||||
create.ExternalLink = request.Resource.ExternalLink
|
||||
} else {
|
||||
workspaceStorageSetting, err := s.Store.GetWorkspaceStorageSetting(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace storage setting: %v", err)
|
||||
}
|
||||
size := binary.Size(request.Resource.Content)
|
||||
uploadSizeLimit := int(workspaceStorageSetting.UploadSizeLimitMb) * MebiByte
|
||||
if uploadSizeLimit == 0 {
|
||||
uploadSizeLimit = MaxUploadBufferSizeBytes
|
||||
}
|
||||
if size > uploadSizeLimit {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "file size exceeds the limit")
|
||||
}
|
||||
create.Size = int64(size)
|
||||
create.Blob = request.Resource.Content
|
||||
if err := SaveResourceBlob(ctx, s.Store, create); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to save resource blob: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if request.Resource.Memo != nil {
|
||||
memoID, err := ExtractMemoIDFromName(*request.Resource.Memo)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo id: %v", err)
|
||||
}
|
||||
create.MemoID = &memoID
|
||||
}
|
||||
resource, err := s.Store.CreateResource(ctx, create)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create resource: %v", err)
|
||||
}
|
||||
|
||||
return s.convertResourceFromStore(ctx, resource), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListResources(ctx context.Context, _ *v1pb.ListResourcesRequest) (*v1pb.ListResourcesResponse, error) {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
resources, err := s.Store.ListResources(ctx, &store.FindResource{
|
||||
CreatorID: &user.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list resources: %v", err)
|
||||
}
|
||||
|
||||
response := &v1pb.ListResourcesResponse{}
|
||||
for _, resource := range resources {
|
||||
response.Resources = append(response.Resources, s.convertResourceFromStore(ctx, resource))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) SearchResources(ctx context.Context, request *v1pb.SearchResourcesRequest) (*v1pb.SearchResourcesResponse, error) {
|
||||
if request.Filter == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "filter is empty")
|
||||
}
|
||||
filter, err := parseSearchResourcesFilter(request.Filter)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse filter: %v", err)
|
||||
}
|
||||
resourceFind := &store.FindResource{}
|
||||
if filter.UID != nil {
|
||||
resourceFind.UID = filter.UID
|
||||
}
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
resourceFind.CreatorID = &user.ID
|
||||
resources, err := s.Store.ListResources(ctx, resourceFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to search resources: %v", err)
|
||||
}
|
||||
|
||||
response := &v1pb.SearchResourcesResponse{}
|
||||
for _, resource := range resources {
|
||||
response.Resources = append(response.Resources, s.convertResourceFromStore(ctx, resource))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetResource(ctx context.Context, request *v1pb.GetResourceRequest) (*v1pb.Resource, error) {
|
||||
id, err := ExtractResourceIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid resource id: %v", err)
|
||||
}
|
||||
resource, err := s.Store.GetResource(ctx, &store.FindResource{
|
||||
ID: &id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get resource: %v", err)
|
||||
}
|
||||
if resource == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "resource not found")
|
||||
}
|
||||
|
||||
return s.convertResourceFromStore(ctx, resource), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetResourceBinary(ctx context.Context, request *v1pb.GetResourceBinaryRequest) (*httpbody.HttpBody, error) {
|
||||
resourceFind := &store.FindResource{
|
||||
GetBlob: true,
|
||||
}
|
||||
if request.Name != "" {
|
||||
id, err := ExtractResourceIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid resource id: %v", err)
|
||||
}
|
||||
resourceFind.ID = &id
|
||||
}
|
||||
if request.Uid != "" {
|
||||
resourceFind.UID = &request.Uid
|
||||
}
|
||||
resource, err := s.Store.GetResource(ctx, resourceFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get resource: %v", err)
|
||||
}
|
||||
if resource == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "resource not found")
|
||||
}
|
||||
// Check the related memo visibility.
|
||||
if resource.MemoID != nil {
|
||||
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: resource.MemoID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to find memo by ID: %v", resource.MemoID)
|
||||
}
|
||||
if memo != nil && memo.Visibility != store.Public {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if memo.Visibility == store.Private && user.ID != resource.CreatorID {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "unauthorized access")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
blob := resource.Blob
|
||||
if resource.InternalPath != "" {
|
||||
resourcePath := filepath.FromSlash(resource.InternalPath)
|
||||
if !filepath.IsAbs(resourcePath) {
|
||||
resourcePath = filepath.Join(s.Profile.Data, resourcePath)
|
||||
}
|
||||
|
||||
file, err := os.Open(resourcePath)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to open the file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
blob, err = io.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to read the file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
httpBody := &httpbody.HttpBody{
|
||||
ContentType: resource.Type,
|
||||
Data: blob,
|
||||
}
|
||||
return httpBody, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateResource(ctx context.Context, request *v1pb.UpdateResourceRequest) (*v1pb.Resource, error) {
|
||||
id, err := ExtractResourceIDFromName(request.Resource.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid resource id: %v", err)
|
||||
}
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update mask is required")
|
||||
}
|
||||
|
||||
currentTs := time.Now().Unix()
|
||||
update := &store.UpdateResource{
|
||||
ID: id,
|
||||
UpdatedTs: ¤tTs,
|
||||
}
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
if field == "filename" {
|
||||
update.Filename = &request.Resource.Filename
|
||||
} else if field == "memo" {
|
||||
if request.Resource.Memo == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "memo is required")
|
||||
}
|
||||
memoID, err := ExtractMemoIDFromName(*request.Resource.Memo)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid memo id: %v", err)
|
||||
}
|
||||
update.MemoID = &memoID
|
||||
}
|
||||
}
|
||||
|
||||
resource, err := s.Store.UpdateResource(ctx, update)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update resource: %v", err)
|
||||
}
|
||||
return s.convertResourceFromStore(ctx, resource), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteResource(ctx context.Context, request *v1pb.DeleteResourceRequest) (*emptypb.Empty, error) {
|
||||
id, err := ExtractResourceIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid resource id: %v", err)
|
||||
}
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
resource, err := s.Store.GetResource(ctx, &store.FindResource{
|
||||
ID: &id,
|
||||
CreatorID: &user.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to find resource: %v", err)
|
||||
}
|
||||
if resource == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "resource not found")
|
||||
}
|
||||
// Delete the resource from the database.
|
||||
if err := s.Store.DeleteResource(ctx, &store.DeleteResource{
|
||||
ID: resource.ID,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete resource: %v", err)
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) convertResourceFromStore(ctx context.Context, resource *store.Resource) *v1pb.Resource {
|
||||
resourceMessage := &v1pb.Resource{
|
||||
Name: fmt.Sprintf("%s%d", ResourceNamePrefix, resource.ID),
|
||||
Uid: resource.UID,
|
||||
CreateTime: timestamppb.New(time.Unix(resource.CreatedTs, 0)),
|
||||
Filename: resource.Filename,
|
||||
ExternalLink: resource.ExternalLink,
|
||||
Type: resource.Type,
|
||||
Size: resource.Size,
|
||||
}
|
||||
if resource.MemoID != nil {
|
||||
memo, _ := s.Store.GetMemo(ctx, &store.FindMemo{
|
||||
ID: resource.MemoID,
|
||||
})
|
||||
if memo != nil {
|
||||
memoName := fmt.Sprintf("%s%d", MemoNamePrefix, memo.ID)
|
||||
resourceMessage.Memo = &memoName
|
||||
}
|
||||
}
|
||||
|
||||
return resourceMessage
|
||||
}
|
||||
|
||||
// SaveResourceBlob save the blob of resource based on the storage config.
|
||||
func SaveResourceBlob(ctx context.Context, s *store.Store, create *store.Resource) error {
|
||||
workspaceStorageSetting, err := s.GetWorkspaceStorageSetting(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to find workspace storage setting")
|
||||
}
|
||||
|
||||
if workspaceStorageSetting.StorageType == storepb.WorkspaceStorageSetting_STORAGE_TYPE_LOCAL {
|
||||
filepathTemplate := "assets/{timestamp}_{filename}"
|
||||
if workspaceStorageSetting.FilepathTemplate != "" {
|
||||
filepathTemplate = workspaceStorageSetting.FilepathTemplate
|
||||
}
|
||||
|
||||
internalPath := filepathTemplate
|
||||
if !strings.Contains(internalPath, "{filename}") {
|
||||
internalPath = filepath.Join(internalPath, "{filename}")
|
||||
}
|
||||
internalPath = replacePathTemplate(internalPath, create.Filename)
|
||||
internalPath = filepath.ToSlash(internalPath)
|
||||
|
||||
// Ensure the directory exists.
|
||||
osPath := filepath.FromSlash(internalPath)
|
||||
if !filepath.IsAbs(osPath) {
|
||||
osPath = filepath.Join(s.Profile.Data, osPath)
|
||||
}
|
||||
dir := filepath.Dir(osPath)
|
||||
if err = os.MkdirAll(dir, os.ModePerm); err != nil {
|
||||
return errors.Wrap(err, "Failed to create directory")
|
||||
}
|
||||
dst, err := os.Create(osPath)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to create file")
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
// Write the blob to the file.
|
||||
if err := os.WriteFile(osPath, create.Blob, 0644); err != nil {
|
||||
return errors.Wrap(err, "Failed to write file")
|
||||
}
|
||||
create.InternalPath = internalPath
|
||||
create.Blob = nil
|
||||
} else if workspaceStorageSetting.StorageType == storepb.WorkspaceStorageSetting_STORAGE_TYPE_S3 {
|
||||
s3Config := workspaceStorageSetting.S3Config
|
||||
if s3Config == nil {
|
||||
return errors.Errorf("No actived external storage found")
|
||||
}
|
||||
s3Client, err := s3.NewClient(ctx, &s3.Config{
|
||||
AccessKeyID: s3Config.AccessKeyId,
|
||||
AcesssKeySecret: s3Config.AccessKeySecret,
|
||||
Endpoint: s3Config.Endpoint,
|
||||
Region: s3Config.Region,
|
||||
Bucket: s3Config.Bucket,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to create s3 client")
|
||||
}
|
||||
|
||||
filepathTemplate := workspaceStorageSetting.FilepathTemplate
|
||||
if !strings.Contains(filepathTemplate, "{filename}") {
|
||||
filepathTemplate = filepath.Join(filepathTemplate, "{filename}")
|
||||
}
|
||||
filepathTemplate = replacePathTemplate(filepathTemplate, create.Filename)
|
||||
r := bytes.NewReader(create.Blob)
|
||||
link, err := s3Client.UploadFile(ctx, filepathTemplate, create.Type, r)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to upload via s3 client")
|
||||
}
|
||||
|
||||
create.ExternalLink = link
|
||||
create.Blob = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var fileKeyPattern = regexp.MustCompile(`\{[a-z]{1,9}\}`)
|
||||
|
||||
func replacePathTemplate(path, filename string) string {
|
||||
t := time.Now()
|
||||
path = fileKeyPattern.ReplaceAllStringFunc(path, func(s string) string {
|
||||
switch s {
|
||||
case "{filename}":
|
||||
return filename
|
||||
case "{timestamp}":
|
||||
return fmt.Sprintf("%d", t.Unix())
|
||||
case "{year}":
|
||||
return fmt.Sprintf("%d", t.Year())
|
||||
case "{month}":
|
||||
return fmt.Sprintf("%02d", t.Month())
|
||||
case "{day}":
|
||||
return fmt.Sprintf("%02d", t.Day())
|
||||
case "{hour}":
|
||||
return fmt.Sprintf("%02d", t.Hour())
|
||||
case "{minute}":
|
||||
return fmt.Sprintf("%02d", t.Minute())
|
||||
case "{second}":
|
||||
return fmt.Sprintf("%02d", t.Second())
|
||||
case "{uuid}":
|
||||
return util.GenUUID()
|
||||
}
|
||||
return s
|
||||
})
|
||||
return path
|
||||
}
|
||||
|
||||
// SearchResourcesFilterCELAttributes are the CEL attributes for SearchResourcesFilter.
|
||||
var SearchResourcesFilterCELAttributes = []cel.EnvOption{
|
||||
cel.Variable("uid", cel.StringType),
|
||||
}
|
||||
|
||||
type SearchResourcesFilter struct {
|
||||
UID *string
|
||||
}
|
||||
|
||||
func parseSearchResourcesFilter(expression string) (*SearchResourcesFilter, error) {
|
||||
e, err := cel.NewEnv(SearchResourcesFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ast, issues := e.Compile(expression)
|
||||
if issues != nil {
|
||||
return nil, errors.Errorf("found issue %v", issues)
|
||||
}
|
||||
filter := &SearchResourcesFilter{}
|
||||
expr, err := cel.AstToParsedExpr(ast)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
callExpr := expr.GetExpr().GetCallExpr()
|
||||
findSearchResourcesField(callExpr, filter)
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
func findSearchResourcesField(callExpr *expr.Expr_Call, filter *SearchResourcesFilter) {
|
||||
if len(callExpr.Args) == 2 {
|
||||
idExpr := callExpr.Args[0].GetIdentExpr()
|
||||
if idExpr != nil {
|
||||
if idExpr.Name == "uid" {
|
||||
uid := callExpr.Args[1].GetConstExpr().GetStringValue()
|
||||
filter.UID = &uid
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
for _, arg := range callExpr.Args {
|
||||
callExpr := arg.GetCallExpr()
|
||||
if callExpr != nil {
|
||||
findSearchResourcesField(callExpr, filter)
|
||||
}
|
||||
}
|
||||
}
|
260
server/router/api/v1/tag_service.go
Normal file
260
server/router/api/v1/tag_service.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/yourselfhosted/gomark/ast"
|
||||
"github.com/yourselfhosted/gomark/parser"
|
||||
"github.com/yourselfhosted/gomark/parser/tokenizer"
|
||||
"github.com/yourselfhosted/gomark/restore"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) UpsertTag(ctx context.Context, request *v1pb.UpsertTagRequest) (*v1pb.Tag, error) {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
|
||||
tag, err := s.Store.UpsertTag(ctx, &store.Tag{
|
||||
Name: request.Name,
|
||||
CreatorID: user.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert tag: %v", err)
|
||||
}
|
||||
|
||||
tagMessage, err := s.convertTagFromStore(ctx, tag)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert tag: %v", err)
|
||||
}
|
||||
return tagMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) BatchUpsertTag(ctx context.Context, request *v1pb.BatchUpsertTagRequest) (*emptypb.Empty, error) {
|
||||
for _, r := range request.Requests {
|
||||
if _, err := s.UpsertTag(ctx, r); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to batch upsert tags: %v", err)
|
||||
}
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListTags(ctx context.Context, _ *v1pb.ListTagsRequest) (*v1pb.ListTagsResponse, error) {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
tagFind := &store.FindTag{
|
||||
CreatorID: user.ID,
|
||||
}
|
||||
tags, err := s.Store.ListTags(ctx, tagFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list tags: %v", err)
|
||||
}
|
||||
|
||||
response := &v1pb.ListTagsResponse{}
|
||||
for _, tag := range tags {
|
||||
t, err := s.convertTagFromStore(ctx, tag)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to convert tag: %v", err)
|
||||
}
|
||||
response.Tags = append(response.Tags, t)
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) RenameTag(ctx context.Context, request *v1pb.RenameTagRequest) (*emptypb.Empty, error) {
|
||||
userID, err := ExtractUserIDFromName(request.User)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
|
||||
// Find all related memos.
|
||||
memos, err := s.Store.ListMemos(ctx, &store.FindMemo{
|
||||
CreatorID: &user.ID,
|
||||
ContentSearch: []string{fmt.Sprintf("#%s", request.OldName)},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
|
||||
}
|
||||
// Replace tag name in memo content.
|
||||
for _, memo := range memos {
|
||||
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to parse memo: %v", err)
|
||||
}
|
||||
TraverseASTNodes(nodes, func(node ast.Node) {
|
||||
if tag, ok := node.(*ast.Tag); ok && tag.Content == request.OldName {
|
||||
tag.Content = request.NewName
|
||||
}
|
||||
})
|
||||
content := restore.Restore(nodes)
|
||||
if err := s.Store.UpdateMemo(ctx, &store.UpdateMemo{
|
||||
ID: memo.ID,
|
||||
Content: &content,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update memo: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete old tag and create new tag.
|
||||
if err := s.Store.DeleteTag(ctx, &store.DeleteTag{
|
||||
CreatorID: user.ID,
|
||||
Name: request.OldName,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete tag: %v", err)
|
||||
}
|
||||
if _, err := s.Store.UpsertTag(ctx, &store.Tag{
|
||||
CreatorID: user.ID,
|
||||
Name: request.NewName,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert tag: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteTag(ctx context.Context, request *v1pb.DeleteTagRequest) (*emptypb.Empty, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Tag.Creator)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
if err := s.Store.DeleteTag(ctx, &store.DeleteTag{
|
||||
Name: request.Tag.Name,
|
||||
CreatorID: user.ID,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete tag: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetTagSuggestions(ctx context.Context, request *v1pb.GetTagSuggestionsRequest) (*v1pb.GetTagSuggestionsResponse, error) {
|
||||
userID, err := ExtractUserIDFromName(request.User)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
normalRowStatus := store.Normal
|
||||
memoFind := &store.FindMemo{
|
||||
CreatorID: &user.ID,
|
||||
ContentSearch: []string{"#"},
|
||||
RowStatus: &normalRowStatus,
|
||||
}
|
||||
memos, err := s.Store.ListMemos(ctx, memoFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list memos: %v", err)
|
||||
}
|
||||
|
||||
tagList, err := s.Store.ListTags(ctx, &store.FindTag{
|
||||
CreatorID: user.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list tags: %v", err)
|
||||
}
|
||||
|
||||
tagNameList := []string{}
|
||||
for _, tag := range tagList {
|
||||
tagNameList = append(tagNameList, tag.Name)
|
||||
}
|
||||
tagMapSet := make(map[string]bool)
|
||||
for _, memo := range memos {
|
||||
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse memo content")
|
||||
}
|
||||
|
||||
// Dynamically upsert tags from memo content.
|
||||
TraverseASTNodes(nodes, func(node ast.Node) {
|
||||
if tagNode, ok := node.(*ast.Tag); ok {
|
||||
tag := tagNode.Content
|
||||
if !slices.Contains(tagNameList, tag) {
|
||||
tagMapSet[tag] = true
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
suggestions := []string{}
|
||||
for tag := range tagMapSet {
|
||||
suggestions = append(suggestions, tag)
|
||||
}
|
||||
sort.Strings(suggestions)
|
||||
|
||||
return &v1pb.GetTagSuggestionsResponse{
|
||||
Tags: suggestions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) convertTagFromStore(ctx context.Context, tag *store.Tag) (*v1pb.Tag, error) {
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &tag.CreatorID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get user")
|
||||
}
|
||||
return &v1pb.Tag{
|
||||
Name: tag.Name,
|
||||
Creator: fmt.Sprintf("%s%d", UserNamePrefix, user.ID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TraverseASTNodes(nodes []ast.Node, fn func(ast.Node)) {
|
||||
for _, node := range nodes {
|
||||
fn(node)
|
||||
switch n := node.(type) {
|
||||
case *ast.Paragraph:
|
||||
TraverseASTNodes(n.Children, fn)
|
||||
case *ast.Heading:
|
||||
TraverseASTNodes(n.Children, fn)
|
||||
case *ast.Blockquote:
|
||||
TraverseASTNodes(n.Children, fn)
|
||||
case *ast.OrderedList:
|
||||
TraverseASTNodes(n.Children, fn)
|
||||
case *ast.UnorderedList:
|
||||
TraverseASTNodes(n.Children, fn)
|
||||
case *ast.TaskList:
|
||||
TraverseASTNodes(n.Children, fn)
|
||||
case *ast.Bold:
|
||||
TraverseASTNodes(n.Children, fn)
|
||||
}
|
||||
}
|
||||
}
|
620
server/router/api/v1/user_service.go
Normal file
620
server/router/api/v1/user_service.go
Normal file
@@ -0,0 +1,620 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
"google.golang.org/genproto/googleapis/api/httpbody"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) ListUsers(ctx context.Context, _ *v1pb.ListUsersRequest) (*v1pb.ListUsersResponse, error) {
|
||||
currentUser, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser.Role != store.RoleHost && currentUser.Role != store.RoleAdmin {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
users, err := s.Store.ListUsers(ctx, &store.FindUser{})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list users: %v", err)
|
||||
}
|
||||
|
||||
response := &v1pb.ListUsersResponse{
|
||||
Users: []*v1pb.User{},
|
||||
}
|
||||
for _, user := range users {
|
||||
response.Users = append(response.Users, convertUserFromStore(user))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) SearchUsers(ctx context.Context, request *v1pb.SearchUsersRequest) (*v1pb.SearchUsersResponse, error) {
|
||||
if request.Filter == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "filter is empty")
|
||||
}
|
||||
filter, err := parseSearchUsersFilter(request.Filter)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse filter: %v", err)
|
||||
}
|
||||
userFind := &store.FindUser{}
|
||||
if filter.Username != nil {
|
||||
userFind.Username = filter.Username
|
||||
}
|
||||
if filter.Random {
|
||||
userFind.Random = true
|
||||
}
|
||||
if filter.Limit != nil {
|
||||
userFind.Limit = filter.Limit
|
||||
}
|
||||
|
||||
users, err := s.Store.ListUsers(ctx, userFind)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to search users: %v", err)
|
||||
}
|
||||
|
||||
response := &v1pb.SearchUsersResponse{
|
||||
Users: []*v1pb.User{},
|
||||
}
|
||||
for _, user := range users {
|
||||
response.Users = append(response.Users, convertUserFromStore(user))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetUser(ctx context.Context, request *v1pb.GetUserRequest) (*v1pb.User, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
|
||||
return convertUserFromStore(user), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetUserAvatarBinary(ctx context.Context, request *v1pb.GetUserAvatarBinaryRequest) (*httpbody.HttpBody, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
if user.AvatarURL == "" {
|
||||
return nil, status.Errorf(codes.NotFound, "avatar not found")
|
||||
}
|
||||
|
||||
imageType, base64Data, err := extractImageInfo(user.AvatarURL)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to extract image info: %v", err)
|
||||
}
|
||||
imageData, err := base64.StdEncoding.DecodeString(base64Data)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to decode string: %v", err)
|
||||
}
|
||||
httpBody := &httpbody.HttpBody{
|
||||
ContentType: imageType,
|
||||
Data: imageData,
|
||||
}
|
||||
return httpBody, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserRequest) (*v1pb.User, error) {
|
||||
currentUser, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser.Role != store.RoleHost {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if !util.UIDMatcher.MatchString(strings.ToLower(request.User.Username)) {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", request.User.Username)
|
||||
}
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to generate password hash").SetInternal(err)
|
||||
}
|
||||
|
||||
user, err := s.Store.CreateUser(ctx, &store.User{
|
||||
Username: request.User.Username,
|
||||
Role: convertUserRoleToStore(request.User.Role),
|
||||
Email: request.User.Email,
|
||||
Nickname: request.User.Nickname,
|
||||
PasswordHash: string(passwordHash),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create user: %v", err)
|
||||
}
|
||||
|
||||
return convertUserFromStore(user), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserRequest) (*v1pb.User, error) {
|
||||
userID, err := ExtractUserIDFromName(request.User.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
currentUser, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update mask is empty")
|
||||
}
|
||||
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
|
||||
currentTs := time.Now().Unix()
|
||||
update := &store.UpdateUser{
|
||||
ID: user.ID,
|
||||
UpdatedTs: ¤tTs,
|
||||
}
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
if field == "username" {
|
||||
if !util.UIDMatcher.MatchString(strings.ToLower(request.User.Username)) {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", request.User.Username)
|
||||
}
|
||||
update.Username = &request.User.Username
|
||||
} else if field == "nickname" {
|
||||
update.Nickname = &request.User.Nickname
|
||||
} else if field == "email" {
|
||||
update.Email = &request.User.Email
|
||||
} else if field == "avatar_url" {
|
||||
update.AvatarURL = &request.User.AvatarUrl
|
||||
} else if field == "description" {
|
||||
update.Description = &request.User.Description
|
||||
} else if field == "role" {
|
||||
role := convertUserRoleToStore(request.User.Role)
|
||||
update.Role = &role
|
||||
} else if field == "password" {
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to generate password hash").SetInternal(err)
|
||||
}
|
||||
passwordHashStr := string(passwordHash)
|
||||
update.PasswordHash = &passwordHashStr
|
||||
} else if field == "row_status" {
|
||||
rowStatus := convertRowStatusToStore(request.User.RowStatus)
|
||||
update.RowStatus = &rowStatus
|
||||
} else {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid update path: %s", field)
|
||||
}
|
||||
}
|
||||
|
||||
updatedUser, err := s.Store.UpdateUser(ctx, update)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update user: %v", err)
|
||||
}
|
||||
|
||||
return convertUserFromStore(updatedUser), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteUser(ctx context.Context, request *v1pb.DeleteUserRequest) (*emptypb.Empty, error) {
|
||||
userID, err := ExtractUserIDFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user name: %v", err)
|
||||
}
|
||||
currentUser, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin && currentUser.Role != store.RoleHost {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{ID: &userID})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if err := s.Store.DeleteUser(ctx, &store.DeleteUser{
|
||||
ID: user.ID,
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete user: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func getDefaultUserSetting() *v1pb.UserSetting {
|
||||
return &v1pb.UserSetting{
|
||||
Locale: "en",
|
||||
Appearance: "system",
|
||||
MemoVisibility: "PRIVATE",
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetUserSetting(ctx context.Context, _ *v1pb.GetUserSettingRequest) (*v1pb.UserSetting, error) {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
userSettings, err := s.Store.ListUserSettings(ctx, &store.FindUserSetting{
|
||||
UserID: &user.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list user settings: %v", err)
|
||||
}
|
||||
userSettingMessage := getDefaultUserSetting()
|
||||
for _, setting := range userSettings {
|
||||
if setting.Key == storepb.UserSettingKey_USER_SETTING_LOCALE {
|
||||
userSettingMessage.Locale = setting.GetLocale()
|
||||
} else if setting.Key == storepb.UserSettingKey_USER_SETTING_APPEARANCE {
|
||||
userSettingMessage.Appearance = setting.GetAppearance()
|
||||
} else if setting.Key == storepb.UserSettingKey_USER_SETTING_MEMO_VISIBILITY {
|
||||
userSettingMessage.MemoVisibility = setting.GetMemoVisibility()
|
||||
}
|
||||
}
|
||||
return userSettingMessage, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateUserSetting(ctx context.Context, request *v1pb.UpdateUserSettingRequest) (*v1pb.UserSetting, error) {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update mask is empty")
|
||||
}
|
||||
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
if field == "locale" {
|
||||
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSettingKey_USER_SETTING_LOCALE,
|
||||
Value: &storepb.UserSetting_Locale{
|
||||
Locale: request.Setting.Locale,
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
|
||||
}
|
||||
} else if field == "appearance" {
|
||||
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSettingKey_USER_SETTING_APPEARANCE,
|
||||
Value: &storepb.UserSetting_Appearance{
|
||||
Appearance: request.Setting.Appearance,
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
|
||||
}
|
||||
} else if field == "memo_visibility" {
|
||||
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSettingKey_USER_SETTING_MEMO_VISIBILITY,
|
||||
Value: &storepb.UserSetting_MemoVisibility{
|
||||
MemoVisibility: request.Setting.MemoVisibility,
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
|
||||
}
|
||||
} else {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid update path: %s", field)
|
||||
}
|
||||
}
|
||||
|
||||
return s.GetUserSetting(ctx, &v1pb.GetUserSettingRequest{})
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListUserAccessTokens(ctx context.Context, _ *v1pb.ListUserAccessTokensRequest) (*v1pb.ListUserAccessTokensResponse, error) {
|
||||
currentUser, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if currentUser == nil {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, currentUser.ID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
|
||||
}
|
||||
|
||||
accessTokens := []*v1pb.UserAccessToken{}
|
||||
for _, userAccessToken := range userAccessTokens {
|
||||
claims := &ClaimsMessage{}
|
||||
_, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) {
|
||||
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
|
||||
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
|
||||
}
|
||||
if kid, ok := t.Header["kid"].(string); ok {
|
||||
if kid == "v1" {
|
||||
return []byte(s.Secret), nil
|
||||
}
|
||||
}
|
||||
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
|
||||
})
|
||||
if err != nil {
|
||||
// If the access token is invalid or expired, just ignore it.
|
||||
continue
|
||||
}
|
||||
|
||||
userAccessToken := &v1pb.UserAccessToken{
|
||||
AccessToken: userAccessToken.AccessToken,
|
||||
Description: userAccessToken.Description,
|
||||
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
|
||||
}
|
||||
if claims.ExpiresAt != nil {
|
||||
userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
|
||||
}
|
||||
accessTokens = append(accessTokens, userAccessToken)
|
||||
}
|
||||
|
||||
// Sort by issued time in descending order.
|
||||
slices.SortFunc(accessTokens, func(i, j *v1pb.UserAccessToken) int {
|
||||
return int(i.IssuedAt.Seconds - j.IssuedAt.Seconds)
|
||||
})
|
||||
response := &v1pb.ListUserAccessTokensResponse{
|
||||
AccessTokens: accessTokens,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) CreateUserAccessToken(ctx context.Context, request *v1pb.CreateUserAccessTokenRequest) (*v1pb.UserAccessToken, error) {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
expiresAt := time.Time{}
|
||||
if request.ExpiresAt != nil {
|
||||
expiresAt = request.ExpiresAt.AsTime()
|
||||
}
|
||||
|
||||
accessToken, err := GenerateAccessToken(user.Username, user.ID, expiresAt, []byte(s.Secret))
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
|
||||
}
|
||||
|
||||
claims := &ClaimsMessage{}
|
||||
_, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
|
||||
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
|
||||
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
|
||||
}
|
||||
if kid, ok := t.Header["kid"].(string); ok {
|
||||
if kid == "v1" {
|
||||
return []byte(s.Secret), nil
|
||||
}
|
||||
}
|
||||
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to parse access token: %v", err)
|
||||
}
|
||||
|
||||
// Upsert the access token to user setting store.
|
||||
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken, request.Description); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert access token to store: %v", err)
|
||||
}
|
||||
|
||||
userAccessToken := &v1pb.UserAccessToken{
|
||||
AccessToken: accessToken,
|
||||
Description: request.Description,
|
||||
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
|
||||
}
|
||||
if claims.ExpiresAt != nil {
|
||||
userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time)
|
||||
}
|
||||
return userAccessToken, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteUserAccessToken(ctx context.Context, request *v1pb.DeleteUserAccessTokenRequest) (*emptypb.Empty, error) {
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
|
||||
}
|
||||
updatedUserAccessTokens := []*storepb.AccessTokensUserSetting_AccessToken{}
|
||||
for _, userAccessToken := range userAccessTokens {
|
||||
if userAccessToken.AccessToken == request.AccessToken {
|
||||
continue
|
||||
}
|
||||
updatedUserAccessTokens = append(updatedUserAccessTokens, userAccessToken)
|
||||
}
|
||||
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
|
||||
Value: &storepb.UserSetting_AccessTokens{
|
||||
AccessTokens: &storepb.AccessTokensUserSetting{
|
||||
AccessTokens: updatedUserAccessTokens,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpsertAccessTokenToStore(ctx context.Context, user *store.User, accessToken, description string) error {
|
||||
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to get user access tokens")
|
||||
}
|
||||
userAccessToken := storepb.AccessTokensUserSetting_AccessToken{
|
||||
AccessToken: accessToken,
|
||||
Description: description,
|
||||
}
|
||||
userAccessTokens = append(userAccessTokens, &userAccessToken)
|
||||
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
|
||||
UserId: user.ID,
|
||||
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
|
||||
Value: &storepb.UserSetting_AccessTokens{
|
||||
AccessTokens: &storepb.AccessTokensUserSetting{
|
||||
AccessTokens: userAccessTokens,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "failed to upsert user setting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertUserFromStore(user *store.User) *v1pb.User {
|
||||
userpb := &v1pb.User{
|
||||
Name: fmt.Sprintf("%s%d", UserNamePrefix, user.ID),
|
||||
Id: user.ID,
|
||||
RowStatus: convertRowStatusFromStore(user.RowStatus),
|
||||
CreateTime: timestamppb.New(time.Unix(user.CreatedTs, 0)),
|
||||
UpdateTime: timestamppb.New(time.Unix(user.UpdatedTs, 0)),
|
||||
Role: convertUserRoleFromStore(user.Role),
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
Nickname: user.Nickname,
|
||||
AvatarUrl: user.AvatarURL,
|
||||
Description: user.Description,
|
||||
}
|
||||
// Use the avatar URL instead of raw base64 image data to reduce the response size.
|
||||
if user.AvatarURL != "" {
|
||||
userpb.AvatarUrl = fmt.Sprintf("/o/%s/avatar", userpb.Name)
|
||||
}
|
||||
return userpb
|
||||
}
|
||||
|
||||
func convertUserRoleFromStore(role store.Role) v1pb.User_Role {
|
||||
switch role {
|
||||
case store.RoleHost:
|
||||
return v1pb.User_HOST
|
||||
case store.RoleAdmin:
|
||||
return v1pb.User_ADMIN
|
||||
case store.RoleUser:
|
||||
return v1pb.User_USER
|
||||
default:
|
||||
return v1pb.User_ROLE_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func convertUserRoleToStore(role v1pb.User_Role) store.Role {
|
||||
switch role {
|
||||
case v1pb.User_HOST:
|
||||
return store.RoleHost
|
||||
case v1pb.User_ADMIN:
|
||||
return store.RoleAdmin
|
||||
case v1pb.User_USER:
|
||||
return store.RoleUser
|
||||
default:
|
||||
return store.RoleUser
|
||||
}
|
||||
}
|
||||
|
||||
// SearchUsersFilterCELAttributes are the CEL attributes for SearchUsersFilter.
|
||||
var SearchUsersFilterCELAttributes = []cel.EnvOption{
|
||||
cel.Variable("username", cel.StringType),
|
||||
cel.Variable("random", cel.BoolType),
|
||||
cel.Variable("limit", cel.IntType),
|
||||
}
|
||||
|
||||
type SearchUsersFilter struct {
|
||||
Username *string
|
||||
Random bool
|
||||
Limit *int
|
||||
}
|
||||
|
||||
func parseSearchUsersFilter(expression string) (*SearchUsersFilter, error) {
|
||||
e, err := cel.NewEnv(SearchUsersFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ast, issues := e.Compile(expression)
|
||||
if issues != nil {
|
||||
return nil, errors.Errorf("found issue %v", issues)
|
||||
}
|
||||
filter := &SearchUsersFilter{}
|
||||
expr, err := cel.AstToParsedExpr(ast)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
callExpr := expr.GetExpr().GetCallExpr()
|
||||
findSearchUsersField(callExpr, filter)
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
func findSearchUsersField(callExpr *expr.Expr_Call, filter *SearchUsersFilter) {
|
||||
if len(callExpr.Args) == 2 {
|
||||
idExpr := callExpr.Args[0].GetIdentExpr()
|
||||
if idExpr != nil {
|
||||
if idExpr.Name == "username" {
|
||||
username := callExpr.Args[1].GetConstExpr().GetStringValue()
|
||||
filter.Username = &username
|
||||
} else if idExpr.Name == "random" {
|
||||
random := callExpr.Args[1].GetConstExpr().GetBoolValue()
|
||||
filter.Random = random
|
||||
} else if idExpr.Name == "limit" {
|
||||
limit := int(callExpr.Args[1].GetConstExpr().GetInt64Value())
|
||||
filter.Limit = &limit
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
for _, arg := range callExpr.Args {
|
||||
callExpr := arg.GetCallExpr()
|
||||
if callExpr != nil {
|
||||
findSearchUsersField(callExpr, filter)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func extractImageInfo(dataURI string) (string, string, error) {
|
||||
dataURIRegex := regexp.MustCompile(`^data:(?P<type>.+);base64,(?P<base64>.+)`)
|
||||
matches := dataURIRegex.FindStringSubmatch(dataURI)
|
||||
if len(matches) != 3 {
|
||||
return "", "", errors.New("Invalid data URI format")
|
||||
}
|
||||
imageType := matches[1]
|
||||
base64Data := matches[2]
|
||||
return imageType, base64Data, nil
|
||||
}
|
127
server/router/api/v1/v1.go
Normal file
127
server/router/api/v1/v1.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||
"github.com/improbable-eng/grpc-web/go/grpcweb"
|
||||
"github.com/labstack/echo/v4"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/reflection"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/server/profile"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
type APIV1Service struct {
|
||||
v1pb.UnimplementedWorkspaceServiceServer
|
||||
v1pb.UnimplementedWorkspaceSettingServiceServer
|
||||
v1pb.UnimplementedAuthServiceServer
|
||||
v1pb.UnimplementedUserServiceServer
|
||||
v1pb.UnimplementedMemoServiceServer
|
||||
v1pb.UnimplementedResourceServiceServer
|
||||
v1pb.UnimplementedTagServiceServer
|
||||
v1pb.UnimplementedInboxServiceServer
|
||||
v1pb.UnimplementedActivityServiceServer
|
||||
v1pb.UnimplementedWebhookServiceServer
|
||||
v1pb.UnimplementedMarkdownServiceServer
|
||||
v1pb.UnimplementedIdentityProviderServiceServer
|
||||
|
||||
Secret string
|
||||
Profile *profile.Profile
|
||||
Store *store.Store
|
||||
|
||||
grpcServer *grpc.Server
|
||||
}
|
||||
|
||||
func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store, grpcServer *grpc.Server) *APIV1Service {
|
||||
grpc.EnableTracing = true
|
||||
apiv1Service := &APIV1Service{
|
||||
Secret: secret,
|
||||
Profile: profile,
|
||||
Store: store,
|
||||
grpcServer: grpcServer,
|
||||
}
|
||||
v1pb.RegisterWorkspaceServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterWorkspaceSettingServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterAuthServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterUserServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterMemoServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterTagServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterResourceServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterInboxServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterActivityServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterWebhookServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterMarkdownServiceServer(grpcServer, apiv1Service)
|
||||
v1pb.RegisterIdentityProviderServiceServer(grpcServer, apiv1Service)
|
||||
reflection.Register(grpcServer)
|
||||
return apiv1Service
|
||||
}
|
||||
|
||||
// RegisterGateway registers the gRPC-Gateway with the given Echo instance.
|
||||
func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Echo) error {
|
||||
// Create a client connection to the gRPC Server we just started.
|
||||
// This is where the gRPC-Gateway proxies the requests.
|
||||
conn, err := grpc.DialContext(
|
||||
ctx,
|
||||
fmt.Sprintf(":%d", s.Profile.Port),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
gwMux := runtime.NewServeMux()
|
||||
if err := v1pb.RegisterWorkspaceServiceHandler(context.Background(), gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterWorkspaceSettingServiceHandler(context.Background(), gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterAuthServiceHandler(context.Background(), gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterUserServiceHandler(context.Background(), gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterMemoServiceHandler(context.Background(), gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterTagServiceHandler(context.Background(), gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterResourceServiceHandler(context.Background(), gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterInboxServiceHandler(context.Background(), gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterActivityServiceHandler(context.Background(), gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterWebhookServiceHandler(context.Background(), gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterMarkdownServiceHandler(context.Background(), gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v1pb.RegisterIdentityProviderServiceHandler(context.Background(), gwMux, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
echoServer.Any("*", echo.WrapHandler(gwMux))
|
||||
|
||||
// GRPC web proxy.
|
||||
options := []grpcweb.Option{
|
||||
grpcweb.WithCorsForRegisteredEndpointsOnly(false),
|
||||
grpcweb.WithOriginFunc(func(origin string) bool {
|
||||
return true
|
||||
}),
|
||||
}
|
||||
wrappedGrpc := grpcweb.WrapServer(s.grpcServer, options...)
|
||||
echoServer.Any("/memos.api.v1.*", echo.WrapHandler(wrappedGrpc))
|
||||
|
||||
return nil
|
||||
}
|
114
server/router/api/v1/webhook_service.go
Normal file
114
server/router/api/v1/webhook_service.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) CreateWebhook(ctx context.Context, request *v1pb.CreateWebhookRequest) (*v1pb.Webhook, error) {
|
||||
currentUser, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
|
||||
webhook, err := s.Store.CreateWebhook(ctx, &store.Webhook{
|
||||
CreatorID: currentUser.ID,
|
||||
Name: request.Name,
|
||||
URL: request.Url,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to create webhook, error: %+v", err)
|
||||
}
|
||||
return convertWebhookFromStore(webhook), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) ListWebhooks(ctx context.Context, request *v1pb.ListWebhooksRequest) (*v1pb.ListWebhooksResponse, error) {
|
||||
webhooks, err := s.Store.ListWebhooks(ctx, &store.FindWebhook{
|
||||
CreatorID: &request.CreatorId,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list webhooks, error: %+v", err)
|
||||
}
|
||||
|
||||
response := &v1pb.ListWebhooksResponse{
|
||||
Webhooks: []*v1pb.Webhook{},
|
||||
}
|
||||
for _, webhook := range webhooks {
|
||||
response.Webhooks = append(response.Webhooks, convertWebhookFromStore(webhook))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetWebhook(ctx context.Context, request *v1pb.GetWebhookRequest) (*v1pb.Webhook, error) {
|
||||
currentUser, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
|
||||
}
|
||||
|
||||
webhook, err := s.Store.GetWebhook(ctx, &store.FindWebhook{
|
||||
ID: &request.Id,
|
||||
CreatorID: ¤tUser.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get webhook, error: %+v", err)
|
||||
}
|
||||
if webhook == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "webhook not found")
|
||||
}
|
||||
return convertWebhookFromStore(webhook), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) UpdateWebhook(ctx context.Context, request *v1pb.UpdateWebhookRequest) (*v1pb.Webhook, error) {
|
||||
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
|
||||
}
|
||||
|
||||
update := &store.UpdateWebhook{}
|
||||
for _, field := range request.UpdateMask.Paths {
|
||||
switch field {
|
||||
case "row_status":
|
||||
rowStatus := store.RowStatus(request.Webhook.RowStatus.String())
|
||||
update.RowStatus = &rowStatus
|
||||
case "name":
|
||||
update.Name = &request.Webhook.Name
|
||||
case "url":
|
||||
update.URL = &request.Webhook.Url
|
||||
}
|
||||
}
|
||||
|
||||
webhook, err := s.Store.UpdateWebhook(ctx, update)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update webhook, error: %+v", err)
|
||||
}
|
||||
return convertWebhookFromStore(webhook), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) DeleteWebhook(ctx context.Context, request *v1pb.DeleteWebhookRequest) (*emptypb.Empty, error) {
|
||||
err := s.Store.DeleteWebhook(ctx, &store.DeleteWebhook{
|
||||
ID: request.Id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete webhook, error: %+v", err)
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func convertWebhookFromStore(webhook *store.Webhook) *v1pb.Webhook {
|
||||
return &v1pb.Webhook{
|
||||
Id: webhook.ID,
|
||||
CreatedTime: timestamppb.New(time.Unix(webhook.CreatedTs, 0)),
|
||||
UpdatedTime: timestamppb.New(time.Unix(webhook.UpdatedTs, 0)),
|
||||
RowStatus: convertRowStatusFromStore(webhook.RowStatus),
|
||||
CreatorId: webhook.CreatorID,
|
||||
Name: webhook.Name,
|
||||
Url: webhook.URL,
|
||||
}
|
||||
}
|
49
server/router/api/v1/workspace_service.go
Normal file
49
server/router/api/v1/workspace_service.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
var ownerCache *v1pb.User
|
||||
|
||||
func (s *APIV1Service) GetWorkspaceProfile(ctx context.Context, _ *v1pb.GetWorkspaceProfileRequest) (*v1pb.WorkspaceProfile, error) {
|
||||
workspaceProfile := &v1pb.WorkspaceProfile{
|
||||
Version: s.Profile.Version,
|
||||
Mode: s.Profile.Mode,
|
||||
}
|
||||
owner, err := s.GetInstanceOwner(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get instance owner: %v", err)
|
||||
}
|
||||
if owner != nil {
|
||||
workspaceProfile.Owner = owner.Name
|
||||
}
|
||||
return workspaceProfile, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetInstanceOwner(ctx context.Context) (*v1pb.User, error) {
|
||||
if ownerCache != nil {
|
||||
return ownerCache, nil
|
||||
}
|
||||
|
||||
hostUserType := store.RoleHost
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUser{
|
||||
Role: &hostUserType,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to find owner")
|
||||
}
|
||||
if user == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ownerCache = convertUserFromStore(user)
|
||||
return ownerCache, nil
|
||||
}
|
225
server/router/api/v1/workspace_setting_service.go
Normal file
225
server/router/api/v1/workspace_setting_service.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) ListWorkspaceSettings(ctx context.Context, _ *v1pb.ListWorkspaceSettingsRequest) (*v1pb.ListWorkspaceSettingsResponse, error) {
|
||||
workspaceSettings, err := s.Store.ListWorkspaceSettings(ctx, &store.FindWorkspaceSetting{})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace setting: %v", err)
|
||||
}
|
||||
|
||||
response := &v1pb.ListWorkspaceSettingsResponse{
|
||||
Settings: []*v1pb.WorkspaceSetting{},
|
||||
}
|
||||
for _, workspaceSetting := range workspaceSettings {
|
||||
if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_BASIC {
|
||||
continue
|
||||
}
|
||||
response.Settings = append(response.Settings, convertWorkspaceSettingFromStore(workspaceSetting))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) GetWorkspaceSetting(ctx context.Context, request *v1pb.GetWorkspaceSettingRequest) (*v1pb.WorkspaceSetting, error) {
|
||||
settingKeyString, err := ExtractWorkspaceSettingKeyFromName(request.Name)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid workspace setting name: %v", err)
|
||||
}
|
||||
settingKey := storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[settingKeyString])
|
||||
workspaceSetting, err := s.Store.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{
|
||||
Name: settingKey.String(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get workspace setting: %v", err)
|
||||
}
|
||||
if workspaceSetting == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "workspace setting not found")
|
||||
}
|
||||
|
||||
return convertWorkspaceSettingFromStore(workspaceSetting), nil
|
||||
}
|
||||
|
||||
func (s *APIV1Service) SetWorkspaceSetting(ctx context.Context, request *v1pb.SetWorkspaceSettingRequest) (*v1pb.WorkspaceSetting, error) {
|
||||
if s.Profile.Mode == "demo" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "setting workspace setting is not allowed in demo mode")
|
||||
}
|
||||
|
||||
user, err := getCurrentUser(ctx, s.Store)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
|
||||
}
|
||||
if user.Role != store.RoleHost {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
|
||||
}
|
||||
|
||||
workspaceSetting, err := s.Store.UpsertWorkspaceSetting(ctx, convertWorkspaceSettingToStore(request.Setting))
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to upsert workspace setting: %v", err)
|
||||
}
|
||||
|
||||
return convertWorkspaceSettingFromStore(workspaceSetting), nil
|
||||
}
|
||||
|
||||
func convertWorkspaceSettingFromStore(setting *storepb.WorkspaceSetting) *v1pb.WorkspaceSetting {
|
||||
workspaceSetting := &v1pb.WorkspaceSetting{
|
||||
Name: fmt.Sprintf("%s%s", WorkspaceSettingNamePrefix, setting.Key.String()),
|
||||
}
|
||||
switch setting.Value.(type) {
|
||||
case *storepb.WorkspaceSetting_GeneralSetting:
|
||||
workspaceSetting.Value = &v1pb.WorkspaceSetting_GeneralSetting{
|
||||
GeneralSetting: convertWorkspaceGeneralSettingFromStore(setting.GetGeneralSetting()),
|
||||
}
|
||||
case *storepb.WorkspaceSetting_StorageSetting:
|
||||
workspaceSetting.Value = &v1pb.WorkspaceSetting_StorageSetting{
|
||||
StorageSetting: convertWorkspaceStorageSettingFromStore(setting.GetStorageSetting()),
|
||||
}
|
||||
case *storepb.WorkspaceSetting_MemoRelatedSetting:
|
||||
workspaceSetting.Value = &v1pb.WorkspaceSetting_MemoRelatedSetting{
|
||||
MemoRelatedSetting: convertWorkspaceMemoRelatedSettingFromStore(setting.GetMemoRelatedSetting()),
|
||||
}
|
||||
}
|
||||
return workspaceSetting
|
||||
}
|
||||
|
||||
func convertWorkspaceSettingToStore(setting *v1pb.WorkspaceSetting) *storepb.WorkspaceSetting {
|
||||
settingKeyString, _ := ExtractWorkspaceSettingKeyFromName(setting.Name)
|
||||
workspaceSetting := &storepb.WorkspaceSetting{
|
||||
Key: storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[settingKeyString]),
|
||||
Value: &storepb.WorkspaceSetting_GeneralSetting{
|
||||
GeneralSetting: convertWorkspaceGeneralSettingToStore(setting.GetGeneralSetting()),
|
||||
},
|
||||
}
|
||||
switch workspaceSetting.Key {
|
||||
case storepb.WorkspaceSettingKey_WORKSPACE_SETTING_GENERAL:
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_GeneralSetting{
|
||||
GeneralSetting: convertWorkspaceGeneralSettingToStore(setting.GetGeneralSetting()),
|
||||
}
|
||||
case storepb.WorkspaceSettingKey_WORKSPACE_SETTING_STORAGE:
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_StorageSetting{
|
||||
StorageSetting: convertWorkspaceStorageSettingToStore(setting.GetStorageSetting()),
|
||||
}
|
||||
case storepb.WorkspaceSettingKey_WORKSPACE_SETTING_MEMO_RELATED:
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_MemoRelatedSetting{
|
||||
MemoRelatedSetting: convertWorkspaceMemoRelatedSettingToStore(setting.GetMemoRelatedSetting()),
|
||||
}
|
||||
}
|
||||
return workspaceSetting
|
||||
}
|
||||
|
||||
func convertWorkspaceGeneralSettingFromStore(setting *storepb.WorkspaceGeneralSetting) *v1pb.WorkspaceGeneralSetting {
|
||||
if setting == nil {
|
||||
return nil
|
||||
}
|
||||
generalSetting := &v1pb.WorkspaceGeneralSetting{
|
||||
InstanceUrl: setting.InstanceUrl,
|
||||
DisallowSignup: setting.DisallowSignup,
|
||||
DisallowPasswordLogin: setting.DisallowPasswordLogin,
|
||||
AdditionalScript: setting.AdditionalScript,
|
||||
AdditionalStyle: setting.AdditionalStyle,
|
||||
}
|
||||
if setting.CustomProfile != nil {
|
||||
generalSetting.CustomProfile = &v1pb.WorkspaceCustomProfile{
|
||||
Title: setting.CustomProfile.Title,
|
||||
Description: setting.CustomProfile.Description,
|
||||
LogoUrl: setting.CustomProfile.LogoUrl,
|
||||
Locale: setting.CustomProfile.Locale,
|
||||
Appearance: setting.CustomProfile.Appearance,
|
||||
}
|
||||
}
|
||||
return generalSetting
|
||||
}
|
||||
|
||||
func convertWorkspaceGeneralSettingToStore(setting *v1pb.WorkspaceGeneralSetting) *storepb.WorkspaceGeneralSetting {
|
||||
if setting == nil {
|
||||
return nil
|
||||
}
|
||||
generalSetting := &storepb.WorkspaceGeneralSetting{
|
||||
InstanceUrl: setting.InstanceUrl,
|
||||
DisallowSignup: setting.DisallowSignup,
|
||||
DisallowPasswordLogin: setting.DisallowPasswordLogin,
|
||||
AdditionalScript: setting.AdditionalScript,
|
||||
AdditionalStyle: setting.AdditionalStyle,
|
||||
}
|
||||
if setting.CustomProfile != nil {
|
||||
generalSetting.CustomProfile = &storepb.WorkspaceCustomProfile{
|
||||
Title: setting.CustomProfile.Title,
|
||||
Description: setting.CustomProfile.Description,
|
||||
LogoUrl: setting.CustomProfile.LogoUrl,
|
||||
Locale: setting.CustomProfile.Locale,
|
||||
Appearance: setting.CustomProfile.Appearance,
|
||||
}
|
||||
}
|
||||
return generalSetting
|
||||
}
|
||||
|
||||
func convertWorkspaceStorageSettingFromStore(settingpb *storepb.WorkspaceStorageSetting) *v1pb.WorkspaceStorageSetting {
|
||||
if settingpb == nil {
|
||||
return nil
|
||||
}
|
||||
setting := &v1pb.WorkspaceStorageSetting{
|
||||
StorageType: v1pb.WorkspaceStorageSetting_StorageType(settingpb.StorageType),
|
||||
FilepathTemplate: settingpb.FilepathTemplate,
|
||||
UploadSizeLimitMb: settingpb.UploadSizeLimitMb,
|
||||
}
|
||||
if settingpb.S3Config != nil {
|
||||
setting.S3Config = &v1pb.WorkspaceStorageSetting_S3Config{
|
||||
AccessKeyId: settingpb.S3Config.AccessKeyId,
|
||||
AccessKeySecret: settingpb.S3Config.AccessKeySecret,
|
||||
Endpoint: settingpb.S3Config.Endpoint,
|
||||
Region: settingpb.S3Config.Region,
|
||||
Bucket: settingpb.S3Config.Bucket,
|
||||
}
|
||||
}
|
||||
return setting
|
||||
}
|
||||
|
||||
func convertWorkspaceStorageSettingToStore(setting *v1pb.WorkspaceStorageSetting) *storepb.WorkspaceStorageSetting {
|
||||
if setting == nil {
|
||||
return nil
|
||||
}
|
||||
settingpb := &storepb.WorkspaceStorageSetting{
|
||||
StorageType: storepb.WorkspaceStorageSetting_StorageType(setting.StorageType),
|
||||
FilepathTemplate: setting.FilepathTemplate,
|
||||
UploadSizeLimitMb: setting.UploadSizeLimitMb,
|
||||
}
|
||||
if setting.S3Config != nil {
|
||||
settingpb.S3Config = &storepb.WorkspaceStorageSetting_S3Config{
|
||||
AccessKeyId: setting.S3Config.AccessKeyId,
|
||||
AccessKeySecret: setting.S3Config.AccessKeySecret,
|
||||
Endpoint: setting.S3Config.Endpoint,
|
||||
Region: setting.S3Config.Region,
|
||||
Bucket: setting.S3Config.Bucket,
|
||||
}
|
||||
}
|
||||
return settingpb
|
||||
}
|
||||
|
||||
func convertWorkspaceMemoRelatedSettingFromStore(setting *storepb.WorkspaceMemoRelatedSetting) *v1pb.WorkspaceMemoRelatedSetting {
|
||||
if setting == nil {
|
||||
return nil
|
||||
}
|
||||
return &v1pb.WorkspaceMemoRelatedSetting{
|
||||
DisallowPublicVisible: setting.DisallowPublicVisible,
|
||||
DisplayWithUpdateTime: setting.DisplayWithUpdateTime,
|
||||
}
|
||||
}
|
||||
|
||||
func convertWorkspaceMemoRelatedSettingToStore(setting *v1pb.WorkspaceMemoRelatedSetting) *storepb.WorkspaceMemoRelatedSetting {
|
||||
if setting == nil {
|
||||
return nil
|
||||
}
|
||||
return &storepb.WorkspaceMemoRelatedSetting{
|
||||
DisallowPublicVisible: setting.DisallowPublicVisible,
|
||||
DisplayWithUpdateTime: setting.DisplayWithUpdateTime,
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user