From 89ee9d50047bd5b3ab1bd3c140a8c97d26050094 Mon Sep 17 00:00:00 2001
From: tobi <31960611+tsmethurst@users.noreply.github.com>
Date: Fri, 21 Jul 2023 14:56:38 +0200
Subject: [PATCH] [bugfix] Return all accounts when list accounts limit <= 0
(#2014)
---
docs/api/swagger.yaml | 6 +-
internal/api/client/lists/listaccounts.go | 30 ++-
.../api/client/lists/listaccounts_test.go | 246 ++++++++++++++++++
internal/api/client/lists/lists_test.go | 108 ++++++++
internal/api/util/parsequery.go | 15 +-
internal/processing/list/get.go | 69 ++++-
6 files changed, 451 insertions(+), 23 deletions(-)
create mode 100644 internal/api/client/lists/listaccounts_test.go
create mode 100644 internal/api/client/lists/lists_test.go
diff --git a/docs/api/swagger.yaml b/docs/api/swagger.yaml
index a0ee30cab..eb9ec82ee 100644
--- a/docs/api/swagger.yaml
+++ b/docs/api/swagger.yaml
@@ -4873,9 +4873,11 @@ paths:
in: query
name: min_id
type: string
- - default: 20
- description: Number of accounts to return.
+ - default: 40
+ description: 'Number of accounts to return. If set to 0 explicitly, all accounts in the list will be returned, and pagination headers will not be used. This is a workaround for Mastodon API peculiarities: https://docs.joinmastodon.org/methods/lists/#query-parameters.'
in: query
+ maximum: 80
+ minimum: 0
name: limit
type: integer
produces:
diff --git a/internal/api/client/lists/listaccounts.go b/internal/api/client/lists/listaccounts.go
index da902384f..6feffb1e8 100644
--- a/internal/api/client/lists/listaccounts.go
+++ b/internal/api/client/lists/listaccounts.go
@@ -79,8 +79,13 @@ import (
// -
// name: limit
// type: integer
-// description: Number of accounts to return.
-// default: 20
+// description: >-
+// Number of accounts to return.
+// If set to 0 explicitly, all accounts in the list will be returned, and pagination headers will not be used.
+// This is a workaround for Mastodon API peculiarities: https://docs.joinmastodon.org/methods/lists/#query-parameters.
+// default: 40
+// minimum: 0
+// maximum: 80
// in: query
// required: false
//
@@ -129,14 +134,31 @@ func (m *Module) ListAccountsGETHandler(c *gin.Context) {
return
}
- limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 20, 40, 1)
+ limit, errWithCode := apiutil.ParseLimit(c.Query(apiutil.LimitKey), 40, 80, 0)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
+ var (
+ ctx = c.Request.Context()
+ )
+
+ if limit == 0 {
+ // Return all accounts in the list without pagination.
+ accounts, errWithCode := m.processor.List().GetAllListAccounts(ctx, authed.Account, targetListID)
+ if errWithCode != nil {
+ apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
+ return
+ }
+
+ c.JSON(http.StatusOK, accounts)
+ return
+ }
+
+ // Return subset of accounts in the list with pagination.
resp, errWithCode := m.processor.List().GetListAccounts(
- c.Request.Context(),
+ ctx,
authed.Account,
targetListID,
c.Query(MaxIDKey),
diff --git a/internal/api/client/lists/listaccounts_test.go b/internal/api/client/lists/listaccounts_test.go
new file mode 100644
index 000000000..64e9ef768
--- /dev/null
+++ b/internal/api/client/lists/listaccounts_test.go
@@ -0,0 +1,246 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package lists_test
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/api/client/lists"
+ apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
+ apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/oauth"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type ListAccountsTestSuite struct {
+ ListsStandardTestSuite
+}
+
+func (suite *ListAccountsTestSuite) getListAccounts(
+ expectedHTTPStatus int,
+ expectedBody string,
+ listID string,
+ maxID string,
+ sinceID string,
+ minID string,
+ limit *int,
+) (
+ []*apimodel.Account,
+ string, // Link header
+ error,
+) {
+
+ var (
+ recorder = httptest.NewRecorder()
+ ctx, _ = testrig.CreateGinTestContext(recorder, nil)
+ )
+
+ // Prepare test context.
+ ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"])
+ ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"]))
+ ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"])
+ ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"])
+
+ // Inject path parameters.
+ ctx.AddParam("id", listID)
+
+ // Inject query parameters.
+ requestPath := config.GetProtocol() + "://" + config.GetHost() + "/api/" + lists.BasePath + "/" + listID + "/accounts"
+
+ if limit != nil {
+ requestPath += "?limit=" + strconv.Itoa(*limit)
+ } else {
+ requestPath += "?limit=40"
+ }
+ if maxID != "" {
+ requestPath += "&" + apiutil.MaxIDKey + "=" + maxID
+ }
+ if sinceID != "" {
+ requestPath += "&" + apiutil.SinceIDKey + "=" + sinceID
+ }
+ if minID != "" {
+ requestPath += "&" + apiutil.MinIDKey + "=" + minID
+ }
+
+ // Prepare test context request.
+ request := httptest.NewRequest(http.MethodGet, requestPath, nil)
+ request.Header.Set("accept", "application/json")
+ ctx.Request = request
+
+ // trigger the handler
+ suite.listsModule.ListAccountsGETHandler(ctx)
+
+ // read the response
+ result := recorder.Result()
+ defer result.Body.Close()
+
+ b, err := ioutil.ReadAll(result.Body)
+ if err != nil {
+ return nil, "", err
+ }
+
+ errs := gtserror.MultiError{}
+
+ // check code + body
+ if resultCode := recorder.Code; expectedHTTPStatus != resultCode {
+ errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode))
+ }
+
+ // if we got an expected body, return early
+ if expectedBody != "" {
+ if string(b) != expectedBody {
+ errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b)))
+ }
+ return nil, "", errs.Combine()
+ }
+
+ resp := []*apimodel.Account{}
+ if err := json.Unmarshal(b, &resp); err != nil {
+ return nil, "", err
+ }
+
+ return resp, result.Header.Get("Link"), nil
+}
+
+func (suite *ListAccountsTestSuite) TestGetListAccountsPaginatedDefaultLimit() {
+ var (
+ expectedHTTPStatus = 200
+ expectedBody = ""
+ listID = suite.testLists["local_account_1_list_1"].ID
+ maxID = ""
+ minID = ""
+ sinceID = ""
+ limit *int = nil
+ )
+
+ accounts, link, err := suite.getListAccounts(
+ expectedHTTPStatus,
+ expectedBody,
+ listID,
+ maxID,
+ sinceID,
+ minID,
+ limit,
+ )
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.Len(accounts, 2)
+ suite.Equal(
+ `; rel="next", `+
+ `; rel="prev"`,
+ link,
+ )
+}
+
+func (suite *ListAccountsTestSuite) TestGetListAccountsPaginatedNextPage() {
+ var (
+ expectedHTTPStatus = 200
+ expectedBody = ""
+ listID = suite.testLists["local_account_1_list_1"].ID
+ maxID = ""
+ minID = ""
+ sinceID = ""
+ limit *int = func() *int { l := 1; return &l }() // Set to 1.
+ )
+
+ // First response, ask for 1 account.
+ accounts, link, err := suite.getListAccounts(
+ expectedHTTPStatus,
+ expectedBody,
+ listID,
+ maxID,
+ sinceID,
+ minID,
+ limit,
+ )
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.Len(accounts, 1)
+ suite.Equal(
+ `; rel="next", `+
+ `; rel="prev"`,
+ link,
+ )
+
+ // Next response, ask for next 1 account.
+ maxID = "01H0G8FFM1AGQDRNGBGGX8CYJQ"
+ accounts, link, err = suite.getListAccounts(
+ expectedHTTPStatus,
+ expectedBody,
+ listID,
+ maxID,
+ sinceID,
+ minID,
+ limit,
+ )
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.Len(accounts, 1)
+ suite.Equal(
+ `; rel="next", `+
+ `; rel="prev"`,
+ link,
+ )
+}
+
+func (suite *ListAccountsTestSuite) TestGetListAccountsUnpaginated() {
+ var (
+ expectedHTTPStatus = 200
+ expectedBody = ""
+ listID = suite.testLists["local_account_1_list_1"].ID
+ maxID = ""
+ minID = ""
+ sinceID = ""
+ limit *int = func() *int { l := 0; return &l }() // Set to 0 explicitly.
+ )
+
+ accounts, link, err := suite.getListAccounts(
+ expectedHTTPStatus,
+ expectedBody,
+ listID,
+ maxID,
+ sinceID,
+ minID,
+ limit,
+ )
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ suite.Len(accounts, 2)
+ suite.Empty(link)
+}
+
+func TestListAccountsTestSuite(t *testing.T) {
+ suite.Run(t, new(ListAccountsTestSuite))
+}
diff --git a/internal/api/client/lists/lists_test.go b/internal/api/client/lists/lists_test.go
new file mode 100644
index 000000000..ebaf6998e
--- /dev/null
+++ b/internal/api/client/lists/lists_test.go
@@ -0,0 +1,108 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package lists_test
+
+import (
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/api/client/lists"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/email"
+ "github.com/superseriousbusiness/gotosocial/internal/federation"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/media"
+ "github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/superseriousbusiness/gotosocial/internal/storage"
+ "github.com/superseriousbusiness/gotosocial/internal/visibility"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type ListsStandardTestSuite struct {
+ // standard suite interfaces
+ suite.Suite
+ db db.DB
+ storage *storage.Driver
+ mediaManager *media.Manager
+ federator federation.Federator
+ processor *processing.Processor
+ emailSender email.Sender
+ state state.State
+
+ // standard suite models
+ testTokens map[string]*gtsmodel.Token
+ testClients map[string]*gtsmodel.Client
+ testApplications map[string]*gtsmodel.Application
+ testUsers map[string]*gtsmodel.User
+ testAccounts map[string]*gtsmodel.Account
+ testAttachments map[string]*gtsmodel.MediaAttachment
+ testStatuses map[string]*gtsmodel.Status
+ testEmojis map[string]*gtsmodel.Emoji
+ testEmojiCategories map[string]*gtsmodel.EmojiCategory
+ testLists map[string]*gtsmodel.List
+
+ // module being tested
+ listsModule *lists.Module
+}
+
+func (suite *ListsStandardTestSuite) SetupSuite() {
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+ suite.testAttachments = testrig.NewTestAttachments()
+ suite.testStatuses = testrig.NewTestStatuses()
+ suite.testEmojis = testrig.NewTestEmojis()
+ suite.testEmojiCategories = testrig.NewTestEmojiCategories()
+ suite.testLists = testrig.NewTestLists()
+}
+
+func (suite *ListsStandardTestSuite) SetupTest() {
+ suite.state.Caches.Init()
+ suite.state.Caches.Start()
+ testrig.StartWorkers(&suite.state)
+
+ testrig.InitTestConfig()
+ testrig.InitTestLog()
+
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.state.DB = suite.db
+ suite.storage = testrig.NewInMemoryStorage()
+ suite.state.Storage = suite.storage
+
+ testrig.StartTimelines(
+ &suite.state,
+ visibility.NewFilter(&suite.state),
+ testrig.NewTestTypeConverter(suite.db),
+ )
+
+ suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
+ suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager)
+ suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
+ suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager)
+ suite.listsModule = lists.New(suite.processor)
+
+ testrig.StandardDBSetup(suite.db, nil)
+ testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
+}
+
+func (suite *ListsStandardTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+ testrig.StandardStorageTeardown(suite.storage)
+ testrig.StopWorkers(&suite.state)
+}
diff --git a/internal/api/util/parsequery.go b/internal/api/util/parsequery.go
index 92105ef82..662870910 100644
--- a/internal/api/util/parsequery.go
+++ b/internal/api/util/parsequery.go
@@ -27,11 +27,12 @@ import (
const (
/* Common keys */
- IDKey = "id"
- LimitKey = "limit"
- LocalKey = "local"
- MaxIDKey = "max_id"
- MinIDKey = "min_id"
+ IDKey = "id"
+ LimitKey = "limit"
+ LocalKey = "local"
+ MaxIDKey = "max_id"
+ SinceIDKey = "since_id"
+ MinIDKey = "min_id"
/* Search keys */
@@ -76,10 +77,8 @@ func ParseLimit(value string, defaultValue int, max, min int) (int, gtserror.Wit
i, err := parseInt(value, defaultValue, max, min, LimitKey)
if err != nil {
return 0, err
- } else if i == 0 {
- // treat 0 as an empty query
- return defaultValue, nil
}
+
return i, nil
}
diff --git a/internal/processing/list/get.go b/internal/processing/list/get.go
index 0fc14f934..1a03898ed 100644
--- a/internal/processing/list/get.go
+++ b/internal/processing/list/get.go
@@ -75,6 +75,46 @@ func (p *Processor) GetAll(ctx context.Context, account *gtsmodel.Account) ([]*a
return apiLists, nil
}
+// GetAllListAccounts returns all accounts that are in the given list,
+// owned by the given account. There's no pagination for this endpoint.
+//
+// See https://docs.joinmastodon.org/methods/lists/#query-parameters:
+//
+// Limit: Integer. Maximum number of results. Defaults to 40 accounts.
+// Max 80 accounts. Set to 0 in order to get all accounts without pagination.
+func (p *Processor) GetAllListAccounts(
+ ctx context.Context,
+ account *gtsmodel.Account,
+ listID string,
+) ([]*apimodel.Account, gtserror.WithCode) {
+ // Ensure list exists + is owned by requesting account.
+ _, errWithCode := p.getList(
+ // Use barebones ctx; no embedded
+ // structs necessary for this call.
+ gtscontext.SetBarebones(ctx),
+ account.ID,
+ listID,
+ )
+ if errWithCode != nil {
+ return nil, errWithCode
+ }
+
+ // Get all entries for this list.
+ listEntries, err := p.state.DB.GetListEntries(ctx, listID, "", "", "", 0)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ err = gtserror.Newf("error getting list entries: %w", err)
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+
+ // Extract accounts from list entries + add them to response.
+ accounts := make([]*apimodel.Account, 0, len(listEntries))
+ p.accountsFromListEntries(ctx, listEntries, func(acc *apimodel.Account) {
+ accounts = append(accounts, acc)
+ })
+
+ return accounts, nil
+}
+
// GetListAccounts returns accounts that are in the given list, owned by the given account.
// The additional parameters can be used for paging.
func (p *Processor) GetListAccounts(
@@ -121,6 +161,25 @@ func (p *Processor) GetListAccounts(
prevMinIDValue = listEntries[0].ID
)
+ // Extract accounts from list entries + add them to response.
+ p.accountsFromListEntries(ctx, listEntries, func(acc *apimodel.Account) {
+ items = append(items, acc)
+ })
+
+ return util.PackagePageableResponse(util.PageableResponseParams{
+ Items: items,
+ Path: "/api/v1/lists/" + listID + "/accounts",
+ NextMaxIDValue: nextMaxIDValue,
+ PrevMinIDValue: prevMinIDValue,
+ Limit: limit,
+ })
+}
+
+func (p *Processor) accountsFromListEntries(
+ ctx context.Context,
+ listEntries []*gtsmodel.ListEntry,
+ appendAcc func(*apimodel.Account),
+) {
// For each list entry, we want the account it points to.
// To get this, we need to first get the follow that the
// list entry pertains to, then extract the target account
@@ -144,14 +203,6 @@ func (p *Processor) GetListAccounts(
continue
}
- items = append(items, apiAccount)
+ appendAcc(apiAccount)
}
-
- return util.PackagePageableResponse(util.PageableResponseParams{
- Items: items,
- Path: "/api/v1/lists/" + listID + "/accounts",
- NextMaxIDValue: nextMaxIDValue,
- PrevMinIDValue: prevMinIDValue,
- Limit: limit,
- })
}