diff --git a/internal/cache/gts.go b/internal/cache/gts.go index 339605354..507947305 100644 --- a/internal/cache/gts.go +++ b/internal/cache/gts.go @@ -952,6 +952,7 @@ func (c *GTSCaches) initStatus() { {Name: "ID"}, {Name: "URI"}, {Name: "URL"}, + {Name: "PollID"}, {Name: "BoostOfID.AccountID"}, {Name: "ThreadID", Multi: true}, }, copyF, cap) diff --git a/internal/db/bundb/migrations/20231215115920_add_status_poll_index.go b/internal/db/bundb/migrations/20231215115920_add_status_poll_index.go new file mode 100644 index 000000000..54b585d60 --- /dev/null +++ b/internal/db/bundb/migrations/20231215115920_add_status_poll_index.go @@ -0,0 +1,66 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package migrations + +import ( + "context" + + "github.com/uptrace/bun" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + type spec struct { + index string + table string + columns []string + } + + for _, spec := range []spec{ + { + index: "statuses_poll_id_idx", + table: "statuses", + columns: []string{"poll_id"}, + }, + } { + if _, err := tx. + NewCreateIndex(). + Table(spec.table). + Index(spec.index). + Column(spec.columns...). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + } + + return nil + }) + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} diff --git a/internal/db/bundb/poll.go b/internal/db/bundb/poll.go index 830fb88ec..3e77fb6c5 100644 --- a/internal/db/bundb/poll.go +++ b/internal/db/bundb/poll.go @@ -50,20 +50,6 @@ func (p *pollDB) GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, er ) } -func (p *pollDB) GetPollByStatusID(ctx context.Context, statusID string) (*gtsmodel.Poll, error) { - return p.getPoll( - ctx, - "StatusID", - func(poll *gtsmodel.Poll) error { - return p.db.NewSelect(). - Model(poll). - Where("? = ?", bun.Ident("poll.status_id"), statusID). - Scan(ctx) - }, - statusID, - ) -} - func (p *pollDB) getPoll(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Poll) error, keyParts ...any) (*gtsmodel.Poll, error) { // Fetch poll from database cache with loader callback poll, err := p.state.Caches.GTS.Poll().Load(lookup, func() (*gtsmodel.Poll, error) { diff --git a/internal/db/bundb/poll_test.go b/internal/db/bundb/poll_test.go index 479557c55..6bdbdb983 100644 --- a/internal/db/bundb/poll_test.go +++ b/internal/db/bundb/poll_test.go @@ -67,10 +67,6 @@ func (suite *PollTestSuite) TestGetPollBy() { "id": func() (*gtsmodel.Poll, error) { return suite.db.GetPollByID(ctx, poll.ID) }, - - "status_id": func() (*gtsmodel.Poll, error) { - return suite.db.GetPollByStatusID(ctx, poll.StatusID) - }, } { // Clear database caches. @@ -287,10 +283,6 @@ func (suite *PollTestSuite) TestDeletePoll() { // Ensure that afterwards we cannot fetch poll. _, err = suite.db.GetPollByID(ctx, poll.ID) suite.ErrorIs(err, db.ErrNoEntries) - - // Or again by the status it's attached to. - _, err = suite.db.GetPollByStatusID(ctx, poll.StatusID) - suite.ErrorIs(err, db.ErrNoEntries) } } diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index dd161e1ec..da252c7f7 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -87,6 +87,17 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St ) } +func (s *statusDB) GetStatusByPollID(ctx context.Context, pollID string) (*gtsmodel.Status, error) { + return s.getStatus( + ctx, + "PollID", + func(status *gtsmodel.Status) error { + return s.db.NewSelect().Model(status).Where("? = ?", bun.Ident("status.poll_id"), pollID).Scan(ctx) + }, + pollID, + ) +} + func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccountID string) (*gtsmodel.Status, error) { return s.getStatus( ctx, diff --git a/internal/db/poll.go b/internal/db/poll.go index b59d27c73..ac0229855 100644 --- a/internal/db/poll.go +++ b/internal/db/poll.go @@ -27,9 +27,6 @@ type Poll interface { // GetPollByID fetches the Poll with given ID from the database. GetPollByID(ctx context.Context, id string) (*gtsmodel.Poll, error) - // GetPollByStatusID fetches the Poll with given status ID column value from the database. - GetPollByStatusID(ctx context.Context, statusID string) (*gtsmodel.Poll, error) - // GetOpenPolls fetches all local Polls in the database with an unset `closed_at` column. GetOpenPolls(ctx context.Context) ([]*gtsmodel.Poll, error) diff --git a/internal/db/status.go b/internal/db/status.go index 1ebf503a8..8034d39e7 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -25,15 +25,18 @@ import ( // Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses. type Status interface { - // GetStatusByID returns one status from the database, with no rel fields populated, only their linking ID / URIs + // GetStatusByID fetches the status from the database with matching id column. GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, error) - // GetStatusByURI returns one status from the database, with no rel fields populated, only their linking ID / URIs + // GetStatusByURI fetches the status from the database with matching uri column. GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, error) - // GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs + // GetStatusByURL fetches the status from the database with matching url column. GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, error) + // GetStatusByPollID fetches the status from the database with matching poll_id column. + GetStatusByPollID(ctx context.Context, pollID string) (*gtsmodel.Status, error) + // GetStatusBoost fetches the status whose boost_of_id column refers to boostOfID, authored by given account ID. GetStatusBoost(ctx context.Context, boostOfID string, byAccountID string) (*gtsmodel.Status, error) diff --git a/internal/federation/dereferencing/status.go b/internal/federation/dereferencing/status.go index 8a8ec60b1..2a2b99d25 100644 --- a/internal/federation/dereferencing/status.go +++ b/internal/federation/dereferencing/status.go @@ -40,14 +40,25 @@ import ( // statusUpToDate returns whether the given status model is both updateable // (i.e. remote status) and whether it needs an update based on `fetched_at`. -func statusUpToDate(status *gtsmodel.Status) bool { +func statusUpToDate(status *gtsmodel.Status, force bool) bool { if *status.Local { // Can't update local statuses. return true } - // If this status was updated recently (last interval), we return as-is. - if next := status.FetchedAt.Add(2 * time.Hour); time.Now().Before(next) { + // Default limit we allow + // statuses to be refreshed. + limit := 2 * time.Hour + + if force { + // We specifically allow the force flag + // to force an early refresh (on a much + // smaller cooldown period). + limit = 5 * time.Minute + } + + // If this status was updated recently (within limit), return as-is. + if next := status.FetchedAt.Add(limit); time.Now().Before(next) { return true } @@ -125,7 +136,7 @@ func (d *Dereferencer) getStatusByURI(ctx context.Context, requestUser string, u } // Check whether needs update. - if statusUpToDate(status) { + if statusUpToDate(status, false) { // This is existing up-to-date status, ensure it is populated. if err := d.state.DB.PopulateStatus(ctx, status); err != nil { log.Errorf(ctx, "error populating existing status: %v", err) @@ -159,8 +170,8 @@ func (d *Dereferencer) RefreshStatus( statusable ap.Statusable, force bool, ) (*gtsmodel.Status, ap.Statusable, error) { - // Check whether needs update. - if !force && statusUpToDate(status) { + // Check whether status needs update. + if statusUpToDate(status, force) { return status, nil, nil } @@ -204,8 +215,8 @@ func (d *Dereferencer) RefreshStatusAsync( statusable ap.Statusable, force bool, ) { - // Check whether needs update. - if !force && statusUpToDate(status) { + // Check whether status needs update. + if statusUpToDate(status, force) { return } diff --git a/internal/processing/common/status.go b/internal/processing/common/status.go index 0a1f495fb..ae03a5306 100644 --- a/internal/processing/common/status.go +++ b/internal/processing/common/status.go @@ -30,10 +30,12 @@ import ( // GetTargetStatusBy fetches the target status with db load function, given the authorized (or, nil) requester's // account. This returns an approprate gtserror.WithCode accounting for not found and visibility to requester. +// The refresh argument allows specifying whether the returned copy should be force refreshed. func (p *Processor) GetTargetStatusBy( ctx context.Context, requester *gtsmodel.Account, getTargetFromDB func() (*gtsmodel.Status, error), + refresh bool, ) ( status *gtsmodel.Status, visible bool, @@ -61,47 +63,52 @@ func (p *Processor) GetTargetStatusBy( } if requester != nil && visible { - // Ensure remote status is up-to-date. - p.federator.RefreshStatusAsync(ctx, - requester.Username, - target, - nil, - false, - ) + // We only bother refreshing if this status + // is visible to requester, AND there *is* + // a requester (i.e. request is authorized) + // to prevent a possible DOS vector. + + if refresh { + // Refresh required, forcibly do synchronously. + _, _, err := p.federator.RefreshStatus(ctx, + requester.Username, + target, + nil, + true, // force + ) + if err != nil { + log.Errorf(ctx, "error refreshing status: %v", err) + } + } else { + // Only refresh async *if* out-of-date. + p.federator.RefreshStatusAsync(ctx, + requester.Username, + target, + nil, + false, // force + ) + } } return target, visible, nil } -// GetTargetStatusByID is a call-through to GetTargetStatus() using the db GetStatusByID() function. -func (p *Processor) GetTargetStatusByID( - ctx context.Context, - requester *gtsmodel.Account, - targetID string, -) ( - status *gtsmodel.Status, - visible bool, - errWithCode gtserror.WithCode, -) { - return p.GetTargetStatusBy(ctx, requester, func() (*gtsmodel.Status, error) { - return p.state.DB.GetStatusByID(ctx, targetID) - }) -} - -// GetVisibleTargetStatus calls GetTargetStatusByID(), +// GetVisibleTargetStatus calls GetTargetStatusBy(), // but converts a non-visible result to not-found error. -func (p *Processor) GetVisibleTargetStatus( +func (p *Processor) GetVisibleTargetStatusBy( ctx context.Context, requester *gtsmodel.Account, - targetID string, + getTargetFromDB func() (*gtsmodel.Status, error), + refresh bool, ) ( status *gtsmodel.Status, errWithCode gtserror.WithCode, ) { // Fetch the target status by ID from the database. - target, visible, errWithCode := p.GetTargetStatusByID(ctx, + target, visible, errWithCode := p.GetTargetStatusBy(ctx, requester, - targetID, + getTargetFromDB, + refresh, ) if errWithCode != nil { return nil, errWithCode @@ -119,6 +126,22 @@ func (p *Processor) GetVisibleTargetStatus( return target, nil } +// GetVisibleTargetStatus calls GetVisibleTargetStatusBy(), +// passing in a database function that fetches by status ID. +func (p *Processor) GetVisibleTargetStatus( + ctx context.Context, + requester *gtsmodel.Account, + targetID string, + refresh bool, +) ( + status *gtsmodel.Status, + errWithCode gtserror.WithCode, +) { + return p.GetVisibleTargetStatusBy(ctx, requester, func() (*gtsmodel.Status, error) { + return p.state.DB.GetStatusByID(ctx, targetID) + }, refresh) +} + // UnwrapIfBoost "unwraps" the given status if // it's a boost wrapper, by returning the boosted // status it targets (pending visibility checks). @@ -132,9 +155,10 @@ func (p *Processor) UnwrapIfBoost( if status.BoostOfID == "" { return status, nil } - return p.GetVisibleTargetStatus(ctx, - requester, status.BoostOfID, + requester, + status.BoostOfID, + false, ) } diff --git a/internal/processing/fedi/status.go b/internal/processing/fedi/status.go index 1c1af9cb4..2674ebf68 100644 --- a/internal/processing/fedi/status.go +++ b/internal/processing/fedi/status.go @@ -100,6 +100,7 @@ func (p *Processor) StatusRepliesGet( status, errWithCode := p.c.GetVisibleTargetStatus(ctx, requester, statusID, + false, // refresh ) if errWithCode != nil { return nil, errWithCode diff --git a/internal/processing/polls/poll.go b/internal/processing/polls/poll.go index 3b258b76c..19cf555e5 100644 --- a/internal/processing/polls/poll.go +++ b/internal/processing/polls/poll.go @@ -19,11 +19,8 @@ package polls import ( "context" - "errors" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/processing/common" @@ -48,35 +45,24 @@ func New(common *common.Processor, state *state.State, converter *typeutils.Conv } // getTargetPoll fetches a target poll ID for requesting account, taking visibility of the poll's originating status into account. -func (p *Processor) getTargetPoll(ctx context.Context, requestingAccount *gtsmodel.Account, targetID string) (*gtsmodel.Poll, gtserror.WithCode) { - // Load the requested poll with ID. - // (barebones as we fetch status below) - poll, err := p.state.DB.GetPollByID( - gtscontext.SetBarebones(ctx), - targetID, +func (p *Processor) getTargetPoll(ctx context.Context, requester *gtsmodel.Account, targetID string) (*gtsmodel.Poll, gtserror.WithCode) { + // Load the status the poll is attached to by the poll ID, + // checking for visibility and ensuring it is up-to-date. + status, errWithCode := p.c.GetVisibleTargetStatusBy(ctx, + requester, + func() (*gtsmodel.Status, error) { + return p.state.DB.GetStatusByPollID(ctx, targetID) + }, + true, // refresh ) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, gtserror.NewErrorInternalError(err) - } - - if poll == nil { - // No poll could be found for given ID. - const text = "target poll not found" - return nil, gtserror.NewErrorNotFound( - errors.New(text), - text, - ) - } - - // Check that we can see + fetch the originating status for requesting account. - status, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, poll.StatusID) if errWithCode != nil { return nil, errWithCode } - // Update poll status. + // Return most up-to-date + // copy of the status poll. + poll := status.Poll poll.Status = status - return poll, nil } diff --git a/internal/processing/status/bookmark.go b/internal/processing/status/bookmark.go index 634529ba4..224445838 100644 --- a/internal/processing/status/bookmark.go +++ b/internal/processing/status/bookmark.go @@ -30,7 +30,11 @@ import ( ) func (p *Processor) getBookmarkableStatus(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*gtsmodel.Status, string, gtserror.WithCode) { - targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, targetStatusID) + targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, + requestingAccount, + targetStatusID, + false, // refresh + ) if errWithCode != nil { return nil, "", errWithCode } diff --git a/internal/processing/status/boost.go b/internal/processing/status/boost.go index 2062fb802..2fc96091e 100644 --- a/internal/processing/status/boost.go +++ b/internal/processing/status/boost.go @@ -43,6 +43,7 @@ func (p *Processor) BoostCreate( ctx, requester, targetID, + false, // refresh ) if errWithCode != nil { return nil, errWithCode @@ -112,6 +113,7 @@ func (p *Processor) BoostRemove( ctx, requester, targetID, + false, // refresh ) if errWithCode != nil { return nil, errWithCode diff --git a/internal/processing/status/fave.go b/internal/processing/status/fave.go index dbeba7fe9..7ac270e8c 100644 --- a/internal/processing/status/fave.go +++ b/internal/processing/status/fave.go @@ -47,6 +47,7 @@ func (p *Processor) getFaveableStatus( ctx, requester, targetID, + false, // refresh ) if errWithCode != nil { return nil, nil, errWithCode @@ -149,7 +150,11 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel. // FavedBy returns a slice of accounts that have liked the given status, filtered according to privacy settings. func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) { - targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, targetStatusID) + targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, + requestingAccount, + targetStatusID, + false, // refresh + ) if errWithCode != nil { return nil, errWithCode } diff --git a/internal/processing/status/get.go b/internal/processing/status/get.go index c182bd148..f8c037404 100644 --- a/internal/processing/status/get.go +++ b/internal/processing/status/get.go @@ -28,7 +28,11 @@ import ( // Get gets the given status, taking account of privacy settings and blocks etc. func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, targetStatusID) + targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, + requestingAccount, + targetStatusID, + false, // refresh + ) if errWithCode != nil { return nil, errWithCode } @@ -38,7 +42,11 @@ func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account // WebGet gets the given status for web use, taking account of privacy settings. func (p *Processor) WebGet(ctx context.Context, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, nil, targetStatusID) + targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, + nil, // requester + targetStatusID, + false, // refresh + ) if errWithCode != nil { return nil, errWithCode } @@ -57,7 +65,11 @@ func (p *Processor) contextGet( targetStatusID string, convert func(context.Context, *gtsmodel.Status, *gtsmodel.Account) (*apimodel.Status, error), ) (*apimodel.Context, gtserror.WithCode) { - targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, targetStatusID) + targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, + requestingAccount, + targetStatusID, + false, // refresh + ) if errWithCode != nil { return nil, errWithCode } diff --git a/internal/processing/status/mute.go b/internal/processing/status/mute.go index 1663ee0bc..fb4f3b384 100644 --- a/internal/processing/status/mute.go +++ b/internal/processing/status/mute.go @@ -41,7 +41,11 @@ func (p *Processor) getMuteableStatus( requestingAccount *gtsmodel.Account, targetStatusID string, ) (*gtsmodel.Status, gtserror.WithCode) { - targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, targetStatusID) + targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, + requestingAccount, + targetStatusID, + false, // refresh + ) if errWithCode != nil { return nil, errWithCode } diff --git a/internal/processing/status/pin.go b/internal/processing/status/pin.go index b31288a64..f08b9652c 100644 --- a/internal/processing/status/pin.go +++ b/internal/processing/status/pin.go @@ -39,7 +39,11 @@ const allowedPinnedCount = 10 // - Status is public, unlisted, or followers-only. // - Status is not a boost. func (p *Processor) getPinnableStatus(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*gtsmodel.Status, gtserror.WithCode) { - targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, targetStatusID) + targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, + requestingAccount, + targetStatusID, + false, // refresh + ) if errWithCode != nil { return nil, errWithCode }