diff --git a/api/v2/memo_service.go b/api/v2/memo_service.go index 3bedcc72..e45adc93 100644 --- a/api/v2/memo_service.go +++ b/api/v2/memo_service.go @@ -82,9 +82,12 @@ func (s *APIV2Service) CreateMemo(ctx context.Context, request *apiv2pb.CreateMe } func (s *APIV2Service) ListMemos(ctx context.Context, request *apiv2pb.ListMemosRequest) (*apiv2pb.ListMemosResponse, error) { - memoFind, err := s.buildFindMemosWithFilter(ctx, request.Filter, true) - if err != nil { - return nil, err + memoFind := &store.FindMemo{ + // Exclude comments by default. + ExcludeComments: true, + } + if err := s.buildMemoFindWithFilter(ctx, memoFind, request.Filter); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "failed to build find memos with filter") } var limit, offset int @@ -435,52 +438,8 @@ func (s *APIV2Service) GetUserMemosStats(ctx context.Context, request *apiv2pb.G ExcludeComments: true, ExcludeContent: true, } - displayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get memo display with updated ts setting value") - } - if displayWithUpdatedTs { - memoFind.OrderByUpdatedTs = true - } - if request.Filter != "" { - filter, err := parseListMemosFilter(request.Filter) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err) - } - if len(filter.ContentSearch) > 0 { - memoFind.ContentSearch = filter.ContentSearch - } - if len(filter.Visibilities) > 0 { - memoFind.VisibilityList = filter.Visibilities - } - if filter.OrderByPinned { - memoFind.OrderByPinned = filter.OrderByPinned - } - if filter.DisplayTimeAfter != nil { - displayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get memo display with updated ts setting value") - } - if displayWithUpdatedTs { - memoFind.UpdatedTsAfter = filter.DisplayTimeAfter - } else { - memoFind.CreatedTsAfter = filter.DisplayTimeAfter - } - } - if filter.DisplayTimeBefore != nil { - displayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get memo display with updated ts setting value") - } - if displayWithUpdatedTs { - memoFind.UpdatedTsBefore = filter.DisplayTimeBefore - } else { - memoFind.CreatedTsBefore = filter.DisplayTimeBefore - } - } - if filter.RowStatus != nil { - memoFind.RowStatus = filter.RowStatus - } + if err := s.buildMemoFindWithFilter(ctx, memoFind, request.Filter); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "failed to build find memos with filter") } memos, err := s.Store.ListMemos(ctx, memoFind) @@ -493,6 +452,10 @@ func (s *APIV2Service) GetUserMemosStats(ctx context.Context, request *apiv2pb.G return nil, status.Errorf(codes.Internal, "invalid timezone location") } + displayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get memo display with updated ts setting value") + } stats := make(map[string]int32) for _, memo := range memos { displayTs := memo.CreatedTs @@ -509,8 +472,13 @@ func (s *APIV2Service) GetUserMemosStats(ctx context.Context, request *apiv2pb.G } func (s *APIV2Service) ExportMemos(ctx context.Context, request *apiv2pb.ExportMemosRequest) (*apiv2pb.ExportMemosResponse, error) { - memoFind, err := s.buildFindMemosWithFilter(ctx, request.Filter, true) - if err != nil { + normalRowStatus := store.Normal + memoFind := &store.FindMemo{ + RowStatus: &normalRowStatus, + // Exclude comments by default. + ExcludeComments: true, + } + if err := s.buildMemoFindWithFilter(ctx, memoFind, request.Filter); err != nil { return nil, status.Errorf(codes.Internal, "failed to build find memos with filter") } @@ -521,7 +489,6 @@ func (s *APIV2Service) ExportMemos(ctx context.Context, request *apiv2pb.ExportM buf := new(bytes.Buffer) writer := zip.NewWriter(buf) - for _, memo := range memos { memoMessage, err := s.convertMemoFromStore(ctx, memo) if err != nil { @@ -536,9 +503,7 @@ func (s *APIV2Service) ExportMemos(ctx context.Context, request *apiv2pb.ExportM return nil, status.Errorf(codes.Internal, "Failed to write to memo file") } } - - err = writer.Close() - if err != nil { + if err := writer.Close(); err != nil { return nil, status.Errorf(codes.Internal, "Failed to close zip file writer") } @@ -768,88 +733,90 @@ func (s *APIV2Service) dispatchMemoRelatedWebhook(ctx context.Context, memo *api return nil } -func (s *APIV2Service) buildFindMemosWithFilter(ctx context.Context, filter string, excludeComments bool) (*store.FindMemo, error) { - memoFind := &store.FindMemo{ - // Exclude comments by default. - ExcludeComments: excludeComments, +func (s *APIV2Service) buildMemoFindWithFilter(ctx context.Context, find *store.FindMemo, filter string) error { + user, _ := getCurrentUser(ctx, s.Store) + if find == nil { + find = &store.FindMemo{} } if filter != "" { filter, err := parseListMemosFilter(filter) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err) + return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err) } if len(filter.ContentSearch) > 0 { - memoFind.ContentSearch = filter.ContentSearch + find.ContentSearch = filter.ContentSearch } if len(filter.Visibilities) > 0 { - memoFind.VisibilityList = filter.Visibilities + find.VisibilityList = filter.Visibilities } if filter.OrderByPinned { - memoFind.OrderByPinned = filter.OrderByPinned + find.OrderByPinned = filter.OrderByPinned } if filter.DisplayTimeAfter != nil { displayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get memo display with updated ts setting value") + return status.Errorf(codes.Internal, "failed to get memo display with updated ts setting value") } if displayWithUpdatedTs { - memoFind.UpdatedTsAfter = filter.DisplayTimeAfter + find.UpdatedTsAfter = filter.DisplayTimeAfter } else { - memoFind.CreatedTsAfter = filter.DisplayTimeAfter + find.CreatedTsAfter = filter.DisplayTimeAfter } } if filter.DisplayTimeBefore != nil { displayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get memo display with updated ts setting value") + return status.Errorf(codes.Internal, "failed to get memo display with updated ts setting value") } if displayWithUpdatedTs { - memoFind.UpdatedTsBefore = filter.DisplayTimeBefore + find.UpdatedTsBefore = filter.DisplayTimeBefore } else { - memoFind.CreatedTsBefore = filter.DisplayTimeBefore + find.CreatedTsBefore = filter.DisplayTimeBefore } } if filter.Creator != nil { username, err := ExtractUsernameFromName(*filter.Creator) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid creator name") + return status.Errorf(codes.InvalidArgument, "invalid creator name") } user, err := s.Store.GetUser(ctx, &store.FindUser{ Username: &username, }) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get user") + return status.Errorf(codes.Internal, "failed to get user") } if user == nil { - return nil, status.Errorf(codes.NotFound, "user not found") + return status.Errorf(codes.NotFound, "user not found") } - memoFind.CreatorID = &user.ID + find.CreatorID = &user.ID } if filter.RowStatus != nil { - memoFind.RowStatus = filter.RowStatus + find.RowStatus = filter.RowStatus } } else { - return nil, status.Errorf(codes.InvalidArgument, "filter is required") + // If no filter is provided, check if the user is authenticated. + if user == nil { + return status.Errorf(codes.InvalidArgument, "filter is required") + } } - user, _ := getCurrentUser(ctx, s.Store) // If the user is not authenticated, only public memos are visible. if user == nil { - memoFind.VisibilityList = []store.Visibility{store.Public} + find.VisibilityList = []store.Visibility{store.Public} } - if user != nil && memoFind.CreatorID != nil && *memoFind.CreatorID != user.ID { - memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected} + if user != nil && find.CreatorID != nil && *find.CreatorID != user.ID { + find.VisibilityList = []store.Visibility{store.Public, store.Protected} } displayWithUpdatedTs, err := s.getMemoDisplayWithUpdatedTsSettingValue(ctx) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get memo display with updated ts setting value") + return status.Errorf(codes.Internal, "failed to get memo display with updated ts setting value") } if displayWithUpdatedTs { - memoFind.OrderByUpdatedTs = true + find.OrderByUpdatedTs = true } - return memoFind, nil + return nil } func convertMemoToWebhookPayload(memo *apiv2pb.Memo) *webhook.WebhookPayload {