diff --git a/internal/ap/collections.go b/internal/ap/collections.go index c55bbe70b..62c81fd57 100644 --- a/internal/ap/collections.go +++ b/internal/ap/collections.go @@ -105,6 +105,15 @@ func (iter *regularCollectionIterator) PrevItem() TypeOrIRI { return cur } +func (iter *regularCollectionIterator) TotalItems() int { + totalItems := iter.GetActivityStreamsTotalItems() + if totalItems == nil || !totalItems.IsXMLSchemaNonNegativeInteger() { + return -1 + } + + return totalItems.Get() +} + func (iter *regularCollectionIterator) initItems() bool { if iter.once { return (iter.items != nil) @@ -147,6 +156,15 @@ func (iter *orderedCollectionIterator) PrevItem() TypeOrIRI { return cur } +func (iter *orderedCollectionIterator) TotalItems() int { + totalItems := iter.GetActivityStreamsTotalItems() + if totalItems == nil || !totalItems.IsXMLSchemaNonNegativeInteger() { + return -1 + } + + return totalItems.Get() +} + func (iter *orderedCollectionIterator) initItems() bool { if iter.once { return (iter.items != nil) @@ -203,6 +221,15 @@ func (iter *regularCollectionPageIterator) PrevItem() TypeOrIRI { return cur } +func (iter *regularCollectionPageIterator) TotalItems() int { + totalItems := iter.GetActivityStreamsTotalItems() + if totalItems == nil || !totalItems.IsXMLSchemaNonNegativeInteger() { + return -1 + } + + return totalItems.Get() +} + func (iter *regularCollectionPageIterator) initItems() bool { if iter.once { return (iter.items != nil) @@ -259,6 +286,15 @@ func (iter *orderedCollectionPageIterator) PrevItem() TypeOrIRI { return cur } +func (iter *orderedCollectionPageIterator) TotalItems() int { + totalItems := iter.GetActivityStreamsTotalItems() + if totalItems == nil || !totalItems.IsXMLSchemaNonNegativeInteger() { + return -1 + } + + return totalItems.Get() +} + func (iter *orderedCollectionPageIterator) initItems() bool { if iter.once { return (iter.items != nil) diff --git a/internal/ap/interfaces.go b/internal/ap/interfaces.go index 05f6742cc..8f2e17c09 100644 --- a/internal/ap/interfaces.go +++ b/internal/ap/interfaces.go @@ -307,6 +307,12 @@ type CollectionIterator interface { NextItem() TypeOrIRI PrevItem() TypeOrIRI + + // TotalItems returns the total items + // present in the collection, derived + // from the totalItems property, or -1 + // if totalItems not present / readable. + TotalItems() int } // CollectionPageIterator represents the minimum interface for interacting with a wrapped @@ -319,6 +325,12 @@ type CollectionPageIterator interface { NextItem() TypeOrIRI PrevItem() TypeOrIRI + + // TotalItems returns the total items + // present in the collection, derived + // from the totalItems property, or -1 + // if totalItems not present / readable. + TotalItems() int } // Flaggable represents the minimum interface for an activitystreams 'Flag' activity. diff --git a/internal/api/client/statuses/statuspin_test.go b/internal/api/client/statuses/statuspin_test.go index 034860031..39909e6c4 100644 --- a/internal/api/client/statuses/statuspin_test.go +++ b/internal/api/client/statuses/statuspin_test.go @@ -48,11 +48,12 @@ func (suite *StatusPinTestSuite) createPin( expectedHTTPStatus int, expectedBody string, targetStatusID string, + requestingAcct *gtsmodel.Account, ) (*apimodel.Status, error) { // instantiate recorder + test context recorder := httptest.NewRecorder() ctx, _ := testrig.CreateGinTestContext(recorder, nil) - ctx.Set(oauth.SessionAuthorizedAccount, suite.testAccounts["local_account_1"]) + ctx.Set(oauth.SessionAuthorizedAccount, requestingAcct) ctx.Set(oauth.SessionAuthorizedToken, oauth.DBTokenToToken(suite.testTokens["local_account_1"])) ctx.Set(oauth.SessionAuthorizedApplication, suite.testApplications["application_1"]) ctx.Set(oauth.SessionAuthorizedUser, suite.testUsers["local_account_1"]) @@ -101,8 +102,10 @@ func (suite *StatusPinTestSuite) createPin( func (suite *StatusPinTestSuite) TestPinStatusPublicOK() { // Pin an unpinned public status that this account owns. targetStatus := suite.testStatuses["local_account_1_status_1"] + testAccount := new(gtsmodel.Account) + *testAccount = *suite.testAccounts["local_account_1"] - resp, err := suite.createPin(http.StatusOK, "", targetStatus.ID) + resp, err := suite.createPin(http.StatusOK, "", targetStatus.ID, testAccount) if err != nil { suite.FailNow(err.Error()) } @@ -113,8 +116,10 @@ func (suite *StatusPinTestSuite) TestPinStatusPublicOK() { func (suite *StatusPinTestSuite) TestPinStatusFollowersOnlyOK() { // Pin an unpinned followers only status that this account owns. targetStatus := suite.testStatuses["local_account_1_status_5"] + testAccount := new(gtsmodel.Account) + *testAccount = *suite.testAccounts["local_account_1"] - resp, err := suite.createPin(http.StatusOK, "", targetStatus.ID) + resp, err := suite.createPin(http.StatusOK, "", targetStatus.ID, testAccount) if err != nil { suite.FailNow(err.Error()) } @@ -127,6 +132,8 @@ func (suite *StatusPinTestSuite) TestPinStatusTwiceError() { targetStatus := >smodel.Status{} *targetStatus = *suite.testStatuses["local_account_1_status_5"] targetStatus.PinnedAt = time.Now() + testAccount := new(gtsmodel.Account) + *testAccount = *suite.testAccounts["local_account_1"] if err := suite.db.UpdateStatus(context.Background(), targetStatus, "pinned_at"); err != nil { suite.FailNow(err.Error()) @@ -136,6 +143,7 @@ func (suite *StatusPinTestSuite) TestPinStatusTwiceError() { http.StatusUnprocessableEntity, `{"error":"Unprocessable Entity: status already pinned"}`, targetStatus.ID, + testAccount, ); err != nil { suite.FailNow(err.Error()) } @@ -144,11 +152,14 @@ func (suite *StatusPinTestSuite) TestPinStatusTwiceError() { func (suite *StatusPinTestSuite) TestPinStatusOtherAccountError() { // Try to pin a status that doesn't belong to us. targetStatus := suite.testStatuses["admin_account_status_1"] + testAccount := new(gtsmodel.Account) + *testAccount = *suite.testAccounts["local_account_1"] if _, err := suite.createPin( http.StatusUnprocessableEntity, `{"error":"Unprocessable Entity: status 01F8MH75CBF9JFX4ZAD54N0W0R does not belong to account 01F8MH1H7YV1Z7D2C8K2730QBF"}`, targetStatus.ID, + testAccount, ); err != nil { suite.FailNow(err.Error()) } @@ -156,7 +167,8 @@ func (suite *StatusPinTestSuite) TestPinStatusOtherAccountError() { func (suite *StatusPinTestSuite) TestPinStatusTooManyPins() { // Test pinning too many statuses. - testAccount := suite.testAccounts["local_account_1"] + testAccount := new(gtsmodel.Account) + *testAccount = *suite.testAccounts["local_account_1"] // Spam 10 pinned statuses into the database. ctx := context.Background() @@ -181,12 +193,18 @@ func (suite *StatusPinTestSuite) TestPinStatusTooManyPins() { } } + // Regenerate account stats to set pinned count. + if err := suite.db.RegenerateAccountStats(ctx, testAccount); err != nil { + suite.FailNow(err.Error()) + } + // Try to pin one more status as a treat. targetStatus := suite.testStatuses["local_account_1_status_1"] if _, err := suite.createPin( http.StatusUnprocessableEntity, `{"error":"Unprocessable Entity: status pin limit exceeded, you've already pinned 10 status(es) out of 10"}`, targetStatus.ID, + testAccount, ); err != nil { suite.FailNow(err.Error()) } diff --git a/internal/cache/cache.go b/internal/cache/cache.go index d35162172..2e5e2c2dd 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -52,9 +52,9 @@ func (c *Caches) Init() { log.Infof(nil, "init: %p", c) c.initAccount() - c.initAccountCounts() c.initAccountNote() c.initAccountSettings() + c.initAccountStats() c.initApplication() c.initBlock() c.initBlockIDs() @@ -124,6 +124,7 @@ func (c *Caches) Sweep(threshold float64) { c.GTS.Account.Trim(threshold) c.GTS.AccountNote.Trim(threshold) c.GTS.AccountSettings.Trim(threshold) + c.GTS.AccountStats.Trim(threshold) c.GTS.Block.Trim(threshold) c.GTS.BlockIDs.Trim(threshold) c.GTS.Emoji.Trim(threshold) diff --git a/internal/cache/db.go b/internal/cache/db.go index cb0ed6712..d993d6143 100644 --- a/internal/cache/db.go +++ b/internal/cache/db.go @@ -20,7 +20,6 @@ package cache import ( "time" - "codeberg.org/gruf/go-cache/v3/simple" "codeberg.org/gruf/go-cache/v3/ttl" "codeberg.org/gruf/go-structr" "github.com/superseriousbusiness/gotosocial/internal/cache/domain" @@ -36,16 +35,12 @@ type GTSCaches struct { // AccountNote provides access to the gtsmodel Note database cache. AccountNote StructCache[*gtsmodel.AccountNote] - // TEMPORARY CACHE TO ALLEVIATE SLOW COUNT QUERIES, - // (in time will be removed when these IDs are cached). - AccountCounts *simple.Cache[string, struct { - Statuses int - Pinned int - }] - // AccountSettings provides access to the gtsmodel AccountSettings database cache. AccountSettings StructCache[*gtsmodel.AccountSettings] + // AccountStats provides access to the gtsmodel AccountStats database cache. + AccountStats StructCache[*gtsmodel.AccountStats] + // Application provides access to the gtsmodel Application database cache. Application StructCache[*gtsmodel.Application] @@ -200,6 +195,7 @@ func (c *Caches) initAccount() { a2.AlsoKnownAs = nil a2.Move = nil a2.Settings = nil + a2.Stats = nil return a2 } @@ -223,22 +219,6 @@ func (c *Caches) initAccount() { }) } -func (c *Caches) initAccountCounts() { - // Simply use size of accounts cache, - // as this cache will be very small. - cap := c.GTS.Account.Cap() - if cap == 0 { - panic("must be initialized before accounts") - } - - log.Infof(nil, "cache size = %d", cap) - - c.GTS.AccountCounts = simple.New[string, struct { - Statuses int - Pinned int - }](0, cap) -} - func (c *Caches) initAccountNote() { // Calculate maximum cache size. cap := calculateResultCacheMax( @@ -295,6 +275,29 @@ func (c *Caches) initAccountSettings() { }) } +func (c *Caches) initAccountStats() { + // Calculate maximum cache size. + cap := calculateResultCacheMax( + sizeofAccountStats(), // model in-mem size. + config.GetCacheAccountStatsMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + c.GTS.AccountStats.Init(structr.CacheConfig[*gtsmodel.AccountStats]{ + Indices: []structr.IndexConfig{ + {Fields: "AccountID"}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + Copy: func(s1 *gtsmodel.AccountStats) *gtsmodel.AccountStats { + s2 := new(gtsmodel.AccountStats) + *s2 = *s1 + return s2 + }, + }) +} + func (c *Caches) initApplication() { // Calculate maximum cache size. cap := calculateResultCacheMax( diff --git a/internal/cache/invalidate.go b/internal/cache/invalidate.go index 547015eac..01d332d40 100644 --- a/internal/cache/invalidate.go +++ b/internal/cache/invalidate.go @@ -27,8 +27,8 @@ import ( // HOOKS TO BE CALLED ON DELETE YOU MUST FIRST POPULATE IT IN THE CACHE. func (c *Caches) OnInvalidateAccount(account *gtsmodel.Account) { - // Invalidate status counts for this account. - c.GTS.AccountCounts.Invalidate(account.ID) + // Invalidate stats for this account. + c.GTS.AccountStats.Invalidate("AccountID", account.ID) // Invalidate account ID cached visibility. c.Visibility.Invalidate("ItemID", account.ID) @@ -168,8 +168,8 @@ func (c *Caches) OnInvalidatePollVote(vote *gtsmodel.PollVote) { } func (c *Caches) OnInvalidateStatus(status *gtsmodel.Status) { - // Invalidate status counts for this account. - c.GTS.AccountCounts.Invalidate(status.AccountID) + // Invalidate stats for this account. + c.GTS.AccountStats.Invalidate("AccountID", status.AccountID) // Invalidate status ID cached visibility. c.Visibility.Invalidate("ItemID", status.ID) diff --git a/internal/cache/size.go b/internal/cache/size.go index 9c1a82abc..5bd99c3d8 100644 --- a/internal/cache/size.go +++ b/internal/cache/size.go @@ -264,6 +264,17 @@ func sizeofAccountSettings() uintptr { })) } +func sizeofAccountStats() uintptr { + return uintptr(size.Of(>smodel.AccountStats{ + AccountID: exampleID, + FollowersCount: util.Ptr(100), + FollowingCount: util.Ptr(100), + StatusesCount: util.Ptr(100), + StatusesPinnedCount: util.Ptr(100), + LastStatusAt: exampleTime, + })) +} + func sizeofApplication() uintptr { return uintptr(size.Of(>smodel.Application{ ID: exampleID, diff --git a/internal/config/config.go b/internal/config/config.go index 3cd67525f..a738dded4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -195,6 +195,7 @@ type CacheConfiguration struct { AccountMemRatio float64 `name:"account-mem-ratio"` AccountNoteMemRatio float64 `name:"account-note-mem-ratio"` AccountSettingsMemRatio float64 `name:"account-settings-mem-ratio"` + AccountStatsMemRatio float64 `name:"account-stats-mem-ratio"` ApplicationMemRatio float64 `name:"application-mem-ratio"` BlockMemRatio float64 `name:"block-mem-ratio"` BlockIDsMemRatio float64 `name:"block-mem-ratio"` diff --git a/internal/config/defaults.go b/internal/config/defaults.go index f5f8fb6ac..e84e619b8 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -159,6 +159,7 @@ var Defaults = Configuration{ AccountMemRatio: 5, AccountNoteMemRatio: 1, AccountSettingsMemRatio: 0.1, + AccountStatsMemRatio: 2, ApplicationMemRatio: 0.1, BlockMemRatio: 2, BlockIDsMemRatio: 3, diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go index a8c919834..c986dd19d 100644 --- a/internal/config/helpers.gen.go +++ b/internal/config/helpers.gen.go @@ -2825,6 +2825,31 @@ func GetCacheAccountSettingsMemRatio() float64 { return global.GetCacheAccountSe // SetCacheAccountSettingsMemRatio safely sets the value for global configuration 'Cache.AccountSettingsMemRatio' field func SetCacheAccountSettingsMemRatio(v float64) { global.SetCacheAccountSettingsMemRatio(v) } +// GetCacheAccountStatsMemRatio safely fetches the Configuration value for state's 'Cache.AccountStatsMemRatio' field +func (st *ConfigState) GetCacheAccountStatsMemRatio() (v float64) { + st.mutex.RLock() + v = st.config.Cache.AccountStatsMemRatio + st.mutex.RUnlock() + return +} + +// SetCacheAccountStatsMemRatio safely sets the Configuration value for state's 'Cache.AccountStatsMemRatio' field +func (st *ConfigState) SetCacheAccountStatsMemRatio(v float64) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.AccountStatsMemRatio = v + st.reloadToViper() +} + +// CacheAccountStatsMemRatioFlag returns the flag name for the 'Cache.AccountStatsMemRatio' field +func CacheAccountStatsMemRatioFlag() string { return "cache-account-stats-mem-ratio" } + +// GetCacheAccountStatsMemRatio safely fetches the value for global configuration 'Cache.AccountStatsMemRatio' field +func GetCacheAccountStatsMemRatio() float64 { return global.GetCacheAccountStatsMemRatio() } + +// SetCacheAccountStatsMemRatio safely sets the value for global configuration 'Cache.AccountStatsMemRatio' field +func SetCacheAccountStatsMemRatio(v float64) { global.SetCacheAccountStatsMemRatio(v) } + // GetCacheApplicationMemRatio safely fetches the Configuration value for state's 'Cache.ApplicationMemRatio' field func (st *ConfigState) GetCacheApplicationMemRatio() (v float64) { st.mutex.RLock() diff --git a/internal/db/account.go b/internal/db/account.go index 7cdf7b57f..dec36d2ac 100644 --- a/internal/db/account.go +++ b/internal/db/account.go @@ -20,7 +20,6 @@ package db import ( "context" "net/netip" - "time" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/paging" @@ -100,12 +99,6 @@ type Account interface { // GetAccountsUsingEmoji fetches all account models using emoji with given ID stored in their 'emojis' column. GetAccountsUsingEmoji(ctx context.Context, emojiID string) ([]*gtsmodel.Account, error) - // GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID. - CountAccountStatuses(ctx context.Context, accountID string) (int, error) - - // CountAccountPinned returns the total number of pinned statuses owned by account with the given id. - CountAccountPinned(ctx context.Context, accountID string) (int, error) - // GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided // then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can // be very memory intensive so you probably shouldn't do this! @@ -128,13 +121,6 @@ type Account interface { // In the case of no statuses, this function will return db.ErrNoEntries. GetAccountWebStatuses(ctx context.Context, accountID string, limit int, maxID string) ([]*gtsmodel.Status, error) - // GetAccountLastPosted simply gets the timestamp of the most recent post by the account. - // - // If webOnly is true, then the time of the last non-reply, non-boost, public status of the account will be returned. - // - // The returned time will be zero if account has never posted anything. - GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error) - // SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment. SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error @@ -150,4 +136,24 @@ type Account interface { // Update local account settings. UpdateAccountSettings(ctx context.Context, settings *gtsmodel.AccountSettings, columns ...string) error + + // PopulateAccountStats gets (or creates and gets) account stats for + // the given account, and attaches them to the account model. + PopulateAccountStats(ctx context.Context, account *gtsmodel.Account) error + + // RegenerateAccountStats creates, upserts, and returns stats + // for the given account, and attaches them to the account model. + // + // Unlike GetAccountStats, it will always get the database stats fresh. + // This can be used to "refresh" stats. + // + // Because this involves database calls that can be expensive (on Postgres + // specifically), callers should prefer GetAccountStats in 99% of cases. + RegenerateAccountStats(ctx context.Context, account *gtsmodel.Account) error + + // Update account stats. + UpdateAccountStats(ctx context.Context, stats *gtsmodel.AccountStats, columns ...string) error + + // DeleteAccountStats deletes the accountStats entry for the given accountID. + DeleteAccountStats(ctx context.Context, accountID string) error } diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 45e67c10b..2b3c78aff 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -630,6 +630,13 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou } } + if account.Stats == nil { + // Get / Create stats for this account. + if err := a.state.DB.PopulateAccountStats(ctx, account); err != nil { + errs.Appendf("error populating account stats: %w", err) + } + } + return errs.Combine() } @@ -735,31 +742,6 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) error { }) } -func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, webOnly bool) (time.Time, error) { - createdAt := time.Time{} - - q := a.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). - Column("status.created_at"). - Where("? = ?", bun.Ident("status.account_id"), accountID). - Order("status.id DESC"). - Limit(1) - - if webOnly { - q = q. - Where("? IS NULL", bun.Ident("status.in_reply_to_uri")). - Where("? IS NULL", bun.Ident("status.boost_of_id")). - Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). - Where("? = ?", bun.Ident("status.federated"), true) - } - - if err := q.Scan(ctx, &createdAt); err != nil { - return time.Time{}, err - } - return createdAt, nil -} - func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachment *gtsmodel.MediaAttachment, accountID string) error { if *mediaAttachment.Avatar && *mediaAttachment.Header { return errors.New("one media attachment cannot be both header and avatar") @@ -845,59 +827,6 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g return *faves, nil } -func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, error) { - counts, err := a.getAccountStatusCounts(ctx, accountID) - return counts.Statuses, err -} - -func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, error) { - counts, err := a.getAccountStatusCounts(ctx, accountID) - return counts.Pinned, err -} - -func (a *accountDB) getAccountStatusCounts(ctx context.Context, accountID string) (struct { - Statuses int - Pinned int -}, error) { - // Check for an already cached copy of account status counts. - counts, ok := a.state.Caches.GTS.AccountCounts.Get(accountID) - if ok { - return counts, nil - } - - if err := a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - var err error - - // Scan database for account statuses. - counts.Statuses, err = tx.NewSelect(). - Table("statuses"). - Where("? = ?", bun.Ident("account_id"), accountID). - Count(ctx) - if err != nil { - return err - } - - // Scan database for pinned statuses. - counts.Pinned, err = tx.NewSelect(). - Table("statuses"). - Where("? = ?", bun.Ident("account_id"), accountID). - Where("? IS NOT NULL", bun.Ident("pinned_at")). - Count(ctx) - if err != nil { - return err - } - - return nil - }); err != nil { - return counts, err - } - - // Store this account counts result in the cache. - a.state.Caches.GTS.AccountCounts.Set(accountID, counts) - - return counts, nil -} - func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error) { // Ensure reasonable if limit < 0 { @@ -1147,3 +1076,185 @@ func (a *accountDB) UpdateAccountSettings( return nil }) } + +func (a *accountDB) PopulateAccountStats(ctx context.Context, account *gtsmodel.Account) error { + // Fetch stats from db cache with loader callback. + stats, err := a.state.Caches.GTS.AccountStats.LoadOne( + "AccountID", + func() (*gtsmodel.AccountStats, error) { + // Not cached! Perform database query. + var stats gtsmodel.AccountStats + if err := a.db. + NewSelect(). + Model(&stats). + Where("? = ?", bun.Ident("account_stats.account_id"), account.ID). + Scan(ctx); err != nil { + return nil, err + } + return &stats, nil + }, + account.ID, + ) + + if err != nil && !errors.Is(err, db.ErrNoEntries) { + // Real error. + return err + } + + if stats == nil { + // Don't have stats yet, generate them. + return a.RegenerateAccountStats(ctx, account) + } + + // We have a stats, attach + // it to the account. + account.Stats = stats + + // Check if this is a local + // stats by looking at the + // account they pertain to. + if account.IsRemote() { + // Account is remote. Updating + // stats for remote accounts is + // handled in the dereferencer. + // + // Nothing more to do! + return nil + } + + // Stats account is local, check + // if we need to regenerate. + const statsFreshness = 48 * time.Hour + expiry := stats.RegeneratedAt.Add(statsFreshness) + if time.Now().After(expiry) { + // Stats have expired, regenerate them. + return a.RegenerateAccountStats(ctx, account) + } + + // Stats are still fresh. + return nil +} + +func (a *accountDB) RegenerateAccountStats(ctx context.Context, account *gtsmodel.Account) error { + // Initialize a new stats struct. + stats := >smodel.AccountStats{ + AccountID: account.ID, + RegeneratedAt: time.Now(), + } + + // Count followers outside of transaction since + // it uses a cache + requires its own db calls. + followerIDs, err := a.state.DB.GetAccountFollowerIDs(ctx, account.ID, nil) + if err != nil { + return err + } + stats.FollowersCount = util.Ptr(len(followerIDs)) + + // Count following outside of transaction since + // it uses a cache + requires its own db calls. + followIDs, err := a.state.DB.GetAccountFollowIDs(ctx, account.ID, nil) + if err != nil { + return err + } + stats.FollowingCount = util.Ptr(len(followIDs)) + + // Count follow requests outside of transaction since + // it uses a cache + requires its own db calls. + followRequestIDs, err := a.state.DB.GetAccountFollowRequestIDs(ctx, account.ID, nil) + if err != nil { + return err + } + stats.FollowRequestsCount = util.Ptr(len(followRequestIDs)) + + // Populate remaining stats struct fields. + // This can be done inside a transaction. + if err := a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + var err error + + // Scan database for account statuses. + statusesCount, err := tx.NewSelect(). + Table("statuses"). + Where("? = ?", bun.Ident("account_id"), account.ID). + Count(ctx) + if err != nil { + return err + } + stats.StatusesCount = &statusesCount + + // Scan database for pinned statuses. + statusesPinnedCount, err := tx.NewSelect(). + Table("statuses"). + Where("? = ?", bun.Ident("account_id"), account.ID). + Where("? IS NOT NULL", bun.Ident("pinned_at")). + Count(ctx) + if err != nil { + return err + } + stats.StatusesPinnedCount = &statusesPinnedCount + + // Scan database for last status. + lastStatusAt := time.Time{} + err = tx. + NewSelect(). + TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). + Column("status.created_at"). + Where("? = ?", bun.Ident("status.account_id"), account.ID). + Order("status.id DESC"). + Limit(1). + Scan(ctx, &lastStatusAt) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return err + } + stats.LastStatusAt = lastStatusAt + + return nil + }); err != nil { + return err + } + + // Upsert this stats in case a race + // meant someone else inserted it first. + if err := a.state.Caches.GTS.AccountStats.Store(stats, func() error { + if _, err := NewUpsert(a.db). + Model(stats). + Constraint("account_id"). + Exec(ctx); err != nil { + return err + } + return nil + }); err != nil { + return err + } + + account.Stats = stats + return nil +} + +func (a *accountDB) UpdateAccountStats(ctx context.Context, stats *gtsmodel.AccountStats, columns ...string) error { + return a.state.Caches.GTS.AccountStats.Store(stats, func() error { + if _, err := a.db. + NewUpdate(). + Model(stats). + Column(columns...). + Where("? = ?", bun.Ident("account_stats.account_id"), stats.AccountID). + Exec(ctx); err != nil { + return err + } + + return nil + }) +} + +func (a *accountDB) DeleteAccountStats(ctx context.Context, accountID string) error { + defer a.state.Caches.GTS.AccountStats.Invalidate("AccountID", accountID) + + if _, err := a.db. + NewDelete(). + Table("account_stats"). + Where("? = ?", bun.Ident("account_id"), accountID). + Exec(ctx); err != nil { + return err + } + + return nil +} diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index dd96543b6..ea211e16f 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -220,6 +220,8 @@ func (suite *AccountTestSuite) TestGetAccountBy() { a2.Emojis = nil a1.Settings = nil a2.Settings = nil + a1.Stats = nil + a2.Stats = nil // Clear database-set fields. a1.CreatedAt = time.Time{} @@ -413,18 +415,6 @@ func (suite *AccountTestSuite) TestUpdateAccount() { suite.WithinDuration(time.Now(), noCache.UpdatedAt, 5*time.Second) } -func (suite *AccountTestSuite) TestGetAccountLastPosted() { - lastPosted, err := suite.db.GetAccountLastPosted(context.Background(), suite.testAccounts["local_account_1"].ID, false) - suite.NoError(err) - suite.EqualValues(1702200240, lastPosted.Unix()) -} - -func (suite *AccountTestSuite) TestGetAccountLastPostedWebOnly() { - lastPosted, err := suite.db.GetAccountLastPosted(context.Background(), suite.testAccounts["local_account_1"].ID, true) - suite.NoError(err) - suite.EqualValues(1702200240, lastPosted.Unix()) -} - func (suite *AccountTestSuite) TestInsertAccountWithDefaults() { key, err := rsa.GenerateKey(rand.Reader, 2048) suite.NoError(err) @@ -466,22 +456,6 @@ func (suite *AccountTestSuite) TestGetAccountPinnedStatusesNothingPinned() { suite.Empty(statuses) // This account has nothing pinned. } -func (suite *AccountTestSuite) TestCountAccountPinnedSomeResults() { - testAccount := suite.testAccounts["admin_account"] - - pinned, err := suite.db.CountAccountPinned(context.Background(), testAccount.ID) - suite.NoError(err) - suite.Equal(pinned, 2) // This account has 2 statuses pinned. -} - -func (suite *AccountTestSuite) TestCountAccountPinnedNothingPinned() { - testAccount := suite.testAccounts["local_account_1"] - - pinned, err := suite.db.CountAccountPinned(context.Background(), testAccount.ID) - suite.NoError(err) - suite.Equal(pinned, 0) // This account has nothing pinned. -} - func (suite *AccountTestSuite) TestPopulateAccountWithUnknownMovedToURI() { testAccount := >smodel.Account{} *testAccount = *suite.testAccounts["local_account_1"] @@ -676,6 +650,55 @@ func (suite *AccountTestSuite) TestGetPendingAccounts() { suite.Len(accounts, 1) } +func (suite *AccountTestSuite) TestAccountStatsAll() { + ctx := context.Background() + for _, account := range suite.testAccounts { + // Get stats for the first time. They + // should all be generated now since + // they're not stored in the test rig. + if err := suite.db.PopulateAccountStats(ctx, account); err != nil { + suite.FailNow(err.Error()) + } + stats := account.Stats + suite.NotNil(stats) + suite.WithinDuration(time.Now(), stats.RegeneratedAt, 5*time.Second) + + // Get stats a second time. They shouldn't + // be regenerated since we just did it. + if err := suite.db.PopulateAccountStats(ctx, account); err != nil { + suite.FailNow(err.Error()) + } + stats2 := account.Stats + suite.NotNil(stats2) + suite.Equal(stats2.RegeneratedAt, stats.RegeneratedAt) + + // Update the stats to indicate they're out of date. + stats2.RegeneratedAt = time.Now().Add(-72 * time.Hour) + if err := suite.db.UpdateAccountStats(ctx, stats2, "regenerated_at"); err != nil { + suite.FailNow(err.Error()) + } + + // Get stats for a third time, they + // should get regenerated now, but + // only for local accounts. + if err := suite.db.PopulateAccountStats(ctx, account); err != nil { + suite.FailNow(err.Error()) + } + stats3 := account.Stats + suite.NotNil(stats3) + if account.IsLocal() { + suite.True(stats3.RegeneratedAt.After(stats.RegeneratedAt)) + } else { + suite.False(stats3.RegeneratedAt.After(stats.RegeneratedAt)) + } + + // Now delete the stats. + if err := suite.db.DeleteAccountStats(ctx, account.ID); err != nil { + suite.FailNow(err.Error()) + } + } +} + func TestAccountTestSuite(t *testing.T) { suite.Run(t, new(AccountTestSuite)) } diff --git a/internal/db/bundb/migrations/20240414122348_account_stats_model.go b/internal/db/bundb/migrations/20240414122348_account_stats_model.go new file mode 100644 index 000000000..450ca04d4 --- /dev/null +++ b/internal/db/bundb/migrations/20240414122348_account_stats_model.go @@ -0,0 +1,52 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package migrations + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + // Create new AccountStats table. + if _, err := tx. + NewCreateTable(). + Model(>smodel.AccountStats{}). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + + return nil + }) + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index 1c533af39..052f29cb3 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -112,7 +112,7 @@ func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount } func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) { - followIDs, err := r.getAccountFollowIDs(ctx, accountID, page) + followIDs, err := r.GetAccountFollowIDs(ctx, accountID, page) if err != nil { return nil, err } @@ -120,7 +120,7 @@ func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string } func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { - followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID) + followIDs, err := r.GetAccountLocalFollowIDs(ctx, accountID) if err != nil { return nil, err } @@ -128,7 +128,7 @@ func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID s } func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) { - followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, page) + followerIDs, err := r.GetAccountFollowerIDs(ctx, accountID, page) if err != nil { return nil, err } @@ -136,7 +136,7 @@ func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID stri } func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) { - followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID) + followerIDs, err := r.GetAccountLocalFollowerIDs(ctx, accountID) if err != nil { return nil, err } @@ -144,7 +144,7 @@ func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID } func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) { - followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, page) + followReqIDs, err := r.GetAccountFollowRequestIDs(ctx, accountID, page) if err != nil { return nil, err } @@ -152,7 +152,7 @@ func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID } func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) { - followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, page) + followReqIDs, err := r.GetAccountFollowRequestingIDs(ctx, accountID, page) if err != nil { return nil, err } @@ -160,49 +160,14 @@ func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, account } func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Block, error) { - blockIDs, err := r.getAccountBlockIDs(ctx, accountID, page) + blockIDs, err := r.GetAccountBlockIDs(ctx, accountID, page) if err != nil { return nil, err } return r.GetBlocksByIDs(ctx, blockIDs) } -func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) { - followIDs, err := r.getAccountFollowIDs(ctx, accountID, nil) - return len(followIDs), err -} - -func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) { - followIDs, err := r.getAccountLocalFollowIDs(ctx, accountID) - return len(followIDs), err -} - -func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) { - followerIDs, err := r.getAccountFollowerIDs(ctx, accountID, nil) - return len(followerIDs), err -} - -func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) { - followerIDs, err := r.getAccountLocalFollowerIDs(ctx, accountID) - return len(followerIDs), err -} - -func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) { - followReqIDs, err := r.getAccountFollowRequestIDs(ctx, accountID, nil) - return len(followReqIDs), err -} - -func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) { - followReqIDs, err := r.getAccountFollowRequestingIDs(ctx, accountID, nil) - return len(followReqIDs), err -} - -func (r *relationshipDB) CountAccountBlocks(ctx context.Context, accountID string) (int, error) { - blockIDs, err := r.getAccountBlockIDs(ctx, accountID, nil) - return len(blockIDs), err -} - -func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +func (r *relationshipDB) GetAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { return loadPagedIDs(&r.state.Caches.GTS.FollowIDs, ">"+accountID, page, func() ([]string, error) { var followIDs []string @@ -217,7 +182,7 @@ func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID stri }) } -func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) { +func (r *relationshipDB) GetAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) { return r.state.Caches.GTS.FollowIDs.Load("l>"+accountID, func() ([]string, error) { var followIDs []string @@ -232,7 +197,7 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID }) } -func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +func (r *relationshipDB) GetAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { return loadPagedIDs(&r.state.Caches.GTS.FollowIDs, "<"+accountID, page, func() ([]string, error) { var followIDs []string @@ -247,7 +212,7 @@ func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID st }) } -func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) { +func (r *relationshipDB) GetAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) { return r.state.Caches.GTS.FollowIDs.Load("l<"+accountID, func() ([]string, error) { var followIDs []string @@ -262,7 +227,7 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account }) } -func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +func (r *relationshipDB) GetAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { return loadPagedIDs(&r.state.Caches.GTS.FollowRequestIDs, ">"+accountID, page, func() ([]string, error) { var followReqIDs []string @@ -277,7 +242,7 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account }) } -func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +func (r *relationshipDB) GetAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { return loadPagedIDs(&r.state.Caches.GTS.FollowRequestIDs, "<"+accountID, page, func() ([]string, error) { var followReqIDs []string @@ -292,7 +257,7 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco }) } -func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { +func (r *relationshipDB) GetAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) { return loadPagedIDs(&r.state.Caches.GTS.BlockIDs, accountID, page, func() ([]string, error) { var blockIDs []string diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go index 9858e4768..f1d1a35d2 100644 --- a/internal/db/bundb/relationship_test.go +++ b/internal/db/bundb/relationship_test.go @@ -773,20 +773,6 @@ func (suite *RelationshipTestSuite) TestGetAccountFollows() { suite.Len(follows, 2) } -func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() { - account := suite.testAccounts["local_account_1"] - followsCount, err := suite.db.CountAccountLocalFollows(context.Background(), account.ID) - suite.NoError(err) - suite.Equal(2, followsCount) -} - -func (suite *RelationshipTestSuite) TestCountAccountFollows() { - account := suite.testAccounts["local_account_1"] - followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID) - suite.NoError(err) - suite.Equal(2, followsCount) -} - func (suite *RelationshipTestSuite) TestGetAccountFollowers() { account := suite.testAccounts["local_account_1"] follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID, nil) @@ -794,20 +780,6 @@ func (suite *RelationshipTestSuite) TestGetAccountFollowers() { suite.Len(follows, 2) } -func (suite *RelationshipTestSuite) TestCountAccountFollowers() { - account := suite.testAccounts["local_account_1"] - followsCount, err := suite.db.CountAccountFollowers(context.Background(), account.ID) - suite.NoError(err) - suite.Equal(2, followsCount) -} - -func (suite *RelationshipTestSuite) TestCountAccountFollowersLocalOnly() { - account := suite.testAccounts["local_account_1"] - followsCount, err := suite.db.CountAccountLocalFollowers(context.Background(), account.ID) - suite.NoError(err) - suite.Equal(2, followsCount) -} - func (suite *RelationshipTestSuite) TestUnfollowExisting() { originAccount := suite.testAccounts["local_account_1"] targetAccount := suite.testAccounts["admin_account"] diff --git a/internal/db/bundb/upsert.go b/internal/db/bundb/upsert.go index 34724446c..4a6395179 100644 --- a/internal/db/bundb/upsert.go +++ b/internal/db/bundb/upsert.go @@ -189,14 +189,14 @@ func (u *UpsertQuery) insertQuery() (*bun.InsertQuery, error) { constraintIDPlaceholders = append(constraintIDPlaceholders, "?") constraintIDs = append(constraintIDs, bun.Ident(constraint)) } - onSQL := "conflict (" + strings.Join(constraintIDPlaceholders, ", ") + ") do update" + onSQL := "CONFLICT (" + strings.Join(constraintIDPlaceholders, ", ") + ") DO UPDATE" setClauses := make([]string, 0, len(columns)) setIDs := make([]interface{}, 0, 2*len(columns)) for _, column := range columns { + setClauses = append(setClauses, "? = ?") // "excluded" is a special table that contains only the row involved in a conflict. - setClauses = append(setClauses, "? = excluded.?") - setIDs = append(setIDs, bun.Ident(column), bun.Ident(column)) + setIDs = append(setIDs, bun.Ident(column), bun.Ident("excluded."+column)) } setSQL := strings.Join(setClauses, ", ") diff --git a/internal/db/relationship.go b/internal/db/relationship.go index 5191701bb..cd4539791 100644 --- a/internal/db/relationship.go +++ b/internal/db/relationship.go @@ -140,44 +140,44 @@ type Relationship interface { // GetAccountFollows returns a slice of follows owned by the given accountID. GetAccountFollows(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) + // GetAccountFollowIDs is like GetAccountFollows, but returns just IDs. + GetAccountFollowIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) + // GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance. GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) + // GetAccountLocalFollowIDs is like GetAccountLocalFollows, but returns just IDs. + GetAccountLocalFollowIDs(ctx context.Context, accountID string) ([]string, error) + // GetAccountFollowers fetches follows that target given accountID. GetAccountFollowers(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.Follow, error) + // GetAccountFollowerIDs is like GetAccountFollowers, but returns just IDs. + GetAccountFollowerIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) + // GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance. GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) + // GetAccountLocalFollowerIDs is like GetAccountLocalFollowers, but returns just IDs. + GetAccountLocalFollowerIDs(ctx context.Context, accountID string) ([]string, error) + // GetAccountFollowRequests returns all follow requests targeting the given account. GetAccountFollowRequests(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) + // GetAccountFollowRequestIDs is like GetAccountFollowRequests, but returns just IDs. + GetAccountFollowRequestIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) + // GetAccountFollowRequesting returns all follow requests originating from the given account. GetAccountFollowRequesting(ctx context.Context, accountID string, page *paging.Page) ([]*gtsmodel.FollowRequest, error) + // GetAccountFollowRequestingIDs is like GetAccountFollowRequesting, but returns just IDs. + GetAccountFollowRequestingIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) + // GetAccountBlocks returns all blocks originating from the given account, with given optional paging parameters. GetAccountBlocks(ctx context.Context, accountID string, paging *paging.Page) ([]*gtsmodel.Block, error) - // CountAccountFollows returns the amount of accounts that the given accountID is following. - CountAccountFollows(ctx context.Context, accountID string) (int, error) - - // CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance. - CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) - - // CountAccountFollowers returns the amounts that the given ID is followed by. - CountAccountFollowers(ctx context.Context, accountID string) (int, error) - - // CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance. - CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) - - // CountAccountFollowRequests returns number of follow requests targeting the given account. - CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) - - // CountAccountFollowerRequests returns number of follow requests originating from the given account. - CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) - - // CountAccountBlocks ... - CountAccountBlocks(ctx context.Context, accountID string) (int, error) + // GetAccountBlockIDs is like GetAccountBlocks, but returns just IDs. + GetAccountBlockIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) // GetNote gets a private note from a source account on a target account, if it exists. GetNote(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.AccountNote, error) diff --git a/internal/federation/dereferencing/account.go b/internal/federation/dereferencing/account.go index 305b3f05c..e8d32f58a 100644 --- a/internal/federation/dereferencing/account.go +++ b/internal/federation/dereferencing/account.go @@ -695,7 +695,7 @@ func (d *Dereferencer) enrichAccount( representation of the target account, derived from a combination of webfinger lookups and dereferencing. Further fetching beyond this point is for peripheral - things like account avatar, header, emojis. + things like account avatar, header, emojis, stats. */ // Ensure internal db ID is @@ -718,6 +718,11 @@ func (d *Dereferencer) enrichAccount( log.Errorf(ctx, "error fetching remote emojis for account %s: %v", uri, err) } + // Fetch followers/following count for this account. + if err := d.fetchRemoteAccountStats(ctx, latestAcc, requestUser); err != nil { + log.Errorf(ctx, "error fetching remote stats for account %s: %v", uri, err) + } + if account.IsNew() { // Prefer published/created time from // apubAcc, fall back to FetchedAt value. @@ -1036,6 +1041,113 @@ func (d *Dereferencer) fetchRemoteAccountEmojis(ctx context.Context, targetAccou return changed, nil } +func (d *Dereferencer) fetchRemoteAccountStats(ctx context.Context, account *gtsmodel.Account, requestUser string) error { + // Ensure we have a stats model for this account. + if account.Stats == nil { + if err := d.state.DB.PopulateAccountStats(ctx, account); err != nil { + return gtserror.Newf("db error getting account stats: %w", err) + } + } + + // We want to update stats by getting remote + // followers/following/statuses counts for + // this account. + // + // If we fail getting any particular stat, + // it will just fall back to counting local. + + // Followers first. + if count, err := d.countCollection( + ctx, + account.FollowersURI, + requestUser, + ); err != nil { + // Log this but don't bail. + log.Warnf(ctx, + "couldn't count followers for @%s@%s: %v", + account.Username, account.Domain, err, + ) + } else if count > 0 { + // Positive integer is useful! + account.Stats.FollowersCount = &count + } + + // Now following. + if count, err := d.countCollection( + ctx, + account.FollowingURI, + requestUser, + ); err != nil { + // Log this but don't bail. + log.Warnf(ctx, + "couldn't count following for @%s@%s: %v", + account.Username, account.Domain, err, + ) + } else if count > 0 { + // Positive integer is useful! + account.Stats.FollowingCount = &count + } + + // Now statuses count. + if count, err := d.countCollection( + ctx, + account.OutboxURI, + requestUser, + ); err != nil { + // Log this but don't bail. + log.Warnf(ctx, + "couldn't count statuses for @%s@%s: %v", + account.Username, account.Domain, err, + ) + } else if count > 0 { + // Positive integer is useful! + account.Stats.StatusesCount = &count + } + + // Update stats now. + if err := d.state.DB.UpdateAccountStats( + ctx, + account.Stats, + "followers_count", + "following_count", + "statuses_count", + ); err != nil { + return gtserror.Newf("db error updating account stats: %w", err) + } + + return nil +} + +// countCollection parses the given uriStr, +// dereferences the result as a collection +// type, and returns total items as 0, or +// a positive integer, or -1 if total items +// cannot be counted. +// +// Error will be returned for invalid non-empty +// URIs or dereferencing isses. +func (d *Dereferencer) countCollection( + ctx context.Context, + uriStr string, + requestUser string, +) (int, error) { + if uriStr == "" { + return -1, nil + } + + uri, err := url.Parse(uriStr) + if err != nil { + return -1, err + } + + collect, err := d.dereferenceCollection(ctx, requestUser, uri) + if err != nil { + return -1, err + } + + return collect.TotalItems(), nil +} + // dereferenceAccountFeatured dereferences an account's featuredCollectionURI (if not empty). For each discovered status, this status will // be dereferenced (if necessary) and marked as pinned (if necessary). Then, old pins will be removed if they're not included in new pins. func (d *Dereferencer) dereferenceAccountFeatured(ctx context.Context, requestUser string, account *gtsmodel.Account) error { diff --git a/internal/federation/dereferencing/collection.go b/internal/federation/dereferencing/collection.go index 07f56c952..1a9f1555b 100644 --- a/internal/federation/dereferencing/collection.go +++ b/internal/federation/dereferencing/collection.go @@ -40,7 +40,7 @@ func (d *Dereferencer) dereferenceCollection(ctx context.Context, username strin rsp, err := transport.Dereference(ctx, pageIRI) if err != nil { - return nil, gtserror.Newf("error deferencing %s: %w", pageIRI.String(), err) + return nil, gtserror.Newf("error dereferencing %s: %w", pageIRI.String(), err) } collect, err := ap.ResolveCollection(ctx, rsp.Body) diff --git a/internal/federation/federatingdb/accept.go b/internal/federation/federatingdb/accept.go index 7ec9346e0..50a7c2db1 100644 --- a/internal/federation/federatingdb/accept.go +++ b/internal/federation/federatingdb/accept.go @@ -89,11 +89,13 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA return err } + // Process side effects asynchronously. f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ - APObjectType: ap.ActivityFollow, - APActivityType: ap.ActivityAccept, - GTSModel: follow, - ReceivingAccount: receivingAcct, + APObjectType: ap.ActivityFollow, + APActivityType: ap.ActivityAccept, + GTSModel: follow, + ReceivingAccount: receivingAcct, + RequestingAccount: requestingAcct, }) } @@ -136,11 +138,13 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA return err } + // Process side effects asynchronously. f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ - APObjectType: ap.ActivityFollow, - APActivityType: ap.ActivityAccept, - GTSModel: follow, - ReceivingAccount: receivingAcct, + APObjectType: ap.ActivityFollow, + APActivityType: ap.ActivityAccept, + GTSModel: follow, + ReceivingAccount: receivingAcct, + RequestingAccount: requestingAcct, }) continue diff --git a/internal/federation/federatingdb/announce.go b/internal/federation/federatingdb/announce.go index e13e212da..2f5950a30 100644 --- a/internal/federation/federatingdb/announce.go +++ b/internal/federation/federatingdb/announce.go @@ -82,10 +82,11 @@ func (f *federatingDB) Announce(ctx context.Context, announce vocab.ActivityStre // This is a new boost. Process side effects asynchronously. f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ - APObjectType: ap.ActivityAnnounce, - APActivityType: ap.ActivityCreate, - GTSModel: boost, - ReceivingAccount: receivingAcct, + APObjectType: ap.ActivityAnnounce, + APActivityType: ap.ActivityCreate, + GTSModel: boost, + ReceivingAccount: receivingAcct, + RequestingAccount: requestingAcct, }) return nil diff --git a/internal/federation/federatingdb/create.go b/internal/federation/federatingdb/create.go index cacaf07cf..94261526e 100644 --- a/internal/federation/federatingdb/create.go +++ b/internal/federation/federatingdb/create.go @@ -131,10 +131,11 @@ func (f *federatingDB) activityBlock(ctx context.Context, asType vocab.Type, rec } f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ - APObjectType: ap.ActivityBlock, - APActivityType: ap.ActivityCreate, - GTSModel: block, - ReceivingAccount: receiving, + APObjectType: ap.ActivityBlock, + APActivityType: ap.ActivityCreate, + GTSModel: block, + ReceivingAccount: receiving, + RequestingAccount: requestingAccount, }) return nil @@ -307,7 +308,8 @@ func (f *federatingDB) createPollOptionables( PollID: inReplyTo.PollID, Poll: inReplyTo.Poll, }, - ReceivingAccount: receiver, + ReceivingAccount: receiver, + RequestingAccount: requester, }) return nil @@ -376,12 +378,13 @@ func (f *federatingDB) createStatusable( // Pass the statusable URI (APIri) into the processor // worker and do the rest of the processing asynchronously. f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ - APObjectType: ap.ObjectNote, - APActivityType: ap.ActivityCreate, - APIri: ap.GetJSONLDId(statusable), - APObjectModel: nil, - GTSModel: nil, - ReceivingAccount: receiver, + APObjectType: ap.ObjectNote, + APActivityType: ap.ActivityCreate, + APIri: ap.GetJSONLDId(statusable), + APObjectModel: nil, + GTSModel: nil, + ReceivingAccount: receiver, + RequestingAccount: requester, }) return nil } @@ -389,12 +392,13 @@ func (f *federatingDB) createStatusable( // Do the rest of the processing asynchronously. The processor // will handle inserting/updating + further dereferencing the status. f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ - APObjectType: ap.ObjectNote, - APActivityType: ap.ActivityCreate, - APIri: nil, - GTSModel: nil, - APObjectModel: statusable, - ReceivingAccount: receiver, + APObjectType: ap.ObjectNote, + APActivityType: ap.ActivityCreate, + APIri: nil, + GTSModel: nil, + APObjectModel: statusable, + ReceivingAccount: receiver, + RequestingAccount: requester, }) return nil @@ -436,10 +440,11 @@ func (f *federatingDB) activityFollow(ctx context.Context, asType vocab.Type, re } f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ - APObjectType: ap.ActivityFollow, - APActivityType: ap.ActivityCreate, - GTSModel: followRequest, - ReceivingAccount: receivingAccount, + APObjectType: ap.ActivityFollow, + APActivityType: ap.ActivityCreate, + GTSModel: followRequest, + ReceivingAccount: receivingAccount, + RequestingAccount: requestingAccount, }) return nil @@ -480,10 +485,11 @@ func (f *federatingDB) activityLike(ctx context.Context, asType vocab.Type, rece } f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ - APObjectType: ap.ActivityLike, - APActivityType: ap.ActivityCreate, - GTSModel: fave, - ReceivingAccount: receivingAccount, + APObjectType: ap.ActivityLike, + APActivityType: ap.ActivityCreate, + GTSModel: fave, + ReceivingAccount: receivingAccount, + RequestingAccount: requestingAccount, }) return nil @@ -531,10 +537,11 @@ func (f *federatingDB) activityFlag(ctx context.Context, asType vocab.Type, rece } f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ - APObjectType: ap.ActivityFlag, - APActivityType: ap.ActivityCreate, - GTSModel: report, - ReceivingAccount: receivingAccount, + APObjectType: ap.ActivityFlag, + APActivityType: ap.ActivityCreate, + GTSModel: report, + ReceivingAccount: receivingAccount, + RequestingAccount: requestingAccount, }) return nil diff --git a/internal/federation/federatingdb/delete.go b/internal/federation/federatingdb/delete.go index 384291463..14bc20209 100644 --- a/internal/federation/federatingdb/delete.go +++ b/internal/federation/federatingdb/delete.go @@ -63,10 +63,11 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error { if a, err := f.state.DB.GetAccountByURI(ctx, id.String()); err == nil && requestingAcct.ID == a.ID { l.Debugf("deleting account: %s", a.ID) f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ - APObjectType: ap.ObjectProfile, - APActivityType: ap.ActivityDelete, - GTSModel: a, - ReceivingAccount: receivingAcct, + APObjectType: ap.ObjectProfile, + APActivityType: ap.ActivityDelete, + GTSModel: a, + ReceivingAccount: receivingAcct, + RequestingAccount: requestingAcct, }) } diff --git a/internal/federation/federatingdb/update.go b/internal/federation/federatingdb/update.go index bd8ad3106..733abba0d 100644 --- a/internal/federation/federatingdb/update.go +++ b/internal/federation/federatingdb/update.go @@ -99,11 +99,12 @@ func (f *federatingDB) updateAccountable(ctx context.Context, receivingAcct *gts // updating of eg., avatar/header, emojis, etc. The actual db // inserts/updates will take place there. f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ - APObjectType: ap.ObjectProfile, - APActivityType: ap.ActivityUpdate, - GTSModel: requestingAcct, - APObjectModel: accountable, - ReceivingAccount: receivingAcct, + APObjectType: ap.ObjectProfile, + APActivityType: ap.ActivityUpdate, + GTSModel: requestingAcct, + APObjectModel: accountable, + ReceivingAccount: receivingAcct, + RequestingAccount: requestingAcct, }) return nil @@ -155,11 +156,12 @@ func (f *federatingDB) updateStatusable(ctx context.Context, receivingAcct *gtsm // Queue an UPDATE NOTE activity to our fedi API worker, // this will handle necessary database insertions, etc. f.state.Workers.EnqueueFediAPI(ctx, messages.FromFediAPI{ - APObjectType: ap.ObjectNote, - APActivityType: ap.ActivityUpdate, - GTSModel: status, // original status - APObjectModel: (ap.Statusable)(statusable), - ReceivingAccount: receivingAcct, + APObjectType: ap.ObjectNote, + APActivityType: ap.ActivityUpdate, + GTSModel: status, // original status + APObjectModel: (ap.Statusable)(statusable), + ReceivingAccount: receivingAcct, + RequestingAccount: requestingAcct, }) return nil diff --git a/internal/gtsmodel/account.go b/internal/gtsmodel/account.go index 3bbcb37e3..79a35e561 100644 --- a/internal/gtsmodel/account.go +++ b/internal/gtsmodel/account.go @@ -80,6 +80,7 @@ type Account struct { SuspendedAt time.Time `bun:"type:timestamptz,nullzero"` // When was this account suspended (eg., don't allow it to log in/post, don't accept media/posts from this account) SuspensionOrigin string `bun:"type:CHAR(26),nullzero"` // id of the database entry that caused this account to become suspended -- can be an account ID or a domain block ID Settings *AccountSettings `bun:"-"` // gtsmodel.AccountSettings for this account. + Stats *AccountStats `bun:"-"` // gtsmodel.AccountStats for this account. } // IsLocal returns whether account is a local user account. diff --git a/internal/gtsmodel/accountstats.go b/internal/gtsmodel/accountstats.go new file mode 100644 index 000000000..92b50d5e3 --- /dev/null +++ b/internal/gtsmodel/accountstats.go @@ -0,0 +1,33 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package gtsmodel + +import "time" + +// AccountStats models statistics +// for a remote or local account. +type AccountStats struct { + AccountID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` // AccountID of this AccountStats. + RegeneratedAt time.Time `bun:"type:timestamptz,nullzero"` // Time this stats model was last regenerated (ie., created from scratch using COUNTs). + FollowersCount *int `bun:",nullzero,notnull"` // Number of accounts following AccountID. + FollowingCount *int `bun:",nullzero,notnull"` // Number of accounts followed by AccountID. + FollowRequestsCount *int `bun:",nullzero,notnull"` // Number of pending follow requests aimed at AccountID. + StatusesCount *int `bun:",nullzero,notnull"` // Number of statuses created by AccountID. + StatusesPinnedCount *int `bun:",nullzero,notnull"` // Number of statuses pinned by AccountID. + LastStatusAt time.Time `bun:"type:timestamptz,nullzero"` // Time of most recent status created by AccountID. +} diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go index 858e42d36..a900c566d 100644 --- a/internal/processing/account/delete.go +++ b/internal/processing/account/delete.go @@ -485,6 +485,11 @@ func (p *Processor) deleteAccountPeripheral(ctx context.Context, account *gtsmod return gtserror.Newf("error deleting poll votes by account: %w", err) } + // Delete account stats model. + if err := p.state.DB.DeleteAccountStats(ctx, account.ID); err != nil { + return gtserror.Newf("error deleting stats for account: %w", err) + } + return nil } diff --git a/internal/processing/account/move.go b/internal/processing/account/move.go index a68c8f750..602e8c021 100644 --- a/internal/processing/account/move.go +++ b/internal/processing/account/move.go @@ -113,7 +113,7 @@ func (p *Processor) MoveSelf( // in quick succession, so get a lock on // this account. lockKey := originAcct.URI - unlock := p.state.ClientLocks.Lock(lockKey) + unlock := p.state.AccountLocks.Lock(lockKey) defer unlock() // Ensure we have a valid, up-to-date representation of the target account. diff --git a/internal/processing/account/rss.go b/internal/processing/account/rss.go index f2c6cba5e..cbbb4875b 100644 --- a/internal/processing/account/rss.go +++ b/internal/processing/account/rss.go @@ -69,14 +69,18 @@ func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string) return nil, never, gtserror.NewErrorNotFound(err) } + // Ensure account stats populated. + if account.Stats == nil { + if err := p.state.DB.PopulateAccountStats(ctx, account); err != nil { + err = gtserror.Newf("db error getting account stats %s: %w", username, err) + return nil, never, gtserror.NewErrorInternalError(err) + } + } + // LastModified time is needed by callers to check freshness for cacheing. // This might be a zero time.Time if account has never posted a status that's // eligible to appear in the RSS feed; that's fine. - lastPostAt, err := p.state.DB.GetAccountLastPosted(ctx, account.ID, true) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - err = gtserror.Newf("db error getting account %s last posted: %w", username, err) - return nil, never, gtserror.NewErrorInternalError(err) - } + lastPostAt := account.Stats.LastStatusAt return func() (string, gtserror.WithCode) { // Assemble author namestring once only. diff --git a/internal/processing/account/rss_test.go b/internal/processing/account/rss_test.go index 6ae285f9e..c08ef8874 100644 --- a/internal/processing/account/rss_test.go +++ b/internal/processing/account/rss_test.go @@ -19,7 +19,6 @@ package account_test import ( "context" - "fmt" "testing" "github.com/stretchr/testify/suite" @@ -32,14 +31,11 @@ type GetRSSTestSuite struct { func (suite *GetRSSTestSuite) TestGetAccountRSSAdmin() { getFeed, lastModified, err := suite.accountProcessor.GetRSSFeedForUsername(context.Background(), "admin") suite.NoError(err) - suite.EqualValues(1634733405, lastModified.Unix()) + suite.EqualValues(1634726497, lastModified.Unix()) feed, err := getFeed() suite.NoError(err) - - fmt.Println(feed) - - suite.Equal("\n \n Posts from @admin@localhost:8080\n http://localhost:8080/@admin\n Posts from @admin@localhost:8080\n Wed, 20 Oct 2021 12:36:45 +0000\n Wed, 20 Oct 2021 12:36:45 +0000\n \n open to see some puppies\n http://localhost:8080/@admin/statuses/01F8MHAAY43M6RJ473VQFCVH37\n @admin@localhost:8080 made a new post: "🐕🐕🐕🐕🐕"\n \n @admin@localhost:8080\n http://localhost:8080/@admin/statuses/01F8MHAAY43M6RJ473VQFCVH37\n Wed, 20 Oct 2021 12:36:45 +0000\n http://localhost:8080/@admin/feed.rss\n \n \n hello world! #welcome ! first post on the instance :rainbow: !\n http://localhost:8080/@admin/statuses/01F8MH75CBF9JFX4ZAD54N0W0R\n @admin@localhost:8080 posted 1 attachment: "hello world! #welcome ! first post on the instance :rainbow: !"\n !]]>\n @admin@localhost:8080\n \n http://localhost:8080/@admin/statuses/01F8MH75CBF9JFX4ZAD54N0W0R\n Wed, 20 Oct 2021 11:36:45 +0000\n http://localhost:8080/@admin/feed.rss\n \n \n", feed) + suite.Equal("\n \n Posts from @admin@localhost:8080\n http://localhost:8080/@admin\n Posts from @admin@localhost:8080\n Wed, 20 Oct 2021 10:41:37 +0000\n Wed, 20 Oct 2021 10:41:37 +0000\n \n open to see some puppies\n http://localhost:8080/@admin/statuses/01F8MHAAY43M6RJ473VQFCVH37\n @admin@localhost:8080 made a new post: "🐕🐕🐕🐕🐕"\n \n @admin@localhost:8080\n http://localhost:8080/@admin/statuses/01F8MHAAY43M6RJ473VQFCVH37\n Wed, 20 Oct 2021 12:36:45 +0000\n http://localhost:8080/@admin/feed.rss\n \n \n hello world! #welcome ! first post on the instance :rainbow: !\n http://localhost:8080/@admin/statuses/01F8MH75CBF9JFX4ZAD54N0W0R\n @admin@localhost:8080 posted 1 attachment: "hello world! #welcome ! first post on the instance :rainbow: !"\n !]]>\n @admin@localhost:8080\n \n http://localhost:8080/@admin/statuses/01F8MH75CBF9JFX4ZAD54N0W0R\n Wed, 20 Oct 2021 11:36:45 +0000\n http://localhost:8080/@admin/feed.rss\n \n \n", feed) } func (suite *GetRSSTestSuite) TestGetAccountRSSZork() { @@ -49,9 +45,6 @@ func (suite *GetRSSTestSuite) TestGetAccountRSSZork() { feed, err := getFeed() suite.NoError(err) - - fmt.Println(feed) - suite.Equal("\n \n Posts from @the_mighty_zork@localhost:8080\n http://localhost:8080/@the_mighty_zork\n Posts from @the_mighty_zork@localhost:8080\n Sun, 10 Dec 2023 09:24:00 +0000\n Sun, 10 Dec 2023 09:24:00 +0000\n \n http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/small/01F8MH58A357CV5K7R7TJMSH6S.jpg\n Avatar for @the_mighty_zork@localhost:8080\n http://localhost:8080/@the_mighty_zork\n \n \n HTML in post\n http://localhost:8080/@the_mighty_zork/statuses/01HH9KYNQPA416TNJ53NSATP40\n @the_mighty_zork@localhost:8080 made a new post: "Here's a bunch of HTML, read it and weep, weep then! ```html <section class="about-user"> <div class="col-header"> <h2>About</h2> </div> <div class="fields"> <h3 class="sr-only">Fields</h3> <dl> ...\n Here's a bunch of HTML, read it and weep, weep then!

<section class="about-user">\n    <div class="col-header">\n        <h2>About</h2>\n    </div>            \n    <div class="fields">\n        <h3 class="sr-only">Fields</h3>\n        <dl>\n            <div class="field">\n                <dt>should you follow me?</dt>\n                <dd>maybe!</dd>\n            </div>\n            <div class="field">\n                <dt>age</dt>\n                <dd>120</dd>\n            </div>\n        </dl>\n    </div>\n    <div class="bio">\n        <h3 class="sr-only">Bio</h3>\n        <p>i post about things that concern me</p>\n    </div>\n    <div class="sr-only" role="group">\n        <h3 class="sr-only">Stats</h3>\n        <span>Joined in Jun, 2022.</span>\n        <span>8 posts.</span>\n        <span>Followed by 1.</span>\n        <span>Following 1.</span>\n    </div>\n    <div class="accountstats" aria-hidden="true">\n        <b>Joined</b><time datetime="2022-06-04T13:12:00.000Z">Jun, 2022</time>\n        <b>Posts</b><span>8</span>\n        <b>Followed by</b><span>1</span>\n        <b>Following</b><span>1</span>\n    </div>\n</section>\n

There, hope you liked that!

]]>
\n @the_mighty_zork@localhost:8080\n http://localhost:8080/@the_mighty_zork/statuses/01HH9KYNQPA416TNJ53NSATP40\n Sun, 10 Dec 2023 09:24:00 +0000\n http://localhost:8080/@the_mighty_zork/feed.rss\n
\n \n introduction post\n http://localhost:8080/@the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY\n @the_mighty_zork@localhost:8080 made a new post: "hello everyone!"\n \n @the_mighty_zork@localhost:8080\n http://localhost:8080/@the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY\n Wed, 20 Oct 2021 10:40:37 +0000\n http://localhost:8080/@the_mighty_zork/feed.rss\n \n
\n
", feed) } @@ -77,9 +70,6 @@ func (suite *GetRSSTestSuite) TestGetAccountRSSZorkNoPosts() { feed, err := getFeed() suite.NoError(err) - - fmt.Println(feed) - suite.Equal("\n \n Posts from @the_mighty_zork@localhost:8080\n http://localhost:8080/@the_mighty_zork\n Posts from @the_mighty_zork@localhost:8080\n Fri, 20 May 2022 11:09:18 +0000\n Fri, 20 May 2022 11:09:18 +0000\n \n http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/small/01F8MH58A357CV5K7R7TJMSH6S.jpg\n Avatar for @the_mighty_zork@localhost:8080\n http://localhost:8080/@the_mighty_zork\n \n \n", feed) } diff --git a/internal/processing/admin/accountapprove.go b/internal/processing/admin/accountapprove.go index e34cb18e3..c8a49e089 100644 --- a/internal/processing/admin/accountapprove.go +++ b/internal/processing/admin/accountapprove.go @@ -49,7 +49,7 @@ func (p *Processor) AccountApprove( // Get a lock on the account URI, // to ensure it's not also being // rejected at the same time! - unlock := p.state.ClientLocks.Lock(user.Account.URI) + unlock := p.state.AccountLocks.Lock(user.Account.URI) defer unlock() if !*user.Approved { diff --git a/internal/processing/admin/accountreject.go b/internal/processing/admin/accountreject.go index bc7a1c20a..eee2b2ff5 100644 --- a/internal/processing/admin/accountreject.go +++ b/internal/processing/admin/accountreject.go @@ -52,7 +52,7 @@ func (p *Processor) AccountReject( // Get a lock on the account URI, // since we're going to be deleting // it and its associated user. - unlock := p.state.ClientLocks.Lock(user.Account.URI) + unlock := p.state.AccountLocks.Lock(user.Account.URI) defer unlock() // Can't reject an account with a diff --git a/internal/processing/fedi/collections.go b/internal/processing/fedi/collections.go index 0eacf45da..7a6c99adb 100644 --- a/internal/processing/fedi/collections.go +++ b/internal/processing/fedi/collections.go @@ -126,11 +126,12 @@ func (p *Processor) FollowersGet(ctx context.Context, requestedUser string, page return nil, gtserror.NewErrorInternalError(err) } - // Calculate total number of followers available for account. - total, err := p.state.DB.CountAccountFollowers(ctx, receiver.ID) - if err != nil { - err := gtserror.Newf("error counting followers: %w", err) - return nil, gtserror.NewErrorInternalError(err) + // Ensure we have stats for this account. + if receiver.Stats == nil { + if err := p.state.DB.PopulateAccountStats(ctx, receiver); err != nil { + err := gtserror.Newf("error getting stats for account %s: %w", receiver.ID, err) + return nil, gtserror.NewErrorInternalError(err) + } } var obj vocab.Type @@ -138,7 +139,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestedUser string, page // Start the AS collection params. var params ap.CollectionParams params.ID = collectionID - params.Total = total + params.Total = *receiver.Stats.FollowersCount switch { @@ -235,11 +236,12 @@ func (p *Processor) FollowingGet(ctx context.Context, requestedUser string, page return nil, gtserror.NewErrorInternalError(err) } - // Calculate total number of following available for account. - total, err := p.state.DB.CountAccountFollows(ctx, receiver.ID) - if err != nil { - err := gtserror.Newf("error counting follows: %w", err) - return nil, gtserror.NewErrorInternalError(err) + // Ensure we have stats for this account. + if receiver.Stats == nil { + if err := p.state.DB.PopulateAccountStats(ctx, receiver); err != nil { + err := gtserror.Newf("error getting stats for account %s: %w", receiver.ID, err) + return nil, gtserror.NewErrorInternalError(err) + } } var obj vocab.Type @@ -247,7 +249,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestedUser string, page // Start AS collection params. var params ap.CollectionParams params.ID = collectionID - params.Total = total + params.Total = *receiver.Stats.FollowingCount switch { case receiver.IsInstance() || diff --git a/internal/processing/status/pin.go b/internal/processing/status/pin.go index 9a4a4b266..d0688331b 100644 --- a/internal/processing/status/pin.go +++ b/internal/processing/status/pin.go @@ -82,18 +82,26 @@ func (p *Processor) PinCreate(ctx context.Context, requestingAccount *gtsmodel.A return nil, errWithCode } + // Get a lock on this account. + unlock := p.state.AccountLocks.Lock(requestingAccount.URI) + defer unlock() + if !targetStatus.PinnedAt.IsZero() { err := errors.New("status already pinned") return nil, gtserror.NewErrorUnprocessableEntity(err, err.Error()) } - pinnedCount, err := p.state.DB.CountAccountPinned(ctx, requestingAccount.ID) - if err != nil { - return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking number of pinned statuses: %w", err)) + // Ensure account stats populated. + if requestingAccount.Stats == nil { + if err := p.state.DB.PopulateAccountStats(ctx, requestingAccount); err != nil { + err = gtserror.Newf("db error getting account stats: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } } + pinnedCount := *requestingAccount.Stats.StatusesPinnedCount if pinnedCount >= allowedPinnedCount { - err = fmt.Errorf("status pin limit exceeded, you've already pinned %d status(es) out of %d", pinnedCount, allowedPinnedCount) + err := fmt.Errorf("status pin limit exceeded, you've already pinned %d status(es) out of %d", pinnedCount, allowedPinnedCount) return nil, gtserror.NewErrorUnprocessableEntity(err, err.Error()) } @@ -103,6 +111,17 @@ func (p *Processor) PinCreate(ctx context.Context, requestingAccount *gtsmodel.A return nil, gtserror.NewErrorInternalError(err) } + // Update account stats. + *requestingAccount.Stats.StatusesPinnedCount++ + if err := p.state.DB.UpdateAccountStats( + ctx, + requestingAccount.Stats, + "statuses_pinned_count", + ); err != nil { + err = gtserror.Newf("db error updating stats: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + if err := p.c.InvalidateTimelinedStatus(ctx, requestingAccount.ID, targetStatusID); err != nil { err = gtserror.Newf("error invalidating status from timelines: %w", err) return nil, gtserror.NewErrorInternalError(err) @@ -128,16 +147,45 @@ func (p *Processor) PinRemove(ctx context.Context, requestingAccount *gtsmodel.A return nil, errWithCode } + // Get a lock on this account. + unlock := p.state.AccountLocks.Lock(requestingAccount.URI) + defer unlock() + if targetStatus.PinnedAt.IsZero() { + // Status already not pinned. return p.c.GetAPIStatus(ctx, requestingAccount, targetStatus) } + // Ensure account stats populated. + if requestingAccount.Stats == nil { + if err := p.state.DB.PopulateAccountStats(ctx, requestingAccount); err != nil { + err = gtserror.Newf("db error getting account stats: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + } + targetStatus.PinnedAt = time.Time{} if err := p.state.DB.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil { err = gtserror.Newf("db error unpinning status: %w", err) return nil, gtserror.NewErrorInternalError(err) } + // Update account stats. + // + // Clamp to 0 to avoid funny business. + *requestingAccount.Stats.StatusesPinnedCount-- + if *requestingAccount.Stats.StatusesPinnedCount < 0 { + *requestingAccount.Stats.StatusesPinnedCount = 0 + } + if err := p.state.DB.UpdateAccountStats( + ctx, + requestingAccount.Stats, + "statuses_pinned_count", + ); err != nil { + err = gtserror.Newf("db error updating stats: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + if err := p.c.InvalidateTimelinedStatus(ctx, requestingAccount.ID, targetStatusID); err != nil { err = gtserror.Newf("error invalidating status from timelines: %w", err) return nil, gtserror.NewErrorInternalError(err) diff --git a/internal/processing/workers/fromclientapi.go b/internal/processing/workers/fromclientapi.go index 37c330cf0..1412ea003 100644 --- a/internal/processing/workers/fromclientapi.go +++ b/internal/processing/workers/fromclientapi.go @@ -247,6 +247,11 @@ func (p *clientAPI) CreateStatus(ctx context.Context, cMsg messages.FromClientAP return gtserror.Newf("%T not parseable as *gtsmodel.Status", cMsg.GTSModel) } + // Update stats for the actor account. + if err := p.utilF.incrementStatusesCount(ctx, cMsg.OriginAccount, status); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + if err := p.surface.timelineAndNotifyStatus(ctx, status); err != nil { log.Errorf(ctx, "error timelining and notifying status: %v", err) } @@ -311,6 +316,11 @@ func (p *clientAPI) CreateFollowReq(ctx context.Context, cMsg messages.FromClien return gtserror.Newf("%T not parseable as *gtsmodel.FollowRequest", cMsg.GTSModel) } + // Update stats for the target account. + if err := p.utilF.incrementFollowRequestsCount(ctx, cMsg.TargetAccount); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + if err := p.surface.notifyFollowRequest(ctx, followRequest); err != nil { log.Errorf(ctx, "error notifying follow request: %v", err) } @@ -360,6 +370,11 @@ func (p *clientAPI) CreateAnnounce(ctx context.Context, cMsg messages.FromClient return gtserror.Newf("%T not parseable as *gtsmodel.Status", cMsg.GTSModel) } + // Update stats for the actor account. + if err := p.utilF.incrementStatusesCount(ctx, cMsg.OriginAccount, boost); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + // Timeline and notify the boost wrapper status. if err := p.surface.timelineAndNotifyStatus(ctx, boost); err != nil { log.Errorf(ctx, "error timelining and notifying status: %v", err) @@ -485,6 +500,20 @@ func (p *clientAPI) AcceptFollow(ctx context.Context, cMsg messages.FromClientAP return gtserror.Newf("%T not parseable as *gtsmodel.Follow", cMsg.GTSModel) } + // Update stats for the target account. + if err := p.utilF.decrementFollowRequestsCount(ctx, cMsg.TargetAccount); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + + if err := p.utilF.incrementFollowersCount(ctx, cMsg.TargetAccount); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + + // Update stats for the origin account. + if err := p.utilF.incrementFollowingCount(ctx, cMsg.OriginAccount); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + if err := p.surface.notifyFollow(ctx, follow); err != nil { log.Errorf(ctx, "error notifying follow: %v", err) } @@ -502,6 +531,11 @@ func (p *clientAPI) RejectFollowRequest(ctx context.Context, cMsg messages.FromC return gtserror.Newf("%T not parseable as *gtsmodel.FollowRequest", cMsg.GTSModel) } + // Update stats for the target account. + if err := p.utilF.decrementFollowRequestsCount(ctx, cMsg.TargetAccount); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + if err := p.federate.RejectFollow( ctx, p.converter.FollowRequestToFollow(ctx, followReq), @@ -518,6 +552,16 @@ func (p *clientAPI) UndoFollow(ctx context.Context, cMsg messages.FromClientAPI) return gtserror.Newf("%T not parseable as *gtsmodel.Follow", cMsg.GTSModel) } + // Update stats for the origin account. + if err := p.utilF.decrementFollowingCount(ctx, cMsg.OriginAccount); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + + // Update stats for the target account. + if err := p.utilF.decrementFollowersCount(ctx, cMsg.TargetAccount); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + if err := p.federate.UndoFollow(ctx, follow); err != nil { log.Errorf(ctx, "error federating follow undo: %v", err) } @@ -565,6 +609,11 @@ func (p *clientAPI) UndoAnnounce(ctx context.Context, cMsg messages.FromClientAP return gtserror.Newf("db error deleting status: %w", err) } + // Update stats for the origin account. + if err := p.utilF.decrementStatusesCount(ctx, cMsg.OriginAccount); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + if err := p.surface.deleteStatusFromTimelines(ctx, status.ID); err != nil { log.Errorf(ctx, "error removing timelined status: %v", err) } @@ -603,6 +652,11 @@ func (p *clientAPI) DeleteStatus(ctx context.Context, cMsg messages.FromClientAP log.Errorf(ctx, "error wiping status: %v", err) } + // Update stats for the origin account. + if err := p.utilF.decrementStatusesCount(ctx, cMsg.OriginAccount); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + if status.InReplyToID != "" { // Interaction counts changed on the replied status; // uncache the prepared version from all timelines. diff --git a/internal/processing/workers/fromclientapi_test.go b/internal/processing/workers/fromclientapi_test.go index 3d3630b11..5e294597d 100644 --- a/internal/processing/workers/fromclientapi_test.go +++ b/internal/processing/workers/fromclientapi_test.go @@ -182,11 +182,6 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusWithNotification() { nil, nil, ) - statusJSON = suite.statusJSON( - ctx, - status, - receivingAccount, - ) ) // Update the follow from receiving account -> posting account so @@ -212,6 +207,12 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusWithNotification() { suite.FailNow(err.Error()) } + statusJSON := suite.statusJSON( + ctx, + status, + receivingAccount, + ) + // Check message in home stream. suite.checkStreamed( homeStream, @@ -285,11 +286,6 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusReply() { suite.testStatuses["local_account_2_status_1"], nil, ) - statusJSON = suite.statusJSON( - ctx, - status, - receivingAccount, - ) ) // Process the new status. @@ -305,6 +301,12 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusReply() { suite.FailNow(err.Error()) } + statusJSON := suite.statusJSON( + ctx, + status, + receivingAccount, + ) + // Check message in home stream. suite.checkStreamed( homeStream, @@ -451,11 +453,6 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusListRepliesPolicyLis suite.testStatuses["local_account_2_status_1"], nil, ) - statusJSON = suite.statusJSON( - ctx, - status, - receivingAccount, - ) ) // Modify replies policy of test list to show replies @@ -480,6 +477,12 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusListRepliesPolicyLis suite.FailNow(err.Error()) } + statusJSON := suite.statusJSON( + ctx, + status, + receivingAccount, + ) + // Check message in home stream. suite.checkStreamed( homeStream, @@ -518,11 +521,6 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusListRepliesPolicyLis suite.testStatuses["local_account_2_status_1"], nil, ) - statusJSON = suite.statusJSON( - ctx, - status, - receivingAccount, - ) ) // Modify replies policy of test list to show replies @@ -552,6 +550,12 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusListRepliesPolicyLis suite.FailNow(err.Error()) } + statusJSON := suite.statusJSON( + ctx, + status, + receivingAccount, + ) + // Check message in home stream. suite.checkStreamed( homeStream, @@ -590,11 +594,6 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusReplyListRepliesPoli suite.testStatuses["local_account_2_status_1"], nil, ) - statusJSON = suite.statusJSON( - ctx, - status, - receivingAccount, - ) ) // Modify replies policy of test list. @@ -619,6 +618,12 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusReplyListRepliesPoli suite.FailNow(err.Error()) } + statusJSON := suite.statusJSON( + ctx, + status, + receivingAccount, + ) + // Check message in home stream. suite.checkStreamed( homeStream, @@ -654,11 +659,6 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusBoost() { nil, suite.testStatuses["local_account_2_status_1"], ) - statusJSON = suite.statusJSON( - ctx, - status, - receivingAccount, - ) ) // Process the new status. @@ -674,6 +674,12 @@ func (suite *FromClientAPITestSuite) TestProcessCreateStatusBoost() { suite.FailNow(err.Error()) } + statusJSON := suite.statusJSON( + ctx, + status, + receivingAccount, + ) + // Check message in home stream. suite.checkStreamed( homeStream, diff --git a/internal/processing/workers/fromfediapi.go b/internal/processing/workers/fromfediapi.go index 7b0e72490..0b1106a9e 100644 --- a/internal/processing/workers/fromfediapi.go +++ b/internal/processing/workers/fromfediapi.go @@ -122,7 +122,7 @@ func (p *Processor) ProcessFromFediAPI(ctx context.Context, fMsg messages.FromFe // UPDATE SOMETHING case ap.ActivityUpdate: - switch fMsg.APObjectType { //nolint:gocritic + switch fMsg.APObjectType { // UPDATE NOTE/STATUS case ap.ObjectNote: @@ -133,6 +133,15 @@ func (p *Processor) ProcessFromFediAPI(ctx context.Context, fMsg messages.FromFe return p.fediAPI.UpdateAccount(ctx, fMsg) } + // ACCEPT SOMETHING + case ap.ActivityAccept: + switch fMsg.APObjectType { //nolint:gocritic + + // ACCEPT FOLLOW + case ap.ActivityFollow: + return p.fediAPI.AcceptFollow(ctx, fMsg) + } + // DELETE SOMETHING case ap.ActivityDelete: switch fMsg.APObjectType { @@ -220,6 +229,11 @@ func (p *fediAPI) CreateStatus(ctx context.Context, fMsg messages.FromFediAPI) e return nil } + // Update stats for the remote account. + if err := p.utilF.incrementStatusesCount(ctx, fMsg.RequestingAccount, status); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + if status.InReplyToID != "" { // Interaction counts changed on the replied status; uncache the // prepared version from all timelines. The status dereferencer @@ -290,14 +304,20 @@ func (p *fediAPI) CreateFollowReq(ctx context.Context, fMsg messages.FromFediAPI } if *followRequest.TargetAccount.Locked { - // Account on our instance is locked: just notify the follow request. + // Local account is locked: just notify the follow request. if err := p.surface.notifyFollowRequest(ctx, followRequest); err != nil { log.Errorf(ctx, "error notifying follow request: %v", err) } + + // And update stats for the local account. + if err := p.utilF.incrementFollowRequestsCount(ctx, fMsg.ReceivingAccount); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + return nil } - // Account on our instance is not locked: + // Local account is not locked: // Automatically accept the follow request // and notify about the new follower. follow, err := p.state.DB.AcceptFollowRequest( @@ -309,6 +329,16 @@ func (p *fediAPI) CreateFollowReq(ctx context.Context, fMsg messages.FromFediAPI return gtserror.Newf("error accepting follow request: %w", err) } + // Update stats for the local account. + if err := p.utilF.incrementFollowersCount(ctx, fMsg.ReceivingAccount); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + + // Update stats for the remote account. + if err := p.utilF.incrementFollowingCount(ctx, fMsg.RequestingAccount); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + if err := p.federate.AcceptFollow(ctx, follow); err != nil { log.Errorf(ctx, "error federating follow request accept: %v", err) } @@ -369,6 +399,11 @@ func (p *fediAPI) CreateAnnounce(ctx context.Context, fMsg messages.FromFediAPI) return gtserror.Newf("error dereferencing announce: %w", err) } + // Update stats for the remote account. + if err := p.utilF.incrementStatusesCount(ctx, fMsg.RequestingAccount, boost); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + // Timeline and notify the announce. if err := p.surface.timelineAndNotifyStatus(ctx, boost); err != nil { log.Errorf(ctx, "error timelining and notifying status: %v", err) @@ -509,6 +544,24 @@ func (p *fediAPI) UpdateAccount(ctx context.Context, fMsg messages.FromFediAPI) return nil } +func (p *fediAPI) AcceptFollow(ctx context.Context, fMsg messages.FromFediAPI) error { + // Update stats for the remote account. + if err := p.utilF.decrementFollowRequestsCount(ctx, fMsg.RequestingAccount); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + + if err := p.utilF.incrementFollowersCount(ctx, fMsg.RequestingAccount); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + + // Update stats for the local account. + if err := p.utilF.incrementFollowingCount(ctx, fMsg.ReceivingAccount); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + + return nil +} + func (p *fediAPI) UpdateStatus(ctx context.Context, fMsg messages.FromFediAPI) error { // Cast the existing Status model attached to msg. existing, ok := fMsg.GTSModel.(*gtsmodel.Status) @@ -567,6 +620,11 @@ func (p *fediAPI) DeleteStatus(ctx context.Context, fMsg messages.FromFediAPI) e log.Errorf(ctx, "error wiping status: %v", err) } + // Update stats for the remote account. + if err := p.utilF.decrementStatusesCount(ctx, fMsg.RequestingAccount); err != nil { + log.Errorf(ctx, "error updating account stats: %v", err) + } + if status.InReplyToID != "" { // Interaction counts changed on the replied status; // uncache the prepared version from all timelines. diff --git a/internal/processing/workers/fromfediapi_test.go b/internal/processing/workers/fromfediapi_test.go index 51f61bd12..eb3d73e0c 100644 --- a/internal/processing/workers/fromfediapi_test.go +++ b/internal/processing/workers/fromfediapi_test.go @@ -55,10 +55,11 @@ func (suite *FromFediAPITestSuite) TestProcessFederationAnnounce() { announceStatus.Visibility = boostedStatus.Visibility err := suite.processor.Workers().ProcessFromFediAPI(context.Background(), messages.FromFediAPI{ - APObjectType: ap.ActivityAnnounce, - APActivityType: ap.ActivityCreate, - GTSModel: announceStatus, - ReceivingAccount: suite.testAccounts["local_account_1"], + APObjectType: ap.ActivityAnnounce, + APActivityType: ap.ActivityCreate, + GTSModel: announceStatus, + ReceivingAccount: suite.testAccounts["local_account_1"], + RequestingAccount: boostingAccount, }) suite.NoError(err) @@ -115,10 +116,11 @@ func (suite *FromFediAPITestSuite) TestProcessReplyMention() { // Send the replied status off to the fedi worker to be further processed. err = suite.processor.Workers().ProcessFromFediAPI(context.Background(), messages.FromFediAPI{ - APObjectType: ap.ObjectNote, - APActivityType: ap.ActivityCreate, - APObjectModel: replyingStatusable, - ReceivingAccount: suite.testAccounts["local_account_1"], + APObjectType: ap.ObjectNote, + APActivityType: ap.ActivityCreate, + APObjectModel: replyingStatusable, + ReceivingAccount: repliedAccount, + RequestingAccount: replyingAccount, }) suite.NoError(err) @@ -178,10 +180,11 @@ func (suite *FromFediAPITestSuite) TestProcessFave() { suite.NoError(err) err = suite.processor.Workers().ProcessFromFediAPI(context.Background(), messages.FromFediAPI{ - APObjectType: ap.ActivityLike, - APActivityType: ap.ActivityCreate, - GTSModel: fave, - ReceivingAccount: favedAccount, + APObjectType: ap.ActivityLike, + APActivityType: ap.ActivityCreate, + GTSModel: fave, + ReceivingAccount: favedAccount, + RequestingAccount: favingAccount, }) suite.NoError(err) @@ -247,10 +250,11 @@ func (suite *FromFediAPITestSuite) TestProcessFaveWithDifferentReceivingAccount( suite.NoError(err) err = suite.processor.Workers().ProcessFromFediAPI(context.Background(), messages.FromFediAPI{ - APObjectType: ap.ActivityLike, - APActivityType: ap.ActivityCreate, - GTSModel: fave, - ReceivingAccount: receivingAccount, + APObjectType: ap.ActivityLike, + APActivityType: ap.ActivityCreate, + GTSModel: fave, + ReceivingAccount: receivingAccount, + RequestingAccount: favingAccount, }) suite.NoError(err) @@ -318,10 +322,11 @@ func (suite *FromFediAPITestSuite) TestProcessAccountDelete() { // now they are mufos! err = suite.processor.Workers().ProcessFromFediAPI(ctx, messages.FromFediAPI{ - APObjectType: ap.ObjectProfile, - APActivityType: ap.ActivityDelete, - GTSModel: deletedAccount, - ReceivingAccount: receivingAccount, + APObjectType: ap.ObjectProfile, + APActivityType: ap.ActivityDelete, + GTSModel: deletedAccount, + ReceivingAccount: receivingAccount, + RequestingAccount: deletedAccount, }) suite.NoError(err) @@ -398,10 +403,11 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() { suite.NoError(err) err = suite.processor.Workers().ProcessFromFediAPI(ctx, messages.FromFediAPI{ - APObjectType: ap.ActivityFollow, - APActivityType: ap.ActivityCreate, - GTSModel: satanFollowRequestTurtle, - ReceivingAccount: targetAccount, + APObjectType: ap.ActivityFollow, + APActivityType: ap.ActivityCreate, + GTSModel: satanFollowRequestTurtle, + ReceivingAccount: targetAccount, + RequestingAccount: originAccount, }) suite.NoError(err) @@ -451,10 +457,11 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() { suite.NoError(err) err = suite.processor.Workers().ProcessFromFediAPI(ctx, messages.FromFediAPI{ - APObjectType: ap.ActivityFollow, - APActivityType: ap.ActivityCreate, - GTSModel: satanFollowRequestTurtle, - ReceivingAccount: targetAccount, + APObjectType: ap.ActivityFollow, + APActivityType: ap.ActivityCreate, + GTSModel: satanFollowRequestTurtle, + ReceivingAccount: targetAccount, + RequestingAccount: originAccount, }) suite.NoError(err) @@ -526,11 +533,12 @@ func (suite *FromFediAPITestSuite) TestCreateStatusFromIRI() { statusCreator := suite.testAccounts["remote_account_2"] err := suite.processor.Workers().ProcessFromFediAPI(ctx, messages.FromFediAPI{ - APObjectType: ap.ObjectNote, - APActivityType: ap.ActivityCreate, - GTSModel: nil, // gtsmodel is nil because this is a forwarded status -- we want to dereference it using the iri - ReceivingAccount: receivingAccount, - APIri: testrig.URLMustParse("http://example.org/users/Some_User/statuses/afaba698-5740-4e32-a702-af61aa543bc1"), + APObjectType: ap.ObjectNote, + APActivityType: ap.ActivityCreate, + GTSModel: nil, // gtsmodel is nil because this is a forwarded status -- we want to dereference it using the iri + ReceivingAccount: receivingAccount, + RequestingAccount: statusCreator, + APIri: testrig.URLMustParse("http://example.org/users/Some_User/statuses/afaba698-5740-4e32-a702-af61aa543bc1"), }) suite.NoError(err) diff --git a/internal/processing/workers/util.go b/internal/processing/workers/util.go index a38ecd336..cd936f428 100644 --- a/internal/processing/workers/util.go +++ b/internal/processing/workers/util.go @@ -238,3 +238,258 @@ func (u *utilF) redirectFollowers( return true } + +func (u *utilF) incrementStatusesCount( + ctx context.Context, + account *gtsmodel.Account, + status *gtsmodel.Status, +) error { + // Lock on this account since we're changing stats. + unlock := u.state.AccountLocks.Lock(account.URI) + defer unlock() + + // Populate stats. + if account.Stats == nil { + if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil { + return gtserror.Newf("db error getting account stats: %w", err) + } + } + + // Update stats by incrementing status + // count by one and setting last posted. + *account.Stats.StatusesCount++ + account.Stats.LastStatusAt = status.CreatedAt + if err := u.state.DB.UpdateAccountStats( + ctx, + account.Stats, + "statuses_count", + "last_status_at", + ); err != nil { + return gtserror.Newf("db error updating account stats: %w", err) + } + + return nil +} + +func (u *utilF) decrementStatusesCount( + ctx context.Context, + account *gtsmodel.Account, +) error { + // Lock on this account since we're changing stats. + unlock := u.state.AccountLocks.Lock(account.URI) + defer unlock() + + // Populate stats. + if account.Stats == nil { + if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil { + return gtserror.Newf("db error getting account stats: %w", err) + } + } + + // Update stats by decrementing + // status count by one. + // + // Clamp to 0 to avoid funny business. + *account.Stats.StatusesCount-- + if *account.Stats.StatusesCount < 0 { + *account.Stats.StatusesCount = 0 + } + if err := u.state.DB.UpdateAccountStats( + ctx, + account.Stats, + "statuses_count", + ); err != nil { + return gtserror.Newf("db error updating account stats: %w", err) + } + + return nil +} + +func (u *utilF) incrementFollowersCount( + ctx context.Context, + account *gtsmodel.Account, +) error { + // Lock on this account since we're changing stats. + unlock := u.state.AccountLocks.Lock(account.URI) + defer unlock() + + // Populate stats. + if account.Stats == nil { + if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil { + return gtserror.Newf("db error getting account stats: %w", err) + } + } + + // Update stats by incrementing followers + // count by one and setting last posted. + *account.Stats.FollowersCount++ + if err := u.state.DB.UpdateAccountStats( + ctx, + account.Stats, + "followers_count", + ); err != nil { + return gtserror.Newf("db error updating account stats: %w", err) + } + + return nil +} + +func (u *utilF) decrementFollowersCount( + ctx context.Context, + account *gtsmodel.Account, +) error { + // Lock on this account since we're changing stats. + unlock := u.state.AccountLocks.Lock(account.URI) + defer unlock() + + // Populate stats. + if account.Stats == nil { + if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil { + return gtserror.Newf("db error getting account stats: %w", err) + } + } + + // Update stats by decrementing + // followers count by one. + // + // Clamp to 0 to avoid funny business. + *account.Stats.FollowersCount-- + if *account.Stats.FollowersCount < 0 { + *account.Stats.FollowersCount = 0 + } + if err := u.state.DB.UpdateAccountStats( + ctx, + account.Stats, + "followers_count", + ); err != nil { + return gtserror.Newf("db error updating account stats: %w", err) + } + + return nil +} + +func (u *utilF) incrementFollowingCount( + ctx context.Context, + account *gtsmodel.Account, +) error { + // Lock on this account since we're changing stats. + unlock := u.state.AccountLocks.Lock(account.URI) + defer unlock() + + // Populate stats. + if account.Stats == nil { + if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil { + return gtserror.Newf("db error getting account stats: %w", err) + } + } + + // Update stats by incrementing + // followers count by one. + *account.Stats.FollowingCount++ + if err := u.state.DB.UpdateAccountStats( + ctx, + account.Stats, + "following_count", + ); err != nil { + return gtserror.Newf("db error updating account stats: %w", err) + } + + return nil +} + +func (u *utilF) decrementFollowingCount( + ctx context.Context, + account *gtsmodel.Account, +) error { + // Lock on this account since we're changing stats. + unlock := u.state.AccountLocks.Lock(account.URI) + defer unlock() + + // Populate stats. + if account.Stats == nil { + if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil { + return gtserror.Newf("db error getting account stats: %w", err) + } + } + + // Update stats by decrementing + // following count by one. + // + // Clamp to 0 to avoid funny business. + *account.Stats.FollowingCount-- + if *account.Stats.FollowingCount < 0 { + *account.Stats.FollowingCount = 0 + } + if err := u.state.DB.UpdateAccountStats( + ctx, + account.Stats, + "following_count", + ); err != nil { + return gtserror.Newf("db error updating account stats: %w", err) + } + + return nil +} + +func (u *utilF) incrementFollowRequestsCount( + ctx context.Context, + account *gtsmodel.Account, +) error { + // Lock on this account since we're changing stats. + unlock := u.state.AccountLocks.Lock(account.URI) + defer unlock() + + // Populate stats. + if account.Stats == nil { + if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil { + return gtserror.Newf("db error getting account stats: %w", err) + } + } + + // Update stats by incrementing + // follow requests count by one. + *account.Stats.FollowRequestsCount++ + if err := u.state.DB.UpdateAccountStats( + ctx, + account.Stats, + "follow_requests_count", + ); err != nil { + return gtserror.Newf("db error updating account stats: %w", err) + } + + return nil +} + +func (u *utilF) decrementFollowRequestsCount( + ctx context.Context, + account *gtsmodel.Account, +) error { + // Lock on this account since we're changing stats. + unlock := u.state.AccountLocks.Lock(account.URI) + defer unlock() + + // Populate stats. + if account.Stats == nil { + if err := u.state.DB.PopulateAccountStats(ctx, account); err != nil { + return gtserror.Newf("db error getting account stats: %w", err) + } + } + + // Update stats by decrementing + // follow requests count by one. + // + // Clamp to 0 to avoid funny business. + *account.Stats.FollowRequestsCount-- + if *account.Stats.FollowRequestsCount < 0 { + *account.Stats.FollowRequestsCount = 0 + } + if err := u.state.DB.UpdateAccountStats( + ctx, + account.Stats, + "follow_requests_count", + ); err != nil { + return gtserror.Newf("db error updating account stats: %w", err) + } + + return nil +} diff --git a/internal/state/state.go b/internal/state/state.go index 6120515b9..f1eb5a9da 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -50,11 +50,12 @@ type State struct { // functions, and by the go-fed/activity library. FedLocks mutexes.MutexMap - // ClientLocks provides access to this state's - // mutex map of per URI client locks. - // - // Used during account migration actions. - ClientLocks mutexes.MutexMap + // AccountLocks provides access to this state's + // mutex map of per URI locks, intended for use + // when updating accounts, migrating, approving + // or rejecting an account, changing stats, + // pinned statuses, etc. + AccountLocks mutexes.MutexMap // Storage provides access to the storage driver. Storage *storage.Driver diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go index e3786a9ae..94f2bcda4 100644 --- a/internal/typeutils/internaltofrontend.go +++ b/internal/typeutils/internaltofrontend.go @@ -63,18 +63,22 @@ func toMastodonVersion(in string) string { // if something goes wrong. The returned application should be ready to serialize on an API level, and may have sensitive fields // (such as client id and client secret), so serve it only to an authorized user who should have permission to see it. func (c *Converter) AccountToAPIAccountSensitive(ctx context.Context, a *gtsmodel.Account) (*apimodel.Account, error) { - // we can build this sensitive account easily by first getting the public account.... + // We can build this sensitive account model + // by first getting the public account, and + // then adding the Source object to it. apiAccount, err := c.AccountToAPIAccountPublic(ctx, a) if err != nil { return nil, err } - // then adding the Source object to it... - - // check pending follow requests aimed at this account - frc, err := c.state.DB.CountAccountFollowRequests(ctx, a.ID) - if err != nil { - return nil, fmt.Errorf("error counting follow requests: %s", err) + // Ensure account stats populated. + if a.Stats == nil { + if err := c.state.DB.PopulateAccountStats(ctx, a); err != nil { + return nil, gtserror.Newf( + "error getting stats for account %s: %w", + a.ID, err, + ) + } } statusContentType := string(apimodel.StatusContentTypeDefault) @@ -89,7 +93,7 @@ func (c *Converter) AccountToAPIAccountSensitive(ctx context.Context, a *gtsmode StatusContentType: statusContentType, Note: a.NoteRaw, Fields: c.fieldsToAPIFields(a.FieldsRaw), - FollowRequestsCount: frc, + FollowRequestsCount: *a.Stats.FollowRequestsCount, AlsoKnownAsURIs: a.AlsoKnownAsURIs, } @@ -100,8 +104,22 @@ func (c *Converter) AccountToAPIAccountSensitive(ctx context.Context, a *gtsmode // if something goes wrong. The returned account should be ready to serialize on an API level, and may NOT have sensitive fields. // In other words, this is the public record that the server has of an account. func (c *Converter) AccountToAPIAccountPublic(ctx context.Context, a *gtsmodel.Account) (*apimodel.Account, error) { - if err := c.state.DB.PopulateAccount(ctx, a); err != nil { + // Populate account struct fields. + err := c.state.DB.PopulateAccount(ctx, a) + + switch { + case err == nil: + // No problem. + + case err != nil && a.Stats != nil: + // We have stats so that's + // *maybe* OK, try to continue. log.Errorf(ctx, "error(s) populating account, will continue: %s", err) + + default: + // There was an error and we don't + // have stats, we can't continue. + return nil, gtserror.Newf("account stats not populated, could not continue: %w", err) } // Basic account stats: @@ -110,30 +128,17 @@ func (c *Converter) AccountToAPIAccountPublic(ctx context.Context, a *gtsmodel.A // - Statuses count // - Last status time - followersCount, err := c.state.DB.CountAccountFollowers(ctx, a.ID) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, gtserror.Newf("error counting followers: %w", err) - } - - followingCount, err := c.state.DB.CountAccountFollows(ctx, a.ID) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, gtserror.Newf("error counting following: %w", err) - } - - statusesCount, err := c.state.DB.CountAccountStatuses(ctx, a.ID) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, gtserror.Newf("error counting statuses: %w", err) - } - - var lastStatusAt *string - lastPosted, err := c.state.DB.GetAccountLastPosted(ctx, a.ID, false) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, gtserror.Newf("error getting last posted: %w", err) - } - - if !lastPosted.IsZero() { - lastStatusAt = util.Ptr(util.FormatISO8601(lastPosted)) - } + var ( + followersCount = *a.Stats.FollowersCount + followingCount = *a.Stats.FollowingCount + statusesCount = *a.Stats.StatusesCount + lastStatusAt = func() *string { + if a.Stats.LastStatusAt.IsZero() { + return nil + } + return util.Ptr(util.FormatISO8601(a.Stats.LastStatusAt)) + }() + ) // Profile media + nice extras: // - Avatar diff --git a/test/envparsing.sh b/test/envparsing.sh index a379750c0..11532b044 100755 --- a/test/envparsing.sh +++ b/test/envparsing.sh @@ -26,6 +26,7 @@ EXPECT=$(cat << "EOF" "account-mem-ratio": 5, "account-note-mem-ratio": 1, "account-settings-mem-ratio": 0.1, + "account-stats-mem-ratio": 2, "application-mem-ratio": 0.1, "block-mem-ratio": 3, "boost-of-ids-mem-ratio": 3,