From 89ee9d50047bd5b3ab1bd3c140a8c97d26050094 Mon Sep 17 00:00:00 2001 From: tobi <31960611+tsmethurst@users.noreply.github.com> Date: Fri, 21 Jul 2023 14:56:38 +0200 Subject: [PATCH] [bugfix] Return all accounts when list accounts limit <= 0 (#2014) --- docs/api/swagger.yaml | 6 +- internal/api/client/lists/listaccounts.go | 30 ++- .../api/client/lists/listaccounts_test.go | 246 ++++++++++++++++++ internal/api/client/lists/lists_test.go | 108 ++++++++ internal/api/util/parsequery.go | 15 +- internal/processing/list/get.go | 69 ++++- 6 files changed, 451 insertions(+), 23 deletions(-) create mode 100644 internal/api/client/lists/listaccounts_test.go create mode 100644 internal/api/client/lists/lists_test.go diff --git a/docs/api/swagger.yaml b/docs/api/swagger.yaml index a0ee30cab..eb9ec82ee 100644 --- a/docs/api/swagger.yaml +++ b/docs/api/swagger.yaml @@ -4873,9 +4873,11 @@ paths: in: query name: min_id type: string - - default: 20 - description: Number of accounts to return. + - default: 40 + description: 'Number of accounts to return. If set to 0 explicitly, all accounts in the list will be returned, and pagination headers will not be used. This is a workaround for Mastodon API peculiarities: https://docs.joinmastodon.org/methods/lists/#query-parameters.' in: query + maximum: 80 + minimum: 0 name: limit type: integer produces: diff --git a/internal/api/client/lists/listaccounts.go b/internal/api/client/lists/listaccounts.go index da902384f..6feffb1e8 100644 --- a/internal/api/client/lists/listaccounts.go +++ b/internal/api/client/lists/listaccounts.go @@ -79,8 +79,13 @@ import ( // - // name: limit // type: integer -// description: Number of accounts to return. -// default: 20 +// description: >- +// Number of accounts to return. +// If set to 0 explicitly, all accounts in the list will be returned, and pagination headers will not be used. +// This is a workaround for Mastodon API peculiarities: https://docs.joinmastodon.org/methods/lists/#query-parameters. +// default: 40 +// minimum: 0 +// maximum: 80 // in: query // required: false // @@ -129,14 +134,31 @@ func (m *Module) ListAccountsGETHandler(c *gin.Context) { return } - limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 20, 40, 1) + limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 40, 80, 0) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } + var ( + ctx = c.Request.Context() + ) + + if limit == 0 { + // Return all accounts in the list without pagination. + accounts, errWithCode := m.processor.List().GetAllListAccounts(ctx, authed.Account, targetListID) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + c.JSON(http.StatusOK, accounts) + return + } + + // Return subset of accounts in the list with pagination. resp, errWithCode := m.processor.List().GetListAccounts( - c.Request.Context(), + ctx, authed.Account, targetListID, c.Query(MaxIDKey), diff --git a/internal/api/client/lists/listaccounts_test.go b/internal/api/client/lists/listaccounts_test.go new file mode 100644 index 000000000..64e9ef768 --- /dev/null +++ b/internal/api/client/lists/listaccounts_test.go @@ -0,0 +1,246 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package lists_test + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/api/client/lists" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type ListAccountsTestSuite struct { + ListsStandardTestSuite +} + +func (suite *ListAccountsTestSuite) getListAccounts( + expectedHTTPStatus int, + expectedBody string, + listID string, + maxID string, + sinceID string, + minID string, + limit *int, +) ( + []*apimodel.Account, + string, // Link header + error, +) { + + var ( + recorder = httptest.NewRecorder() + ctx, _ = testrig.CreateGinTestContext(recorder, nil) + ) + + // Prepare test context. + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + + // Inject path parameters. + ctx.AddParam("id", listID) + + // Inject query parameters. + requestPath := config.GetProtocol() + "://" + config.GetHost() + "/api/" + lists.BasePath + "/" + listID + "/accounts" + + if limit != nil { + requestPath += "?limit=" + strconv.Itoa(*limit) + } else { + requestPath += "?limit=40" + } + if maxID != "" { + requestPath += "&" + apiutil.MaxIDKey + "=" + maxID + } + if sinceID != "" { + requestPath += "&" + apiutil.SinceIDKey + "=" + sinceID + } + if minID != "" { + requestPath += "&" + apiutil.MinIDKey + "=" + minID + } + + // Prepare test context request. + request := httptest.NewRequest(http.MethodGet, requestPath, nil) + request.Header.Set("accept", "application/json") + ctx.Request = request + + // trigger the handler + suite.listsModule.ListAccountsGETHandler(ctx) + + // read the response + result := recorder.Result() + defer result.Body.Close() + + b, err := ioutil.ReadAll(result.Body) + if err != nil { + return nil, "", err + } + + errs := gtserror.MultiError{} + + // check code + body + if resultCode := recorder.Code; expectedHTTPStatus != resultCode { + errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode)) + } + + // if we got an expected body, return early + if expectedBody != "" { + if string(b) != expectedBody { + errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b))) + } + return nil, "", errs.Combine() + } + + resp := []*apimodel.Account{} + if err := json.Unmarshal(b, &resp); err != nil { + return nil, "", err + } + + return resp, result.Header.Get("Link"), nil +} + +func (suite *ListAccountsTestSuite) TestGetListAccountsPaginatedDefaultLimit() { + var ( + expectedHTTPStatus = 200 + expectedBody = "" + listID = suite.testLists["local_account_1_list_1"].ID + maxID = "" + minID = "" + sinceID = "" + limit *int = nil + ) + + accounts, link, err := suite.getListAccounts( + expectedHTTPStatus, + expectedBody, + listID, + maxID, + sinceID, + minID, + limit, + ) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Len(accounts, 2) + suite.Equal( + `; rel="next", `+ + `; rel="prev"`, + link, + ) +} + +func (suite *ListAccountsTestSuite) TestGetListAccountsPaginatedNextPage() { + var ( + expectedHTTPStatus = 200 + expectedBody = "" + listID = suite.testLists["local_account_1_list_1"].ID + maxID = "" + minID = "" + sinceID = "" + limit *int = func() *int { l := 1; return &l }() // Set to 1. + ) + + // First response, ask for 1 account. + accounts, link, err := suite.getListAccounts( + expectedHTTPStatus, + expectedBody, + listID, + maxID, + sinceID, + minID, + limit, + ) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Len(accounts, 1) + suite.Equal( + `; rel="next", `+ + `; rel="prev"`, + link, + ) + + // Next response, ask for next 1 account. + maxID = "01H0G8FFM1AGQDRNGBGGX8CYJQ" + accounts, link, err = suite.getListAccounts( + expectedHTTPStatus, + expectedBody, + listID, + maxID, + sinceID, + minID, + limit, + ) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Len(accounts, 1) + suite.Equal( + `; rel="next", `+ + `; rel="prev"`, + link, + ) +} + +func (suite *ListAccountsTestSuite) TestGetListAccountsUnpaginated() { + var ( + expectedHTTPStatus = 200 + expectedBody = "" + listID = suite.testLists["local_account_1_list_1"].ID + maxID = "" + minID = "" + sinceID = "" + limit *int = func() *int { l := 0; return &l }() // Set to 0 explicitly. + ) + + accounts, link, err := suite.getListAccounts( + expectedHTTPStatus, + expectedBody, + listID, + maxID, + sinceID, + minID, + limit, + ) + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Len(accounts, 2) + suite.Empty(link) +} + +func TestListAccountsTestSuite(t *testing.T) { + suite.Run(t, new(ListAccountsTestSuite)) +} diff --git a/internal/api/client/lists/lists_test.go b/internal/api/client/lists/lists_test.go new file mode 100644 index 000000000..ebaf6998e --- /dev/null +++ b/internal/api/client/lists/lists_test.go @@ -0,0 +1,108 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package lists_test + +import ( + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/api/client/lists" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/email" + "github.com/superseriousbusiness/gotosocial/internal/federation" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/media" + "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/visibility" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type ListsStandardTestSuite struct { + // standard suite interfaces + suite.Suite + db db.DB + storage *storage.Driver + mediaManager *media.Manager + federator federation.Federator + processor *processing.Processor + emailSender email.Sender + state state.State + + // standard suite models + testTokens map[string]*gtsmodel.Token + testClients map[string]*gtsmodel.Client + testApplications map[string]*gtsmodel.Application + testUsers map[string]*gtsmodel.User + testAccounts map[string]*gtsmodel.Account + testAttachments map[string]*gtsmodel.MediaAttachment + testStatuses map[string]*gtsmodel.Status + testEmojis map[string]*gtsmodel.Emoji + testEmojiCategories map[string]*gtsmodel.EmojiCategory + testLists map[string]*gtsmodel.List + + // module being tested + listsModule *lists.Module +} + +func (suite *ListsStandardTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() + suite.testAttachments = testrig.NewTestAttachments() + suite.testStatuses = testrig.NewTestStatuses() + suite.testEmojis = testrig.NewTestEmojis() + suite.testEmojiCategories = testrig.NewTestEmojiCategories() + suite.testLists = testrig.NewTestLists() +} + +func (suite *ListsStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + suite.state.Caches.Start() + testrig.StartWorkers(&suite.state) + + testrig.InitTestConfig() + testrig.InitTestLog() + + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db + suite.storage = testrig.NewInMemoryStorage() + suite.state.Storage = suite.storage + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + testrig.NewTestTypeConverter(suite.db), + ) + + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) + suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) + suite.listsModule = lists.New(suite.processor) + + testrig.StandardDBSetup(suite.db, nil) + testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") +} + +func (suite *ListsStandardTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) + testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) +} diff --git a/internal/api/util/parsequery.go b/internal/api/util/parsequery.go index 92105ef82..662870910 100644 --- a/internal/api/util/parsequery.go +++ b/internal/api/util/parsequery.go @@ -27,11 +27,12 @@ import ( const ( /* Common keys */ - IDKey = "id" - LimitKey = "limit" - LocalKey = "local" - MaxIDKey = "max_id" - MinIDKey = "min_id" + IDKey = "id" + LimitKey = "limit" + LocalKey = "local" + MaxIDKey = "max_id" + SinceIDKey = "since_id" + MinIDKey = "min_id" /* Search keys */ @@ -76,10 +77,8 @@ func ParseLimit(value string, defaultValue int, max, min int) (int, gtserror.Wit i, err := parseInt(value, defaultValue, max, min, LimitKey) if err != nil { return 0, err - } else if i == 0 { - // treat 0 as an empty query - return defaultValue, nil } + return i, nil } diff --git a/internal/processing/list/get.go b/internal/processing/list/get.go index 0fc14f934..1a03898ed 100644 --- a/internal/processing/list/get.go +++ b/internal/processing/list/get.go @@ -75,6 +75,46 @@ func (p *Processor) GetAll(ctx context.Context, account *gtsmodel.Account) ([]*a return apiLists, nil } +// GetAllListAccounts returns all accounts that are in the given list, +// owned by the given account. There's no pagination for this endpoint. +// +// See https://docs.joinmastodon.org/methods/lists/#query-parameters: +// +// Limit: Integer. Maximum number of results. Defaults to 40 accounts. +// Max 80 accounts. Set to 0 in order to get all accounts without pagination. +func (p *Processor) GetAllListAccounts( + ctx context.Context, + account *gtsmodel.Account, + listID string, +) ([]*apimodel.Account, gtserror.WithCode) { + // Ensure list exists + is owned by requesting account. + _, errWithCode := p.getList( + // Use barebones ctx; no embedded + // structs necessary for this call. + gtscontext.SetBarebones(ctx), + account.ID, + listID, + ) + if errWithCode != nil { + return nil, errWithCode + } + + // Get all entries for this list. + listEntries, err := p.state.DB.GetListEntries(ctx, listID, "", "", "", 0) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + err = gtserror.Newf("error getting list entries: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + // Extract accounts from list entries + add them to response. + accounts := make([]*apimodel.Account, 0, len(listEntries)) + p.accountsFromListEntries(ctx, listEntries, func(acc *apimodel.Account) { + accounts = append(accounts, acc) + }) + + return accounts, nil +} + // GetListAccounts returns accounts that are in the given list, owned by the given account. // The additional parameters can be used for paging. func (p *Processor) GetListAccounts( @@ -121,6 +161,25 @@ func (p *Processor) GetListAccounts( prevMinIDValue = listEntries[0].ID ) + // Extract accounts from list entries + add them to response. + p.accountsFromListEntries(ctx, listEntries, func(acc *apimodel.Account) { + items = append(items, acc) + }) + + return util.PackagePageableResponse(util.PageableResponseParams{ + Items: items, + Path: "/api/v1/lists/" + listID + "/accounts", + NextMaxIDValue: nextMaxIDValue, + PrevMinIDValue: prevMinIDValue, + Limit: limit, + }) +} + +func (p *Processor) accountsFromListEntries( + ctx context.Context, + listEntries []*gtsmodel.ListEntry, + appendAcc func(*apimodel.Account), +) { // For each list entry, we want the account it points to. // To get this, we need to first get the follow that the // list entry pertains to, then extract the target account @@ -144,14 +203,6 @@ func (p *Processor) GetListAccounts( continue } - items = append(items, apiAccount) + appendAcc(apiAccount) } - - return util.PackagePageableResponse(util.PageableResponseParams{ - Items: items, - Path: "/api/v1/lists/" + listID + "/accounts", - NextMaxIDValue: nextMaxIDValue, - PrevMinIDValue: prevMinIDValue, - Limit: limit, - }) }