diff --git a/internal/cache/cache.go b/internal/cache/cache.go index a278336ae..17fa03323 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -51,6 +51,7 @@ func (c *Caches) Init() { log.Infof(nil, "init: %p", c) c.initAccount() + c.initAccountCounts() c.initAccountNote() c.initApplication() c.initBlock() diff --git a/internal/cache/db.go b/internal/cache/db.go index 894d74109..275a25451 100644 --- a/internal/cache/db.go +++ b/internal/cache/db.go @@ -36,6 +36,13 @@ type GTSCaches struct { // AccountNote provides access to the gtsmodel Note database cache. AccountNote structr.Cache[*gtsmodel.AccountNote] + // TEMPORARY CACHE TO ALLEVIATE SLOW COUNT QUERIES, + // (in time will be removed when these IDs are cached). + AccountCounts *simple.Cache[string, struct { + Statuses int + Pinned int + }] + // Application provides access to the gtsmodel Application database cache. Application structr.Cache[*gtsmodel.Application] @@ -192,6 +199,22 @@ func (c *Caches) initAccount() { }) } +func (c *Caches) initAccountCounts() { + // Simply use size of accounts cache, + // as this cache will be very small. + cap := c.GTS.Account.Cap() + if cap == 0 { + panic("must be initialized before accounts") + } + + log.Infof(nil, "cache size = %d", cap) + + c.GTS.AccountCounts = simple.New[string, struct { + Statuses int + Pinned int + }](0, cap) +} + func (c *Caches) initAccountNote() { // Calculate maximum cache size. cap := calculateResultCacheMax( diff --git a/internal/cache/invalidate.go b/internal/cache/invalidate.go index d85c503da..e7dfa9e8a 100644 --- a/internal/cache/invalidate.go +++ b/internal/cache/invalidate.go @@ -27,6 +27,9 @@ import ( // HOOKS TO BE CALLED ON DELETE YOU MUST FIRST POPULATE IT IN THE CACHE. func (c *Caches) OnInvalidateAccount(account *gtsmodel.Account) { + // Invalidate status counts for this account. + c.GTS.AccountCounts.Invalidate(account.ID) + // Invalidate account ID cached visibility. c.Visibility.Invalidate("ItemID", account.ID) c.Visibility.Invalidate("RequesterID", account.ID) @@ -151,6 +154,9 @@ func (c *Caches) OnInvalidatePollVote(vote *gtsmodel.PollVote) { } func (c *Caches) OnInvalidateStatus(status *gtsmodel.Status) { + // Invalidate status counts for this account. + c.GTS.AccountCounts.Invalidate(status.AccountID) + // Invalidate status ID cached visibility. c.Visibility.Invalidate("ItemID", status.ID) diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 4b4c78726..e0d574f62 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -532,20 +532,56 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g } func (a *accountDB) CountAccountStatuses(ctx context.Context, accountID string) (int, error) { - return a.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). - Where("? = ?", bun.Ident("status.account_id"), accountID). - Count(ctx) + counts, err := a.getAccountStatusCounts(ctx, accountID) + return counts.Statuses, err } func (a *accountDB) CountAccountPinned(ctx context.Context, accountID string) (int, error) { - return a.db. - NewSelect(). - TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")). - Where("? = ?", bun.Ident("status.account_id"), accountID). - Where("? IS NOT NULL", bun.Ident("status.pinned_at")). - Count(ctx) + counts, err := a.getAccountStatusCounts(ctx, accountID) + return counts.Pinned, err +} + +func (a *accountDB) getAccountStatusCounts(ctx context.Context, accountID string) (struct { + Statuses int + Pinned int +}, error) { + // Check for an already cached copy of account status counts. + counts, ok := a.state.Caches.GTS.AccountCounts.Get(accountID) + if ok { + return counts, nil + } + + if err := a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + var err error + + // Scan database for account statuses. + counts.Statuses, err = tx.NewSelect(). + Table("statuses"). + Where("? = ?", bun.Ident("account_id"), accountID). + Count(ctx) + if err != nil { + return err + } + + // Scan database for pinned statuses. + counts.Pinned, err = tx.NewSelect(). + Table("statuses"). + Where("? = ?", bun.Ident("account_id"), accountID). + Where("? IS NOT NULL", bun.Ident("pinned_at")). + Count(ctx) + if err != nil { + return err + } + + return nil + }); err != nil { + return counts, err + } + + // Store this account counts result in the cache. + a.state.Caches.GTS.AccountCounts.Set(accountID, counts) + + return counts, nil } func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, mediaOnly bool, publicOnly bool) ([]*gtsmodel.Status, error) { diff --git a/internal/db/bundb/drivers.go b/internal/db/bundb/drivers.go index 14d84e6fa..a70b598d2 100644 --- a/internal/db/bundb/drivers.go +++ b/internal/db/bundb/drivers.go @@ -36,14 +36,14 @@ var ( sqliteDriver = getSQLiteDriver() ) +//go:linkname getSQLiteDriver modernc.org/sqlite.newDriver +func getSQLiteDriver() *sqlite.Driver + func init() { sql.Register("pgx-gts", &PostgreSQLDriver{}) sql.Register("sqlite-gts", &SQLiteDriver{}) } -//go:linkname getSQLiteDriver modernc.org/sqlite.newDriver -func getSQLiteDriver() *sqlite.Driver - // PostgreSQLDriver is our own wrapper around the // pgx/stdlib.Driver{} type in order to wrap further // SQL driver types with our own err processing. @@ -66,7 +66,10 @@ func (c *PostgreSQLConn) Begin() (driver.Tx, error) { func (c *PostgreSQLConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { tx, err := c.conn.BeginTx(ctx, opts) err = processPostgresError(err) - return tx, err + if err != nil { + return nil, err + } + return &PostgreSQLTx{tx}, nil } func (c *PostgreSQLConn) Prepare(query string) (driver.Stmt, error) { @@ -74,13 +77,16 @@ func (c *PostgreSQLConn) Prepare(query string) (driver.Stmt, error) { } func (c *PostgreSQLConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - stmt, err := c.conn.PrepareContext(ctx, query) + st, err := c.conn.PrepareContext(ctx, query) err = processPostgresError(err) - return stmt, err + if err != nil { + return nil, err + } + return &PostgreSQLStmt{stmt: st.(stmt)}, nil } -func (c *PostgreSQLConn) Exec(query string, args []driver.NamedValue) (driver.Result, error) { - return c.ExecContext(context.Background(), query, args) +func (c *PostgreSQLConn) Exec(query string, args []driver.Value) (driver.Result, error) { + return c.ExecContext(context.Background(), query, toNamedValues(args)) } func (c *PostgreSQLConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { @@ -89,8 +95,8 @@ func (c *PostgreSQLConn) ExecContext(ctx context.Context, query string, args []d return result, err } -func (c *PostgreSQLConn) Query(query string, args []driver.NamedValue) (driver.Rows, error) { - return c.QueryContext(context.Background(), query, args) +func (c *PostgreSQLConn) Query(query string, args []driver.Value) (driver.Rows, error) { + return c.QueryContext(context.Background(), query, toNamedValues(args)) } func (c *PostgreSQLConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { @@ -115,6 +121,28 @@ func (tx *PostgreSQLTx) Rollback() error { return processPostgresError(err) } +type PostgreSQLStmt struct{ stmt } + +func (stmt *PostgreSQLStmt) Exec(args []driver.Value) (driver.Result, error) { + return stmt.ExecContext(context.Background(), toNamedValues(args)) +} + +func (stmt *PostgreSQLStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + res, err := stmt.stmt.ExecContext(ctx, args) + err = processSQLiteError(err) + return res, err +} + +func (stmt *PostgreSQLStmt) Query(args []driver.Value) (driver.Rows, error) { + return stmt.QueryContext(context.Background(), toNamedValues(args)) +} + +func (stmt *PostgreSQLStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + rows, err := stmt.stmt.QueryContext(ctx, args) + err = processSQLiteError(err) + return rows, err +} + // SQLiteDriver is our own wrapper around the // sqlite.Driver{} type in order to wrap further // SQL driver types with our own functionality, @@ -141,6 +169,9 @@ func (c *SQLiteConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx dri err = processSQLiteError(err) return err }) + if err != nil { + return nil, err + } return &SQLiteTx{Context: ctx, Tx: tx}, nil } @@ -148,17 +179,20 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { return c.PrepareContext(context.Background(), query) } -func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) { +func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (st driver.Stmt, err error) { err = retryOnBusy(ctx, func() error { - stmt, err = c.conn.PrepareContext(ctx, query) + st, err = c.conn.PrepareContext(ctx, query) err = processSQLiteError(err) return err }) - return + if err != nil { + return nil, err + } + return &SQLiteStmt{st.(stmt)}, nil } -func (c *SQLiteConn) Exec(query string, args []driver.NamedValue) (driver.Result, error) { - return c.ExecContext(context.Background(), query, args) +func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) { + return c.ExecContext(context.Background(), query, toNamedValues(args)) } func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { @@ -170,8 +204,8 @@ func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []drive return } -func (c *SQLiteConn) Query(query string, args []driver.NamedValue) (driver.Rows, error) { - return c.QueryContext(context.Background(), query, args) +func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) { + return c.QueryContext(context.Background(), query, toNamedValues(args)) } func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { @@ -213,29 +247,64 @@ func (tx *SQLiteTx) Rollback() (err error) { return } +type SQLiteStmt struct{ stmt } + +func (stmt *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { + return stmt.ExecContext(context.Background(), toNamedValues(args)) +} + +func (stmt *SQLiteStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) { + err = retryOnBusy(ctx, func() error { + res, err = stmt.stmt.ExecContext(ctx, args) + err = processSQLiteError(err) + return err + }) + return +} + +func (stmt *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) { + return stmt.QueryContext(context.Background(), toNamedValues(args)) +} + +func (stmt *SQLiteStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) { + err = retryOnBusy(ctx, func() error { + rows, err = stmt.stmt.QueryContext(ctx, args) + err = processSQLiteError(err) + return err + }) + return +} + type conn interface { driver.Conn driver.ConnPrepareContext + driver.Execer //nolint:staticcheck driver.ExecerContext + driver.Queryer //nolint:staticcheck driver.QueryerContext driver.ConnBeginTx } +type stmt interface { + driver.Stmt + driver.StmtExecContext + driver.StmtQueryContext +} + // retryOnBusy will retry given function on returned 'errBusy'. func retryOnBusy(ctx context.Context, fn func() error) error { + if err := fn(); err != errBusy { + return err + } + return retryOnBusySlow(ctx, fn) +} + +// retryOnBusySlow is the outlined form of retryOnBusy, to allow the fast path (i.e. only +// 1 attempt) to be inlined, leaving the slow retry loop to be a separate function call. +func retryOnBusySlow(ctx context.Context, fn func() error) error { var backoff time.Duration for i := 0; ; i++ { - // Perform func. - err := fn() - - if err != errBusy { - // May be nil, or may be - // some other error, either - // way return here. - return err - } - // backoff according to a multiplier of 2ms * 2^2n, // up to a maximum possible backoff time of 5 minutes. // @@ -257,11 +326,37 @@ func retryOnBusy(ctx context.Context, fn func() error) error { select { // Context cancelled. case <-ctx.Done(): + return ctx.Err() // Backoff for some time. case <-time.After(backoff): } + + // Perform func. + err := fn() + + if err != errBusy { + // May be nil, or may be + // some other error, either + // way return here. + return err + } } return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff) } + +// toNamedValues converts older driver.Value types to driver.NamedValue types. +func toNamedValues(args []driver.Value) []driver.NamedValue { + if args == nil { + return nil + } + args2 := make([]driver.NamedValue, len(args)) + for i := range args { + args2[i] = driver.NamedValue{ + Ordinal: i + 1, + Value: args[i], + } + } + return args2 +}