diff --git a/.golangci.yml b/.golangci.yml index a517026db..83ea9d20d 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -83,3 +83,12 @@ linters-settings: # Enable all checks, but disable SA1012: nil context passing. # See: https://staticcheck.io/docs/configuration/options/#checks checks: ["all", "-SA1012"] + +issues: + exclude-rules: + # Exclude VSCode custom folding region comments in files that use them. + # Already fixed in go-critic and can be removed next time go-critic is updated. + - linters: + - gocritic + path: internal/db/filter.go + text: 'commentFormatting: put a space between `//` and comment text' diff --git a/docs/api/swagger.yaml b/docs/api/swagger.yaml index 2c4e88f2d..2d5e9bed8 100644 --- a/docs/api/swagger.yaml +++ b/docs/api/swagger.yaml @@ -1209,6 +1209,61 @@ definitions: type: object x-go-name: Field x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model + filterContext: + description: v1 and v2 filter APIs use the same set of contexts. + title: FilterContext represents the context in which to apply a filter. + type: string + x-go-name: FilterContext + x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model + filterV1: + description: |- + Note that v1 filters are mapped to v2 filters and v2 filter keywords internally. + If whole_word is true, client app should do: + Define ‘word constituent character’ for your app. In the official implementation, it’s [A-Za-z0-9_] in JavaScript, and [[:word:]] in Ruby. + Ruby uses the POSIX character class (Letter | Mark | Decimal_Number | Connector_Punctuation). + If the phrase starts with a word character, and if the previous character before matched range is a word character, its matched range should be treated to not match. + If the phrase ends with a word character, and if the next character after matched range is a word character, its matched range should be treated to not match. + Please check app/javascript/mastodon/selectors/index.js and app/lib/feed_manager.rb in the Mastodon source code for more details. + properties: + context: + description: The contexts in which the filter should be applied. + example: + - home + - public + items: + $ref: '#/definitions/filterContext' + minLength: 1 + type: array + uniqueItems: true + x-go-name: Context + expires_at: + description: When the filter should no longer be applied. Null if the filter does not expire. + example: "2024-02-01T02:57:49Z" + type: string + x-go-name: ExpiresAt + id: + description: The ID of the filter in the database. + type: string + x-go-name: ID + irreversible: + description: Should matching entities be removed from the user's timelines/views, instead of hidden? + example: false + type: boolean + x-go-name: Irreversible + phrase: + description: The text to be filtered. + example: fnord + type: string + x-go-name: Phrase + whole_word: + description: Should the filter consider word boundaries? + example: true + type: boolean + x-go-name: WholeWord + title: FilterV1 represents a user-defined filter for determining which statuses should not be shown to the user. + type: object + x-go-name: FilterV1 + x-go-package: github.com/superseriousbusiness/gotosocial/internal/api/model headerFilterCreateRequest: properties: header: @@ -5570,6 +5625,246 @@ paths: summary: Get an array of all hashtags that you currently have featured on your profile. tags: - featured_tags + /api/v1/filters: + get: + operationId: filtersV1Get + produces: + - application/json + responses: + "200": + description: Requested filters. + schema: + $ref: '#/definitions/filterV1' + "400": + description: bad request + "401": + description: unauthorized + "404": + description: not found + "406": + description: not acceptable + "500": + description: internal server error + security: + - OAuth2 Bearer: + - read:filters + summary: Get all filters for the authenticated account. + tags: + - filters + post: + consumes: + - application/json + - application/xml + - application/x-www-form-urlencoded + operationId: filterV1Post + parameters: + - description: The text to be filtered. + example: fnord + in: formData + maxLength: 40 + name: phrase + required: true + type: string + - description: The contexts in which the filter should be applied. + enum: + - home + - notifications + - public + - thread + - account + example: + - home + - public + in: formData + items: + $ref: '#/definitions/filterContext' + minLength: 1 + name: context + required: true + type: array + uniqueItems: true + - description: Number of seconds from now that the filter should expire. If omitted, filter never expires. + example: 86400 + in: formData + name: expires_in + type: number + - default: false + description: Should matching entities be removed from the user's timelines/views, instead of hidden? Not supported yet. + example: false + in: formData + name: irreversible + type: boolean + - default: false + description: Should the filter consider word boundaries? + example: true + in: formData + name: whole_word + type: boolean + produces: + - application/json + responses: + "200": + description: New filter. + schema: + $ref: '#/definitions/filterV1' + "400": + description: bad request + "401": + description: unauthorized + "404": + description: not found + "406": + description: not acceptable + "422": + description: unprocessable content + "500": + description: internal server error + security: + - OAuth2 Bearer: + - write:filters + summary: Create a single filter. + tags: + - filters + /api/v1/filters/{id}: + delete: + operationId: filterV1Delete + parameters: + - description: ID of the list + in: path + name: id + required: true + type: string + produces: + - application/json + responses: + "200": + description: filter deleted + "400": + description: bad request + "401": + description: unauthorized + "404": + description: not found + "406": + description: not acceptable + "500": + description: internal server error + security: + - OAuth2 Bearer: + - write:filters + summary: Delete a single filter with the given ID. + tags: + - filters + get: + operationId: filterV1Get + parameters: + - description: ID of the filter + in: path + name: id + required: true + type: string + produces: + - application/json + responses: + "200": + description: Requested filter. + schema: + $ref: '#/definitions/filterV1' + "400": + description: bad request + "401": + description: unauthorized + "404": + description: not found + "406": + description: not acceptable + "500": + description: internal server error + security: + - OAuth2 Bearer: + - read:filters + summary: Get a single filter with the given ID. + tags: + - filters + put: + consumes: + - application/json + - application/xml + - application/x-www-form-urlencoded + operationId: filterV1Put + parameters: + - description: ID of the filter. + in: path + name: id + required: true + type: string + - description: The text to be filtered. + example: fnord + in: formData + maxLength: 40 + name: phrase + required: true + type: string + - description: The contexts in which the filter should be applied. + enum: + - home + - notifications + - public + - thread + - account + example: + - home + - public + in: formData + items: + $ref: '#/definitions/filterContext' + minLength: 1 + name: context + required: true + type: array + uniqueItems: true + - description: Number of seconds from now that the filter should expire. If omitted, filter never expires. + example: 86400 + in: formData + name: expires_in + type: number + - default: false + description: Should matching entities be removed from the user's timelines/views, instead of hidden? Not supported yet. + example: false + in: formData + name: irreversible + type: boolean + - default: false + description: Should the filter consider word boundaries? + example: true + in: formData + name: whole_word + type: boolean + produces: + - application/json + responses: + "200": + description: Updated filter. + schema: + $ref: '#/definitions/filterV1' + "400": + description: bad request + "401": + description: unauthorized + "404": + description: not found + "406": + description: not acceptable + "422": + description: unprocessable content + "500": + description: internal server error + security: + - OAuth2 Bearer: + - write:filters + summary: Update a single filter with the given ID. + tags: + - filters /api/v1/follow_requests: get: description: |- @@ -7971,6 +8266,7 @@ securityDefinitions: read:blocks: grant read access to blocks read:custom_emojis: grant read access to custom_emojis read:favourites: grant read access to favourites + read:filters: grant read access to filters read:follows: grant read access to follows read:lists: grant read access to lists read:media: grant read access to media @@ -7983,6 +8279,7 @@ securityDefinitions: write: grants write access to everything write:accounts: grants write access to accounts write:blocks: grants write access to blocks + write:filters: grants write access to filters write:follows: grants write access to follows write:lists: grants write access to lists write:media: grants write access to media diff --git a/docs/swagger.go b/docs/swagger.go index 8f64bcc42..73c9a3d9a 100644 --- a/docs/swagger.go +++ b/docs/swagger.go @@ -36,6 +36,7 @@ // read:blocks: grant read access to blocks // read:custom_emojis: grant read access to custom_emojis // read:favourites: grant read access to favourites +// read:filters: grant read access to filters // read:follows: grant read access to follows // read:lists: grant read access to lists // read:media: grant read access to media @@ -48,6 +49,7 @@ // write: grants write access to everything // write:accounts: grants write access to accounts // write:blocks: grants write access to blocks +// write:filters: grants write access to filters // write:follows: grants write access to follows // write:lists: grants write access to lists // write:media: grants write access to media diff --git a/internal/api/client.go b/internal/api/client.go index 1112efa31..d41add017 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -29,7 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/api/client/customemojis" "github.com/superseriousbusiness/gotosocial/internal/api/client/favourites" "github.com/superseriousbusiness/gotosocial/internal/api/client/featuredtags" - filter "github.com/superseriousbusiness/gotosocial/internal/api/client/filters" + filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1" "github.com/superseriousbusiness/gotosocial/internal/api/client/followrequests" "github.com/superseriousbusiness/gotosocial/internal/api/client/instance" "github.com/superseriousbusiness/gotosocial/internal/api/client/lists" @@ -62,7 +62,7 @@ type Client struct { customEmojis *customemojis.Module // api/v1/custom_emojis favourites *favourites.Module // api/v1/favourites featuredTags *featuredtags.Module // api/v1/featured_tags - filters *filter.Module // api/v1/filters + filtersV1 *filtersV1.Module // api/v1/filters followRequests *followrequests.Module // api/v1/follow_requests instance *instance.Module // api/v1/instance lists *lists.Module // api/v1/lists @@ -104,7 +104,7 @@ func (c *Client) Route(r *router.Router, m ...gin.HandlerFunc) { c.customEmojis.Route(h) c.favourites.Route(h) c.featuredTags.Route(h) - c.filters.Route(h) + c.filtersV1.Route(h) c.followRequests.Route(h) c.instance.Route(h) c.lists.Route(h) @@ -134,7 +134,7 @@ func NewClient(db db.DB, p *processing.Processor) *Client { customEmojis: customemojis.New(p), favourites: favourites.New(p), featuredTags: featuredtags.New(p), - filters: filter.New(p), + filtersV1: filtersV1.New(p), followRequests: followrequests.New(p), instance: instance.New(p), lists: lists.New(p), diff --git a/internal/api/client/filters/filter.go b/internal/api/client/filters/v1/filter.go similarity index 70% rename from internal/api/client/filters/filter.go rename to internal/api/client/filters/v1/filter.go index 68c99e825..9daeb75d3 100644 --- a/internal/api/client/filters/filter.go +++ b/internal/api/client/filters/v1/filter.go @@ -15,20 +15,23 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -package filter +package v1 import ( - "net/http" - "github.com/gin-gonic/gin" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/processing" + "net/http" ) const ( // BasePath is the base path for serving the filters API, minus the 'api' prefix BasePath = "/v1/filters" + // BasePathWithID is the base path with the ID key in it, for operations on an existing filter. + BasePathWithID = BasePath + "/:" + apiutil.IDKey ) +// Module implements APIs for client-side aka "v1" filtering. type Module struct { processor *processing.Processor } @@ -41,4 +44,8 @@ func New(processor *processing.Processor) *Module { func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) { attachHandler(http.MethodGet, BasePath, m.FiltersGETHandler) + attachHandler(http.MethodPost, BasePath, m.FilterPOSTHandler) + attachHandler(http.MethodGet, BasePathWithID, m.FilterGETHandler) + attachHandler(http.MethodPut, BasePathWithID, m.FilterPUTHandler) + attachHandler(http.MethodDelete, BasePathWithID, m.FilterDELETEHandler) } diff --git a/internal/api/client/filters/v1/filter_test.go b/internal/api/client/filters/v1/filter_test.go new file mode 100644 index 000000000..c92e22a05 --- /dev/null +++ b/internal/api/client/filters/v1/filter_test.go @@ -0,0 +1,117 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1_test + +import ( + "github.com/stretchr/testify/suite" + filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/email" + "github.com/superseriousbusiness/gotosocial/internal/federation" + "github.com/superseriousbusiness/gotosocial/internal/filter/visibility" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/media" + "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" + "github.com/superseriousbusiness/gotosocial/testrig" + "testing" +) + +type FiltersTestSuite struct { + suite.Suite + db db.DB + storage *storage.Driver + mediaManager *media.Manager + federator *federation.Federator + processor *processing.Processor + emailSender email.Sender + sentEmails map[string]string + state state.State + + // standard suite models + testTokens map[string]*gtsmodel.Token + testClients map[string]*gtsmodel.Client + testApplications map[string]*gtsmodel.Application + testUsers map[string]*gtsmodel.User + testAccounts map[string]*gtsmodel.Account + testStatuses map[string]*gtsmodel.Status + testFilters map[string]*gtsmodel.Filter + testFilterKeywords map[string]*gtsmodel.FilterKeyword + testFilterStatuses map[string]*gtsmodel.FilterStatus + + // module being tested + filtersModule *filtersV1.Module +} + +func (suite *FiltersTestSuite) SetupSuite() { + suite.testTokens = testrig.NewTestTokens() + suite.testClients = testrig.NewTestClients() + suite.testApplications = testrig.NewTestApplications() + suite.testUsers = testrig.NewTestUsers() + suite.testAccounts = testrig.NewTestAccounts() + suite.testStatuses = testrig.NewTestStatuses() + suite.testFilters = testrig.NewTestFilters() + suite.testFilterKeywords = testrig.NewTestFilterKeywords() + suite.testFilterStatuses = testrig.NewTestFilterStatuses() +} + +func (suite *FiltersTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartNoopWorkers(&suite.state) + + testrig.InitTestConfig() + config.Config(func(cfg *config.Configuration) { + cfg.WebAssetBaseDir = "../../../../../web/assets/" + cfg.WebTemplateBaseDir = "../../../../../web/templates/" + }) + testrig.InitTestLog() + + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db + suite.storage = testrig.NewInMemoryStorage() + suite.state.Storage = suite.storage + + testrig.StartTimelines( + &suite.state, + visibility.NewFilter(&suite.state), + typeutils.NewConverter(&suite.state), + ) + + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../../testrig/media")), suite.mediaManager) + suite.sentEmails = make(map[string]string) + suite.emailSender = testrig.NewEmailSender("../../../../../web/template/", suite.sentEmails) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) + suite.filtersModule = filtersV1.New(suite.processor) + + testrig.StandardDBSetup(suite.db, nil) + testrig.StandardStorageSetup(suite.storage, "../../../../../testrig/media") +} + +func (suite *FiltersTestSuite) TearDownTest() { + testrig.StandardDBTeardown(suite.db) + testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) +} + +func TestFiltersTestSuite(t *testing.T) { + suite.Run(t, new(FiltersTestSuite)) +} diff --git a/internal/api/client/filters/v1/filterdelete.go b/internal/api/client/filters/v1/filterdelete.go new file mode 100644 index 000000000..d86b277a6 --- /dev/null +++ b/internal/api/client/filters/v1/filterdelete.go @@ -0,0 +1,90 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1 + +import ( + "net/http" + + "github.com/gin-gonic/gin" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// FilterDELETEHandler swagger:operation DELETE /api/v1/filters/{id} filterV1Delete +// +// Delete a single filter with the given ID. +// +// --- +// tags: +// - filters +// +// produces: +// - application/json +// +// parameters: +// - +// name: id +// type: string +// description: ID of the list +// in: path +// required: true +// +// security: +// - OAuth2 Bearer: +// - write:filters +// +// responses: +// '200': +// description: filter deleted +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) FilterDELETEHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + id, errWithCode := apiutil.ParseID(c.Param(apiutil.IDKey)) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + errWithCode = m.processor.FiltersV1().Delete(c.Request.Context(), authed.Account, id) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + c.JSON(http.StatusOK, apiutil.EmptyJSONObject) +} diff --git a/internal/api/client/filters/v1/filterdelete_test.go b/internal/api/client/filters/v1/filterdelete_test.go new file mode 100644 index 000000000..83155f08a --- /dev/null +++ b/internal/api/client/filters/v1/filterdelete_test.go @@ -0,0 +1,112 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1_test + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + + filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +func (suite *FiltersTestSuite) deleteFilter( + filterKeywordID string, + expectedHTTPStatus int, + expectedBody string, +) error { + // instantiate recorder + test context + recorder := httptest.NewRecorder() + ctx, _ := testrig.CreateGinTestContext(recorder, nil) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + + // create the request + ctx.Request = httptest.NewRequest(http.MethodDelete, config.GetProtocol()+"://"+config.GetHost()+"/api/"+filtersV1.BasePath+"/"+filterKeywordID, nil) + ctx.Request.Header.Set("accept", "application/json") + + ctx.AddParam("id", filterKeywordID) + + // trigger the handler + suite.filtersModule.FilterDELETEHandler(ctx) + + // read the response + result := recorder.Result() + defer result.Body.Close() + + b, err := io.ReadAll(result.Body) + if err != nil { + return err + } + + errs := gtserror.NewMultiError(2) + + // check code + body + if resultCode := recorder.Code; expectedHTTPStatus != resultCode { + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) + } + + // if we got an expected body, return early + if expectedBody != "" { + if string(b) != expectedBody { + errs.Appendf("expected %s got %s", expectedBody, string(b)) + } + return errs.Combine() + } + + resp := &struct{}{} + if err := json.Unmarshal(b, resp); err != nil { + return err + } + + return nil +} + +func (suite *FiltersTestSuite) TestDeleteFilter() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + + err := suite.deleteFilter(id, http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestDeleteAnotherAccountsFilter() { + id := suite.testFilterKeywords["local_account_2_filter_1_keyword_1"].ID + + err := suite.deleteFilter(id, http.StatusNotFound, `{"error":"Not Found"}`) + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestDeleteNonexistentFilter() { + id := "not_even_a_real_ULID" + + err := suite.deleteFilter(id, http.StatusNotFound, `{"error":"Not Found"}`) + if err != nil { + suite.FailNow(err.Error()) + } +} diff --git a/internal/api/client/filters/v1/filterget.go b/internal/api/client/filters/v1/filterget.go new file mode 100644 index 000000000..35c44b60c --- /dev/null +++ b/internal/api/client/filters/v1/filterget.go @@ -0,0 +1,93 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1 + +import ( + "net/http" + + "github.com/gin-gonic/gin" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// FilterGETHandler swagger:operation GET /api/v1/filters/{id} filterV1Get +// +// Get a single filter with the given ID. +// +// --- +// tags: +// - filters +// +// produces: +// - application/json +// +// parameters: +// - +// name: id +// type: string +// description: ID of the filter +// in: path +// required: true +// +// security: +// - OAuth2 Bearer: +// - read:filters +// +// responses: +// '200': +// name: filter +// description: Requested filter. +// schema: +// "$ref": "#/definitions/filterV1" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) FilterGETHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + id, errWithCode := apiutil.ParseID(c.Param(apiutil.IDKey)) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + apiFilter, errWithCode := m.processor.FiltersV1().Get(c.Request.Context(), authed.Account, id) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + c.JSON(http.StatusOK, apiFilter) +} diff --git a/internal/api/client/filters/v1/filterget_test.go b/internal/api/client/filters/v1/filterget_test.go new file mode 100644 index 000000000..a9dbf6dbb --- /dev/null +++ b/internal/api/client/filters/v1/filterget_test.go @@ -0,0 +1,121 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1_test + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + + filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +func (suite *FiltersTestSuite) getFilter( + filterKeywordID string, + expectedHTTPStatus int, + expectedBody string, +) (*apimodel.FilterV1, error) { + // instantiate recorder + test context + recorder := httptest.NewRecorder() + ctx, _ := testrig.CreateGinTestContext(recorder, nil) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + + // create the request + ctx.Request = httptest.NewRequest(http.MethodGet, config.GetProtocol()+"://"+config.GetHost()+"/api/"+filtersV1.BasePath+"/"+filterKeywordID, nil) + ctx.Request.Header.Set("accept", "application/json") + + ctx.AddParam("id", filterKeywordID) + + // trigger the handler + suite.filtersModule.FilterGETHandler(ctx) + + // read the response + result := recorder.Result() + defer result.Body.Close() + + b, err := io.ReadAll(result.Body) + if err != nil { + return nil, err + } + + errs := gtserror.NewMultiError(2) + + // check code + body + if resultCode := recorder.Code; expectedHTTPStatus != resultCode { + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) + } + + // if we got an expected body, return early + if expectedBody != "" { + if string(b) != expectedBody { + errs.Appendf("expected %s got %s", expectedBody, string(b)) + } + return nil, errs.Combine() + } + + resp := &apimodel.FilterV1{} + if err := json.Unmarshal(b, resp); err != nil { + return nil, err + } + + return resp, nil +} + +func (suite *FiltersTestSuite) TestGetFilter() { + // v1 filters map to individual filter keywords, but also use the settings of the associated filter. + expectedFilterGtsModel := suite.testFilters["local_account_1_filter_1"] + expectedFilterKeywordGtsModel := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"] + + filter, err := suite.getFilter(expectedFilterKeywordGtsModel.ID, http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } + + suite.NotEmpty(filter) + suite.Equal(expectedFilterGtsModel.Action == gtsmodel.FilterActionHide, filter.Irreversible) + suite.Equal(expectedFilterKeywordGtsModel.ID, filter.ID) + suite.Equal(expectedFilterKeywordGtsModel.Keyword, filter.Phrase) +} + +func (suite *FiltersTestSuite) TestGetAnotherAccountsFilter() { + id := suite.testFilterKeywords["local_account_2_filter_1_keyword_1"].ID + + _, err := suite.getFilter(id, http.StatusNotFound, `{"error":"Not Found"}`) + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestGetNonexistentFilter() { + id := "not_even_a_real_ULID" + + _, err := suite.getFilter(id, http.StatusNotFound, `{"error":"Not Found"}`) + if err != nil { + suite.FailNow(err.Error()) + } +} diff --git a/internal/api/client/filters/v1/filterpost.go b/internal/api/client/filters/v1/filterpost.go new file mode 100644 index 000000000..b0a626199 --- /dev/null +++ b/internal/api/client/filters/v1/filterpost.go @@ -0,0 +1,147 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1 + +import ( + "net/http" + + "github.com/gin-gonic/gin" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// FilterPOSTHandler swagger:operation POST /api/v1/filters filterV1Post +// +// Create a single filter. +// +// --- +// tags: +// - filters +// +// consumes: +// - application/json +// - application/xml +// - application/x-www-form-urlencoded +// +// produces: +// - application/json +// +// parameters: +// - +// name: phrase +// in: formData +// required: true +// description: The text to be filtered. +// maxLength: 40 +// type: string +// example: "fnord" +// - +// name: context +// in: formData +// required: true +// description: The contexts in which the filter should be applied. +// enum: +// - home +// - notifications +// - public +// - thread +// - account +// example: +// - home +// - public +// items: +// $ref: '#/definitions/filterContext' +// minLength: 1 +// type: array +// uniqueItems: true +// - +// name: expires_in +// in: formData +// description: Number of seconds from now that the filter should expire. If omitted, filter never expires. +// type: number +// example: 86400 +// - +// name: irreversible +// in: formData +// description: Should matching entities be removed from the user's timelines/views, instead of hidden? Not supported yet. +// type: boolean +// default: false +// example: false +// - +// name: whole_word +// in: formData +// description: Should the filter consider word boundaries? +// type: boolean +// default: false +// example: true +// +// security: +// - OAuth2 Bearer: +// - write:filters +// +// responses: +// '200': +// name: filter +// description: New filter. +// schema: +// "$ref": "#/definitions/filterV1" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '422': +// description: unprocessable content +// '500': +// description: internal server error +func (m *Module) FilterPOSTHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + form := &apimodel.FilterCreateUpdateRequestV1{} + if err := c.ShouldBind(form); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if err := validateNormalizeCreateUpdateFilter(form); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnprocessableEntity(err, err.Error()), m.processor.InstanceGetV1) + return + } + + apiFilter, errWithCode := m.processor.FiltersV1().Create(c.Request.Context(), authed.Account, form) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + apiutil.JSON(c, http.StatusOK, apiFilter) +} diff --git a/internal/api/client/filters/v1/filterpost_test.go b/internal/api/client/filters/v1/filterpost_test.go new file mode 100644 index 000000000..729b2bd72 --- /dev/null +++ b/internal/api/client/filters/v1/filterpost_test.go @@ -0,0 +1,239 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1_test + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + + filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +func (suite *FiltersTestSuite) postFilter( + phrase *string, + context *[]string, + irreversible *bool, + wholeWord *bool, + expiresIn *int, + requestJson *string, + expectedHTTPStatus int, + expectedBody string, +) (*apimodel.FilterV1, error) { + // instantiate recorder + test context + recorder := httptest.NewRecorder() + ctx, _ := testrig.CreateGinTestContext(recorder, nil) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + + // create the request + ctx.Request = httptest.NewRequest(http.MethodPost, config.GetProtocol()+"://"+config.GetHost()+"/api/"+filtersV1.BasePath, nil) + ctx.Request.Header.Set("accept", "application/json") + if requestJson != nil { + ctx.Request.Header.Set("content-type", "application/json") + ctx.Request.Body = io.NopCloser(strings.NewReader(*requestJson)) + } else { + ctx.Request.Form = make(url.Values) + if phrase != nil { + ctx.Request.Form["phrase"] = []string{*phrase} + } + if context != nil { + ctx.Request.Form["context[]"] = *context + } + if irreversible != nil { + ctx.Request.Form["irreversible"] = []string{strconv.FormatBool(*irreversible)} + } + if wholeWord != nil { + ctx.Request.Form["whole_word"] = []string{strconv.FormatBool(*wholeWord)} + } + if expiresIn != nil { + ctx.Request.Form["expires_in"] = []string{strconv.Itoa(*expiresIn)} + } + } + + // trigger the handler + suite.filtersModule.FilterPOSTHandler(ctx) + + // read the response + result := recorder.Result() + defer result.Body.Close() + + b, err := io.ReadAll(result.Body) + if err != nil { + return nil, err + } + + errs := gtserror.NewMultiError(2) + + // check code + body + if resultCode := recorder.Code; expectedHTTPStatus != resultCode { + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) + } + + // if we got an expected body, return early + if expectedBody != "" { + if string(b) != expectedBody { + errs.Appendf("expected %s got %s", expectedBody, string(b)) + } + return nil, errs.Combine() + } + + resp := &apimodel.FilterV1{} + if err := json.Unmarshal(b, resp); err != nil { + return nil, err + } + + return resp, nil +} + +func (suite *FiltersTestSuite) TestPostFilterFull() { + phrase := "GNU/Linux" + context := []string{"home", "public"} + irreversible := false + wholeWord := true + expiresIn := 86400 + filter, err := suite.postFilter(&phrase, &context, &irreversible, &wholeWord, &expiresIn, nil, http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Equal(phrase, filter.Phrase) + filterContext := make([]string, 0, len(filter.Context)) + for _, c := range filter.Context { + filterContext = append(filterContext, string(c)) + } + suite.ElementsMatch(context, filterContext) + suite.Equal(irreversible, filter.Irreversible) + suite.Equal(wholeWord, filter.WholeWord) + if suite.NotNil(filter.ExpiresAt) { + suite.NotEmpty(*filter.ExpiresAt) + } +} + +func (suite *FiltersTestSuite) TestPostFilterFullJSON() { + // Use a numeric literal with a fractional part to test the JSON-specific handling for non-integer "expires_in". + requestJson := `{ + "phrase":"GNU/Linux", + "context": ["home", "public"], + "irreversible": false, + "whole_word": true, + "expires_in": 86400.1 + }` + filter, err := suite.postFilter(nil, nil, nil, nil, nil, &requestJson, http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Equal("GNU/Linux", filter.Phrase) + suite.ElementsMatch( + []apimodel.FilterContext{ + apimodel.FilterContextHome, + apimodel.FilterContextPublic, + }, + filter.Context, + ) + suite.Equal(false, filter.Irreversible) + suite.Equal(true, filter.WholeWord) + if suite.NotNil(filter.ExpiresAt) { + suite.NotEmpty(*filter.ExpiresAt) + } +} + +func (suite *FiltersTestSuite) TestPostFilterMinimal() { + phrase := "GNU/Linux" + context := []string{"home"} + filter, err := suite.postFilter(&phrase, &context, nil, nil, nil, nil, http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Equal(phrase, filter.Phrase) + filterContext := make([]string, 0, len(filter.Context)) + for _, c := range filter.Context { + filterContext = append(filterContext, string(c)) + } + suite.ElementsMatch(context, filterContext) + suite.False(filter.Irreversible) + suite.False(filter.WholeWord) + suite.Nil(filter.ExpiresAt) +} + +func (suite *FiltersTestSuite) TestPostFilterEmptyPhrase() { + phrase := "" + context := []string{"home"} + _, err := suite.postFilter(&phrase, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestPostFilterMissingPhrase() { + context := []string{"home"} + _, err := suite.postFilter(nil, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestPostFilterEmptyContext() { + phrase := "GNU/Linux" + context := []string{} + _, err := suite.postFilter(&phrase, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestPostFilterMissingContext() { + phrase := "GNU/Linux" + _, err := suite.postFilter(&phrase, nil, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +// There should be a filter with this phrase as its title in our test fixtures. Creating another should fail. +func (suite *FiltersTestSuite) TestPostFilterTitleConflict() { + phrase := "fnord" + _, err := suite.postFilter(&phrase, nil, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +// FUTURE: this should be removed once we support server-side filters. +func (suite *FiltersTestSuite) TestPostFilterIrreversibleNotSupported() { + phrase := "GNU/Linux" + context := []string{"home"} + irreversible := true + _, err := suite.postFilter(&phrase, &context, &irreversible, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} diff --git a/internal/api/client/filters/v1/filterput.go b/internal/api/client/filters/v1/filterput.go new file mode 100644 index 000000000..c686e4515 --- /dev/null +++ b/internal/api/client/filters/v1/filterput.go @@ -0,0 +1,159 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1 + +import ( + "net/http" + + "github.com/gin-gonic/gin" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// FilterPUTHandler swagger:operation PUT /api/v1/filters/{id} filterV1Put +// +// Update a single filter with the given ID. +// +// --- +// tags: +// - filters +// +// consumes: +// - application/json +// - application/xml +// - application/x-www-form-urlencoded +// +// produces: +// - application/json +// +// parameters: +// - +// name: id +// in: path +// type: string +// required: true +// description: ID of the filter. +// - +// name: phrase +// in: formData +// required: true +// description: The text to be filtered. +// maxLength: 40 +// type: string +// example: "fnord" +// - +// name: context +// in: formData +// required: true +// description: The contexts in which the filter should be applied. +// enum: +// - home +// - notifications +// - public +// - thread +// - account +// example: +// - home +// - public +// items: +// $ref: '#/definitions/filterContext' +// minLength: 1 +// type: array +// uniqueItems: true +// - +// name: expires_in +// in: formData +// description: Number of seconds from now that the filter should expire. If omitted, filter never expires. +// type: number +// example: 86400 +// - +// name: irreversible +// in: formData +// description: Should matching entities be removed from the user's timelines/views, instead of hidden? Not supported yet. +// type: boolean +// default: false +// example: false +// - +// name: whole_word +// in: formData +// description: Should the filter consider word boundaries? +// type: boolean +// default: false +// example: true +// +// security: +// - OAuth2 Bearer: +// - write:filters +// +// responses: +// '200': +// name: filter +// description: Updated filter. +// schema: +// "$ref": "#/definitions/filterV1" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '422': +// description: unprocessable content +// '500': +// description: internal server error +func (m *Module) FilterPUTHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + id, errWithCode := apiutil.ParseID(c.Param(apiutil.IDKey)) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + form := &apimodel.FilterCreateUpdateRequestV1{} + if err := c.ShouldBind(form); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if err := validateNormalizeCreateUpdateFilter(form); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnprocessableEntity(err, err.Error()), m.processor.InstanceGetV1) + return + } + + apiFilter, errWithCode := m.processor.FiltersV1().Update(c.Request.Context(), authed.Account, id, form) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + apiutil.JSON(c, http.StatusOK, apiFilter) +} diff --git a/internal/api/client/filters/v1/filterput_test.go b/internal/api/client/filters/v1/filterput_test.go new file mode 100644 index 000000000..0308e53d9 --- /dev/null +++ b/internal/api/client/filters/v1/filterput_test.go @@ -0,0 +1,269 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1_test + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + + filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +func (suite *FiltersTestSuite) putFilter( + filterKeywordID string, + phrase *string, + context *[]string, + irreversible *bool, + wholeWord *bool, + expiresIn *int, + requestJson *string, + expectedHTTPStatus int, + expectedBody string, +) (*apimodel.FilterV1, error) { + // instantiate recorder + test context + recorder := httptest.NewRecorder() + ctx, _ := testrig.CreateGinTestContext(recorder, nil) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + + // create the request + ctx.Request = httptest.NewRequest(http.MethodPut, config.GetProtocol()+"://"+config.GetHost()+"/api/"+filtersV1.BasePath+"/"+filterKeywordID, nil) + ctx.Request.Header.Set("accept", "application/json") + if requestJson != nil { + ctx.Request.Header.Set("content-type", "application/json") + ctx.Request.Body = io.NopCloser(strings.NewReader(*requestJson)) + } else { + ctx.Request.Form = make(url.Values) + if phrase != nil { + ctx.Request.Form["phrase"] = []string{*phrase} + } + if context != nil { + ctx.Request.Form["context[]"] = *context + } + if irreversible != nil { + ctx.Request.Form["irreversible"] = []string{strconv.FormatBool(*irreversible)} + } + if wholeWord != nil { + ctx.Request.Form["whole_word"] = []string{strconv.FormatBool(*wholeWord)} + } + if expiresIn != nil { + ctx.Request.Form["expires_in"] = []string{strconv.Itoa(*expiresIn)} + } + } + + ctx.AddParam("id", filterKeywordID) + + // trigger the handler + suite.filtersModule.FilterPUTHandler(ctx) + + // read the response + result := recorder.Result() + defer result.Body.Close() + + b, err := io.ReadAll(result.Body) + if err != nil { + return nil, err + } + + errs := gtserror.NewMultiError(2) + + // check code + body + if resultCode := recorder.Code; expectedHTTPStatus != resultCode { + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) + } + + // if we got an expected body, return early + if expectedBody != "" { + if string(b) != expectedBody { + errs.Appendf("expected %s got %s", expectedBody, string(b)) + } + return nil, errs.Combine() + } + + resp := &apimodel.FilterV1{} + if err := json.Unmarshal(b, resp); err != nil { + return nil, err + } + + return resp, nil +} + +func (suite *FiltersTestSuite) TestPutFilterFull() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + phrase := "GNU/Linux" + context := []string{"home", "public"} + irreversible := false + wholeWord := true + expiresIn := 86400 + filter, err := suite.putFilter(id, &phrase, &context, &irreversible, &wholeWord, &expiresIn, nil, http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Equal(phrase, filter.Phrase) + filterContext := make([]string, 0, len(filter.Context)) + for _, c := range filter.Context { + filterContext = append(filterContext, string(c)) + } + suite.ElementsMatch(context, filterContext) + suite.Equal(irreversible, filter.Irreversible) + suite.Equal(wholeWord, filter.WholeWord) + if suite.NotNil(filter.ExpiresAt) { + suite.NotEmpty(*filter.ExpiresAt) + } +} + +func (suite *FiltersTestSuite) TestPutFilterFullJSON() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + // Use a numeric literal with a fractional part to test the JSON-specific handling for non-integer "expires_in". + requestJson := `{ + "phrase":"GNU/Linux", + "context": ["home", "public"], + "irreversible": false, + "whole_word": true, + "expires_in": 86400.1 + }` + filter, err := suite.putFilter(id, nil, nil, nil, nil, nil, &requestJson, http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Equal("GNU/Linux", filter.Phrase) + suite.ElementsMatch( + []apimodel.FilterContext{ + apimodel.FilterContextHome, + apimodel.FilterContextPublic, + }, + filter.Context, + ) + suite.Equal(false, filter.Irreversible) + suite.Equal(true, filter.WholeWord) + if suite.NotNil(filter.ExpiresAt) { + suite.NotEmpty(*filter.ExpiresAt) + } +} + +func (suite *FiltersTestSuite) TestPutFilterMinimal() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + phrase := "GNU/Linux" + context := []string{"home"} + filter, err := suite.putFilter(id, &phrase, &context, nil, nil, nil, nil, http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } + + suite.Equal(phrase, filter.Phrase) + filterContext := make([]string, 0, len(filter.Context)) + for _, c := range filter.Context { + filterContext = append(filterContext, string(c)) + } + suite.ElementsMatch(context, filterContext) + suite.False(filter.Irreversible) + suite.False(filter.WholeWord) + suite.Nil(filter.ExpiresAt) +} + +func (suite *FiltersTestSuite) TestPutFilterEmptyPhrase() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + phrase := "" + context := []string{"home"} + _, err := suite.putFilter(id, &phrase, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestPutFilterMissingPhrase() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + context := []string{"home"} + _, err := suite.putFilter(id, nil, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestPutFilterEmptyContext() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + phrase := "GNU/Linux" + context := []string{} + _, err := suite.putFilter(id, &phrase, &context, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestPutFilterMissingContext() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + phrase := "GNU/Linux" + _, err := suite.putFilter(id, &phrase, nil, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +// There should be a filter with this phrase as its title in our test fixtures. Changing ours to that title should fail. +func (suite *FiltersTestSuite) TestPutFilterTitleConflict() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + phrase := "metasyntactic variables" + _, err := suite.putFilter(id, &phrase, nil, nil, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +// FUTURE: this should be removed once we support server-side filters. +func (suite *FiltersTestSuite) TestPutFilterIrreversibleNotSupported() { + id := suite.testFilterKeywords["local_account_1_filter_1_keyword_1"].ID + irreversible := true + _, err := suite.putFilter(id, nil, nil, &irreversible, nil, nil, nil, http.StatusUnprocessableEntity, "") + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestPutAnotherAccountsFilter() { + id := suite.testFilterKeywords["local_account_2_filter_1_keyword_1"].ID + phrase := "GNU/Linux" + context := []string{"home"} + _, err := suite.putFilter(id, &phrase, &context, nil, nil, nil, nil, http.StatusNotFound, `{"error":"Not Found"}`) + if err != nil { + suite.FailNow(err.Error()) + } +} + +func (suite *FiltersTestSuite) TestPutNonexistentFilter() { + id := "not_even_a_real_ULID" + phrase := "GNU/Linux" + context := []string{"home"} + _, err := suite.putFilter(id, &phrase, &context, nil, nil, nil, nil, http.StatusNotFound, `{"error":"Not Found"}`) + if err != nil { + suite.FailNow(err.Error()) + } +} diff --git a/internal/api/client/filters/filtersget.go b/internal/api/client/filters/v1/filtersget.go similarity index 60% rename from internal/api/client/filters/filtersget.go rename to internal/api/client/filters/v1/filtersget.go index 38dd330a7..84d638676 100644 --- a/internal/api/client/filters/filtersget.go +++ b/internal/api/client/filters/v1/filtersget.go @@ -15,7 +15,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -package filter +package v1 import ( "net/http" @@ -26,9 +26,40 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/oauth" ) -// FiltersGETHandler returns a list of filters set by/for the authed account +// FiltersGETHandler swagger:operation GET /api/v1/filters filtersV1Get +// +// Get all filters for the authenticated account. +// +// --- +// tags: +// - filters +// +// produces: +// - application/json +// +// security: +// - OAuth2 Bearer: +// - read:filters +// +// responses: +// '200': +// name: filter +// description: Requested filters. +// schema: +// "$ref": "#/definitions/filterV1" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error func (m *Module) FiltersGETHandler(c *gin.Context) { - if _, err := oauth.Authed(c, true, true, true, true); err != nil { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) return } @@ -38,5 +69,11 @@ func (m *Module) FiltersGETHandler(c *gin.Context) { return } - apiutil.Data(c, http.StatusOK, apiutil.AppJSON, apiutil.EmptyJSONArray) + apiFilters, errWithCode := m.processor.FiltersV1().GetAll(c.Request.Context(), authed.Account) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + c.JSON(http.StatusOK, apiFilters) } diff --git a/internal/api/client/filters/v1/filtersget_test.go b/internal/api/client/filters/v1/filtersget_test.go new file mode 100644 index 000000000..a568239ef --- /dev/null +++ b/internal/api/client/filters/v1/filtersget_test.go @@ -0,0 +1,114 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1_test + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + + filtersV1 "github.com/superseriousbusiness/gotosocial/internal/api/client/filters/v1" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +func (suite *FiltersTestSuite) getFilters( + expectedHTTPStatus int, + expectedBody string, +) ([]*apimodel.FilterV1, error) { + // instantiate recorder + test context + recorder := httptest.NewRecorder() + ctx, _ := testrig.CreateGinTestContext(recorder, nil) + ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) + ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) + ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) + + // create the request + ctx.Request = httptest.NewRequest(http.MethodGet, config.GetProtocol()+"://"+config.GetHost()+"/api/"+filtersV1.BasePath, nil) + ctx.Request.Header.Set("accept", "application/json") + + // trigger the handler + suite.filtersModule.FiltersGETHandler(ctx) + + // read the response + result := recorder.Result() + defer result.Body.Close() + + b, err := io.ReadAll(result.Body) + if err != nil { + return nil, err + } + + errs := gtserror.NewMultiError(2) + + // check code + body + if resultCode := recorder.Code; expectedHTTPStatus != resultCode { + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) + } + + // if we got an expected body, return early + if expectedBody != "" { + if string(b) != expectedBody { + errs.Appendf("expected %s got %s", expectedBody, string(b)) + } + return nil, errs.Combine() + } + + resp := make([]*apimodel.FilterV1, 0) + if err := json.Unmarshal(b, &resp); err != nil { + return nil, err + } + + return resp, nil +} + +func (suite *FiltersTestSuite) TestGetFilters() { + // v1 filters map to individual filter keywords. + expectedFilterIDs := make([]string, 0, len(suite.testFilterKeywords)) + expectedFilterKeywords := make([]string, 0, len(suite.testFilterKeywords)) + for _, filterKeyword := range suite.testFilterKeywords { + if filterKeyword.AccountID == suite.testAccounts["local_account_1"].ID { + expectedFilterIDs = append(expectedFilterIDs, filterKeyword.ID) + expectedFilterKeywords = append(expectedFilterKeywords, filterKeyword.Keyword) + } + } + suite.NotEmpty(expectedFilterIDs) + suite.NotEmpty(expectedFilterKeywords) + + // Fetch all filters for the logged-in account. + filters, err := suite.getFilters(http.StatusOK, "") + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotEmpty(filters) + + // Check that we got the right ones. + actualFilterIDs := make([]string, 0, len(filters)) + actualFilterKeywords := make([]string, 0, len(filters)) + for _, filter := range filters { + actualFilterIDs = append(actualFilterIDs, filter.ID) + actualFilterKeywords = append(actualFilterKeywords, filter.Phrase) + } + suite.ElementsMatch(expectedFilterIDs, actualFilterIDs) + suite.ElementsMatch(expectedFilterKeywords, actualFilterKeywords) +} diff --git a/internal/api/client/filters/v1/validate.go b/internal/api/client/filters/v1/validate.go new file mode 100644 index 000000000..b539c9563 --- /dev/null +++ b/internal/api/client/filters/v1/validate.go @@ -0,0 +1,68 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1 + +import ( + "errors" + "fmt" + "strconv" + + "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/superseriousbusiness/gotosocial/internal/validate" +) + +func validateNormalizeCreateUpdateFilter(form *model.FilterCreateUpdateRequestV1) error { + if err := validate.FilterKeyword(form.Phrase); err != nil { + return err + } + if err := validate.FilterContexts(form.Context); err != nil { + return err + } + + // Apply defaults for missing fields. + form.WholeWord = util.Ptr(util.PtrValueOr(form.WholeWord, false)) + form.Irreversible = util.Ptr(util.PtrValueOr(form.Irreversible, false)) + + if *form.Irreversible { + return errors.New("irreversible aka server-side drop filters are not supported yet") + } + + // Normalize filter expiry if necessary. + // If we parsed this as JSON, expires_in + // may be either a float64 or a string. + if ei := form.ExpiresInI; ei != nil { + switch e := ei.(type) { + case float64: + form.ExpiresIn = util.Ptr(int(e)) + + case string: + expiresIn, err := strconv.Atoi(e) + if err != nil { + return fmt.Errorf("could not parse expires_in value %s as integer: %w", e, err) + } + + form.ExpiresIn = &expiresIn + + default: + return fmt.Errorf("could not parse expires_in type %T as integer", ei) + } + } + + return nil +} diff --git a/internal/api/model/filter.go b/internal/api/model/filter.go index 4a5d29690..027dea48c 100644 --- a/internal/api/model/filter.go +++ b/internal/api/model/filter.go @@ -17,29 +17,23 @@ package model -// Filter represents a user-defined filter for determining which statuses should not be shown to the user. -// If whole_word is true , client app should do: -// Define ‘word constituent character’ for your app. In the official implementation, it’s [A-Za-z0-9_] in JavaScript, and [[:word:]] in Ruby. -// Ruby uses the POSIX character class (Letter | Mark | Decimal_Number | Connector_Punctuation). -// If the phrase starts with a word character, and if the previous character before matched range is a word character, its matched range should be treated to not match. -// If the phrase ends with a word character, and if the next character after matched range is a word character, its matched range should be treated to not match. -// Please check app/javascript/mastodon/selectors/index.js and app/lib/feed_manager.rb in the Mastodon source code for more details. -type Filter struct { - // The ID of the filter in the database. - ID string `json:"id"` - // The text to be filtered. - Phrase string `json:"text"` - // The contexts in which the filter should be applied. - // Array of String (Enumerable anyOf) - // home = home timeline and lists - // notifications = notifications timeline - // public = public timelines - // thread = expanded thread of a detailed status - Context []string `json:"context"` - // Should the filter consider word boundaries? - WholeWord bool `json:"whole_word"` - // When the filter should no longer be applied (ISO 8601 Datetime), or null if the filter does not expire - ExpiresAt string `json:"expires_at,omitempty"` - // Should matching entities in home and notifications be dropped by the server? - Irreversible bool `json:"irreversible"` -} +// FilterContext represents the context in which to apply a filter. +// v1 and v2 filter APIs use the same set of contexts. +// +// swagger:model filterContext +type FilterContext string + +const ( + // FilterContextHome means this filter should be applied to the home timeline and lists. + FilterContextHome FilterContext = "home" + // FilterContextNotifications means this filter should be applied to the notifications timeline. + FilterContextNotifications FilterContext = "notifications" + // FilterContextPublic means this filter should be applied to public timelines. + FilterContextPublic FilterContext = "public" + // FilterContextThread means this filter should be applied to the expanded thread of a detailed status. + FilterContextThread FilterContext = "thread" + // FilterContextAccount means this filter should be applied when viewing a profile. + FilterContextAccount FilterContext = "account" + + FilterContextNumValues = 5 +) diff --git a/internal/api/model/filterv1.go b/internal/api/model/filterv1.go new file mode 100644 index 000000000..52250f537 --- /dev/null +++ b/internal/api/model/filterv1.go @@ -0,0 +1,99 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package model + +// FilterV1 represents a user-defined filter for determining which statuses should not be shown to the user. +// Note that v1 filters are mapped to v2 filters and v2 filter keywords internally. +// If whole_word is true, client app should do: +// Define ‘word constituent character’ for your app. In the official implementation, it’s [A-Za-z0-9_] in JavaScript, and [[:word:]] in Ruby. +// Ruby uses the POSIX character class (Letter | Mark | Decimal_Number | Connector_Punctuation). +// If the phrase starts with a word character, and if the previous character before matched range is a word character, its matched range should be treated to not match. +// If the phrase ends with a word character, and if the next character after matched range is a word character, its matched range should be treated to not match. +// Please check app/javascript/mastodon/selectors/index.js and app/lib/feed_manager.rb in the Mastodon source code for more details. +// +// swagger:model filterV1 +// +// --- +// tags: +// - filters +type FilterV1 struct { + // The ID of the filter in the database. + ID string `json:"id"` + // The text to be filtered. + // + // Example: fnord + Phrase string `json:"phrase"` + // The contexts in which the filter should be applied. + // + // Minimum length: 1 + // Unique: true + // Enum: + // - home + // - notifications + // - public + // - thread + // - account + // Example: ["home", "public"] + Context []FilterContext `json:"context"` + // Should the filter consider word boundaries? + // + // Example: true + WholeWord bool `json:"whole_word"` + // Should matching entities be removed from the user's timelines/views, instead of hidden? + // + // Example: false + Irreversible bool `json:"irreversible"` + // When the filter should no longer be applied. Null if the filter does not expire. + // + // Example: 2024-02-01T02:57:49Z + ExpiresAt *string `json:"expires_at"` +} + +// FilterCreateUpdateRequestV1 captures params for creating or updating a v1 filter. +// +// swagger:ignore +type FilterCreateUpdateRequestV1 struct { + // The text to be filtered. + // + // Required: true + // Maximum length: 40 + // Example: fnord + Phrase string `form:"phrase" json:"phrase" xml:"phrase"` + // The contexts in which the filter should be applied. + // + // Required: true + // Minimum length: 1 + // Unique: true + // Enum: home,notifications,public,thread,account + // Example: ["home", "public"] + Context []FilterContext `form:"context[]" json:"context" xml:"context"` + // Should matching entities be removed from the user's timelines/views, instead of hidden? + // + // Example: false + Irreversible *bool `form:"irreversible" json:"irreversible" xml:"irreversible"` + // Should the filter consider word boundaries? + // + // Example: true + WholeWord *bool `form:"whole_word" json:"whole_word" xml:"whole_word"` + // Number of seconds from now that the filter should expire. If omitted, filter never expires. + ExpiresIn *int `json:"-" form:"expires_in" xml:"expires_in"` + // Number of seconds from now that the filter should expire. If omitted, filter never expires. + // + // Example: 86400 + ExpiresInI interface{} `json:"expires_in"` +} diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 17fa03323..9b70a565c 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -61,6 +61,9 @@ func (c *Caches) Init() { c.initDomainBlock() c.initEmoji() c.initEmojiCategory() + c.initFilter() + c.initFilterKeyword() + c.initFilterStatus() c.initFollow() c.initFollowIDs() c.initFollowRequest() @@ -119,6 +122,9 @@ func (c *Caches) Sweep(threshold float64) { c.GTS.BlockIDs.Trim(threshold) c.GTS.Emoji.Trim(threshold) c.GTS.EmojiCategory.Trim(threshold) + c.GTS.Filter.Trim(threshold) + c.GTS.FilterKeyword.Trim(threshold) + c.GTS.FilterStatus.Trim(threshold) c.GTS.Follow.Trim(threshold) c.GTS.FollowIDs.Trim(threshold) c.GTS.FollowRequest.Trim(threshold) diff --git a/internal/cache/db.go b/internal/cache/db.go index 275a25451..dc9e385cd 100644 --- a/internal/cache/db.go +++ b/internal/cache/db.go @@ -67,6 +67,15 @@ type GTSCaches struct { // EmojiCategory provides access to the gtsmodel EmojiCategory database cache. EmojiCategory structr.Cache[*gtsmodel.EmojiCategory] + // Filter provides access to the gtsmodel Filter database cache. + Filter structr.Cache[*gtsmodel.Filter] + + // FilterKeyword provides access to the gtsmodel FilterKeyword database cache. + FilterKeyword structr.Cache[*gtsmodel.FilterKeyword] + + // FilterStatus provides access to the gtsmodel FilterStatus database cache. + FilterStatus structr.Cache[*gtsmodel.FilterStatus] + // Follow provides access to the gtsmodel Follow database cache. Follow structr.Cache[*gtsmodel.Follow] @@ -409,6 +418,105 @@ func (c *Caches) initEmojiCategory() { }) } +func (c *Caches) initFilter() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofFilter(), // model in-mem size. + config.GetCacheFilterMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(filter1 *gtsmodel.Filter) *gtsmodel.Filter { + filter2 := new(gtsmodel.Filter) + *filter2 = *filter1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/filter.go. + filter2.Keywords = nil + filter2.Statuses = nil + + return filter2 + } + + c.GTS.Filter.Init(structr.Config[*gtsmodel.Filter]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "AccountID", Multiple: true}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + +func (c *Caches) initFilterKeyword() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofFilterKeyword(), // model in-mem size. + config.GetCacheFilterKeywordMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(filterKeyword1 *gtsmodel.FilterKeyword) *gtsmodel.FilterKeyword { + filterKeyword2 := new(gtsmodel.FilterKeyword) + *filterKeyword2 = *filterKeyword1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/filter.go. + filterKeyword2.Filter = nil + + return filterKeyword2 + } + + c.GTS.FilterKeyword.Init(structr.Config[*gtsmodel.FilterKeyword]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "AccountID", Multiple: true}, + {Fields: "FilterID", Multiple: true}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + +func (c *Caches) initFilterStatus() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofFilterStatus(), // model in-mem size. + config.GetCacheFilterStatusMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(filterStatus1 *gtsmodel.FilterStatus) *gtsmodel.FilterStatus { + filterStatus2 := new(gtsmodel.FilterStatus) + *filterStatus2 = *filterStatus1 + + // Don't include ptr fields that + // will be populated separately. + // See internal/db/bundb/filter.go. + filterStatus2.Filter = nil + + return filterStatus2 + } + + c.GTS.FilterStatus.Init(structr.Config[*gtsmodel.FilterStatus]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "AccountID", Multiple: true}, + {Fields: "FilterID", Multiple: true}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + CopyValue: copyF, + }) +} + func (c *Caches) initFollow() { // Calculate maximum cache size. cap := calculateResultCacheMax( diff --git a/internal/cache/size.go b/internal/cache/size.go index 62fb31469..f9d88491d 100644 --- a/internal/cache/size.go +++ b/internal/cache/size.go @@ -309,6 +309,38 @@ func sizeofEmojiCategory() uintptr { })) } +func sizeofFilter() uintptr { + return uintptr(size.Of(>smodel.Filter{ + ID: exampleID, + CreatedAt: exampleTime, + UpdatedAt: exampleTime, + ExpiresAt: exampleTime, + AccountID: exampleID, + Title: exampleTextSmall, + Action: gtsmodel.FilterActionHide, + })) +} + +func sizeofFilterKeyword() uintptr { + return uintptr(size.Of(>smodel.FilterKeyword{ + ID: exampleID, + CreatedAt: exampleTime, + UpdatedAt: exampleTime, + FilterID: exampleID, + Keyword: exampleTextSmall, + })) +} + +func sizeofFilterStatus() uintptr { + return uintptr(size.Of(>smodel.FilterStatus{ + ID: exampleID, + CreatedAt: exampleTime, + UpdatedAt: exampleTime, + FilterID: exampleID, + StatusID: exampleID, + })) +} + func sizeofFollow() uintptr { return uintptr(size.Of(>smodel.Follow{ ID: exampleID, diff --git a/internal/config/config.go b/internal/config/config.go index c810222a1..ea84a4af7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -201,6 +201,9 @@ type CacheConfiguration struct { BoostOfIDsMemRatio float64 `name:"boost-of-ids-mem-ratio"` EmojiMemRatio float64 `name:"emoji-mem-ratio"` EmojiCategoryMemRatio float64 `name:"emoji-category-mem-ratio"` + FilterMemRatio float64 `name:"filter-mem-ratio"` + FilterKeywordMemRatio float64 `name:"filter-keyword-mem-ratio"` + FilterStatusMemRatio float64 `name:"filter-status-mem-ratio"` FollowMemRatio float64 `name:"follow-mem-ratio"` FollowIDsMemRatio float64 `name:"follow-ids-mem-ratio"` FollowRequestMemRatio float64 `name:"follow-request-mem-ratio"` diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 78474539f..c98b54b0b 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -165,6 +165,9 @@ var Defaults = Configuration{ BoostOfIDsMemRatio: 3, EmojiMemRatio: 3, EmojiCategoryMemRatio: 0.1, + FilterMemRatio: 0.5, + FilterKeywordMemRatio: 0.5, + FilterStatusMemRatio: 0.5, FollowMemRatio: 2, FollowIDsMemRatio: 4, FollowRequestMemRatio: 2, diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go index f458074b1..c5d4c992b 100644 --- a/internal/config/helpers.gen.go +++ b/internal/config/helpers.gen.go @@ -2975,6 +2975,81 @@ func GetCacheEmojiCategoryMemRatio() float64 { return global.GetCacheEmojiCatego // SetCacheEmojiCategoryMemRatio safely sets the value for global configuration 'Cache.EmojiCategoryMemRatio' field func SetCacheEmojiCategoryMemRatio(v float64) { global.SetCacheEmojiCategoryMemRatio(v) } +// GetCacheFilterMemRatio safely fetches the Configuration value for state's 'Cache.FilterMemRatio' field +func (st *ConfigState) GetCacheFilterMemRatio() (v float64) { + st.mutex.RLock() + v = st.config.Cache.FilterMemRatio + st.mutex.RUnlock() + return +} + +// SetCacheFilterMemRatio safely sets the Configuration value for state's 'Cache.FilterMemRatio' field +func (st *ConfigState) SetCacheFilterMemRatio(v float64) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.FilterMemRatio = v + st.reloadToViper() +} + +// CacheFilterMemRatioFlag returns the flag name for the 'Cache.FilterMemRatio' field +func CacheFilterMemRatioFlag() string { return "cache-filter-mem-ratio" } + +// GetCacheFilterMemRatio safely fetches the value for global configuration 'Cache.FilterMemRatio' field +func GetCacheFilterMemRatio() float64 { return global.GetCacheFilterMemRatio() } + +// SetCacheFilterMemRatio safely sets the value for global configuration 'Cache.FilterMemRatio' field +func SetCacheFilterMemRatio(v float64) { global.SetCacheFilterMemRatio(v) } + +// GetCacheFilterKeywordMemRatio safely fetches the Configuration value for state's 'Cache.FilterKeywordMemRatio' field +func (st *ConfigState) GetCacheFilterKeywordMemRatio() (v float64) { + st.mutex.RLock() + v = st.config.Cache.FilterKeywordMemRatio + st.mutex.RUnlock() + return +} + +// SetCacheFilterKeywordMemRatio safely sets the Configuration value for state's 'Cache.FilterKeywordMemRatio' field +func (st *ConfigState) SetCacheFilterKeywordMemRatio(v float64) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.FilterKeywordMemRatio = v + st.reloadToViper() +} + +// CacheFilterKeywordMemRatioFlag returns the flag name for the 'Cache.FilterKeywordMemRatio' field +func CacheFilterKeywordMemRatioFlag() string { return "cache-filter-keyword-mem-ratio" } + +// GetCacheFilterKeywordMemRatio safely fetches the value for global configuration 'Cache.FilterKeywordMemRatio' field +func GetCacheFilterKeywordMemRatio() float64 { return global.GetCacheFilterKeywordMemRatio() } + +// SetCacheFilterKeywordMemRatio safely sets the value for global configuration 'Cache.FilterKeywordMemRatio' field +func SetCacheFilterKeywordMemRatio(v float64) { global.SetCacheFilterKeywordMemRatio(v) } + +// GetCacheFilterStatusMemRatio safely fetches the Configuration value for state's 'Cache.FilterStatusMemRatio' field +func (st *ConfigState) GetCacheFilterStatusMemRatio() (v float64) { + st.mutex.RLock() + v = st.config.Cache.FilterStatusMemRatio + st.mutex.RUnlock() + return +} + +// SetCacheFilterStatusMemRatio safely sets the Configuration value for state's 'Cache.FilterStatusMemRatio' field +func (st *ConfigState) SetCacheFilterStatusMemRatio(v float64) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.FilterStatusMemRatio = v + st.reloadToViper() +} + +// CacheFilterStatusMemRatioFlag returns the flag name for the 'Cache.FilterStatusMemRatio' field +func CacheFilterStatusMemRatioFlag() string { return "cache-filter-status-mem-ratio" } + +// GetCacheFilterStatusMemRatio safely fetches the value for global configuration 'Cache.FilterStatusMemRatio' field +func GetCacheFilterStatusMemRatio() float64 { return global.GetCacheFilterStatusMemRatio() } + +// SetCacheFilterStatusMemRatio safely sets the value for global configuration 'Cache.FilterStatusMemRatio' field +func SetCacheFilterStatusMemRatio(v float64) { global.SetCacheFilterStatusMemRatio(v) } + // GetCacheFollowMemRatio safely fetches the Configuration value for state's 'Cache.FollowMemRatio' field func (st *ConfigState) GetCacheFollowMemRatio() (v float64) { st.mutex.RLock() diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 4ecbec7b9..c49da272b 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -62,6 +62,7 @@ type DBService struct { db.Emoji db.HeaderFilter db.Instance + db.Filter db.List db.Marker db.Media @@ -200,6 +201,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { db: db, state: state, }, + Filter: &filterDB{ + db: db, + state: state, + }, List: &listDB{ db: db, state: state, diff --git a/internal/db/bundb/filter.go b/internal/db/bundb/filter.go new file mode 100644 index 000000000..bcd572f34 --- /dev/null +++ b/internal/db/bundb/filter.go @@ -0,0 +1,339 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package bundb + +import ( + "context" + "slices" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" +) + +type filterDB struct { + db *bun.DB + state *state.State +} + +func (f *filterDB) GetFilterByID(ctx context.Context, id string) (*gtsmodel.Filter, error) { + filter, err := f.state.Caches.GTS.Filter.LoadOne( + "ID", + func() (*gtsmodel.Filter, error) { + var filter gtsmodel.Filter + err := f.db. + NewSelect(). + Model(&filter). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx) + return &filter, err + }, + id, + ) + if err != nil { + // already processed + return nil, err + } + + if !gtscontext.Barebones(ctx) { + if err := f.populateFilter(ctx, filter); err != nil { + return nil, err + } + } + + return filter, nil +} + +func (f *filterDB) GetFiltersForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.Filter, error) { + // Fetch IDs of all filters owned by this account. + var filterIDs []string + if err := f.db. + NewSelect(). + Model((*gtsmodel.Filter)(nil)). + Column("id"). + Where("? = ?", bun.Ident("account_id"), accountID). + Scan(ctx, &filterIDs); err != nil { + return nil, err + } + if len(filterIDs) == 0 { + return nil, nil + } + + // Get each filter by ID from the cache or DB. + uncachedFilterIDs := make([]string, 0, len(filterIDs)) + filters, err := f.state.Caches.GTS.Filter.Load( + "ID", + func(load func(keyParts ...any) bool) { + for _, id := range filterIDs { + if !load(id) { + uncachedFilterIDs = append(uncachedFilterIDs, id) + } + } + }, + func() ([]*gtsmodel.Filter, error) { + uncachedFilters := make([]*gtsmodel.Filter, 0, len(uncachedFilterIDs)) + if err := f.db. + NewSelect(). + Model(&uncachedFilters). + Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterIDs)). + Scan(ctx); err != nil { + return nil, err + } + return uncachedFilters, nil + }, + ) + if err != nil { + return nil, err + } + + // Put the filter structs in the same order as the filter IDs. + util.OrderBy(filters, filterIDs, func(filter *gtsmodel.Filter) string { return filter.ID }) + + if gtscontext.Barebones(ctx) { + return filters, nil + } + + // Populate the filters. Remove any that we can't populate from the return slice. + errs := gtserror.NewMultiError(len(filters)) + filters = slices.DeleteFunc(filters, func(filter *gtsmodel.Filter) bool { + if err := f.populateFilter(ctx, filter); err != nil { + errs.Appendf("error populating filter %s: %w", filter.ID, err) + return true + } + return false + }) + + return filters, errs.Combine() +} + +func (f *filterDB) populateFilter(ctx context.Context, filter *gtsmodel.Filter) error { + var err error + errs := gtserror.NewMultiError(2) + + if filter.Keywords == nil { + // Filter keywords are not set, fetch from the database. + filter.Keywords, err = f.state.DB.GetFilterKeywordsForFilterID( + gtscontext.SetBarebones(ctx), + filter.ID, + ) + if err != nil { + errs.Appendf("error populating filter keywords: %w", err) + } + for i := range filter.Keywords { + filter.Keywords[i].Filter = filter + } + } + + if filter.Statuses == nil { + // Filter statuses are not set, fetch from the database. + filter.Statuses, err = f.state.DB.GetFilterStatusesForFilterID( + gtscontext.SetBarebones(ctx), + filter.ID, + ) + if err != nil { + errs.Appendf("error populating filter statuses: %w", err) + } + for i := range filter.Statuses { + filter.Statuses[i].Filter = filter + } + } + + return errs.Combine() +} + +func (f *filterDB) PutFilter(ctx context.Context, filter *gtsmodel.Filter) error { + // Update database. + if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + if _, err := tx. + NewInsert(). + Model(filter). + Exec(ctx); err != nil { + return err + } + + if len(filter.Keywords) > 0 { + if _, err := tx. + NewInsert(). + Model(&filter.Keywords). + Exec(ctx); err != nil { + return err + } + } + + if len(filter.Statuses) > 0 { + if _, err := tx. + NewInsert(). + Model(&filter.Statuses). + Exec(ctx); err != nil { + return err + } + } + + return nil + }); err != nil { + return err + } + + // Update cache. + f.state.Caches.GTS.Filter.Put(filter) + f.state.Caches.GTS.FilterKeyword.Put(filter.Keywords...) + f.state.Caches.GTS.FilterStatus.Put(filter.Statuses...) + + return nil +} + +func (f *filterDB) UpdateFilter( + ctx context.Context, + filter *gtsmodel.Filter, + filterColumns []string, + filterKeywordColumns []string, + deleteFilterKeywordIDs []string, + deleteFilterStatusIDs []string, +) error { + updatedAt := time.Now() + filter.UpdatedAt = updatedAt + for _, filterKeyword := range filter.Keywords { + filterKeyword.UpdatedAt = updatedAt + } + for _, filterStatus := range filter.Statuses { + filterStatus.UpdatedAt = updatedAt + } + + // If we're updating by column, ensure "updated_at" is included. + if len(filterColumns) > 0 { + filterColumns = append(filterColumns, "updated_at") + } + if len(filterKeywordColumns) > 0 { + filterKeywordColumns = append(filterKeywordColumns, "updated_at") + } + + // Update database. + if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + if _, err := tx. + NewUpdate(). + Model(filter). + Column(filterColumns...). + Where("? = ?", bun.Ident("id"), filter.ID). + Exec(ctx); err != nil { + return err + } + + if len(filter.Keywords) > 0 { + if _, err := NewUpsert(tx). + Model(&filter.Keywords). + Constraint("id"). + Column(filterKeywordColumns...). + Exec(ctx); err != nil { + return err + } + } + + if len(filter.Statuses) > 0 { + if _, err := tx. + NewInsert(). + Ignore(). + Model(&filter.Statuses). + Exec(ctx); err != nil { + return err + } + } + + if len(deleteFilterKeywordIDs) > 0 { + if _, err := tx. + NewDelete(). + Model((*gtsmodel.FilterKeyword)(nil)). + Where("? = (?)", bun.Ident("id"), bun.In(deleteFilterKeywordIDs)). + Exec(ctx); err != nil { + return err + } + } + + if len(deleteFilterStatusIDs) > 0 { + if _, err := tx. + NewDelete(). + Model((*gtsmodel.FilterStatus)(nil)). + Where("? = (?)", bun.Ident("id"), bun.In(deleteFilterStatusIDs)). + Exec(ctx); err != nil { + return err + } + } + + return nil + }); err != nil { + return err + } + + // Update cache. + f.state.Caches.GTS.Filter.Put(filter) + f.state.Caches.GTS.FilterKeyword.Put(filter.Keywords...) + f.state.Caches.GTS.FilterStatus.Put(filter.Statuses...) + // TODO: (Vyr) replace with cache multi-invalidate call + for _, id := range deleteFilterKeywordIDs { + f.state.Caches.GTS.FilterKeyword.Invalidate("ID", id) + } + for _, id := range deleteFilterStatusIDs { + f.state.Caches.GTS.FilterStatus.Invalidate("ID", id) + } + + return nil +} + +func (f *filterDB) DeleteFilterByID(ctx context.Context, id string) error { + if err := f.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // Delete all keywords attached to filter. + if _, err := tx. + NewDelete(). + Model((*gtsmodel.FilterKeyword)(nil)). + Where("? = ?", bun.Ident("filter_id"), id). + Exec(ctx); err != nil { + return err + } + + // Delete all statuses attached to filter. + if _, err := tx. + NewDelete(). + Model((*gtsmodel.FilterStatus)(nil)). + Where("? = ?", bun.Ident("filter_id"), id). + Exec(ctx); err != nil { + return err + } + + // Delete the filter itself. + _, err := tx. + NewDelete(). + Model((*gtsmodel.Filter)(nil)). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx) + return err + }); err != nil { + return err + } + + // Invalidate this filter. + f.state.Caches.GTS.Filter.Invalidate("ID", id) + + // Invalidate all keywords and statuses for this filter. + f.state.Caches.GTS.FilterKeyword.Invalidate("FilterID", id) + f.state.Caches.GTS.FilterStatus.Invalidate("FilterID", id) + + return nil +} diff --git a/internal/db/bundb/filter_test.go b/internal/db/bundb/filter_test.go new file mode 100644 index 000000000..7940b6651 --- /dev/null +++ b/internal/db/bundb/filter_test.go @@ -0,0 +1,252 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package bundb_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +type FilterTestSuite struct { + BunDBStandardTestSuite +} + +// TestFilterCRUD tests CRUD and read-all operations on filters. +func (suite *FilterTestSuite) TestFilterCRUD() { + t := suite.T() + + // Create new example filter with attached keyword. + filter := >smodel.Filter{ + ID: "01HNEJNVZZVXJTRB3FX3K2B1YF", + AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47", + Title: "foss jail", + Action: gtsmodel.FilterActionWarn, + ContextHome: util.Ptr(true), + ContextPublic: util.Ptr(true), + } + filterKeyword := >smodel.FilterKeyword{ + ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4", + AccountID: filter.AccountID, + FilterID: filter.ID, + Keyword: "GNU/Linux", + } + filter.Keywords = []*gtsmodel.FilterKeyword{filterKeyword} + + // Create new cancellable test context. + ctx := context.Background() + ctx, cncl := context.WithCancel(ctx) + defer cncl() + + // Insert the example filter into db. + if err := suite.db.PutFilter(ctx, filter); err != nil { + t.Fatalf("error inserting filter: %v", err) + } + + // Now fetch newly created filter. + check, err := suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching filter: %v", err) + } + + // Check all expected fields match. + suite.Equal(filter.ID, check.ID) + suite.Equal(filter.AccountID, check.AccountID) + suite.Equal(filter.Title, check.Title) + suite.Equal(filter.Action, check.Action) + suite.Equal(filter.ContextHome, check.ContextHome) + suite.Equal(filter.ContextNotifications, check.ContextNotifications) + suite.Equal(filter.ContextPublic, check.ContextPublic) + suite.Equal(filter.ContextThread, check.ContextThread) + suite.Equal(filter.ContextAccount, check.ContextAccount) + suite.NotZero(check.CreatedAt) + suite.NotZero(check.UpdatedAt) + + suite.Equal(len(filter.Keywords), len(check.Keywords)) + suite.Equal(filter.Keywords[0].ID, check.Keywords[0].ID) + suite.Equal(filter.Keywords[0].AccountID, check.Keywords[0].AccountID) + suite.Equal(filter.Keywords[0].FilterID, check.Keywords[0].FilterID) + suite.Equal(filter.Keywords[0].Keyword, check.Keywords[0].Keyword) + suite.Equal(filter.Keywords[0].FilterID, check.Keywords[0].FilterID) + suite.NotZero(check.Keywords[0].CreatedAt) + suite.NotZero(check.Keywords[0].UpdatedAt) + + suite.Equal(len(filter.Statuses), len(check.Statuses)) + + // Fetch all filters. + all, err := suite.db.GetFiltersForAccountID(ctx, filter.AccountID) + if err != nil { + t.Fatalf("error fetching filters: %v", err) + } + + // Ensure the result contains our example filter. + suite.Len(all, 1) + suite.Equal(filter.ID, all[0].ID) + + suite.Len(all[0].Keywords, 1) + suite.Equal(filter.Keywords[0].ID, all[0].Keywords[0].ID) + + suite.Empty(all[0].Statuses) + + // Update the filter context and add another keyword and a status. + check.ContextNotifications = util.Ptr(true) + + newKeyword := >smodel.FilterKeyword{ + ID: "01HNEMY810E5XKWDDMN5ZRE749", + FilterID: filter.ID, + AccountID: filter.AccountID, + Keyword: "tux", + } + check.Keywords = append(check.Keywords, newKeyword) + + newStatus := >smodel.FilterStatus{ + ID: "01HNEMYD5XE7C8HH8TNCZ76FN2", + FilterID: filter.ID, + AccountID: filter.AccountID, + StatusID: "01HNEKZW34SQZ8PSDQ0Z10NZES", + } + check.Statuses = append(check.Statuses, newStatus) + + if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil { + t.Fatalf("error updating filter: %v", err) + } + // Now fetch newly updated filter. + check, err = suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching updated filter: %v", err) + } + + // Ensure expected fields were modified on check filter. + suite.True(check.UpdatedAt.After(filter.UpdatedAt)) + if suite.NotNil(check.ContextHome) { + suite.True(*check.ContextHome) + } + if suite.NotNil(check.ContextNotifications) { + suite.True(*check.ContextNotifications) + } + if suite.NotNil(check.ContextPublic) { + suite.True(*check.ContextPublic) + } + if suite.NotNil(check.ContextThread) { + suite.False(*check.ContextThread) + } + if suite.NotNil(check.ContextAccount) { + suite.False(*check.ContextAccount) + } + + // Ensure keyword entries were added. + suite.Len(check.Keywords, 2) + checkFilterKeywordIDs := make([]string, 0, 2) + for _, checkFilterKeyword := range check.Keywords { + checkFilterKeywordIDs = append(checkFilterKeywordIDs, checkFilterKeyword.ID) + } + suite.ElementsMatch([]string{filterKeyword.ID, newKeyword.ID}, checkFilterKeywordIDs) + + // Ensure status entry was added. + suite.Len(check.Statuses, 1) + checkFilterStatusIDs := make([]string, 0, 1) + for _, checkFilterStatus := range check.Statuses { + checkFilterStatusIDs = append(checkFilterStatusIDs, checkFilterStatus.ID) + } + suite.ElementsMatch([]string{newStatus.ID}, checkFilterStatusIDs) + + // Update one filter keyword and delete another. Don't change the filter or the filter status. + filterKeyword.WholeWord = util.Ptr(true) + check.Keywords = []*gtsmodel.FilterKeyword{filterKeyword} + check.Statuses = nil + + if err := suite.db.UpdateFilter(ctx, check, nil, nil, []string{newKeyword.ID}, nil); err != nil { + t.Fatalf("error updating filter: %v", err) + } + check, err = suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching updated filter: %v", err) + } + + // Ensure expected fields were not modified. + suite.Equal(filter.Title, check.Title) + suite.Equal(gtsmodel.FilterActionWarn, check.Action) + if suite.NotNil(check.ContextHome) { + suite.True(*check.ContextHome) + } + if suite.NotNil(check.ContextNotifications) { + suite.True(*check.ContextNotifications) + } + if suite.NotNil(check.ContextPublic) { + suite.True(*check.ContextPublic) + } + if suite.NotNil(check.ContextThread) { + suite.False(*check.ContextThread) + } + if suite.NotNil(check.ContextAccount) { + suite.False(*check.ContextAccount) + } + + // Ensure only changed field of keyword was modified, and other keyword was deleted. + suite.Len(check.Keywords, 1) + suite.Equal(filterKeyword.ID, check.Keywords[0].ID) + suite.Equal("GNU/Linux", check.Keywords[0].Keyword) + if suite.NotNil(check.Keywords[0].WholeWord) { + suite.True(*check.Keywords[0].WholeWord) + } + + // Ensure status entry was not deleted. + suite.Len(check.Statuses, 1) + suite.Equal(newStatus.ID, check.Statuses[0].ID) + + // Add another status entry for the same status ID. It should be ignored without problems. + redundantStatus := >smodel.FilterStatus{ + ID: "01HQXJ5Y405XZSQ67C2BSQ6HJ0", + FilterID: filter.ID, + AccountID: filter.AccountID, + StatusID: newStatus.StatusID, + } + check.Statuses = []*gtsmodel.FilterStatus{redundantStatus} + if err := suite.db.UpdateFilter(ctx, check, nil, nil, nil, nil); err != nil { + t.Fatalf("error updating filter: %v", err) + } + check, err = suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching updated filter: %v", err) + } + + // Ensure status entry was not deleted, updated, or duplicated. + suite.Len(check.Statuses, 1) + suite.Equal(newStatus.ID, check.Statuses[0].ID) + suite.Equal(newStatus.StatusID, check.Statuses[0].StatusID) + + // Now delete the filter from the DB. + if err := suite.db.DeleteFilterByID(ctx, filter.ID); err != nil { + t.Fatalf("error deleting filter: %v", err) + } + + // Ensure we can't refetch it. + _, err = suite.db.GetFilterByID(ctx, filter.ID) + if !errors.Is(err, db.ErrNoEntries) { + t.Fatalf("fetching deleted filter returned unexpected error: %v", err) + } +} + +func TestFilterTestSuite(t *testing.T) { + suite.Run(t, new(FilterTestSuite)) +} diff --git a/internal/db/bundb/filterkeyword.go b/internal/db/bundb/filterkeyword.go new file mode 100644 index 000000000..703d58d43 --- /dev/null +++ b/internal/db/bundb/filterkeyword.go @@ -0,0 +1,191 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package bundb + +import ( + "context" + "slices" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" +) + +func (f *filterDB) GetFilterKeywordByID(ctx context.Context, id string) (*gtsmodel.FilterKeyword, error) { + filterKeyword, err := f.state.Caches.GTS.FilterKeyword.LoadOne( + "ID", + func() (*gtsmodel.FilterKeyword, error) { + var filterKeyword gtsmodel.FilterKeyword + err := f.db. + NewSelect(). + Model(&filterKeyword). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx) + return &filterKeyword, err + }, + id, + ) + if err != nil { + return nil, err + } + + if !gtscontext.Barebones(ctx) { + err = f.populateFilterKeyword(ctx, filterKeyword) + if err != nil { + return nil, err + } + } + + return filterKeyword, nil +} + +func (f *filterDB) populateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error { + if filterKeyword.Filter == nil { + // Filter is not set, fetch from the cache or database. + filter, err := f.state.DB.GetFilterByID( + // Don't populate the filter with all of its keywords and statuses or we'll just end up back here. + gtscontext.SetBarebones(ctx), + filterKeyword.FilterID, + ) + if err != nil { + return err + } + filterKeyword.Filter = filter + } + + return nil +} + +func (f *filterDB) GetFilterKeywordsForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterKeyword, error) { + return f.getFilterKeywords(ctx, "filter_id", filterID) +} + +func (f *filterDB) GetFilterKeywordsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterKeyword, error) { + return f.getFilterKeywords(ctx, "account_id", accountID) +} + +func (f *filterDB) getFilterKeywords(ctx context.Context, idColumn string, id string) ([]*gtsmodel.FilterKeyword, error) { + var filterKeywordIDs []string + if err := f.db. + NewSelect(). + Model((*gtsmodel.FilterKeyword)(nil)). + Column("id"). + Where("? = ?", bun.Ident(idColumn), id). + Scan(ctx, &filterKeywordIDs); err != nil { + return nil, err + } + if len(filterKeywordIDs) == 0 { + return nil, nil + } + + // Get each filter keyword by ID from the cache or DB. + uncachedFilterKeywordIDs := make([]string, 0, len(filterKeywordIDs)) + filterKeywords, err := f.state.Caches.GTS.FilterKeyword.Load( + "ID", + func(load func(keyParts ...any) bool) { + for _, id := range filterKeywordIDs { + if !load(id) { + uncachedFilterKeywordIDs = append(uncachedFilterKeywordIDs, id) + } + } + }, + func() ([]*gtsmodel.FilterKeyword, error) { + uncachedFilterKeywords := make([]*gtsmodel.FilterKeyword, 0, len(uncachedFilterKeywordIDs)) + if err := f.db. + NewSelect(). + Model(&uncachedFilterKeywords). + Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterKeywordIDs)). + Scan(ctx); err != nil { + return nil, err + } + return uncachedFilterKeywords, nil + }, + ) + if err != nil { + return nil, err + } + + // Put the filter keyword structs in the same order as the filter keyword IDs. + util.OrderBy(filterKeywords, filterKeywordIDs, func(filterKeyword *gtsmodel.FilterKeyword) string { + return filterKeyword.ID + }) + + if gtscontext.Barebones(ctx) { + return filterKeywords, nil + } + + // Populate the filter keywords. Remove any that we can't populate from the return slice. + errs := gtserror.NewMultiError(len(filterKeywords)) + filterKeywords = slices.DeleteFunc(filterKeywords, func(filterKeyword *gtsmodel.FilterKeyword) bool { + if err := f.populateFilterKeyword(ctx, filterKeyword); err != nil { + errs.Appendf( + "error populating filter keyword %s: %w", + filterKeyword.ID, + err, + ) + return true + } + return false + }) + + return filterKeywords, errs.Combine() +} + +func (f *filterDB) PutFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error { + return f.state.Caches.GTS.FilterKeyword.Store(filterKeyword, func() error { + _, err := f.db. + NewInsert(). + Model(filterKeyword). + Exec(ctx) + return err + }) +} + +func (f *filterDB) UpdateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword, columns ...string) error { + filterKeyword.UpdatedAt = time.Now() + if len(columns) > 0 { + columns = append(columns, "updated_at") + } + + return f.state.Caches.GTS.FilterKeyword.Store(filterKeyword, func() error { + _, err := f.db. + NewUpdate(). + Model(filterKeyword). + Where("? = ?", bun.Ident("id"), filterKeyword.ID). + Column(columns...). + Exec(ctx) + return err + }) +} + +func (f *filterDB) DeleteFilterKeywordByID(ctx context.Context, id string) error { + if _, err := f.db. + NewDelete(). + Model((*gtsmodel.FilterKeyword)(nil)). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return err + } + + f.state.Caches.GTS.FilterKeyword.Invalidate("ID", id) + + return nil +} diff --git a/internal/db/bundb/filterkeyword_test.go b/internal/db/bundb/filterkeyword_test.go new file mode 100644 index 000000000..91c8d192c --- /dev/null +++ b/internal/db/bundb/filterkeyword_test.go @@ -0,0 +1,143 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package bundb_test + +import ( + "context" + "errors" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +// TestFilterKeywordCRUD tests CRUD and read-all operations on filter keywords. +func (suite *FilterTestSuite) TestFilterKeywordCRUD() { + t := suite.T() + + // Create new filter. + filter := >smodel.Filter{ + ID: "01HNEJNVZZVXJTRB3FX3K2B1YF", + AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47", + Title: "foss jail", + Action: gtsmodel.FilterActionWarn, + ContextHome: util.Ptr(true), + ContextPublic: util.Ptr(true), + } + + // Create new cancellable test context. + ctx := context.Background() + ctx, cncl := context.WithCancel(ctx) + defer cncl() + + // Insert the new filter into the DB. + err := suite.db.PutFilter(ctx, filter) + if err != nil { + t.Fatalf("error inserting filter: %v", err) + } + + // There should be no filter keywords yet. + all, err := suite.db.GetFilterKeywordsForAccountID(ctx, filter.AccountID) + if err != nil { + t.Fatalf("error fetching filter keywords: %v", err) + } + suite.Empty(all) + + // Add a filter keyword to it. + filterKeyword := >smodel.FilterKeyword{ + ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4", + AccountID: filter.AccountID, + FilterID: filter.ID, + Keyword: "GNU/Linux", + } + + // Insert the new filter keyword into the DB. + err = suite.db.PutFilterKeyword(ctx, filterKeyword) + if err != nil { + t.Fatalf("error inserting filter keyword: %v", err) + } + + // Try to find it again and ensure it has the fields we expect. + check, err := suite.db.GetFilterKeywordByID(ctx, filterKeyword.ID) + if err != nil { + t.Fatalf("error fetching filter keyword: %v", err) + } + suite.Equal(filterKeyword.ID, check.ID) + suite.NotZero(check.CreatedAt) + suite.NotZero(check.UpdatedAt) + suite.Equal(filterKeyword.AccountID, check.AccountID) + suite.Equal(filterKeyword.FilterID, check.FilterID) + suite.Equal(filterKeyword.Keyword, check.Keyword) + suite.Equal(filterKeyword.WholeWord, check.WholeWord) + + // Loading filter keywords by account ID should find the one we inserted. + all, err = suite.db.GetFilterKeywordsForAccountID(ctx, filter.AccountID) + if err != nil { + t.Fatalf("error fetching filter keywords: %v", err) + } + suite.Len(all, 1) + suite.Equal(filterKeyword.ID, all[0].ID) + + // Loading filter keywords by filter ID should also find the one we inserted. + all, err = suite.db.GetFilterKeywordsForFilterID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching filter keywords: %v", err) + } + suite.Len(all, 1) + suite.Equal(filterKeyword.ID, all[0].ID) + + // Modify the filter keyword. + filterKeyword.WholeWord = util.Ptr(true) + err = suite.db.UpdateFilterKeyword(ctx, filterKeyword) + if err != nil { + t.Fatalf("error updating filter keyword: %v", err) + } + + // Try to find it again and ensure it has the updated fields we expect. + check, err = suite.db.GetFilterKeywordByID(ctx, filterKeyword.ID) + if err != nil { + t.Fatalf("error fetching filter keyword: %v", err) + } + suite.Equal(filterKeyword.ID, check.ID) + suite.NotZero(check.CreatedAt) + suite.True(check.UpdatedAt.After(check.CreatedAt)) + suite.Equal(filterKeyword.AccountID, check.AccountID) + suite.Equal(filterKeyword.FilterID, check.FilterID) + suite.Equal(filterKeyword.Keyword, check.Keyword) + suite.Equal(filterKeyword.WholeWord, check.WholeWord) + + // Delete the filter keyword from the DB. + err = suite.db.DeleteFilterKeywordByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error deleting filter keyword: %v", err) + } + + // Ensure we can't refetch it. + check, err = suite.db.GetFilterKeywordByID(ctx, filter.ID) + if !errors.Is(err, db.ErrNoEntries) { + t.Fatalf("fetching deleted filter keyword returned unexpected error: %v", err) + } + suite.Nil(check) + + // Ensure the filter itself is still there. + checkFilter, err := suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching filter: %v", err) + } + suite.Equal(filter.ID, checkFilter.ID) +} diff --git a/internal/db/bundb/filterstatus.go b/internal/db/bundb/filterstatus.go new file mode 100644 index 000000000..1e98f5958 --- /dev/null +++ b/internal/db/bundb/filterstatus.go @@ -0,0 +1,191 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package bundb + +import ( + "context" + "slices" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/uptrace/bun" +) + +func (f *filterDB) GetFilterStatusByID(ctx context.Context, id string) (*gtsmodel.FilterStatus, error) { + filterStatus, err := f.state.Caches.GTS.FilterStatus.LoadOne( + "ID", + func() (*gtsmodel.FilterStatus, error) { + var filterStatus gtsmodel.FilterStatus + err := f.db. + NewSelect(). + Model(&filterStatus). + Where("? = ?", bun.Ident("id"), id). + Scan(ctx) + return &filterStatus, err + }, + id, + ) + if err != nil { + return nil, err + } + + if !gtscontext.Barebones(ctx) { + err = f.populateFilterStatus(ctx, filterStatus) + if err != nil { + return nil, err + } + } + + return filterStatus, nil +} + +func (f *filterDB) populateFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error { + if filterStatus.Filter == nil { + // Filter is not set, fetch from the cache or database. + filter, err := f.state.DB.GetFilterByID( + // Don't populate the filter with all of its keywords and statuses or we'll just end up back here. + gtscontext.SetBarebones(ctx), + filterStatus.FilterID, + ) + if err != nil { + return err + } + filterStatus.Filter = filter + } + + return nil +} + +func (f *filterDB) GetFilterStatusesForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterStatus, error) { + return f.getFilterStatuses(ctx, "filter_id", filterID) +} + +func (f *filterDB) GetFilterStatusesForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterStatus, error) { + return f.getFilterStatuses(ctx, "account_id", accountID) +} + +func (f *filterDB) getFilterStatuses(ctx context.Context, idColumn string, id string) ([]*gtsmodel.FilterStatus, error) { + var filterStatusIDs []string + if err := f.db. + NewSelect(). + Model((*gtsmodel.FilterStatus)(nil)). + Column("id"). + Where("? = ?", bun.Ident(idColumn), id). + Scan(ctx, &filterStatusIDs); err != nil { + return nil, err + } + if len(filterStatusIDs) == 0 { + return nil, nil + } + + // Get each filter status by ID from the cache or DB. + uncachedFilterStatusIDs := make([]string, 0, len(filterStatusIDs)) + filterStatuses, err := f.state.Caches.GTS.FilterStatus.Load( + "ID", + func(load func(keyParts ...any) bool) { + for _, id := range filterStatusIDs { + if !load(id) { + uncachedFilterStatusIDs = append(uncachedFilterStatusIDs, id) + } + } + }, + func() ([]*gtsmodel.FilterStatus, error) { + uncachedFilterStatuses := make([]*gtsmodel.FilterStatus, 0, len(uncachedFilterStatusIDs)) + if err := f.db. + NewSelect(). + Model(&uncachedFilterStatuses). + Where("? IN (?)", bun.Ident("id"), bun.In(uncachedFilterStatusIDs)). + Scan(ctx); err != nil { + return nil, err + } + return uncachedFilterStatuses, nil + }, + ) + if err != nil { + return nil, err + } + + // Put the filter status structs in the same order as the filter status IDs. + util.OrderBy(filterStatuses, filterStatusIDs, func(filterStatus *gtsmodel.FilterStatus) string { + return filterStatus.ID + }) + + if gtscontext.Barebones(ctx) { + return filterStatuses, nil + } + + // Populate the filter statuses. Remove any that we can't populate from the return slice. + errs := gtserror.NewMultiError(len(filterStatuses)) + filterStatuses = slices.DeleteFunc(filterStatuses, func(filterStatus *gtsmodel.FilterStatus) bool { + if err := f.populateFilterStatus(ctx, filterStatus); err != nil { + errs.Appendf( + "error populating filter status %s: %w", + filterStatus.ID, + err, + ) + return true + } + return false + }) + + return filterStatuses, errs.Combine() +} + +func (f *filterDB) PutFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error { + return f.state.Caches.GTS.FilterStatus.Store(filterStatus, func() error { + _, err := f.db. + NewInsert(). + Model(filterStatus). + Exec(ctx) + return err + }) +} + +func (f *filterDB) UpdateFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus, columns ...string) error { + filterStatus.UpdatedAt = time.Now() + if len(columns) > 0 { + columns = append(columns, "updated_at") + } + + return f.state.Caches.GTS.FilterStatus.Store(filterStatus, func() error { + _, err := f.db. + NewUpdate(). + Model(filterStatus). + Where("? = ?", bun.Ident("id"), filterStatus.ID). + Column(columns...). + Exec(ctx) + return err + }) +} + +func (f *filterDB) DeleteFilterStatusByID(ctx context.Context, id string) error { + if _, err := f.db. + NewDelete(). + Model((*gtsmodel.FilterStatus)(nil)). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return err + } + + f.state.Caches.GTS.FilterStatus.Invalidate("ID", id) + + return nil +} diff --git a/internal/db/bundb/filterstatus_test.go b/internal/db/bundb/filterstatus_test.go new file mode 100644 index 000000000..48ddb1bed --- /dev/null +++ b/internal/db/bundb/filterstatus_test.go @@ -0,0 +1,122 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package bundb_test + +import ( + "context" + "errors" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +// TestFilterStatusCRD tests CRD (no U) and read-all operations on filter statuses. +func (suite *FilterTestSuite) TestFilterStatusCRD() { + t := suite.T() + + // Create new filter. + filter := >smodel.Filter{ + ID: "01HNEJNVZZVXJTRB3FX3K2B1YF", + AccountID: "01HNEJXCPRTJVJY9MV0VVHGD47", + Title: "foss jail", + Action: gtsmodel.FilterActionWarn, + ContextHome: util.Ptr(true), + ContextPublic: util.Ptr(true), + } + + // Create new cancellable test context. + ctx := context.Background() + ctx, cncl := context.WithCancel(ctx) + defer cncl() + + // Insert the new filter into the DB. + err := suite.db.PutFilter(ctx, filter) + if err != nil { + t.Fatalf("error inserting filter: %v", err) + } + + // There should be no filter statuses yet. + all, err := suite.db.GetFilterStatusesForAccountID(ctx, filter.AccountID) + if err != nil { + t.Fatalf("error fetching filter statuses: %v", err) + } + suite.Empty(all) + + // Add a filter status to it. + filterStatus := >smodel.FilterStatus{ + ID: "01HNEK4RW5QEAMG9Y4ET6ST0J4", + AccountID: filter.AccountID, + FilterID: filter.ID, + StatusID: "01HQXGMQ3QFXRT4GX9WNQ8KC0X", + } + + // Insert the new filter status into the DB. + err = suite.db.PutFilterStatus(ctx, filterStatus) + if err != nil { + t.Fatalf("error inserting filter status: %v", err) + } + + // Try to find it again and ensure it has the fields we expect. + check, err := suite.db.GetFilterStatusByID(ctx, filterStatus.ID) + if err != nil { + t.Fatalf("error fetching filter status: %v", err) + } + suite.Equal(filterStatus.ID, check.ID) + suite.NotZero(check.CreatedAt) + suite.NotZero(check.UpdatedAt) + suite.Equal(filterStatus.AccountID, check.AccountID) + suite.Equal(filterStatus.FilterID, check.FilterID) + suite.Equal(filterStatus.StatusID, check.StatusID) + + // Loading filter statuses by account ID should find the one we inserted. + all, err = suite.db.GetFilterStatusesForAccountID(ctx, filter.AccountID) + if err != nil { + t.Fatalf("error fetching filter statuses: %v", err) + } + suite.Len(all, 1) + suite.Equal(filterStatus.ID, all[0].ID) + + // Loading filter statuses by filter ID should also find the one we inserted. + all, err = suite.db.GetFilterStatusesForFilterID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching filter statuses: %v", err) + } + suite.Len(all, 1) + suite.Equal(filterStatus.ID, all[0].ID) + + // Delete the filter status from the DB. + err = suite.db.DeleteFilterStatusByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error deleting filter status: %v", err) + } + + // Ensure we can't refetch it. + check, err = suite.db.GetFilterStatusByID(ctx, filter.ID) + if !errors.Is(err, db.ErrNoEntries) { + t.Fatalf("fetching deleted filter status returned unexpected error: %v", err) + } + suite.Nil(check) + + // Ensure the filter itself is still there. + checkFilter, err := suite.db.GetFilterByID(ctx, filter.ID) + if err != nil { + t.Fatalf("error fetching filter: %v", err) + } + suite.Equal(filter.ID, checkFilter.ID) +} diff --git a/internal/db/bundb/migrations/20240126064004_add_filters.go b/internal/db/bundb/migrations/20240126064004_add_filters.go new file mode 100644 index 000000000..3ad22f9d8 --- /dev/null +++ b/internal/db/bundb/migrations/20240126064004_add_filters.go @@ -0,0 +1,97 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package migrations + +import ( + "context" + + gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // Filter table. + if _, err := tx. + NewCreateTable(). + Model(>smodel.Filter{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + // Filter keyword table. + if _, err := tx. + NewCreateTable(). + Model(>smodel.FilterKeyword{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + // Filter status table. + if _, err := tx. + NewCreateTable(). + Model(>smodel.FilterStatus{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + // Add indexes to the filter tables. + for table, indexes := range map[string]map[string][]string{ + "filters": { + "filters_account_id_idx": {"account_id"}, + }, + "filter_keywords": { + "filter_keywords_account_id_idx": {"account_id"}, + "filter_keywords_filter_id_idx": {"filter_id"}, + }, + "filter_statuses": { + "filter_statuses_account_id_idx": {"account_id"}, + "filter_statuses_filter_id_idx": {"filter_id"}, + }, + } { + for index, columns := range indexes { + if _, err := tx. + NewCreateIndex(). + Table(table). + Index(index). + Column(columns...). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + } + } + + return nil + }) + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} diff --git a/internal/db/bundb/upsert.go b/internal/db/bundb/upsert.go new file mode 100644 index 000000000..34724446c --- /dev/null +++ b/internal/db/bundb/upsert.go @@ -0,0 +1,230 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package bundb + +import ( + "context" + "database/sql" + "reflect" + "strings" + + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect" +) + +// UpsertQuery is a wrapper around an insert query that can update if an insert fails. +// Doesn't implement the full set of Bun query methods, but we can add more if we need them. +// See https://bun.uptrace.dev/guide/query-insert.html#upsert +type UpsertQuery struct { + db bun.IDB + model interface{} + constraints []string + columns []string +} + +func NewUpsert(idb bun.IDB) *UpsertQuery { + // note: passing in rawtx as conn iface so no double query-hook + // firing when passed through the bun.Tx.Query___() functions. + return &UpsertQuery{db: idb} +} + +// Model sets the model or models to upsert. +func (u *UpsertQuery) Model(model interface{}) *UpsertQuery { + u.model = model + return u +} + +// Constraint sets the columns or indexes that are used to check for conflicts. +// This is required. +func (u *UpsertQuery) Constraint(constraints ...string) *UpsertQuery { + u.constraints = constraints + return u +} + +// Column sets the columns to update if an insert does't happen. +// If empty, all columns not being used for constraints will be updated. +// Cannot overlap with Constraint. +func (u *UpsertQuery) Column(columns ...string) *UpsertQuery { + u.columns = columns + return u +} + +// insertDialect errors if we're using a dialect in which we don't know how to upsert. +func (u *UpsertQuery) insertDialect() error { + dialectName := u.db.Dialect().Name() + switch dialectName { + case dialect.PG, dialect.SQLite: + return nil + default: + // FUTURE: MySQL has its own variation on upserts, but the syntax is different. + return gtserror.Newf("UpsertQuery: upsert not supported by SQL dialect: %s", dialectName) + } +} + +// insertConstraints checks that we have constraints and returns them. +func (u *UpsertQuery) insertConstraints() ([]string, error) { + if len(u.constraints) == 0 { + return nil, gtserror.New("UpsertQuery: upserts require at least one constraint column or index, none provided") + } + return u.constraints, nil +} + +// insertColumns returns the non-constraint columns we'll be updating. +func (u *UpsertQuery) insertColumns(constraints []string) ([]string, error) { + // Constraints as a set. + constraintSet := make(map[string]struct{}, len(constraints)) + for _, constraint := range constraints { + constraintSet[constraint] = struct{}{} + } + + var columns []string + var err error + if len(u.columns) == 0 { + columns, err = u.insertColumnsDefault(constraintSet) + } else { + columns, err = u.insertColumnsSpecified(constraintSet) + } + if err != nil { + return nil, err + } + if len(columns) == 0 { + return nil, gtserror.New("UpsertQuery: there are no columns to update when upserting") + } + + return columns, nil +} + +// hasElem returns whether the type has an element and can call [reflect.Type.Elem] without panicking. +func hasElem(modelType reflect.Type) bool { + switch modelType.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Pointer, reflect.Slice: + return true + default: + return false + } +} + +// insertColumnsDefault returns all non-constraint columns from the model schema. +func (u *UpsertQuery) insertColumnsDefault(constraintSet map[string]struct{}) ([]string, error) { + // Get underlying struct type. + modelType := reflect.TypeOf(u.model) + for hasElem(modelType) { + modelType = modelType.Elem() + } + + table := u.db.Dialect().Tables().Get(modelType) + if table == nil { + return nil, gtserror.Newf("UpsertQuery: couldn't find the table schema for model: %v", u.model) + } + + columns := make([]string, 0, len(u.columns)) + for _, field := range table.Fields { + column := field.Name + if _, overlaps := constraintSet[column]; !overlaps { + columns = append(columns, column) + } + } + + return columns, nil +} + +// insertColumnsSpecified ensures constraints and specified columns to update don't overlap. +func (u *UpsertQuery) insertColumnsSpecified(constraintSet map[string]struct{}) ([]string, error) { + overlapping := make([]string, 0, min(len(u.constraints), len(u.columns))) + for _, column := range u.columns { + if _, overlaps := constraintSet[column]; overlaps { + overlapping = append(overlapping, column) + } + } + + if len(overlapping) > 0 { + return nil, gtserror.Newf( + "UpsertQuery: the following columns can't be used for both constraints and columns to update: %s", + strings.Join(overlapping, ", "), + ) + } + + return u.columns, nil +} + +// insert tries to create a Bun insert query from an upsert query. +func (u *UpsertQuery) insertQuery() (*bun.InsertQuery, error) { + var err error + + err = u.insertDialect() + if err != nil { + return nil, err + } + + constraints, err := u.insertConstraints() + if err != nil { + return nil, err + } + + columns, err := u.insertColumns(constraints) + if err != nil { + return nil, err + } + + // Build the parts of the query that need us to generate SQL. + constraintIDPlaceholders := make([]string, 0, len(constraints)) + constraintIDs := make([]interface{}, 0, len(constraints)) + for _, constraint := range constraints { + constraintIDPlaceholders = append(constraintIDPlaceholders, "?") + constraintIDs = append(constraintIDs, bun.Ident(constraint)) + } + onSQL := "conflict (" + strings.Join(constraintIDPlaceholders, ", ") + ") do update" + + setClauses := make([]string, 0, len(columns)) + setIDs := make([]interface{}, 0, 2*len(columns)) + for _, column := range columns { + // "excluded" is a special table that contains only the row involved in a conflict. + setClauses = append(setClauses, "? = excluded.?") + setIDs = append(setIDs, bun.Ident(column), bun.Ident(column)) + } + setSQL := strings.Join(setClauses, ", ") + + insertQuery := u.db. + NewInsert(). + Model(u.model). + On(onSQL, constraintIDs...). + Set(setSQL, setIDs...) + + return insertQuery, nil +} + +// Exec builds a Bun insert query from the upsert query, and executes it. +func (u *UpsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { + insertQuery, err := u.insertQuery() + if err != nil { + return nil, err + } + + return insertQuery.Exec(ctx, dest...) +} + +// Scan builds a Bun insert query from the upsert query, and scans it. +func (u *UpsertQuery) Scan(ctx context.Context, dest ...interface{}) error { + insertQuery, err := u.insertQuery() + if err != nil { + return err + } + + return insertQuery.Scan(ctx, dest...) +} diff --git a/internal/db/db.go b/internal/db/db.go index 361687e94..f23324777 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -32,6 +32,7 @@ type DB interface { Emoji HeaderFilter Instance + Filter List Marker Media diff --git a/internal/db/filter.go b/internal/db/filter.go new file mode 100644 index 000000000..18943b4f9 --- /dev/null +++ b/internal/db/filter.go @@ -0,0 +1,101 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package db + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// Filter contains methods for creating, reading, updating, and deleting filters and their keyword and status entries. +type Filter interface { + // + + // GetFilterByID gets one filter with the given id. + GetFilterByID(ctx context.Context, id string) (*gtsmodel.Filter, error) + + // GetFiltersForAccountID gets all filters owned by the given accountID. + GetFiltersForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.Filter, error) + + // PutFilter puts a new filter in the database, adding any attached keywords or statuses. + // It uses a transaction to ensure no partial updates. + PutFilter(ctx context.Context, filter *gtsmodel.Filter) error + + // UpdateFilter updates the given filter, + // upserts any attached keywords and inserts any new statuses (existing statuses cannot be updated), + // and deletes indicated filter keywords and statuses by ID. + // It uses a transaction to ensure no partial updates. + // The column lists are optional; if not specified, all columns will be updated. + UpdateFilter( + ctx context.Context, + filter *gtsmodel.Filter, + filterColumns []string, + filterKeywordColumns []string, + deleteFilterKeywordIDs []string, + deleteFilterStatusIDs []string, + ) error + + // DeleteFilterByID deletes one filter with the given ID. + // It uses a transaction to ensure no partial updates. + DeleteFilterByID(ctx context.Context, id string) error + + // + + // + + // GetFilterKeywordByID gets one filter keyword with the given ID. + GetFilterKeywordByID(ctx context.Context, id string) (*gtsmodel.FilterKeyword, error) + + // GetFilterKeywordsForFilterID gets filter keywords from the given filterID. + GetFilterKeywordsForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterKeyword, error) + + // GetFilterKeywordsForAccountID gets filter keywords from the given accountID. + GetFilterKeywordsForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterKeyword, error) + + // PutFilterKeyword inserts a single filter keyword into the database. + PutFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) error + + // UpdateFilterKeyword updates the given filter keyword. + // Columns is optional, if not specified all will be updated. + UpdateFilterKeyword(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword, columns ...string) error + + // DeleteFilterKeywordByID deletes one filter keyword with the given id. + DeleteFilterKeywordByID(ctx context.Context, id string) error + + // + + // + + // GetFilterStatusByID gets one filter status with the given ID. + GetFilterStatusByID(ctx context.Context, id string) (*gtsmodel.FilterStatus, error) + + // GetFilterStatusesForFilterID gets filter statuses from the given filterID. + GetFilterStatusesForFilterID(ctx context.Context, filterID string) ([]*gtsmodel.FilterStatus, error) + + // GetFilterStatusesForAccountID gets filter keywords from the given accountID. + GetFilterStatusesForAccountID(ctx context.Context, accountID string) ([]*gtsmodel.FilterStatus, error) + + // PutFilterStatus inserts a single filter status into the database. + PutFilterStatus(ctx context.Context, filterStatus *gtsmodel.FilterStatus) error + + // DeleteFilterStatusByID deletes one filter status with the given id. + DeleteFilterStatusByID(ctx context.Context, id string) error + + // +} diff --git a/internal/gtsmodel/filter.go b/internal/gtsmodel/filter.go new file mode 100644 index 000000000..db0a15dfd --- /dev/null +++ b/internal/gtsmodel/filter.go @@ -0,0 +1,71 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package gtsmodel + +import "time" + +// Filter stores a filter created by a local account. +type Filter struct { + ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database + CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created + UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated + ExpiresAt time.Time `bun:"type:timestamptz,nullzero"` // Time filter should expire. If null, should not expire. + AccountID string `bun:"type:CHAR(26),notnull,nullzero"` // ID of the local account that created the filter. + Title string `bun:",nullzero,notnull,unique"` // The name of the filter. + Action FilterAction `bun:",nullzero,notnull"` // The action to take. + Keywords []*FilterKeyword `bun:"-"` // Keywords for this filter. + Statuses []*FilterStatus `bun:"-"` // Statuses for this filter. + ContextHome *bool `bun:",nullzero,notnull,default:false"` // Apply filter to home timeline and lists. + ContextNotifications *bool `bun:",nullzero,notnull,default:false"` // Apply filter to notifications. + ContextPublic *bool `bun:",nullzero,notnull,default:false"` // Apply filter to home timeline and lists. + ContextThread *bool `bun:",nullzero,notnull,default:false"` // Apply filter when viewing a status's associated thread. + ContextAccount *bool `bun:",nullzero,notnull,default:false"` // Apply filter when viewing an account profile. +} + +// FilterKeyword stores a single keyword to filter statuses against. +type FilterKeyword struct { + ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database + CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created + UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated + AccountID string `bun:"type:CHAR(26),notnull,nullzero"` // ID of the local account that created the filter keyword. + FilterID string `bun:"type:CHAR(26),notnull,nullzero,unique:filter_keywords_filter_id_keyword_uniq"` // ID of the filter that this keyword belongs to. + Filter *Filter `bun:"-"` // Filter corresponding to FilterID + Keyword string `bun:",nullzero,notnull,unique:filter_keywords_filter_id_keyword_uniq"` // The keyword or phrase to filter against. + WholeWord *bool `bun:",nullzero,notnull,default:false"` // Should the filter consider word boundaries? +} + +// FilterStatus stores a single status to filter. +type FilterStatus struct { + ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // id of this item in the database + CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created + UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item last updated + AccountID string `bun:"type:CHAR(26),notnull,nullzero"` // ID of the local account that created the filter keyword. + FilterID string `bun:"type:CHAR(26),notnull,nullzero,unique:filter_statuses_filter_id_status_id_uniq"` // ID of the filter that this keyword belongs to. + Filter *Filter `bun:"-"` // Filter corresponding to FilterID + StatusID string `bun:"type:CHAR(26),notnull,nullzero,unique:filter_statuses_filter_id_status_id_uniq"` // ID of the status to filter. +} + +// FilterAction represents the action to take on a filtered status. +type FilterAction string + +const ( + // FilterActionWarn means that the status should be shown behind a warning. + FilterActionWarn FilterAction = "warn" + // FilterActionHide means that the status should be removed from timeline results entirely. + FilterActionHide FilterAction = "hide" +) diff --git a/internal/processing/filters/v1/convert.go b/internal/processing/filters/v1/convert.go new file mode 100644 index 000000000..1e0db5ff1 --- /dev/null +++ b/internal/processing/filters/v1/convert.go @@ -0,0 +1,38 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1 + +import ( + "context" + "fmt" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// apiFilter is a shortcut to return the API v1 filter version of the given +// filter keyword, or return an appropriate error if conversion fails. +func (p *Processor) apiFilter(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) (*apimodel.FilterV1, gtserror.WithCode) { + apiFilter, err := p.converter.FilterKeywordToAPIFilterV1(ctx, filterKeyword) + if err != nil { + return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting filter keyword to API v1 filter: %w", err)) + } + + return apiFilter, nil +} diff --git a/internal/processing/filters/v1/create.go b/internal/processing/filters/v1/create.go new file mode 100644 index 000000000..e36d6800a --- /dev/null +++ b/internal/processing/filters/v1/create.go @@ -0,0 +1,87 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1 + +import ( + "context" + "errors" + "fmt" + "time" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +// Create a new filter and filter keyword for the given account, using the provided parameters. +// These params should have already been validated by the time they reach this function. +func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form *apimodel.FilterCreateUpdateRequestV1) (*apimodel.FilterV1, gtserror.WithCode) { + filter := >smodel.Filter{ + ID: id.NewULID(), + AccountID: account.ID, + Title: form.Phrase, + Action: gtsmodel.FilterActionWarn, + } + if *form.Irreversible { + filter.Action = gtsmodel.FilterActionHide + } + if form.ExpiresIn != nil { + filter.ExpiresAt = time.Now().Add(time.Second * time.Duration(*form.ExpiresIn)) + } + for _, context := range form.Context { + switch context { + case apimodel.FilterContextHome: + filter.ContextHome = util.Ptr(true) + case apimodel.FilterContextNotifications: + filter.ContextNotifications = util.Ptr(true) + case apimodel.FilterContextPublic: + filter.ContextPublic = util.Ptr(true) + case apimodel.FilterContextThread: + filter.ContextThread = util.Ptr(true) + case apimodel.FilterContextAccount: + filter.ContextAccount = util.Ptr(true) + default: + return nil, gtserror.NewErrorUnprocessableEntity( + fmt.Errorf("unsupported filter context '%s'", context), + ) + } + } + + filterKeyword := >smodel.FilterKeyword{ + ID: id.NewULID(), + AccountID: account.ID, + FilterID: filter.ID, + Filter: filter, + Keyword: form.Phrase, + WholeWord: util.Ptr(util.PtrValueOr(form.WholeWord, false)), + } + filter.Keywords = []*gtsmodel.FilterKeyword{filterKeyword} + + if err := p.state.DB.PutFilter(ctx, filter); err != nil { + if errors.Is(err, db.ErrAlreadyExists) { + err = errors.New("you already have a filter with this title") + return nil, gtserror.NewErrorConflict(err, err.Error()) + } + return nil, gtserror.NewErrorInternalError(err) + } + + return p.apiFilter(ctx, filterKeyword) +} diff --git a/internal/processing/filters/v1/delete.go b/internal/processing/filters/v1/delete.go new file mode 100644 index 000000000..f2312f039 --- /dev/null +++ b/internal/processing/filters/v1/delete.go @@ -0,0 +1,67 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1 + +import ( + "context" + "errors" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// Delete an existing filter keyword and (if empty afterwards) filter for the given account. +func (p *Processor) Delete( + ctx context.Context, + account *gtsmodel.Account, + filterKeywordID string, +) gtserror.WithCode { + // Get enough of the filter keyword that we can look up its filter ID. + filterKeyword, err := p.state.DB.GetFilterKeywordByID(gtscontext.SetBarebones(ctx), filterKeywordID) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + return gtserror.NewErrorNotFound(err) + } + return gtserror.NewErrorInternalError(err) + } + if filterKeyword.AccountID != account.ID { + return gtserror.NewErrorNotFound(nil) + } + + // Get the filter for this keyword. + filter, err := p.state.DB.GetFilterByID(ctx, filterKeyword.FilterID) + if err != nil { + return gtserror.NewErrorNotFound(err) + } + + if len(filter.Keywords) > 1 || len(filter.Statuses) > 0 { + // The filter has other keywords or statuses. Delete only the requested filter keyword. + if err := p.state.DB.DeleteFilterKeywordByID(ctx, filterKeyword.ID); err != nil { + return gtserror.NewErrorInternalError(err) + } + } else { + // Delete the entire filter. + if err := p.state.DB.DeleteFilterByID(ctx, filter.ID); err != nil { + return gtserror.NewErrorInternalError(err) + } + } + + return nil +} diff --git a/internal/processing/filters/v1/filters.go b/internal/processing/filters/v1/filters.go new file mode 100644 index 000000000..d46c9e72c --- /dev/null +++ b/internal/processing/filters/v1/filters.go @@ -0,0 +1,35 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1 + +import ( + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" +) + +type Processor struct { + state *state.State + converter *typeutils.Converter +} + +func New(state *state.State, converter *typeutils.Converter) Processor { + return Processor{ + state: state, + converter: converter, + } +} diff --git a/internal/processing/filters/v1/get.go b/internal/processing/filters/v1/get.go new file mode 100644 index 000000000..39575dd94 --- /dev/null +++ b/internal/processing/filters/v1/get.go @@ -0,0 +1,78 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1 + +import ( + "context" + "errors" + "slices" + "strings" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// Get looks up a filter keyword by ID and returns it as a v1 filter. +func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, filterKeywordID string) (*apimodel.FilterV1, gtserror.WithCode) { + filterKeyword, err := p.state.DB.GetFilterKeywordByID(ctx, filterKeywordID) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + return nil, gtserror.NewErrorNotFound(err) + } + return nil, gtserror.NewErrorInternalError(err) + } + if filterKeyword.AccountID != account.ID { + return nil, gtserror.NewErrorNotFound(nil) + } + + return p.apiFilter(ctx, filterKeyword) +} + +// GetAll looks up all filter keywords for the current account and returns them as v1 filters. +func (p *Processor) GetAll(ctx context.Context, account *gtsmodel.Account) ([]*apimodel.FilterV1, gtserror.WithCode) { + filters, err := p.state.DB.GetFilterKeywordsForAccountID( + ctx, + account.ID, + ) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + return nil, nil + } + return nil, gtserror.NewErrorInternalError(err) + } + + apiFilters := make([]*apimodel.FilterV1, 0, len(filters)) + for _, list := range filters { + apiFilter, errWithCode := p.apiFilter(ctx, list) + if errWithCode != nil { + return nil, errWithCode + } + + apiFilters = append(apiFilters, apiFilter) + } + + // Sort them by ID so that they're in a stable order. + // Clients may opt to sort them lexically in a locale-aware manner. + slices.SortFunc(apiFilters, func(lhs *apimodel.FilterV1, rhs *apimodel.FilterV1) int { + return strings.Compare(lhs.ID, rhs.ID) + }) + + return apiFilters, nil +} diff --git a/internal/processing/filters/v1/update.go b/internal/processing/filters/v1/update.go new file mode 100644 index 000000000..1fe49721b --- /dev/null +++ b/internal/processing/filters/v1/update.go @@ -0,0 +1,165 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package v1 + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +// Update an existing filter and filter keyword for the given account, using the provided parameters. +// These params should have already been validated by the time they reach this function. +func (p *Processor) Update( + ctx context.Context, + account *gtsmodel.Account, + filterKeywordID string, + form *apimodel.FilterCreateUpdateRequestV1, +) (*apimodel.FilterV1, gtserror.WithCode) { + // Get enough of the filter keyword that we can look up its filter ID. + filterKeyword, err := p.state.DB.GetFilterKeywordByID(gtscontext.SetBarebones(ctx), filterKeywordID) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + return nil, gtserror.NewErrorNotFound(err) + } + return nil, gtserror.NewErrorInternalError(err) + } + if filterKeyword.AccountID != account.ID { + return nil, gtserror.NewErrorNotFound(nil) + } + + // Get the filter for this keyword. + filter, err := p.state.DB.GetFilterByID(ctx, filterKeyword.FilterID) + if err != nil { + if errors.Is(err, db.ErrNoEntries) { + return nil, gtserror.NewErrorNotFound(err) + } + return nil, gtserror.NewErrorInternalError(err) + } + + title := form.Phrase + action := gtsmodel.FilterActionWarn + if *form.Irreversible { + action = gtsmodel.FilterActionHide + } + expiresAt := time.Time{} + if form.ExpiresIn != nil { + expiresAt = time.Now().Add(time.Second * time.Duration(*form.ExpiresIn)) + } + contextHome := false + contextNotifications := false + contextPublic := false + contextThread := false + contextAccount := false + for _, context := range form.Context { + switch context { + case apimodel.FilterContextHome: + contextHome = true + case apimodel.FilterContextNotifications: + contextNotifications = true + case apimodel.FilterContextPublic: + contextPublic = true + case apimodel.FilterContextThread: + contextThread = true + case apimodel.FilterContextAccount: + contextAccount = true + default: + return nil, gtserror.NewErrorUnprocessableEntity( + fmt.Errorf("unsupported filter context '%s'", context), + ) + } + } + + // v1 filter APIs can't change certain fields for a filter with multiple keywords or any statuses, + // since it would be an unexpected side effect on filters that, to the v1 API, appear separate. + // See https://docs.joinmastodon.org/methods/filters/#update-v1 + if len(filter.Keywords) > 1 || len(filter.Statuses) > 0 { + forbiddenFields := make([]string, 0, 4) + if title != filter.Title { + forbiddenFields = append(forbiddenFields, "phrase") + } + if action != filter.Action { + forbiddenFields = append(forbiddenFields, "irreversible") + } + if expiresAt != filter.ExpiresAt { + forbiddenFields = append(forbiddenFields, "expires_in") + } + if contextHome != util.PtrValueOr(filter.ContextHome, false) || + contextNotifications != util.PtrValueOr(filter.ContextNotifications, false) || + contextPublic != util.PtrValueOr(filter.ContextPublic, false) || + contextThread != util.PtrValueOr(filter.ContextThread, false) || + contextAccount != util.PtrValueOr(filter.ContextAccount, false) { + forbiddenFields = append(forbiddenFields, "context") + } + if len(forbiddenFields) > 0 { + return nil, gtserror.NewErrorUnprocessableEntity( + fmt.Errorf("v1 filter backwards compatibility: can't change these fields for a filter with multiple keywords or any statuses: %s", strings.Join(forbiddenFields, ", ")), + ) + } + } + + // Now that we've checked that the changes are legal, apply them to the filter and keyword. + filter.Title = title + filter.Action = action + filter.ExpiresAt = expiresAt + filter.ContextHome = &contextHome + filter.ContextNotifications = &contextNotifications + filter.ContextPublic = &contextPublic + filter.ContextThread = &contextThread + filter.ContextAccount = &contextAccount + filterKeyword.Keyword = form.Phrase + filterKeyword.WholeWord = util.Ptr(util.PtrValueOr(form.WholeWord, false)) + + // We only want to update the relevant filter keyword. + filter.Keywords = []*gtsmodel.FilterKeyword{filterKeyword} + filter.Statuses = nil + filterKeyword.Filter = filter + + filterColumns := []string{ + "title", + "action", + "expires_at", + "context_home", + "context_notifications", + "context_public", + "context_thread", + "context_account", + } + filterKeywordColumns := []string{ + "keyword", + "whole_word", + } + if err := p.state.DB.UpdateFilter(ctx, filter, filterColumns, filterKeywordColumns, nil, nil); err != nil { + if errors.Is(err, db.ErrAlreadyExists) { + err = errors.New("you already have a filter with this title") + return nil, gtserror.NewErrorConflict(err, err.Error()) + } + return nil, gtserror.NewErrorInternalError(err) + } + + return p.apiFilter(ctx, filterKeyword) +} diff --git a/internal/processing/processor.go b/internal/processing/processor.go index bb46d31a9..4aaa94fb7 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -29,6 +29,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/processing/admin" "github.com/superseriousbusiness/gotosocial/internal/processing/common" "github.com/superseriousbusiness/gotosocial/internal/processing/fedi" + filtersv1 "github.com/superseriousbusiness/gotosocial/internal/processing/filters/v1" "github.com/superseriousbusiness/gotosocial/internal/processing/list" "github.com/superseriousbusiness/gotosocial/internal/processing/markers" "github.com/superseriousbusiness/gotosocial/internal/processing/media" @@ -68,20 +69,21 @@ type Processor struct { SUB-PROCESSORS */ - account account.Processor - admin admin.Processor - fedi fedi.Processor - list list.Processor - markers markers.Processor - media media.Processor - polls polls.Processor - report report.Processor - search search.Processor - status status.Processor - stream stream.Processor - timeline timeline.Processor - user user.Processor - workers workers.Processor + account account.Processor + admin admin.Processor + fedi fedi.Processor + filtersv1 filtersv1.Processor + list list.Processor + markers markers.Processor + media media.Processor + polls polls.Processor + report report.Processor + search search.Processor + status status.Processor + stream stream.Processor + timeline timeline.Processor + user user.Processor + workers workers.Processor } func (p *Processor) Account() *account.Processor { @@ -96,6 +98,10 @@ func (p *Processor) Fedi() *fedi.Processor { return &p.fedi } +func (p *Processor) FiltersV1() *filtersv1.Processor { + return &p.filtersv1 +} + func (p *Processor) List() *list.Processor { return &p.list } @@ -177,6 +183,7 @@ func NewProcessor( processor.account = account.New(&common, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc) processor.admin = admin.New(state, cleaner, converter, mediaManager, federator.TransportController(), emailSender) processor.fedi = fedi.New(state, &common, converter, federator, filter) + processor.filtersv1 = filtersv1.New(state, converter) processor.list = list.New(state, converter) processor.markers = markers.New(state, converter) processor.polls = polls.New(&common, state, converter) diff --git a/internal/processing/status/get.go b/internal/processing/status/get.go index 475ab0128..7256d2f82 100644 --- a/internal/processing/status/get.go +++ b/internal/processing/status/get.go @@ -111,7 +111,6 @@ func (p *Processor) contextGet( TopoSort(descendants, targetStatus.AccountID) - //goland:noinspection GoImportUsedAsName context := &apimodel.Context{ Ancestors: make([]apimodel.Status, 0, len(ancestors)), Descendants: make([]apimodel.Status, 0, len(descendants)), diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go index d74f4d86e..df4598deb 100644 --- a/internal/typeutils/internaltofrontend.go +++ b/internal/typeutils/internaltofrontend.go @@ -1617,6 +1617,59 @@ func (c *Converter) convertAttachmentsToAPIAttachments(ctx context.Context, atta return apiAttachments, errs.Combine() } +// FilterToAPIFiltersV1 converts one GTS model filter into an API v1 filter list +func (c *Converter) FilterToAPIFiltersV1(ctx context.Context, filter *gtsmodel.Filter) ([]*apimodel.FilterV1, error) { + apiFilters := make([]*apimodel.FilterV1, 0, len(filter.Keywords)) + for _, filterKeyword := range filter.Keywords { + apiFilter, err := c.FilterKeywordToAPIFilterV1(ctx, filterKeyword) + if err != nil { + return nil, err + } + apiFilters = append(apiFilters, apiFilter) + } + return apiFilters, nil +} + +// FilterKeywordToAPIFilterV1 converts one GTS model filter and filter keyword into an API v1 filter +func (c *Converter) FilterKeywordToAPIFilterV1(ctx context.Context, filterKeyword *gtsmodel.FilterKeyword) (*apimodel.FilterV1, error) { + if filterKeyword.Filter == nil { + return nil, gtserror.New("FilterKeyword model's Filter field isn't populated, but needs to be") + } + filter := filterKeyword.Filter + + apiContexts := make([]apimodel.FilterContext, 0, apimodel.FilterContextNumValues) + if util.PtrValueOr(filter.ContextHome, false) { + apiContexts = append(apiContexts, apimodel.FilterContextHome) + } + if util.PtrValueOr(filter.ContextNotifications, false) { + apiContexts = append(apiContexts, apimodel.FilterContextNotifications) + } + if util.PtrValueOr(filter.ContextPublic, false) { + apiContexts = append(apiContexts, apimodel.FilterContextPublic) + } + if util.PtrValueOr(filter.ContextThread, false) { + apiContexts = append(apiContexts, apimodel.FilterContextThread) + } + if util.PtrValueOr(filter.ContextAccount, false) { + apiContexts = append(apiContexts, apimodel.FilterContextAccount) + } + + var expiresAt *string + if !filter.ExpiresAt.IsZero() { + expiresAt = util.Ptr(util.FormatISO8601(filter.ExpiresAt)) + } + + return &apimodel.FilterV1{ + // v1 filters have a single keyword each, so we use the filter keyword ID as the v1 filter ID. + ID: filterKeyword.ID, + Phrase: filterKeyword.Keyword, + Context: apiContexts, + WholeWord: util.PtrValueOr(filterKeyword.WholeWord, false), + ExpiresAt: expiresAt, + Irreversible: filter.Action == gtsmodel.FilterActionHide, + }, nil +} + // convertEmojisToAPIEmojis will convert a slice of GTS model emojis to frontend API model emojis, falling back to IDs if no GTS models supplied. func (c *Converter) convertEmojisToAPIEmojis(ctx context.Context, emojis []*gtsmodel.Emoji, emojiIDs []string) ([]apimodel.Emoji, error) { var errs gtserror.MultiError diff --git a/internal/validate/formvalidation.go b/internal/validate/formvalidation.go index 3d1d05072..b0332a572 100644 --- a/internal/validate/formvalidation.go +++ b/internal/validate/formvalidation.go @@ -44,6 +44,7 @@ const ( maximumProfileFieldLength = 255 maximumProfileFields = 6 maximumListTitleLength = 200 + maximumFilterKeywordLength = 40 ) // Password returns a helpful error if the given password @@ -306,3 +307,44 @@ func MarkerName(name string) error { } return fmt.Errorf("marker timeline name '%s' was not recognized, valid options are '%s', '%s'", name, apimodel.MarkerNameHome, apimodel.MarkerNameNotifications) } + +// FilterKeyword validates the title of a new or updated List. +func FilterKeyword(keyword string) error { + if keyword == "" { + return fmt.Errorf("filter keyword must be provided, and must be no more than %d chars", maximumFilterKeywordLength) + } + + if length := len([]rune(keyword)); length > maximumFilterKeywordLength { + return fmt.Errorf("filter keyword length must be no more than %d chars, provided keyword was %d chars", maximumFilterKeywordLength, length) + } + + return nil +} + +// FilterContexts validates the context of a new or updated filter. +func FilterContexts(contexts []apimodel.FilterContext) error { + if len(contexts) == 0 { + return fmt.Errorf("at least one filter context is required") + } + for _, context := range contexts { + switch context { + case apimodel.FilterContextHome, + apimodel.FilterContextNotifications, + apimodel.FilterContextPublic, + apimodel.FilterContextThread, + apimodel.FilterContextAccount: + continue + default: + return fmt.Errorf( + "filter context '%s' was not recognized, valid options are '%s', '%s', '%s', '%s', '%s'", + context, + apimodel.FilterContextHome, + apimodel.FilterContextNotifications, + apimodel.FilterContextPublic, + apimodel.FilterContextThread, + apimodel.FilterContextAccount, + ) + } + } + return nil +} diff --git a/test/envparsing.sh b/test/envparsing.sh index 90a5e62c9..617bfc63f 100755 --- a/test/envparsing.sh +++ b/test/envparsing.sh @@ -31,6 +31,9 @@ EXPECT=$(cat << "EOF" "boost-of-ids-mem-ratio": 3, "emoji-category-mem-ratio": 0.1, "emoji-mem-ratio": 3, + "filter-keyword-mem-ratio": 0.5, + "filter-mem-ratio": 0.5, + "filter-status-mem-ratio": 0.5, "follow-ids-mem-ratio": 4, "follow-mem-ratio": 2, "follow-request-ids-mem-ratio": 2, diff --git a/testrig/db.go b/testrig/db.go index 17c8f83b0..a83d93b16 100644 --- a/testrig/db.go +++ b/testrig/db.go @@ -37,6 +37,9 @@ var testModels = []interface{}{ >smodel.Block{}, >smodel.DomainBlock{}, >smodel.EmailDomainBlock{}, + >smodel.Filter{}, + >smodel.FilterKeyword{}, + >smodel.FilterStatus{}, >smodel.Follow{}, >smodel.FollowRequest{}, >smodel.List{}, @@ -329,6 +332,24 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) { } } + for _, v := range NewTestFilters() { + if err := db.Put(ctx, v); err != nil { + log.Panic(nil, err) + } + } + + for _, v := range NewTestFilterKeywords() { + if err := db.Put(ctx, v); err != nil { + log.Panic(nil, err) + } + } + + for _, v := range NewTestFilterStatuses() { + if err := db.Put(ctx, v); err != nil { + log.Panic(nil, err) + } + } + if err := db.CreateInstanceAccount(ctx); err != nil { log.Panic(nil, err) } diff --git a/testrig/testmodels.go b/testrig/testmodels.go index b350cafe5..929317904 100644 --- a/testrig/testmodels.go +++ b/testrig/testmodels.go @@ -3263,6 +3263,87 @@ func NewTestDereferenceRequests(accounts map[string]*gtsmodel.Account) map[strin } } +func NewTestFilters() map[string]*gtsmodel.Filter { + return map[string]*gtsmodel.Filter{ + "local_account_1_filter_1": { + ID: "01HN26VM6KZTW1ANNRVSBMA461", + CreatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"), + UpdatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"), + AccountID: "01F8MH1H7YV1Z7D2C8K2730QBF", + Title: "fnord", + Action: gtsmodel.FilterActionWarn, + ContextHome: util.Ptr(true), + ContextPublic: util.Ptr(true), + }, + "local_account_1_filter_2": { + ID: "01HN277FSPQAWXZXK92QPPYF79", + CreatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"), + UpdatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"), + AccountID: "01F8MH1H7YV1Z7D2C8K2730QBF", + Title: "metasyntactic variables", + Action: gtsmodel.FilterActionWarn, + ContextHome: util.Ptr(true), + ContextPublic: util.Ptr(true), + }, + "local_account_2_filter_1": { + ID: "01HNGFYJBED9FS0VWRVMY4TKXH", + CreatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"), + UpdatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"), + AccountID: "01F8MH1VYJAE00TVVGMM5JNJ8X", + Title: "gamer words", + Action: gtsmodel.FilterActionWarn, + ContextHome: util.Ptr(true), + ContextPublic: util.Ptr(true), + }, + } +} + +func NewTestFilterKeywords() map[string]*gtsmodel.FilterKeyword { + return map[string]*gtsmodel.FilterKeyword{ + "local_account_1_filter_1_keyword_1": { + ID: "01HN272TAVWAXX72ZX4M8JZ0PS", + CreatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"), + UpdatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"), + AccountID: "01F8MH1H7YV1Z7D2C8K2730QBF", + FilterID: "01HN26VM6KZTW1ANNRVSBMA461", + Keyword: "fnord", + WholeWord: util.Ptr(true), + }, + "local_account_1_filter_2_keyword_1": { + ID: "01HN277Y11ENG4EC1ERMAC9FH4", + CreatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"), + UpdatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"), + AccountID: "01F8MH1H7YV1Z7D2C8K2730QBF", + FilterID: "01HN277FSPQAWXZXK92QPPYF79", + Keyword: "foo", + WholeWord: util.Ptr(true), + }, + "local_account_1_filter_2_keyword_2": { + ID: "01HN278494N88BA2FY4DZ5JTNS", + CreatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"), + UpdatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"), + AccountID: "01F8MH1H7YV1Z7D2C8K2730QBF", + FilterID: "01HN277FSPQAWXZXK92QPPYF79", + Keyword: "bar", + WholeWord: util.Ptr(true), + }, + "local_account_2_filter_1_keyword_1": { + ID: "01HNGG51HV2JT67XQ5MQ7RA1WE", + CreatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"), + UpdatedAt: TimeMustParse("2024-01-25T12:20:03+02:00"), + AccountID: "01F8MH1VYJAE00TVVGMM5JNJ8X", + FilterID: "01HNGFYJBED9FS0VWRVMY4TKXH", + Keyword: "Virtual Boy", + WholeWord: util.Ptr(true), + }, + } +} + +func NewTestFilterStatuses() map[string]*gtsmodel.FilterStatus { + // FUTURE: (filters v2) test filter statuses + return map[string]*gtsmodel.FilterStatus{} +} + // GetSignatureForActivity prepares a mock HTTP request as if it were going to deliver activity to destination signed for privkey and pubKeyID, signs the request and returns the header values. func GetSignatureForActivity(activity pub.Activity, pubKeyID string, privkey *rsa.PrivateKey, destination *url.URL) (signatureHeader string, digestHeader string, dateHeader string) { // convert the activity into json bytes