refactor: update memo tags

This commit is contained in:
Steven
2024-05-08 20:03:01 +08:00
parent 2c270438ec
commit d0655ece53
61 changed files with 1815 additions and 2855 deletions

View File

@ -6,13 +6,16 @@ import (
"context"
"fmt"
"log/slog"
"slices"
"time"
"github.com/google/cel-go/cel"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors"
"github.com/yourselfhosted/gomark/ast"
"github.com/yourselfhosted/gomark/parser"
"github.com/yourselfhosted/gomark/parser/tokenizer"
"github.com/yourselfhosted/gomark/restore"
expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
@ -59,6 +62,11 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
if len(create.Content) > contentLengthLimit {
return nil, status.Errorf(codes.InvalidArgument, "content too long (max %d characters)", contentLengthLimit)
}
tags, err := ExtractTagsFromContent(create.Content)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to extract tags")
}
create.Tags = tags
memo, err := s.Store.CreateMemo(ctx, create)
if err != nil {
@ -215,7 +223,10 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
return nil, status.Errorf(codes.NotFound, "memo not found")
}
user, _ := getCurrentUser(ctx, s.Store)
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
if memo.CreatorID != user.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
@ -227,7 +238,19 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
}
for _, path := range request.UpdateMask.Paths {
if path == "content" {
contentLengthLimit, err := s.getContentLengthLimit(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get content length limit")
}
if len(request.Memo.Content) > contentLengthLimit {
return nil, status.Errorf(codes.InvalidArgument, "content too long (max %d characters)", contentLengthLimit)
}
update.Content = &request.Memo.Content
tags, err := ExtractTagsFromContent(*update.Content)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to extract tags")
}
update.Tags = &tags
} else if path == "uid" {
update.UID = &request.Memo.Name
if !util.UIDMatcher.MatchString(*update.UID) {
@ -259,13 +282,6 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR
}
}
}
contentLengthLimit, err := s.getContentLengthLimit(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get content length limit")
}
if update.Content != nil && len(*update.Content) > contentLengthLimit {
return nil, status.Errorf(codes.InvalidArgument, "content too long (max %d characters)", contentLengthLimit)
}
if err = s.Store.UpdateMemo(ctx, update); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update memo")
@ -531,6 +547,158 @@ func (s *APIV1Service) ExportMemos(ctx context.Context, request *v1pb.ExportMemo
}, nil
}
func (s *APIV1Service) ListMemoTags(ctx context.Context, request *v1pb.ListMemoTagsRequest) (*v1pb.ListMemoTagsResponse, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
normalRowStatus := store.Normal
memoFind := &store.FindMemo{
CreatorID: &user.ID,
RowStatus: &normalRowStatus,
ExcludeComments: true,
// Default exclude content for performance.
ExcludeContent: true,
}
if (request.Parent) != "memos/-" {
memoID, err := ExtractMemoIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memoFind.ID = &memoID
}
if request.Rebuild {
// If rebuild is true, include content to extract tags.
memoFind.ExcludeContent = false
}
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos")
}
if request.Rebuild {
for _, memo := range memos {
tags, err := ExtractTagsFromContent(memo.Content)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to extract tags")
}
memo.Tags = tags
if err := s.Store.UpdateMemo(ctx, &store.UpdateMemo{
ID: memo.ID,
Tags: &tags,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update memo")
}
}
}
tagAmounts := map[string]int32{}
for _, memo := range memos {
for _, tag := range memo.Tags {
tagAmounts[tag]++
}
}
return &v1pb.ListMemoTagsResponse{
TagAmounts: tagAmounts,
}, nil
}
func (s *APIV1Service) RenameMemoTag(ctx context.Context, request *v1pb.RenameMemoTagRequest) (*emptypb.Empty, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
memoFind := &store.FindMemo{
CreatorID: &user.ID,
Tag: &request.OldTag,
ExcludeComments: true,
}
if (request.Parent) != "memos/-" {
memoID, err := ExtractMemoIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memoFind.ID = &memoID
}
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos")
}
for _, memo := range memos {
nodes, err := parser.Parse(tokenizer.Tokenize(memo.Content))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to parse memo: %v", err)
}
TraverseASTNodes(nodes, func(node ast.Node) {
if tag, ok := node.(*ast.Tag); ok && tag.Content == request.OldTag {
tag.Content = request.NewTag
}
})
content := restore.Restore(nodes)
tags := util.ReplaceString(memo.Tags, request.OldTag, request.NewTag)
if err := s.Store.UpdateMemo(ctx, &store.UpdateMemo{
ID: memo.ID,
Content: &content,
Tags: &tags,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to update memo: %v", err)
}
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) DeleteMemoTag(ctx context.Context, request *v1pb.DeleteMemoTagRequest) (*emptypb.Empty, error) {
user, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
memoFind := &store.FindMemo{
CreatorID: &user.ID,
Tag: &request.Tag,
ExcludeContent: true,
ExcludeComments: true,
}
if (request.Parent) != "memos/-" {
memoID, err := ExtractMemoIDFromName(request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memoFind.ID = &memoID
}
memos, err := s.Store.ListMemos(ctx, memoFind)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos")
}
for _, memo := range memos {
if request.DeleteRelatedMemos {
err := s.Store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete memo")
}
} else {
archived := store.Archived
err := s.Store.UpdateMemo(ctx, &store.UpdateMemo{
ID: memo.ID,
RowStatus: &archived,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update memo")
}
}
}
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo) (*v1pb.Memo, error) {
displayTs := memo.CreatedTs
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
@ -578,6 +746,7 @@ func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Mem
Content: memo.Content,
Nodes: convertFromASTNodes(nodes),
Visibility: convertVisibilityFromStore(memo.Visibility),
Tags: memo.Tags,
Pinned: memo.Pinned,
ParentId: memo.ParentID,
Relations: listMemoRelationsResponse.Relations,
@ -628,6 +797,9 @@ func (s *APIV1Service) buildMemoFindWithFilter(ctx context.Context, find *store.
if len(filter.Visibilities) > 0 {
find.VisibilityList = filter.Visibilities
}
if filter.Tag != nil {
find.Tag = filter.Tag
}
if filter.OrderByPinned {
find.OrderByPinned = filter.OrderByPinned
}
@ -720,6 +892,7 @@ func (s *APIV1Service) getContentLengthLimit(ctx context.Context) (int, error) {
var SearchMemosFilterCELAttributes = []cel.EnvOption{
cel.Variable("content_search", cel.ListType(cel.StringType)),
cel.Variable("visibilities", cel.ListType(cel.StringType)),
cel.Variable("tag", cel.StringType),
cel.Variable("order_by_pinned", cel.BoolType),
cel.Variable("display_time_before", cel.IntType),
cel.Variable("display_time_after", cel.IntType),
@ -734,6 +907,7 @@ var SearchMemosFilterCELAttributes = []cel.EnvOption{
type SearchMemosFilter struct {
ContentSearch []string
Visibilities []store.Visibility
Tag *string
OrderByPinned bool
DisplayTimeBefore *int64
DisplayTimeAfter *int64
@ -782,6 +956,9 @@ func findSearchMemosField(callExpr *expr.Expr_Call, filter *SearchMemosFilter) {
visibilities = append(visibilities, store.Visibility(value))
}
filter.Visibilities = visibilities
} else if idExpr.Name == "tag" {
tag := callExpr.Args[1].GetConstExpr().GetStringValue()
filter.Tag = &tag
} else if idExpr.Name == "order_by_pinned" {
value := callExpr.Args[1].GetConstExpr().GetBoolValue()
filter.OrderByPinned = value
@ -821,6 +998,45 @@ func findSearchMemosField(callExpr *expr.Expr_Call, filter *SearchMemosFilter) {
}
}
func ExtractTagsFromContent(content string) ([]string, error) {
nodes, err := parser.Parse(tokenizer.Tokenize(content))
if err != nil {
return nil, errors.Wrap(err, "failed to parse content")
}
tags := []string{}
TraverseASTNodes(nodes, func(node ast.Node) {
if tagNode, ok := node.(*ast.Tag); ok {
tag := tagNode.Content
if !slices.Contains(tags, tag) {
tags = append(tags, tag)
}
}
})
return tags, nil
}
func TraverseASTNodes(nodes []ast.Node, fn func(ast.Node)) {
for _, node := range nodes {
fn(node)
switch n := node.(type) {
case *ast.Paragraph:
TraverseASTNodes(n.Children, fn)
case *ast.Heading:
TraverseASTNodes(n.Children, fn)
case *ast.Blockquote:
TraverseASTNodes(n.Children, fn)
case *ast.OrderedList:
TraverseASTNodes(n.Children, fn)
case *ast.UnorderedList:
TraverseASTNodes(n.Children, fn)
case *ast.TaskList:
TraverseASTNodes(n.Children, fn)
case *ast.Bold:
TraverseASTNodes(n.Children, fn)
}
}
}
// DispatchMemoCreatedWebhook dispatches webhook when memo is created.
func (s *APIV1Service) DispatchMemoCreatedWebhook(ctx context.Context, memo *v1pb.Memo) error {
return s.dispatchMemoRelatedWebhook(ctx, memo, "memos.memo.created")