diff --git a/cmd/gotosocial/action/admin/account/account.go b/cmd/gotosocial/action/admin/account/account.go index 9bb5b27c1..612f10cc8 100644 --- a/cmd/gotosocial/action/admin/account/account.go +++ b/cmd/gotosocial/action/admin/account/account.go @@ -28,6 +28,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/validate" "golang.org/x/crypto/bcrypt" @@ -49,15 +50,11 @@ func initState(ctx context.Context) (*state.State, error) { return &state, nil } -func stopState(ctx context.Context, state *state.State) error { - if err := state.DB.Stop(ctx); err != nil { - return fmt.Errorf("error stopping dbConn: %w", err) - } - +func stopState(state *state.State) error { + err := state.DB.Close() state.Workers.Stop() state.Caches.Stop() - - return nil + return err } // Create creates a new account and user @@ -68,6 +65,13 @@ var Create action.GTSAction = func(ctx context.Context) error { return err } + defer func() { + // Ensure state gets stopped on return. + if err := stopState(state); err != nil { + log.Error(ctx, err) + } + }() + username := config.GetAdminAccountUsername() if err := validate.Username(username); err != nil { return err @@ -101,17 +105,14 @@ var Create action.GTSAction = func(ctx context.Context) error { return err } - if _, err := state.DB.NewSignup(ctx, gtsmodel.NewSignup{ + _, err = state.DB.NewSignup(ctx, gtsmodel.NewSignup{ Username: username, Email: email, Password: password, EmailVerified: true, // Assume cli user wants email marked as verified already. PreApproved: true, // Assume cli user wants account marked as approved already. - }); err != nil { - return err - } - - return stopState(ctx, state) + }) + return err } // List returns all existing local accounts. @@ -148,8 +149,7 @@ var List action.GTSAction = func(ctx context.Context) error { for _, u := range users { fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\t%s\n", u.Account.Username, u.AccountID, fmtBool(u.Approved), fmtBool(u.Admin), fmtBool(u.Moderator), fmtDate(u.Account.SuspendedAt), fmtDate(u.ConfirmedAt)) } - w.Flush() - return nil + return w.Flush() } // Confirm sets a user to Approved, sets Email to the current @@ -160,6 +160,13 @@ var Confirm action.GTSAction = func(ctx context.Context) error { return err } + defer func() { + // Ensure state gets stopped on return. + if err := stopState(state); err != nil { + log.Error(ctx, err) + } + }() + username := config.GetAdminAccountUsername() if err := validate.Username(username); err != nil { return err @@ -178,14 +185,10 @@ var Confirm action.GTSAction = func(ctx context.Context) error { user.Approved = func() *bool { a := true; return &a }() user.Email = user.UnconfirmedEmail user.ConfirmedAt = time.Now() - if err := state.DB.UpdateUser( + return state.DB.UpdateUser( ctx, user, "approved", "email", "confirmed_at", - ); err != nil { - return err - } - - return stopState(ctx, state) + ) } // Promote sets admin + moderator flags on a user to true. @@ -195,6 +198,13 @@ var Promote action.GTSAction = func(ctx context.Context) error { return err } + defer func() { + // Ensure state gets stopped on return. + if err := stopState(state); err != nil { + log.Error(ctx, err) + } + }() + username := config.GetAdminAccountUsername() if err := validate.Username(username); err != nil { return err @@ -212,14 +222,10 @@ var Promote action.GTSAction = func(ctx context.Context) error { user.Admin = func() *bool { a := true; return &a }() user.Moderator = func() *bool { a := true; return &a }() - if err := state.DB.UpdateUser( + return state.DB.UpdateUser( ctx, user, "admin", "moderator", - ); err != nil { - return err - } - - return stopState(ctx, state) + ) } // Demote sets admin + moderator flags on a user to false. @@ -229,6 +235,13 @@ var Demote action.GTSAction = func(ctx context.Context) error { return err } + defer func() { + // Ensure state gets stopped on return. + if err := stopState(state); err != nil { + log.Error(ctx, err) + } + }() + username := config.GetAdminAccountUsername() if err := validate.Username(username); err != nil { return err @@ -246,14 +259,10 @@ var Demote action.GTSAction = func(ctx context.Context) error { user.Admin = func() *bool { a := false; return &a }() user.Moderator = func() *bool { a := false; return &a }() - if err := state.DB.UpdateUser( + return state.DB.UpdateUser( ctx, user, "admin", "moderator", - ); err != nil { - return err - } - - return stopState(ctx, state) + ) } // Disable sets Disabled to true on a user. @@ -263,6 +272,13 @@ var Disable action.GTSAction = func(ctx context.Context) error { return err } + defer func() { + // Ensure state gets stopped on return. + if err := stopState(state); err != nil { + log.Error(ctx, err) + } + }() + username := config.GetAdminAccountUsername() if err := validate.Username(username); err != nil { return err @@ -279,14 +295,10 @@ var Disable action.GTSAction = func(ctx context.Context) error { } user.Disabled = func() *bool { d := true; return &d }() - if err := state.DB.UpdateUser( + return state.DB.UpdateUser( ctx, user, "disabled", - ); err != nil { - return err - } - - return stopState(ctx, state) + ) } // Password sets the password of target account. @@ -296,6 +308,13 @@ var Password action.GTSAction = func(ctx context.Context) error { return err } + defer func() { + // Ensure state gets stopped on return. + if err := stopState(state); err != nil { + log.Error(ctx, err) + } + }() + username := config.GetAdminAccountUsername() if err := validate.Username(username); err != nil { return err @@ -322,12 +341,8 @@ var Password action.GTSAction = func(ctx context.Context) error { } user.EncryptedPassword = string(encryptedPassword) - if err := state.DB.UpdateUser( + return state.DB.UpdateUser( ctx, user, "encrypted_password", - ); err != nil { - return err - } - - return stopState(ctx, state) + ) } diff --git a/cmd/gotosocial/action/admin/media/list.go b/cmd/gotosocial/action/admin/media/list.go index e66019ecc..5b6108b11 100644 --- a/cmd/gotosocial/action/admin/media/list.go +++ b/cmd/gotosocial/action/admin/media/list.go @@ -95,12 +95,11 @@ func setupList(ctx context.Context) (*list, error) { }, nil } -func (l *list) shutdown(ctx context.Context) error { +func (l *list) shutdown() error { l.out.Flush() - err := l.dbService.Stop(ctx) + err := l.dbService.Close() l.state.Workers.Stop() l.state.Caches.Stop() - return err } @@ -112,7 +111,7 @@ var ListLocal action.GTSAction = func(ctx context.Context) error { defer func() { // Ensure lister gets shutdown on exit. - if err := list.shutdown(ctx); err != nil { + if err := list.shutdown(); err != nil { log.Error(ctx, err) } }() @@ -144,7 +143,7 @@ var ListRemote action.GTSAction = func(ctx context.Context) error { defer func() { // Ensure lister gets shutdown on exit. - if err := list.shutdown(ctx); err != nil { + if err := list.shutdown(); err != nil { log.Error(ctx, err) } }() diff --git a/cmd/gotosocial/action/admin/media/prune/all.go b/cmd/gotosocial/action/admin/media/prune/all.go index 90c08c7db..b334feb6d 100644 --- a/cmd/gotosocial/action/admin/media/prune/all.go +++ b/cmd/gotosocial/action/admin/media/prune/all.go @@ -36,7 +36,7 @@ var All action.GTSAction = func(ctx context.Context) error { defer func() { // Ensure pruner gets shutdown on exit. - if err := prune.shutdown(ctx); err != nil { + if err := prune.shutdown(); err != nil { log.Error(ctx, err) } }() diff --git a/cmd/gotosocial/action/admin/media/prune/common.go b/cmd/gotosocial/action/admin/media/prune/common.go index ad721675e..ed272984b 100644 --- a/cmd/gotosocial/action/admin/media/prune/common.go +++ b/cmd/gotosocial/action/admin/media/prune/common.go @@ -74,14 +74,14 @@ func setupPrune(ctx context.Context) (*prune, error) { }, nil } -func (p *prune) shutdown(ctx context.Context) error { +func (p *prune) shutdown() error { errs := gtserror.NewMultiError(2) if err := p.storage.Close(); err != nil { errs.Appendf("error closing storage backend: %w", err) } - if err := p.dbService.Stop(ctx); err != nil { + if err := p.dbService.Close(); err != nil { errs.Appendf("error stopping database: %w", err) } diff --git a/cmd/gotosocial/action/admin/media/prune/orphaned.go b/cmd/gotosocial/action/admin/media/prune/orphaned.go index a94c84422..e9cb27256 100644 --- a/cmd/gotosocial/action/admin/media/prune/orphaned.go +++ b/cmd/gotosocial/action/admin/media/prune/orphaned.go @@ -36,7 +36,7 @@ var Orphaned action.GTSAction = func(ctx context.Context) error { defer func() { // Ensure pruner gets shutdown on exit. - if err := prune.shutdown(ctx); err != nil { + if err := prune.shutdown(); err != nil { log.Error(ctx, err) } }() diff --git a/cmd/gotosocial/action/admin/media/prune/remote.go b/cmd/gotosocial/action/admin/media/prune/remote.go index ed521cfe8..5efa5602a 100644 --- a/cmd/gotosocial/action/admin/media/prune/remote.go +++ b/cmd/gotosocial/action/admin/media/prune/remote.go @@ -37,7 +37,7 @@ var Remote action.GTSAction = func(ctx context.Context) error { defer func() { // Ensure pruner gets shutdown on exit. - if err := prune.shutdown(ctx); err != nil { + if err := prune.shutdown(); err != nil { log.Error(ctx, err) } }() diff --git a/cmd/gotosocial/action/admin/trans/export.go b/cmd/gotosocial/action/admin/trans/export.go index 7b487561f..f76982a1b 100644 --- a/cmd/gotosocial/action/admin/trans/export.go +++ b/cmd/gotosocial/action/admin/trans/export.go @@ -52,5 +52,5 @@ var Export action.GTSAction = func(ctx context.Context) error { return err } - return dbConn.Stop(ctx) + return dbConn.Close() } diff --git a/cmd/gotosocial/action/admin/trans/import.go b/cmd/gotosocial/action/admin/trans/import.go index da426b41a..1ebf587ff 100644 --- a/cmd/gotosocial/action/admin/trans/import.go +++ b/cmd/gotosocial/action/admin/trans/import.go @@ -52,5 +52,5 @@ var Import action.GTSAction = func(ctx context.Context) error { return err } - return dbConn.Stop(ctx) + return dbConn.Close() } diff --git a/internal/api/client/media/mediacreate_test.go b/internal/api/client/media/mediacreate_test.go index 471be8557..2ffa8b8a1 100644 --- a/internal/api/client/media/mediacreate_test.go +++ b/internal/api/client/media/mediacreate_test.go @@ -110,7 +110,7 @@ func (suite *MediaCreateTestSuite) SetupSuite() { } func (suite *MediaCreateTestSuite) TearDownSuite() { - if err := suite.db.Stop(context.Background()); err != nil { + if err := suite.db.Close(); err != nil { log.Panicf(nil, "error closing db connection: %s", err) } testrig.StopWorkers(&suite.state) diff --git a/internal/api/client/media/mediaupdate_test.go b/internal/api/client/media/mediaupdate_test.go index 1af3bcf06..8140f1acc 100644 --- a/internal/api/client/media/mediaupdate_test.go +++ b/internal/api/client/media/mediaupdate_test.go @@ -19,7 +19,6 @@ package media_test import ( "bytes" - "context" "encoding/json" "fmt" "io/ioutil" @@ -107,7 +106,7 @@ func (suite *MediaUpdateTestSuite) SetupSuite() { } func (suite *MediaUpdateTestSuite) TearDownSuite() { - if err := suite.db.Stop(context.Background()); err != nil { + if err := suite.db.Close(); err != nil { log.Panicf(nil, "error closing db connection: %s", err) } testrig.StopWorkers(&suite.state) diff --git a/internal/api/fileserver/fileserver_test.go b/internal/api/fileserver/fileserver_test.go index 709458b1a..e57b86082 100644 --- a/internal/api/fileserver/fileserver_test.go +++ b/internal/api/fileserver/fileserver_test.go @@ -18,8 +18,6 @@ package fileserver_test import ( - "context" - "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/fileserver" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -111,7 +109,7 @@ func (suite *FileserverTestSuite) SetupTest() { } func (suite *FileserverTestSuite) TearDownSuite() { - if err := suite.db.Stop(context.Background()); err != nil { + if err := suite.db.Close(); err != nil { log.Panicf(nil, "error closing db connection: %s", err) } testrig.StopWorkers(&suite.state) diff --git a/internal/db/basic.go b/internal/db/basic.go index f8c04c6b9..7cd690aef 100644 --- a/internal/db/basic.go +++ b/internal/db/basic.go @@ -33,9 +33,9 @@ type Basic interface { // For implementations that don't use tables, this can just return nil. DropTable(ctx context.Context, i interface{}) error - // Stop should stop and close the database connection cleanly, returning an error if this is not possible. + // Close should stop and close the database connection cleanly, returning an error if this is not possible. // If the database implementation doesn't need to be stopped, this can just return nil. - Stop(ctx context.Context) error + Close() error // IsHealthy should return nil if the database connection is healthy, or an error if not. IsHealthy(ctx context.Context) error diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 2d9a64454..c88edebbf 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -37,7 +37,7 @@ import ( ) type accountDB struct { - db *WrappedDB + db *DB state *state.State } @@ -229,7 +229,7 @@ func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func( // Not cached! Perform database query if err := dbQuery(&account); err != nil { - return nil, a.db.ProcessError(err) + return nil, err } return &account, nil @@ -415,7 +415,7 @@ func (a *accountDB) GetAccountLastPosted(ctx context.Context, accountID string, } if err := q.Scan(ctx, &createdAt); err != nil { - return time.Time{}, a.db.ProcessError(err) + return time.Time{}, err } return createdAt, nil } @@ -440,7 +440,7 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen NewInsert(). Model(mediaAttachment). Exec(ctx); err != nil { - return a.db.ProcessError(err) + return err } if _, err := a.db. @@ -449,7 +449,7 @@ func (a *accountDB) SetAccountHeaderOrAvatar(ctx context.Context, mediaAttachmen Set("? = ?", column, mediaAttachment.ID). Where("? = ?", bun.Ident("account.id"), accountID). Exec(ctx); err != nil { - return a.db.ProcessError(err) + return err } return nil @@ -474,7 +474,7 @@ func (a *accountDB) GetAccountsUsingEmoji(ctx context.Context, emojiID string) ( Column("account_id"). Where("? = ?", bun.Ident("emoji_id"), emojiID). Exec(ctx, &accountIDs); err != nil { - return nil, a.db.ProcessError(err) + return nil, err } // Convert account IDs into account objects. @@ -489,7 +489,7 @@ func (a *accountDB) GetAccountFaves(ctx context.Context, accountID string) ([]*g Model(faves). Where("? = ?", bun.Ident("status_fave.account_id"), accountID). Scan(ctx); err != nil { - return nil, a.db.ProcessError(err) + return nil, err } return *faves, nil @@ -601,7 +601,7 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li } if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, a.db.ProcessError(err) + return nil, err } // If we're paging up, we still want statuses @@ -628,7 +628,7 @@ func (a *accountDB) GetAccountPinnedStatuses(ctx context.Context, accountID stri Order("status.pinned_at DESC") if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, a.db.ProcessError(err) + return nil, err } return a.statusesFromIDs(ctx, statusIDs) @@ -676,7 +676,7 @@ func (a *accountDB) GetAccountWebStatuses(ctx context.Context, accountID string, q = q.Order("status.id DESC") if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, a.db.ProcessError(err) + return nil, err } return a.statusesFromIDs(ctx, statusIDs) diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index 723648a9e..8af08973c 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -45,7 +45,7 @@ import ( const rsaKeyBits = 2048 type adminDB struct { - db *WrappedDB + db *DB state *state.State } @@ -314,7 +314,7 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) error { _, err = insertQ.Exec(ctx) if err != nil { - return a.db.ProcessError(err) + return err } log.Infof(ctx, "created instance instance %s with id %s", host, i.ID) diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go index b53d2c0b0..f7328e275 100644 --- a/internal/db/bundb/application.go +++ b/internal/db/bundb/application.go @@ -26,7 +26,7 @@ import ( ) type applicationDB struct { - db *WrappedDB + db *DB state *state.State } @@ -58,7 +58,7 @@ func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQue // Not cached! Perform database query. if err := dbQuery(&app); err != nil { - return nil, a.db.ProcessError(err) + return nil, err } return &app, nil @@ -68,7 +68,7 @@ func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQue func (a *applicationDB) PutApplication(ctx context.Context, app *gtsmodel.Application) error { return a.state.Caches.GTS.Application().Store(app, func() error { _, err := a.db.NewInsert().Model(app).Exec(ctx) - return a.db.ProcessError(err) + return err }) } @@ -78,7 +78,7 @@ func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientI Table("applications"). Where("? = ?", bun.Ident("client_id"), clientID). Exec(ctx); err != nil { - return a.db.ProcessError(err) + return err } // NOTE about further side effects: diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go index 33d6c6cb5..eee2a12ef 100644 --- a/internal/db/bundb/basic.go +++ b/internal/db/bundb/basic.go @@ -28,12 +28,12 @@ import ( ) type basicDB struct { - db *WrappedDB + db *DB } func (b *basicDB) Put(ctx context.Context, i interface{}) error { _, err := b.db.NewInsert().Model(i).Exec(ctx) - return b.db.ProcessError(err) + return err } func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) error { @@ -43,7 +43,7 @@ func (b *basicDB) GetByID(ctx context.Context, id string, i interface{}) error { Where("id = ?", id) err := q.Scan(ctx) - return b.db.ProcessError(err) + return err } func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) error { @@ -56,7 +56,7 @@ func (b *basicDB) GetWhere(ctx context.Context, where []db.Where, i interface{}) selectWhere(q, where) err := q.Scan(ctx) - return b.db.ProcessError(err) + return err } func (b *basicDB) GetAll(ctx context.Context, i interface{}) error { @@ -65,7 +65,7 @@ func (b *basicDB) GetAll(ctx context.Context, i interface{}) error { Model(i) err := q.Scan(ctx) - return b.db.ProcessError(err) + return err } func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) error { @@ -75,7 +75,7 @@ func (b *basicDB) DeleteByID(ctx context.Context, id string, i interface{}) erro Where("id = ?", id) _, err := q.Exec(ctx) - return b.db.ProcessError(err) + return err } func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface{}) error { @@ -90,7 +90,7 @@ func (b *basicDB) DeleteWhere(ctx context.Context, where []db.Where, i interface deleteWhere(q, where) _, err := q.Exec(ctx) - return b.db.ProcessError(err) + return err } func (b *basicDB) UpdateByID(ctx context.Context, i interface{}, id string, columns ...string) error { @@ -101,7 +101,7 @@ func (b *basicDB) UpdateByID(ctx context.Context, i interface{}, id string, colu Where("? = ?", bun.Ident("id"), id) _, err := q.Exec(ctx) - return b.db.ProcessError(err) + return err } func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, value interface{}, i interface{}) error { @@ -112,7 +112,7 @@ func (b *basicDB) UpdateWhere(ctx context.Context, where []db.Where, key string, q = q.Set("? = ?", bun.Ident(key), value) _, err := q.Exec(ctx) - return b.db.ProcessError(err) + return err } func (b *basicDB) CreateTable(ctx context.Context, i interface{}) error { @@ -155,14 +155,14 @@ func (b *basicDB) CreateAllTables(ctx context.Context) error { func (b *basicDB) DropTable(ctx context.Context, i interface{}) error { _, err := b.db.NewDropTable().Model(i).IfExists().Exec(ctx) - return b.db.ProcessError(err) + return err } func (b *basicDB) IsHealthy(ctx context.Context) error { - return b.db.DB.PingContext(ctx) + return b.db.PingContext(ctx) } -func (b *basicDB) Stop(ctx context.Context) error { - log.Info(ctx, "closing db connection") - return b.db.DB.Close() +func (b *basicDB) Close() error { + log.Info(nil, "closing db connection") + return b.db.Close() } diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 26b31ff28..ad9053e6e 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -81,12 +81,12 @@ type DBService struct { db.Timeline db.User db.Tombstone - db *WrappedDB + db *DB } // GetDB returns the underlying database connection pool. // Should only be used in testing + exceptional circumstance. -func (dbService *DBService) DB() *WrappedDB { +func (dbService *DBService) DB() *DB { return dbService.db } @@ -114,7 +114,7 @@ func doMigration(ctx context.Context, db *bun.DB) error { // NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface. // Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection. func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { - var db *WrappedDB + var db *DB var err error t := strings.ToLower(config.GetDbType()) @@ -156,7 +156,7 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { // perform any pending database migrations: this includes // the very first 'migration' on startup which just creates // necessary tables - if err := doMigration(ctx, db.DB); err != nil { + if err := doMigration(ctx, db.bun); err != nil { return nil, fmt.Errorf("db migration error: %s", err) } @@ -258,7 +258,7 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { return ps, nil } -func pgConn(ctx context.Context) (*WrappedDB, error) { +func pgConn(ctx context.Context) (*DB, error) { opts, err := deriveBunDBPGOptions() //nolint:contextcheck if err != nil { return nil, fmt.Errorf("could not create bundb postgres options: %s", err) @@ -273,18 +273,18 @@ func pgConn(ctx context.Context) (*WrappedDB, error) { sqldb.SetMaxIdleConns(2) // assume default 2; if max idle is less than max open, it will be automatically adjusted sqldb.SetConnMaxLifetime(5 * time.Minute) // fine to kill old connections - conn := WrapDB(bun.NewDB(sqldb, pgdialect.New())) + db := WrapDB(bun.NewDB(sqldb, pgdialect.New())) // ping to check the db is there and listening - if err := conn.DB.PingContext(ctx); err != nil { + if err := db.PingContext(ctx); err != nil { return nil, fmt.Errorf("postgres ping: %s", err) } log.Info(ctx, "connected to POSTGRES database") - return conn, nil + return db, nil } -func sqliteConn(ctx context.Context) (*WrappedDB, error) { +func sqliteConn(ctx context.Context) (*DB, error) { // validate db address has actually been set address := config.GetDbAddress() if address == "" { @@ -345,10 +345,10 @@ func sqliteConn(ctx context.Context) (*WrappedDB, error) { sqldb.SetConnMaxLifetime(0) // don't kill connections due to age // Wrap Bun database conn in our own wrapper - conn := WrapDB(bun.NewDB(sqldb, sqlitedialect.New())) + db := WrapDB(bun.NewDB(sqldb, sqlitedialect.New())) // ping to check the db is there and listening - if err := conn.DB.PingContext(ctx); err != nil { + if err := db.PingContext(ctx); err != nil { if errWithCode, ok := err.(*sqlite.Error); ok { err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()]) } @@ -356,7 +356,7 @@ func sqliteConn(ctx context.Context) (*WrappedDB, error) { } log.Infof(ctx, "connected to SQLITE database with address %s", address) - return conn, nil + return db, nil } /* @@ -459,7 +459,7 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) { // sqlitePragmas sets desired sqlite pragmas based on configured values, and // logs the results of the pragma queries. Errors if something goes wrong. -func sqlitePragmas(ctx context.Context, db *WrappedDB) error { +func sqlitePragmas(ctx context.Context, db *DB) error { var pragmas [][]string if mode := config.GetDbSqliteJournalMode(); mode != "" { // Set the user provided SQLite journal mode diff --git a/internal/db/bundb/db.go b/internal/db/bundb/db.go new file mode 100644 index 000000000..9b6edcefe --- /dev/null +++ b/internal/db/bundb/db.go @@ -0,0 +1,354 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package bundb + +import ( + "context" + "database/sql" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/schema" +) + +// DB wraps a bun database instance +// to provide common per-dialect SQL error +// conversions to common types, and retries +// on returned busy (SQLite only). +type DB struct { + // our own wrapped db type + // with retry backoff support. + // kept separate to the *bun.DB + // type to be passed into query + // builders as bun.IConn iface + // (this prevents double firing + // bun query hooks). + // + // also holds per-dialect + // error hook function. + raw rawdb + + // bun DB interface we use + // for dialects, and improved + // struct marshal/unmarshaling. + bun *bun.DB +} + +// WrapDB wraps a bun database instance in our database type. +func WrapDB(db *bun.DB) *DB { + var errProc func(error) error + switch name := db.Dialect().Name(); name { + case dialect.PG: + errProc = processPostgresError + case dialect.SQLite: + errProc = processSQLiteError + default: + panic("unknown dialect name: " + name.String()) + } + return &DB{ + raw: rawdb{ + errHook: errProc, + DB: db.DB, + }, + bun: db, + } +} + +// Dialect is a direct call-through to bun.DB.Dialect(). +func (db *DB) Dialect() schema.Dialect { return db.bun.Dialect() } + +// AddQueryHook is a direct call-through to bun.DB.AddQueryHook(). +func (db *DB) AddQueryHook(hook bun.QueryHook) { db.bun.AddQueryHook(hook) } + +// RegisterModels is a direct call-through to bun.DB.RegisterModels(). +func (db *DB) RegisterModel(models ...any) { db.bun.RegisterModel(models...) } + +// PingContext is a direct call-through to bun.DB.PingContext(). +func (db *DB) PingContext(ctx context.Context) error { return db.bun.PingContext(ctx) } + +// Close is a direct call-through to bun.DB.Close(). +func (db *DB) Close() error { return db.bun.Close() } + +// BeginTx wraps bun.DB.BeginTx() with retry-busy timeout. +func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (tx bun.Tx, err error) { + bundb := db.bun // use *bun.DB interface to return bun.Tx type + err = retryOnBusy(ctx, func() error { + tx, err = bundb.BeginTx(ctx, opts) + err = db.raw.errHook(err) + return err + }) + return +} + +// ExecContext wraps bun.DB.ExecContext() with retry-busy timeout. +func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) { + bundb := db.bun // use underlying *bun.DB interface for their query formatting + err = retryOnBusy(ctx, func() error { + result, err = bundb.ExecContext(ctx, query, args...) + err = db.raw.errHook(err) + return err + }) + return +} + +// QueryContext wraps bun.DB.ExecContext() with retry-busy timeout. +func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) { + bundb := db.bun // use underlying *bun.DB interface for their query formatting + err = retryOnBusy(ctx, func() error { + rows, err = bundb.QueryContext(ctx, query, args...) + err = db.raw.errHook(err) + return err + }) + return +} + +// QueryRowContext wraps bun.DB.ExecContext() with retry-busy timeout. +func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) (row *sql.Row) { + bundb := db.bun // use underlying *bun.DB interface for their query formatting + _ = retryOnBusy(ctx, func() error { + row = bundb.QueryRowContext(ctx, query, args...) + err := db.raw.errHook(row.Err()) + return err + }) + return +} + +// RunInTx is functionally the same as bun.DB.RunInTx() but with retry-busy timeouts. +func (db *DB) RunInTx(ctx context.Context, fn func(bun.Tx) error) error { + // Attempt to start new transaction. + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + + var done bool + + defer func() { + if !done { + // Rollback (with retry-backoff). + _ = retryOnBusy(ctx, func() error { + err := tx.Rollback() + return db.raw.errHook(err) + }) + } + }() + + // Perform supplied transaction + if err := fn(tx); err != nil { + return db.raw.errHook(err) + } + + // Commit (with retry-backoff). + err = retryOnBusy(ctx, func() error { + err := tx.Commit() + return db.raw.errHook(err) + }) + done = true + return err +} + +func (db *DB) NewValues(model interface{}) *bun.ValuesQuery { + // note: passing in rawdb as conn iface so no double query-hook + // firing when passed through the bun.DB.Query___() functions. + return bun.NewValuesQuery(db.bun, model).Conn(&db.raw) +} + +func (db *DB) NewMerge() *bun.MergeQuery { + // note: passing in rawdb as conn iface so no double query-hook + // firing when passed through the bun.DB.Query___() functions. + return bun.NewMergeQuery(db.bun).Conn(&db.raw) +} + +func (db *DB) NewSelect() *bun.SelectQuery { + // note: passing in rawdb as conn iface so no double query-hook + // firing when passed through the bun.DB.Query___() functions. + return bun.NewSelectQuery(db.bun).Conn(&db.raw) +} + +func (db *DB) NewInsert() *bun.InsertQuery { + // note: passing in rawdb as conn iface so no double query-hook + // firing when passed through the bun.DB.Query___() functions. + return bun.NewInsertQuery(db.bun).Conn(&db.raw) +} + +func (db *DB) NewUpdate() *bun.UpdateQuery { + // note: passing in rawdb as conn iface so no double query-hook + // firing when passed through the bun.DB.Query___() functions. + return bun.NewUpdateQuery(db.bun).Conn(&db.raw) +} + +func (db *DB) NewDelete() *bun.DeleteQuery { + // note: passing in rawdb as conn iface so no double query-hook + // firing when passed through the bun.DB.Query___() functions. + return bun.NewDeleteQuery(db.bun).Conn(&db.raw) +} + +func (db *DB) NewRaw(query string, args ...interface{}) *bun.RawQuery { + // note: passing in rawdb as conn iface so no double query-hook + // firing when passed through the bun.DB.Query___() functions. + return bun.NewRawQuery(db.bun, query, args...).Conn(&db.raw) +} + +func (db *DB) NewCreateTable() *bun.CreateTableQuery { + // note: passing in rawdb as conn iface so no double query-hook + // firing when passed through the bun.DB.Query___() functions. + return bun.NewCreateTableQuery(db.bun).Conn(&db.raw) +} + +func (db *DB) NewDropTable() *bun.DropTableQuery { + // note: passing in rawdb as conn iface so no double query-hook + // firing when passed through the bun.DB.Query___() functions. + return bun.NewDropTableQuery(db.bun).Conn(&db.raw) +} + +func (db *DB) NewCreateIndex() *bun.CreateIndexQuery { + // note: passing in rawdb as conn iface so no double query-hook + // firing when passed through the bun.DB.Query___() functions. + return bun.NewCreateIndexQuery(db.bun).Conn(&db.raw) +} + +func (db *DB) NewDropIndex() *bun.DropIndexQuery { + // note: passing in rawdb as conn iface so no double query-hook + // firing when passed through the bun.DB.Query___() functions. + return bun.NewDropIndexQuery(db.bun).Conn(&db.raw) +} + +func (db *DB) NewTruncateTable() *bun.TruncateTableQuery { + // note: passing in rawdb as conn iface so no double query-hook + // firing when passed through the bun.DB.Query___() functions. + return bun.NewTruncateTableQuery(db.bun).Conn(&db.raw) +} + +func (db *DB) NewAddColumn() *bun.AddColumnQuery { + // note: passing in rawdb as conn iface so no double query-hook + // firing when passed through the bun.DB.Query___() functions. + return bun.NewAddColumnQuery(db.bun).Conn(&db.raw) +} + +func (db *DB) NewDropColumn() *bun.DropColumnQuery { + // note: passing in rawdb as conn iface so no double query-hook + // firing when passed through the bun.DB.Query___() functions. + return bun.NewDropColumnQuery(db.bun).Conn(&db.raw) +} + +// Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors. +func (db *DB) Exists(ctx context.Context, query *bun.SelectQuery) (bool, error) { + exists, err := query.Exists(ctx) + switch err { + case nil: + return exists, nil + case sql.ErrNoRows: + return false, nil + default: + return false, err + } +} + +// NotExists checks the results of a SelectQuery for the non-existence of the data in question, masking ErrNoEntries errors. +func (db *DB) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, error) { + exists, err := db.Exists(ctx, query) + return !exists, err +} + +type rawdb struct { + // dialect specific error + // processing function hook. + errHook func(error) error + + // embedded raw + // db interface + *sql.DB +} + +// ExecContext wraps sql.DB.ExecContext() with retry-busy timeout. +func (db *rawdb) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) { + err = retryOnBusy(ctx, func() error { + result, err = db.DB.ExecContext(ctx, query, args...) + err = db.errHook(err) + return err + }) + return +} + +// QueryContext wraps sql.DB.QueryContext() with retry-busy timeout. +func (db *rawdb) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) { + err = retryOnBusy(ctx, func() error { + rows, err = db.DB.QueryContext(ctx, query, args...) + err = db.errHook(err) + return err + }) + return +} + +// QueryRowContext wraps sql.DB.QueryRowContext() with retry-busy timeout. +func (db *rawdb) QueryRowContext(ctx context.Context, query string, args ...any) (row *sql.Row) { + _ = retryOnBusy(ctx, func() error { + row = db.DB.QueryRowContext(ctx, query, args...) + err := db.errHook(row.Err()) + return err + }) + return +} + +// retryOnBusy will retry given function on returned 'errBusy'. +func retryOnBusy(ctx context.Context, fn func() error) error { + var backoff time.Duration + + for i := 0; ; i++ { + // Perform func. + err := fn() + + if err != errBusy { + // May be nil, or may be + // some other error, either + // way return here. + return err + } + + // backoff according to a multiplier of 2ms * 2^2n, + // up to a maximum possible backoff time of 5 minutes. + // + // this works out as the following: + // 4ms + // 16ms + // 64ms + // 256ms + // 1.024s + // 4.096s + // 16.384s + // 1m5.536s + // 4m22.144s + backoff = 2 * time.Millisecond * (1 << (2*i + 1)) + if backoff >= 5*time.Minute { + break + } + + select { + // Context cancelled. + case <-ctx.Done(): + + // Backoff for some time. + case <-time.After(backoff): + } + } + + return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff) +} diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 07e1e9fca..c989d4fe4 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -30,7 +30,7 @@ import ( ) type domainDB struct { - db *WrappedDB + db *DB state *state.State } @@ -46,7 +46,7 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain if _, err := d.db.NewInsert(). Model(block). Exec(ctx); err != nil { - return d.db.ProcessError(err) + return err } // Clear the domain block cache (for later reload) @@ -76,7 +76,7 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel Model(&block). Where("? = ?", bun.Ident("domain_block.domain"), domain) if err := q.Scan(ctx); err != nil { - return nil, d.db.ProcessError(err) + return nil, err } return &block, nil @@ -89,7 +89,7 @@ func (d *domainDB) GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock NewSelect(). Model(&blocks). Scan(ctx); err != nil { - return nil, d.db.ProcessError(err) + return nil, err } return blocks, nil @@ -103,7 +103,7 @@ func (d *domainDB) GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel Model(&block). Where("? = ?", bun.Ident("domain_block.id"), id) if err := q.Scan(ctx); err != nil { - return nil, d.db.ProcessError(err) + return nil, err } return &block, nil @@ -121,7 +121,7 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) error { Model((*gtsmodel.DomainBlock)(nil)). Where("? = ?", bun.Ident("domain_block.domain"), domain). Exec(ctx); err != nil { - return d.db.ProcessError(err) + return err } // Clear the domain block cache (for later reload) @@ -152,7 +152,7 @@ func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, er Table("domain_blocks"). Column("domain") if err := q.Scan(ctx, &domains); err != nil { - return nil, d.db.ProcessError(err) + return nil, err } return domains, nil diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go index e675339a2..2a3d91fe4 100644 --- a/internal/db/bundb/emoji.go +++ b/internal/db/bundb/emoji.go @@ -34,14 +34,14 @@ import ( ) type emojiDB struct { - db *WrappedDB + db *DB state *state.State } func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) error { return e.state.Caches.GTS.Emoji().Store(emoji, func() error { _, err := e.db.NewInsert().Model(emoji).Exec(ctx) - return e.db.ProcessError(err) + return err }) } @@ -60,7 +60,7 @@ func (e *emojiDB) UpdateEmoji(ctx context.Context, emoji *gtsmodel.Emoji, column Where("? = ?", bun.Ident("emoji.id"), emoji.ID). Column(columns...). Exec(ctx) - return e.db.ProcessError(err) + return err }) } @@ -294,7 +294,7 @@ func (e *emojiDB) GetEmojisBy(ctx context.Context, domain string, includeDisable Column("subquery.emoji_ids"). TableExpr("(?) AS ?", subQuery, bun.Ident("subquery")). Scan(ctx, &emojiIDs); err != nil { - return nil, e.db.ProcessError(err) + return nil, err } if order == "DESC" { @@ -328,7 +328,7 @@ func (e *emojiDB) GetEmojis(ctx context.Context, maxID string, limit int) ([]*gt } if err := q.Scan(ctx, &emojiIDs); err != nil { - return nil, e.db.ProcessError(err) + return nil, err } return e.GetEmojisByIDs(ctx, emojiIDs) @@ -352,7 +352,7 @@ func (e *emojiDB) GetRemoteEmojis(ctx context.Context, maxID string, limit int) } if err := q.Scan(ctx, &emojiIDs); err != nil { - return nil, e.db.ProcessError(err) + return nil, err } return e.GetEmojisByIDs(ctx, emojiIDs) @@ -374,7 +374,7 @@ func (e *emojiDB) GetCachedEmojisOlderThan(ctx context.Context, olderThan time.T } if err := q.Scan(ctx, &emojiIDs); err != nil { - return nil, e.db.ProcessError(err) + return nil, err } return e.GetEmojisByIDs(ctx, emojiIDs) @@ -393,7 +393,7 @@ func (e *emojiDB) GetUseableEmojis(ctx context.Context) ([]*gtsmodel.Emoji, erro Order("emoji.shortcode ASC") if err := q.Scan(ctx, &emojiIDs); err != nil { - return nil, e.db.ProcessError(err) + return nil, err } return e.GetEmojisByIDs(ctx, emojiIDs) @@ -469,7 +469,7 @@ func (e *emojiDB) GetEmojiByStaticURL(ctx context.Context, imageStaticURL string func (e *emojiDB) PutEmojiCategory(ctx context.Context, emojiCategory *gtsmodel.EmojiCategory) error { return e.state.Caches.GTS.EmojiCategory().Store(emojiCategory, func() error { _, err := e.db.NewInsert().Model(emojiCategory).Exec(ctx) - return e.db.ProcessError(err) + return err }) } @@ -483,7 +483,7 @@ func (e *emojiDB) GetEmojiCategories(ctx context.Context) ([]*gtsmodel.EmojiCate Order("emoji_category.name ASC") if err := q.Scan(ctx, &emojiCategoryIDs); err != nil { - return nil, e.db.ProcessError(err) + return nil, err } return e.GetEmojiCategoriesByIDs(ctx, emojiCategoryIDs) @@ -524,7 +524,7 @@ func (e *emojiDB) getEmoji(ctx context.Context, lookup string, dbQuery func(*gts // Not cached! Perform database query if err := dbQuery(&emoji); err != nil { - return nil, e.db.ProcessError(err) + return nil, err } return &emoji, nil @@ -574,7 +574,7 @@ func (e *emojiDB) getEmojiCategory(ctx context.Context, lookup string, dbQuery f // Not cached! Perform database query if err := dbQuery(&category); err != nil { - return nil, e.db.ProcessError(err) + return nil, err } return &category, nil diff --git a/internal/db/bundb/errors.go b/internal/db/bundb/errors.go index 6bec8edae..46735ca80 100644 --- a/internal/db/bundb/errors.go +++ b/internal/db/bundb/errors.go @@ -32,6 +32,11 @@ var errBusy = errors.New("busy") // processPostgresError processes an error, replacing any postgres specific errors with our own error type func processPostgresError(err error) error { + // Catch nil errs. + if err == nil { + return nil + } + // Attempt to cast as postgres pgErr, ok := err.(*pgconn.PgError) if !ok { @@ -50,6 +55,11 @@ func processPostgresError(err error) error { // processSQLiteError processes an error, replacing any sqlite specific errors with our own error type func processSQLiteError(err error) error { + // Catch nil errs. + if err == nil { + return nil + } + // Attempt to cast as sqlite sqliteErr, ok := err.(*sqlite.Error) if !ok { diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go index 09084642f..7f0e92634 100644 --- a/internal/db/bundb/instance.go +++ b/internal/db/bundb/instance.go @@ -34,7 +34,7 @@ import ( ) type instanceDB struct { - db *WrappedDB + db *DB state *state.State } @@ -56,7 +56,7 @@ func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int count, err := q.Count(ctx) if err != nil { - return 0, i.db.ProcessError(err) + return 0, err } return count, nil } @@ -78,7 +78,7 @@ func (i *instanceDB) CountInstanceStatuses(ctx context.Context, domain string) ( count, err := q.Count(ctx) if err != nil { - return 0, i.db.ProcessError(err) + return 0, err } return count, nil } @@ -101,7 +101,7 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i count, err := q.Count(ctx) if err != nil { - return 0, i.db.ProcessError(err) + return 0, err } return count, nil } @@ -148,7 +148,7 @@ func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery fun // Not cached! Perform database query. if err := dbQuery(&instance); err != nil { - return nil, i.db.ProcessError(err) + return nil, err } return &instance, nil @@ -211,7 +211,7 @@ func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instanc return i.state.Caches.GTS.Instance().Store(instance, func() error { _, err := i.db.NewInsert().Model(instance).Exec(ctx) - return i.db.ProcessError(err) + return err }) } @@ -236,7 +236,7 @@ func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Inst Where("? = ?", bun.Ident("instance.id"), instance.ID). Column(columns...). Exec(ctx) - return i.db.ProcessError(err) + return err }) } @@ -256,7 +256,7 @@ func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool } if err := q.Scan(ctx, &instanceIDs); err != nil { - return nil, i.db.ProcessError(err) + return nil, err } if len(instanceIDs) == 0 { @@ -315,7 +315,7 @@ func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, max } if err := q.Scan(ctx, &accountIDs); err != nil { - return nil, i.db.ProcessError(err) + return nil, err } // Catch case of no accounts early. @@ -361,7 +361,7 @@ func (i *instanceDB) GetInstanceModeratorAddresses(ctx context.Context) ([]strin OrderExpr("? ASC", bun.Ident("user.email")) if err := q.Scan(ctx, &addresses); err != nil { - return nil, i.db.ProcessError(err) + return nil, err } if len(addresses) == 0 { diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go index ec96f1dfc..23d9c13fb 100644 --- a/internal/db/bundb/list.go +++ b/internal/db/bundb/list.go @@ -33,7 +33,7 @@ import ( ) type listDB struct { - db *WrappedDB + db *DB state *state.State } @@ -61,7 +61,7 @@ func (l *listDB) getList(ctx context.Context, lookup string, dbQuery func(*gtsmo // Not cached! Perform database query. if err := dbQuery(&list); err != nil { - return nil, l.db.ProcessError(err) + return nil, err } return &list, nil @@ -93,7 +93,7 @@ func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([] Where("? = ?", bun.Ident("list.account_id"), accountID). Order("list.id DESC"). Scan(ctx, &listIDs); err != nil { - return nil, l.db.ProcessError(err) + return nil, err } if len(listIDs) == 0 { @@ -149,7 +149,7 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error { return l.state.Caches.GTS.List().Store(list, func() error { _, err := l.db.NewInsert().Model(list).Exec(ctx) - return l.db.ProcessError(err) + return err }) } @@ -176,7 +176,7 @@ func (l *listDB) UpdateList(ctx context.Context, list *gtsmodel.List, columns .. Where("? = ?", bun.Ident("list.id"), list.ID). Column(columns...). Exec(ctx) - return l.db.ProcessError(err) + return err }) } @@ -248,7 +248,7 @@ func (l *listDB) getListEntry(ctx context.Context, lookup string, dbQuery func(* // Not cached! Perform database query. if err := dbQuery(&listEntry); err != nil { - return nil, l.db.ProcessError(err) + return nil, err } return &listEntry, nil @@ -328,7 +328,7 @@ func (l *listDB) GetListEntries(ctx context.Context, } if err := q.Scan(ctx, &entryIDs); err != nil { - return nil, l.db.ProcessError(err) + return nil, err } if len(entryIDs) == 0 { @@ -369,7 +369,7 @@ func (l *listDB) GetListEntriesForFollowID(ctx context.Context, followID string) // Select only entries belonging with given followID. Where("? = ?", bun.Ident("entry.follow_id"), followID). Scan(ctx, &entryIDs); err != nil { - return nil, l.db.ProcessError(err) + return nil, err } if len(entryIDs) == 0 { @@ -486,7 +486,7 @@ func (l *listDB) DeleteListEntriesForFollowID(ctx context.Context, followID stri Where("? = ?", bun.Ident("follow_id"), followID). Order("id DESC"). Scan(ctx, &entryIDs); err != nil { - return l.db.ProcessError(err) + return err } for _, id := range entryIDs { @@ -512,7 +512,7 @@ func (l *listDB) ListIncludesAccount(ctx context.Context, listID string, account Where("? = ?", bun.Ident("follow.target_account_id"), accountID). Exists(ctx) - return exists, l.db.ProcessError(err) + return exists, err } // collate will collect the values of type T from an expected slice of length 'len', diff --git a/internal/db/bundb/marker.go b/internal/db/bundb/marker.go index 12526e659..861f7de36 100644 --- a/internal/db/bundb/marker.go +++ b/internal/db/bundb/marker.go @@ -30,7 +30,7 @@ import ( ) type markerDB struct { - db *WrappedDB + db *DB state *state.State } @@ -48,7 +48,7 @@ func (m *markerDB) GetMarker(ctx context.Context, accountID string, name gtsmode Model(&marker). Where("? = ? AND ? = ?", bun.Ident("account_id"), accountID, bun.Ident("name"), name). Scan(ctx); err != nil { - return nil, m.db.ProcessError(err) + return nil, err } return &marker, nil @@ -79,7 +79,7 @@ func (m *markerDB) UpdateMarker(ctx context.Context, marker *gtsmodel.Marker) er if _, err := m.db.NewInsert(). Model(marker). Exec(ctx); err != nil { - return m.db.ProcessError(err) + return err } return nil } @@ -94,12 +94,12 @@ func (m *markerDB) UpdateMarker(ctx context.Context, marker *gtsmodel.Marker) er Where("? = ?", bun.Ident("version"), prevMarker.Version). Exec(ctx) if err != nil { - return m.db.ProcessError(err) + return err } rowsAffected, err := result.RowsAffected() if err != nil { - return m.db.ProcessError(err) + return err } if rowsAffected == 0 { // Will trigger a rollback, although there should be no changes to roll back. diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go index b8120b87a..7f079cb34 100644 --- a/internal/db/bundb/media.go +++ b/internal/db/bundb/media.go @@ -32,7 +32,7 @@ import ( ) type mediaDB struct { - db *WrappedDB + db *DB state *state.State } @@ -74,7 +74,7 @@ func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func // Not cached! Perform database query if err := dbQuery(&attachment); err != nil { - return nil, m.db.ProcessError(err) + return nil, err } return &attachment, nil @@ -84,7 +84,7 @@ func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func func (m *mediaDB) PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error { return m.state.Caches.GTS.Media().Store(media, func() error { _, err := m.db.NewInsert().Model(media).Exec(ctx) - return m.db.ProcessError(err) + return err }) } @@ -101,7 +101,7 @@ func (m *mediaDB) UpdateAttachment(ctx context.Context, media *gtsmodel.MediaAtt Where("? = ?", bun.Ident("media_attachment.id"), media.ID). Column(columns...). Exec(ctx) - return m.db.ProcessError(err) + return err }) } @@ -197,7 +197,7 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { return nil }) - return m.db.ProcessError(err) + return err } func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, error) { @@ -211,7 +211,7 @@ func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) count, err := q.Count(ctx) if err != nil { - return 0, m.db.ProcessError(err) + return 0, err } return count, nil @@ -234,7 +234,7 @@ func (m *mediaDB) GetAttachments(ctx context.Context, maxID string, limit int) ( } if err := q.Scan(ctx, &attachmentIDs); err != nil { - return nil, m.db.ProcessError(err) + return nil, err } return m.GetAttachmentsByIDs(ctx, attachmentIDs) @@ -258,7 +258,7 @@ func (m *mediaDB) GetRemoteAttachments(ctx context.Context, maxID string, limit } if err := q.Scan(ctx, &attachmentIDs); err != nil { - return nil, m.db.ProcessError(err) + return nil, err } return m.GetAttachmentsByIDs(ctx, attachmentIDs) @@ -281,7 +281,7 @@ func (m *mediaDB) GetCachedAttachmentsOlderThan(ctx context.Context, olderThan t } if err := q.Scan(ctx, &attachmentIDs); err != nil { - return nil, m.db.ProcessError(err) + return nil, err } return m.GetAttachmentsByIDs(ctx, attachmentIDs) @@ -309,7 +309,7 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit } if err := q.Scan(ctx, &attachmentIDs); err != nil { - return nil, m.db.ProcessError(err) + return nil, err } return m.GetAttachmentsByIDs(ctx, attachmentIDs) @@ -335,7 +335,7 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim } if err := q.Scan(ctx, &attachmentIDs); err != nil { - return nil, m.db.ProcessError(err) + return nil, err } return m.GetAttachmentsByIDs(ctx, attachmentIDs) @@ -355,7 +355,7 @@ func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan t count, err := q.Count(ctx) if err != nil { - return 0, m.db.ProcessError(err) + return 0, err } return count, nil diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go index 12d71a95a..547d8d0a8 100644 --- a/internal/db/bundb/mention.go +++ b/internal/db/bundb/mention.go @@ -31,7 +31,7 @@ import ( ) type mentionDB struct { - db *WrappedDB + db *DB state *state.State } @@ -45,7 +45,7 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio Where("? = ?", bun.Ident("mention.id"), id) if err := q.Scan(ctx); err != nil { - return nil, m.db.ProcessError(err) + return nil, err } return &mention, nil @@ -105,7 +105,7 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel. func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error { return m.state.Caches.GTS.Mention().Store(mention, func() error { _, err := m.db.NewInsert().Model(mention).Exec(ctx) - return m.db.ProcessError(err) + return err }) } @@ -129,5 +129,5 @@ func (m *mentionDB) DeleteMentionByID(ctx context.Context, id string) error { Table("mentions"). Where("? = ?", bun.Ident("id"), id). Exec(ctx) - return m.db.ProcessError(err) + return err } diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go index b0757fb1e..423fd0be1 100644 --- a/internal/db/bundb/notification.go +++ b/internal/db/bundb/notification.go @@ -31,7 +31,7 @@ import ( ) type notificationDB struct { - db *WrappedDB + db *DB state *state.State } @@ -43,7 +43,7 @@ func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*g Model(¬if). Where("? = ?", bun.Ident("notification.id"), id) if err := q.Scan(ctx); err != nil { - return nil, n.db.ProcessError(err) + return nil, err } return ¬if, nil @@ -68,7 +68,7 @@ func (n *notificationDB) GetNotification( Where("? = ?", bun.Ident("status_id"), statusID) if err := q.Scan(ctx); err != nil { - return nil, n.db.ProcessError(err) + return nil, err } return ¬if, nil @@ -140,7 +140,7 @@ func (n *notificationDB) GetAccountNotifications( } if err := q.Scan(ctx, ¬ifIDs); err != nil { - return nil, n.db.ProcessError(err) + return nil, err } if len(notifIDs) == 0 { @@ -175,7 +175,7 @@ func (n *notificationDB) GetAccountNotifications( func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error { return n.state.Caches.GTS.Notification().Store(notif, func() error { _, err := n.db.NewInsert().Model(notif).Exec(ctx) - return n.db.ProcessError(err) + return err }) } @@ -199,7 +199,7 @@ func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). Where("? = ?", bun.Ident("notification.id"), id). Exec(ctx) - return n.db.ProcessError(err) + return err } func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) error { @@ -227,7 +227,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string } if _, err := q.Exec(ctx, ¬ifIDs); err != nil { - return n.db.ProcessError(err) + return err } defer func() { @@ -252,7 +252,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string Table("notifications"). Where("? IN (?)", bun.Ident("id"), bun.In(notifIDs)). Exec(ctx) - return n.db.ProcessError(err) + return err } func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) error { @@ -265,7 +265,7 @@ func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statu Where("? = ?", bun.Ident("status_id"), statusID) if _, err := q.Exec(ctx, ¬ifIDs); err != nil { - return n.db.ProcessError(err) + return err } defer func() { @@ -290,5 +290,5 @@ func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statu Table("notifications"). Where("? IN (?)", bun.Ident("id"), bun.In(notifIDs)). Exec(ctx) - return n.db.ProcessError(err) + return err } diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go index e7b563f2e..2f93b12ad 100644 --- a/internal/db/bundb/relationship.go +++ b/internal/db/bundb/relationship.go @@ -31,7 +31,7 @@ import ( ) type relationshipDB struct { - db *WrappedDB + db *DB state *state.State } @@ -158,7 +158,7 @@ func (r *relationshipDB) GetAccountBlocks(ctx context.Context, accountID string, // Block IDs not in cache, perform DB query! q := newSelectBlocks(r.db, accountID) if _, err := q.Exec(ctx, &blockIDs); err != nil { - return nil, r.db.ProcessError(err) + return nil, err } return blockIDs, nil @@ -208,7 +208,7 @@ func (r *relationshipDB) getAccountFollowIDs(ctx context.Context, accountID stri // Follow IDs not in cache, perform DB query! q := newSelectFollows(r.db, accountID) if _, err := q.Exec(ctx, &followIDs); err != nil { - return nil, r.db.ProcessError(err) + return nil, err } return followIDs, nil @@ -222,7 +222,7 @@ func (r *relationshipDB) getAccountLocalFollowIDs(ctx context.Context, accountID // Follow IDs not in cache, perform DB query! q := newSelectLocalFollows(r.db, accountID) if _, err := q.Exec(ctx, &followIDs); err != nil { - return nil, r.db.ProcessError(err) + return nil, err } return followIDs, nil @@ -236,7 +236,7 @@ func (r *relationshipDB) getAccountFollowerIDs(ctx context.Context, accountID st // Follow IDs not in cache, perform DB query! q := newSelectFollowers(r.db, accountID) if _, err := q.Exec(ctx, &followIDs); err != nil { - return nil, r.db.ProcessError(err) + return nil, err } return followIDs, nil @@ -250,7 +250,7 @@ func (r *relationshipDB) getAccountLocalFollowerIDs(ctx context.Context, account // Follow IDs not in cache, perform DB query! q := newSelectLocalFollowers(r.db, accountID) if _, err := q.Exec(ctx, &followIDs); err != nil { - return nil, r.db.ProcessError(err) + return nil, err } return followIDs, nil @@ -264,7 +264,7 @@ func (r *relationshipDB) getAccountFollowRequestIDs(ctx context.Context, account // Follow request IDs not in cache, perform DB query! q := newSelectFollowRequests(r.db, accountID) if _, err := q.Exec(ctx, &followReqIDs); err != nil { - return nil, r.db.ProcessError(err) + return nil, err } return followReqIDs, nil @@ -278,7 +278,7 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco // Follow request IDs not in cache, perform DB query! q := newSelectFollowRequesting(r.db, accountID) if _, err := q.Exec(ctx, &followReqIDs); err != nil { - return nil, r.db.ProcessError(err) + return nil, err } return followReqIDs, nil @@ -286,7 +286,7 @@ func (r *relationshipDB) getAccountFollowRequestingIDs(ctx context.Context, acco } // newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID. -func newSelectFollowRequests(db *WrappedDB, accountID string) *bun.SelectQuery { +func newSelectFollowRequests(db *DB, accountID string) *bun.SelectQuery { return db.NewSelect(). TableExpr("?", bun.Ident("follow_requests")). ColumnExpr("?", bun.Ident("id")). @@ -295,7 +295,7 @@ func newSelectFollowRequests(db *WrappedDB, accountID string) *bun.SelectQuery { } // newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID. -func newSelectFollowRequesting(db *WrappedDB, accountID string) *bun.SelectQuery { +func newSelectFollowRequesting(db *DB, accountID string) *bun.SelectQuery { return db.NewSelect(). TableExpr("?", bun.Ident("follow_requests")). ColumnExpr("?", bun.Ident("id")). @@ -304,7 +304,7 @@ func newSelectFollowRequesting(db *WrappedDB, accountID string) *bun.SelectQuery } // newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID. -func newSelectFollows(db *WrappedDB, accountID string) *bun.SelectQuery { +func newSelectFollows(db *DB, accountID string) *bun.SelectQuery { return db.NewSelect(). Table("follows"). Column("id"). @@ -314,7 +314,7 @@ func newSelectFollows(db *WrappedDB, accountID string) *bun.SelectQuery { // newSelectLocalFollows returns a new select query for all rows in the follows table with // account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local). -func newSelectLocalFollows(db *WrappedDB, accountID string) *bun.SelectQuery { +func newSelectLocalFollows(db *DB, accountID string) *bun.SelectQuery { return db.NewSelect(). Table("follows"). Column("id"). @@ -331,7 +331,7 @@ func newSelectLocalFollows(db *WrappedDB, accountID string) *bun.SelectQuery { } // newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID. -func newSelectFollowers(db *WrappedDB, accountID string) *bun.SelectQuery { +func newSelectFollowers(db *DB, accountID string) *bun.SelectQuery { return db.NewSelect(). Table("follows"). Column("id"). @@ -341,7 +341,7 @@ func newSelectFollowers(db *WrappedDB, accountID string) *bun.SelectQuery { // newSelectLocalFollowers returns a new select query for all rows in the follows table with // target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local). -func newSelectLocalFollowers(db *WrappedDB, accountID string) *bun.SelectQuery { +func newSelectLocalFollowers(db *DB, accountID string) *bun.SelectQuery { return db.NewSelect(). Table("follows"). Column("id"). @@ -358,7 +358,7 @@ func newSelectLocalFollowers(db *WrappedDB, accountID string) *bun.SelectQuery { } // newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID. -func newSelectBlocks(db *WrappedDB, accountID string) *bun.SelectQuery { +func newSelectBlocks(db *DB, accountID string) *bun.SelectQuery { return db.NewSelect(). TableExpr("?", bun.Ident("blocks")). ColumnExpr("?", bun.Ident("?")). diff --git a/internal/db/bundb/relationship_block.go b/internal/db/bundb/relationship_block.go index 33a3b85fa..efaa6d1a9 100644 --- a/internal/db/bundb/relationship_block.go +++ b/internal/db/bundb/relationship_block.go @@ -124,7 +124,7 @@ func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery fu // Not cached! Perform database query if err := dbQuery(&block); err != nil { - return nil, r.db.ProcessError(err) + return nil, err } return &block, nil @@ -180,7 +180,7 @@ func (r *relationshipDB) PopulateBlock(ctx context.Context, block *gtsmodel.Bloc func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error { return r.state.Caches.GTS.Block().Store(block, func() error { _, err := r.db.NewInsert().Model(block).Exec(ctx) - return r.db.ProcessError(err) + return err }) } @@ -205,7 +205,7 @@ func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error { Table("blocks"). Where("? = ?", bun.Ident("id"), id). Exec(ctx) - return r.db.ProcessError(err) + return err } func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error { @@ -229,7 +229,7 @@ func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error Table("blocks"). Where("? = ?", bun.Ident("uri"), uri). Exec(ctx) - return r.db.ProcessError(err) + return err } func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID string) error { @@ -246,7 +246,7 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri accountID, ). Scan(ctx, &blockIDs); err != nil { - return r.db.ProcessError(err) + return err } defer func() { @@ -270,5 +270,5 @@ func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID stri Table("blocks"). Where("? IN (?)", bun.Ident("id"), bun.In(blockIDs)). Exec(ctx) - return r.db.ProcessError(err) + return err } diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go index b693269df..6c5a75e4c 100644 --- a/internal/db/bundb/relationship_follow.go +++ b/internal/db/bundb/relationship_follow.go @@ -135,7 +135,7 @@ func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery f // Not cached! Perform database query if err := dbQuery(&follow); err != nil { - return nil, r.db.ProcessError(err) + return nil, err } return &follow, nil @@ -191,7 +191,7 @@ func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Fo func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error { return r.state.Caches.GTS.Follow().Store(follow, func() error { _, err := r.db.NewInsert().Model(follow).Exec(ctx) - return r.db.ProcessError(err) + return err }) } @@ -208,7 +208,7 @@ func (r *relationshipDB) UpdateFollow(ctx context.Context, follow *gtsmodel.Foll Where("? = ?", bun.Ident("follow.id"), follow.ID). Column(columns...). Exec(ctx); err != nil { - return r.db.ProcessError(err) + return err } return nil @@ -221,7 +221,7 @@ func (r *relationshipDB) deleteFollow(ctx context.Context, id string) error { Table("follows"). Where("? = ?", bun.Ident("id"), id). Exec(ctx); err != nil { - return r.db.ProcessError(err) + return err } // Delete every list entry that used this followID. @@ -311,7 +311,7 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str accountID, ). Exec(ctx, &followIDs); err != nil { - return r.db.ProcessError(err) + return err } defer func() { diff --git a/internal/db/bundb/relationship_follow_req.go b/internal/db/bundb/relationship_follow_req.go index cde9dc187..51aceafe1 100644 --- a/internal/db/bundb/relationship_follow_req.go +++ b/internal/db/bundb/relationship_follow_req.go @@ -112,7 +112,7 @@ func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, db // Not cached! Perform database query if err := dbQuery(&followReq); err != nil { - return nil, r.db.ProcessError(err) + return nil, err } return &followReq, nil @@ -168,7 +168,7 @@ func (r *relationshipDB) PopulateFollowRequest(ctx context.Context, follow *gtsm func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error { return r.state.Caches.GTS.FollowRequest().Store(follow, func() error { _, err := r.db.NewInsert().Model(follow).Exec(ctx) - return r.db.ProcessError(err) + return err }) } @@ -185,7 +185,7 @@ func (r *relationshipDB) UpdateFollowRequest(ctx context.Context, followRequest Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID). Column(columns...). Exec(ctx); err != nil { - return r.db.ProcessError(err) + return err } return nil @@ -220,7 +220,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI Model(follow). On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI). Exec(ctx) - return r.db.ProcessError(err) + return err }); err != nil { return nil, err } @@ -231,7 +231,7 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI Table("follow_requests"). Where("? = ?", bun.Ident("id"), followReq.ID). Exec(ctx); err != nil { - return nil, r.db.ProcessError(err) + return nil, err } // Delete original follow request notification @@ -281,7 +281,7 @@ func (r *relationshipDB) DeleteFollowRequest(ctx context.Context, sourceAccountI Table("follow_requests"). Where("? = ?", bun.Ident("id"), follow.ID). Exec(ctx) - return r.db.ProcessError(err) + return err } func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error { @@ -305,7 +305,7 @@ func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) Table("follow_requests"). Where("? = ?", bun.Ident("id"), id). Exec(ctx) - return r.db.ProcessError(err) + return err } func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error { @@ -329,7 +329,7 @@ func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri strin Table("follow_requests"). Where("? = ?", bun.Ident("uri"), uri). Exec(ctx) - return r.db.ProcessError(err) + return err } func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accountID string) error { @@ -347,7 +347,7 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun accountID, ). Exec(ctx, &followReqIDs); err != nil { - return r.db.ProcessError(err) + return err } defer func() { @@ -371,5 +371,5 @@ func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accoun Table("follow_requests"). Where("? IN (?)", bun.Ident("id"), bun.In(followReqIDs)). Exec(ctx) - return r.db.ProcessError(err) + return err } diff --git a/internal/db/bundb/relationship_note.go b/internal/db/bundb/relationship_note.go index 022d9ba0c..84f0ebeab 100644 --- a/internal/db/bundb/relationship_note.go +++ b/internal/db/bundb/relationship_note.go @@ -49,7 +49,7 @@ func (r *relationshipDB) getNote(ctx context.Context, lookup string, dbQuery fun // Not cached! Perform database query if err := dbQuery(¬e); err != nil { - return nil, r.db.ProcessError(err) + return nil, err } return ¬e, nil @@ -94,6 +94,6 @@ func (r *relationshipDB) PutNote(ctx context.Context, note *gtsmodel.AccountNote On("CONFLICT (?, ?) DO UPDATE", bun.Ident("account_id"), bun.Ident("target_account_id")). Set("? = ?, ? = ?", bun.Ident("updated_at"), note.UpdatedAt, bun.Ident("comment"), note.Comment). Exec(ctx) - return r.db.ProcessError(err) + return err }) } diff --git a/internal/db/bundb/report.go b/internal/db/bundb/report.go index eaeac4860..7c1dd16e7 100644 --- a/internal/db/bundb/report.go +++ b/internal/db/bundb/report.go @@ -32,7 +32,7 @@ import ( ) type reportDB struct { - db *WrappedDB + db *DB state *state.State } @@ -94,7 +94,7 @@ func (r *reportDB) GetReports(ctx context.Context, resolved *bool, accountID str } if err := q.Scan(ctx, &reportIDs); err != nil { - return nil, r.db.ProcessError(err) + return nil, err } // Catch case of no reports early @@ -125,7 +125,7 @@ func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*g // Not cached! Perform database query if err := dbQuery(&report); err != nil { - return nil, r.db.ProcessError(err) + return nil, err } return &report, nil @@ -204,7 +204,7 @@ func (r *reportDB) PopulateReport(ctx context.Context, report *gtsmodel.Report) func (r *reportDB) PutReport(ctx context.Context, report *gtsmodel.Report) error { return r.state.Caches.GTS.Report().Store(report, func() error { _, err := r.db.NewInsert().Model(report).Exec(ctx) - return r.db.ProcessError(err) + return err }) } @@ -221,7 +221,7 @@ func (r *reportDB) UpdateReport(ctx context.Context, report *gtsmodel.Report, co Where("? = ?", bun.Ident("report.id"), report.ID). Column(columns...). Exec(ctx); err != nil { - return nil, r.db.ProcessError(err) + return nil, err } r.state.Caches.GTS.Report().Invalidate("ID", report.ID) @@ -248,5 +248,5 @@ func (r *reportDB) DeleteReportByID(ctx context.Context, id string) error { TableExpr("? AS ?", bun.Ident("reports"), bun.Ident("report")). Where("? = ?", bun.Ident("report.id"), id). Exec(ctx) - return r.db.ProcessError(err) + return err } diff --git a/internal/db/bundb/search.go b/internal/db/bundb/search.go index 755f60e7d..061471c19 100644 --- a/internal/db/bundb/search.go +++ b/internal/db/bundb/search.go @@ -57,7 +57,7 @@ import ( // This isn't ideal, of course, but at least we could cover the most common use case of // a caller paging down through results. type searchDB struct { - db *WrappedDB + db *DB state *state.State } @@ -149,7 +149,7 @@ func (s *searchDB) SearchForAccounts( } if err := q.Scan(ctx, &accountIDs); err != nil { - return nil, s.db.ProcessError(err) + return nil, err } if len(accountIDs) == 0 { @@ -327,7 +327,7 @@ func (s *searchDB) SearchForStatuses( } if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, s.db.ProcessError(err) + return nil, err } if len(statusIDs) == 0 { @@ -453,7 +453,7 @@ func (s *searchDB) SearchForTags( } if err := q.Scan(ctx, &tagIDs); err != nil { - return nil, s.db.ProcessError(err) + return nil, err } if len(tagIDs) == 0 { diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go index 8d778ffa2..9310a6463 100644 --- a/internal/db/bundb/session.go +++ b/internal/db/bundb/session.go @@ -27,7 +27,7 @@ import ( ) type sessionDB struct { - db *WrappedDB + db *DB } func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, error) { @@ -40,7 +40,7 @@ func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, er Limit(1). Order("router_session.id DESC"). Scan(ctx); err != nil { - return nil, s.db.ProcessError(err) + return nil, err } // ... create a new one @@ -70,7 +70,7 @@ func (s *sessionDB) createSession(ctx context.Context) (*gtsmodel.RouterSession, NewInsert(). Model(rs). Exec(ctx); err != nil { - return nil, s.db.ProcessError(err) + return nil, err } return rs, nil diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index 311732299..0e97d32cc 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -33,7 +33,7 @@ import ( ) type statusDB struct { - db *WrappedDB + db *DB state *state.State } @@ -115,7 +115,7 @@ func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*g // Not cached! Perform database query. if err := dbQuery(&status); err != nil { - return nil, s.db.ProcessError(err) + return nil, err } return &status, nil @@ -287,7 +287,6 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error }). On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("emoji_id")). Exec(ctx); err != nil { - err = s.db.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } @@ -304,7 +303,6 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error }). On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("tag_id")). Exec(ctx); err != nil { - err = s.db.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } @@ -320,7 +318,6 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error Model(a). Where("? = ?", bun.Ident("media_attachment.id"), a.ID). Exec(ctx); err != nil { - err = s.db.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } @@ -356,7 +353,6 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co }). On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("emoji_id")). Exec(ctx); err != nil { - err = s.db.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } @@ -373,7 +369,6 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co }). On("CONFLICT (?, ?) DO NOTHING", bun.Ident("status_id"), bun.Ident("tag_id")). Exec(ctx); err != nil { - err = s.db.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } @@ -389,7 +384,6 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co Model(a). Where("? = ?", bun.Ident("media_attachment.id"), a.ID). Exec(ctx); err != nil { - err = s.db.ProcessError(err) if !errors.Is(err, db.ErrAlreadyExists) { return err } @@ -468,7 +462,7 @@ func (s *statusDB) GetStatusesUsingEmoji(ctx context.Context, emojiID string) ([ Column("status_id"). Where("? = ?", bun.Ident("emoji_id"), emojiID). Exec(ctx, &statusIDs); err != nil { - return nil, s.db.ProcessError(err) + return nil, err } // Convert status IDs into status objects. @@ -589,7 +583,7 @@ func (s *statusDB) getStatusReplyIDs(ctx context.Context, statusID string) ([]st Where("? = ?", bun.Ident("in_reply_to_id"), statusID). Order("id DESC"). Scan(ctx, &statusIDs); err != nil { - return nil, s.db.ProcessError(err) + return nil, err } return statusIDs, nil @@ -633,7 +627,7 @@ func (s *statusDB) getStatusBoostIDs(ctx context.Context, statusID string) ([]st Where("? = ?", bun.Ident("boost_of_id"), statusID). Order("id DESC"). Scan(ctx, &statusIDs); err != nil { - return nil, s.db.ProcessError(err) + return nil, err } return statusIDs, nil diff --git a/internal/db/bundb/statusbookmark.go b/internal/db/bundb/statusbookmark.go index 8a3c4dad6..742c13966 100644 --- a/internal/db/bundb/statusbookmark.go +++ b/internal/db/bundb/statusbookmark.go @@ -29,7 +29,7 @@ import ( ) type statusBookmarkDB struct { - db *WrappedDB + db *DB state *state.State } @@ -42,7 +42,7 @@ func (s *statusBookmarkDB) GetStatusBookmark(ctx context.Context, id string) (*g Where("? = ?", bun.Ident("status_bookmark.id"), id). Scan(ctx) if err != nil { - return nil, s.db.ProcessError(err) + return nil, err } bookmark.Account, err = s.state.DB.GetAccountByID(ctx, bookmark.AccountID) @@ -75,7 +75,7 @@ func (s *statusBookmarkDB) GetStatusBookmarkID(ctx context.Context, accountID st Limit(1) if err := q.Scan(ctx, &id); err != nil { - return "", s.db.ProcessError(err) + return "", err } return id, nil @@ -114,7 +114,7 @@ func (s *statusBookmarkDB) GetStatusBookmarks(ctx context.Context, accountID str } if err := q.Scan(ctx, &ids); err != nil { - return nil, s.db.ProcessError(err) + return nil, err } bookmarks := make([]*gtsmodel.StatusBookmark, 0, len(ids)) @@ -138,7 +138,7 @@ func (s *statusBookmarkDB) PutStatusBookmark(ctx context.Context, statusBookmark Model(statusBookmark). Exec(ctx) - return s.db.ProcessError(err) + return err } func (s *statusBookmarkDB) DeleteStatusBookmark(ctx context.Context, id string) error { @@ -148,7 +148,7 @@ func (s *statusBookmarkDB) DeleteStatusBookmark(ctx context.Context, id string) Where("? = ?", bun.Ident("status_bookmark.id"), id). Exec(ctx) - return s.db.ProcessError(err) + return err } func (s *statusBookmarkDB) DeleteStatusBookmarks(ctx context.Context, targetAccountID string, originAccountID string) error { @@ -173,7 +173,7 @@ func (s *statusBookmarkDB) DeleteStatusBookmarks(ctx context.Context, targetAcco } if _, err := q.Exec(ctx); err != nil { - return s.db.ProcessError(err) + return err } return nil @@ -190,7 +190,7 @@ func (s *statusBookmarkDB) DeleteStatusBookmarksForStatus(ctx context.Context, s Where("? = ?", bun.Ident("status_bookmark.status_id"), statusID) if _, err := q.Exec(ctx); err != nil { - return s.db.ProcessError(err) + return err } return nil diff --git a/internal/db/bundb/statusfave.go b/internal/db/bundb/statusfave.go index 37b88326b..73ac62fe7 100644 --- a/internal/db/bundb/statusfave.go +++ b/internal/db/bundb/statusfave.go @@ -33,7 +33,7 @@ import ( ) type statusFaveDB struct { - db *WrappedDB + db *DB state *state.State } @@ -82,7 +82,7 @@ func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery // Not cached! Perform database query. if err := dbQuery(&fave); err != nil { - return nil, s.db.ProcessError(err) + return nil, err } return &fave, nil @@ -151,7 +151,7 @@ func (s *statusFaveDB) getStatusFaveIDs(ctx context.Context, statusID string) ([ Column("id"). Where("? = ?", bun.Ident("status_id"), statusID). Scan(ctx, &faveIDs); err != nil { - return nil, s.db.ProcessError(err) + return nil, err } return faveIDs, nil @@ -206,7 +206,7 @@ func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusF NewInsert(). Model(fave). Exec(ctx) - return s.db.ProcessError(err) + return err }) } @@ -225,7 +225,7 @@ func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) erro // to us doing a RETURNING. err = nil } - return s.db.ProcessError(err) + return err } if statusID != "" { @@ -267,7 +267,7 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st // to us doing a RETURNING. err = nil } - return s.db.ProcessError(err) + return err } // Collate (deduplicating) status IDs. @@ -292,7 +292,7 @@ func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID Table("status_faves"). Where("status_id = ?", statusID). Exec(ctx); err != nil { - return s.db.ProcessError(err) + return err } // Invalidate any cached status faves for this status. diff --git a/internal/db/bundb/tag.go b/internal/db/bundb/tag.go index 043af5728..fac621f0a 100644 --- a/internal/db/bundb/tag.go +++ b/internal/db/bundb/tag.go @@ -28,7 +28,7 @@ import ( ) type tagDB struct { - conn *WrappedDB + conn *DB state *state.State } @@ -42,7 +42,7 @@ func (m *tagDB) GetTag(ctx context.Context, id string) (*gtsmodel.Tag, error) { Where("? = ?", bun.Ident("tag.id"), id) if err := q.Scan(ctx); err != nil { - return nil, m.conn.ProcessError(err) + return nil, err } return &tag, nil @@ -63,7 +63,7 @@ func (m *tagDB) GetTagByName(ctx context.Context, name string) (*gtsmodel.Tag, e Where("? = ?", bun.Ident("tag.name"), name) if err := q.Scan(ctx); err != nil { - return nil, m.conn.ProcessError(err) + return nil, err } return &tag, nil @@ -103,7 +103,7 @@ func (m *tagDB) PutTag(ctx context.Context, tag *gtsmodel.Tag) error { // Insert the copy. if err := m.state.Caches.GTS.Tag().Store(t2, func() error { _, err := m.conn.NewInsert().Model(t2).Exec(ctx) - return m.conn.ProcessError(err) + return err }); err != nil { return err // err already processed } diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go index 62f1f642d..1230a84d4 100644 --- a/internal/db/bundb/timeline.go +++ b/internal/db/bundb/timeline.go @@ -33,7 +33,7 @@ import ( ) type timelineDB struct { - db *WrappedDB + db *DB state *state.State } @@ -119,7 +119,7 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI }) if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, t.db.ProcessError(err) + return nil, err } if len(statusIDs) == 0 { @@ -202,7 +202,7 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI } if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, t.db.ProcessError(err) + return nil, err } statuses := make([]*gtsmodel.Status, 0, len(statusIDs)) @@ -253,7 +253,7 @@ func (t *timelineDB) GetFavedTimeline(ctx context.Context, accountID string, max err := fq.Scan(ctx) if err != nil { - return nil, "", "", t.db.ProcessError(err) + return nil, "", "", err } if len(faves) == 0 { @@ -379,7 +379,7 @@ func (t *timelineDB) GetListTimeline( } if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, t.db.ProcessError(err) + return nil, err } if len(statusIDs) == 0 { @@ -487,7 +487,7 @@ func (t *timelineDB) GetTagTimeline( } if err := q.Scan(ctx, &statusIDs); err != nil { - return nil, t.db.ProcessError(err) + return nil, err } if len(statusIDs) == 0 { diff --git a/internal/db/bundb/tombstone.go b/internal/db/bundb/tombstone.go index 0050e6531..f9882d1c6 100644 --- a/internal/db/bundb/tombstone.go +++ b/internal/db/bundb/tombstone.go @@ -27,7 +27,7 @@ import ( ) type tombstoneDB struct { - db *WrappedDB + db *DB state *state.State } @@ -41,7 +41,7 @@ func (t *tombstoneDB) GetTombstoneByURI(ctx context.Context, uri string) (*gtsmo Where("? = ?", bun.Ident("tombstone.uri"), uri) if err := q.Scan(ctx); err != nil { - return nil, t.db.ProcessError(err) + return nil, err } return &tomb, nil @@ -62,7 +62,7 @@ func (t *tombstoneDB) PutTombstone(ctx context.Context, tombstone *gtsmodel.Tomb NewInsert(). Model(tombstone). Exec(ctx) - return t.db.ProcessError(err) + return err }) } @@ -74,5 +74,5 @@ func (t *tombstoneDB) DeleteTombstone(ctx context.Context, id string) error { TableExpr("? AS ?", bun.Ident("tombstones"), bun.Ident("tombstone")). Where("? = ?", bun.Ident("tombstone.id"), id). Exec(ctx) - return t.db.ProcessError(err) + return err } diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go index 9df05596e..eaa1d8e3d 100644 --- a/internal/db/bundb/user.go +++ b/internal/db/bundb/user.go @@ -31,7 +31,7 @@ import ( ) type userDB struct { - db *WrappedDB + db *DB state *state.State } @@ -121,7 +121,7 @@ func (u *userDB) getUser(ctx context.Context, lookup string, dbQuery func(*gtsmo // Not cached! perform database query. if err := dbQuery(&user); err != nil { - return nil, u.db.ProcessError(err) + return nil, err } return &user, nil @@ -150,7 +150,7 @@ func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) { Table("users"). Column("id"). Scan(ctx, &userIDs); err != nil { - return nil, u.db.ProcessError(err) + return nil, err } // Transform user IDs into user slice. @@ -163,7 +163,7 @@ func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) error { NewInsert(). Model(user). Exec(ctx) - return u.db.ProcessError(err) + return err }) } @@ -183,7 +183,7 @@ func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns .. Where("? = ?", bun.Ident("user.id"), user.ID). Column(columns...). Exec(ctx) - return u.db.ProcessError(err) + return err }) } @@ -207,5 +207,5 @@ func (u *userDB) DeleteUserByID(ctx context.Context, userID string) error { TableExpr("? AS ?", bun.Ident("users"), bun.Ident("user")). Where("? = ?", bun.Ident("user.id"), userID). Exec(ctx) - return u.db.ProcessError(err) + return err } diff --git a/internal/db/bundb/wrap.go b/internal/db/bundb/wrap.go deleted file mode 100644 index a5039914a..000000000 --- a/internal/db/bundb/wrap.go +++ /dev/null @@ -1,258 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package bundb - -import ( - "context" - "database/sql" - "time" - - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/uptrace/bun" - "github.com/uptrace/bun/dialect" -) - -// WrappedDB wraps a bun database instance -// to provide common per-dialect SQL error -// conversions to common types, and retries -// on returned busy errors (SQLite only for now). -type WrappedDB struct { - errHook func(error) error - *bun.DB // underlying conn -} - -// WrapDB wraps a bun database instance in our own WrappedDB type. -func WrapDB(db *bun.DB) *WrappedDB { - var errProc func(error) error - switch name := db.Dialect().Name(); name { - case dialect.PG: - errProc = processPostgresError - case dialect.SQLite: - errProc = processSQLiteError - default: - panic("unknown dialect name: " + name.String()) - } - return &WrappedDB{ - errHook: errProc, - DB: db, - } -} - -// BeginTx wraps bun.DB.BeginTx() with retry-busy timeout. -func (db *WrappedDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (tx bun.Tx, err error) { - err = retryOnBusy(ctx, func() error { - tx, err = db.DB.BeginTx(ctx, opts) - err = db.ProcessError(err) - return err - }) - return -} - -// ExecContext wraps bun.DB.ExecContext() with retry-busy timeout. -func (db *WrappedDB) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) { - err = retryOnBusy(ctx, func() error { - result, err = db.DB.ExecContext(ctx, query, args...) - err = db.ProcessError(err) - return err - }) - return -} - -// QueryContext wraps bun.DB.QueryContext() with retry-busy timeout. -func (db *WrappedDB) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) { - err = retryOnBusy(ctx, func() error { - rows, err = db.DB.QueryContext(ctx, query, args...) - err = db.ProcessError(err) - return err - }) - return -} - -// QueryRowContext wraps bun.DB.QueryRowContext() with retry-busy timeout. -func (db *WrappedDB) QueryRowContext(ctx context.Context, query string, args ...any) (row *sql.Row) { - _ = retryOnBusy(ctx, func() error { - row = db.DB.QueryRowContext(ctx, query, args...) - err := db.ProcessError(row.Err()) - return err - }) - return -} - -// RunInTx is functionally the same as bun.DB.RunInTx() but with retry-busy timeouts. -func (db *WrappedDB) RunInTx(ctx context.Context, fn func(bun.Tx) error) error { - // Attempt to start new transaction. - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return err - } - - var done bool - - defer func() { - if !done { - // Rollback (with retry-backoff). - _ = retryOnBusy(ctx, func() error { - err := tx.Rollback() - return db.errHook(err) - }) - } - }() - - // Perform supplied transaction - if err := fn(tx); err != nil { - return db.errHook(err) - } - - // Commit (with retry-backoff). - err = retryOnBusy(ctx, func() error { - err := tx.Commit() - return db.errHook(err) - }) - done = true - return err -} - -func (db *WrappedDB) NewValues(model interface{}) *bun.ValuesQuery { - return bun.NewValuesQuery(db.DB, model).Conn(db) -} - -func (db *WrappedDB) NewMerge() *bun.MergeQuery { - return bun.NewMergeQuery(db.DB).Conn(db) -} - -func (db *WrappedDB) NewSelect() *bun.SelectQuery { - return bun.NewSelectQuery(db.DB).Conn(db) -} - -func (db *WrappedDB) NewInsert() *bun.InsertQuery { - return bun.NewInsertQuery(db.DB).Conn(db) -} - -func (db *WrappedDB) NewUpdate() *bun.UpdateQuery { - return bun.NewUpdateQuery(db.DB).Conn(db) -} - -func (db *WrappedDB) NewDelete() *bun.DeleteQuery { - return bun.NewDeleteQuery(db.DB).Conn(db) -} - -func (db *WrappedDB) NewRaw(query string, args ...interface{}) *bun.RawQuery { - return bun.NewRawQuery(db.DB, query, args...).Conn(db) -} - -func (db *WrappedDB) NewCreateTable() *bun.CreateTableQuery { - return bun.NewCreateTableQuery(db.DB).Conn(db) -} - -func (db *WrappedDB) NewDropTable() *bun.DropTableQuery { - return bun.NewDropTableQuery(db.DB).Conn(db) -} - -func (db *WrappedDB) NewCreateIndex() *bun.CreateIndexQuery { - return bun.NewCreateIndexQuery(db.DB).Conn(db) -} - -func (db *WrappedDB) NewDropIndex() *bun.DropIndexQuery { - return bun.NewDropIndexQuery(db.DB).Conn(db) -} - -func (db *WrappedDB) NewTruncateTable() *bun.TruncateTableQuery { - return bun.NewTruncateTableQuery(db.DB).Conn(db) -} - -func (db *WrappedDB) NewAddColumn() *bun.AddColumnQuery { - return bun.NewAddColumnQuery(db.DB).Conn(db) -} - -func (db *WrappedDB) NewDropColumn() *bun.DropColumnQuery { - return bun.NewDropColumnQuery(db.DB).Conn(db) -} - -// ProcessError processes an error to replace any known values with our own error types, -// making it easier to catch specific situations (e.g. no rows, already exists, etc) -func (db *WrappedDB) ProcessError(err error) error { - if err == nil { - return nil - } - return db.errHook(err) -} - -// Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors -func (db *WrappedDB) Exists(ctx context.Context, query *bun.SelectQuery) (bool, error) { - exists, err := query.Exists(ctx) - switch err { - case nil: - return exists, nil - case sql.ErrNoRows: - return false, nil - default: - return false, err - } -} - -// NotExists is the functional opposite of conn.Exists() -func (db *WrappedDB) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, error) { - exists, err := db.Exists(ctx, query) - return !exists, err -} - -// retryOnBusy will retry given function on returned 'errBusy'. -func retryOnBusy(ctx context.Context, fn func() error) error { - var backoff time.Duration - - for i := 0; ; i++ { - // Perform func. - err := fn() - - if err != errBusy { - // May be nil, or may be - // some other error, either - // way return here. - return err - } - - // backoff according to a multiplier of 2ms * 2^2n, - // up to a maximum possible backoff time of 5 minutes. - // - // this works out as the following: - // 4ms - // 16ms - // 64ms - // 256ms - // 1.024s - // 4.096s - // 16.384s - // 1m5.536s - // 4m22.144s - backoff = 2 * time.Millisecond * (1 << (2*i + 1)) - if backoff >= 5*time.Minute { - break - } - - select { - // Context cancelled. - case <-ctx.Done(): - - // Backoff for some time. - case <-time.After(backoff): - } - } - - return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff) -} diff --git a/internal/gotosocial/gotosocial.go b/internal/gotosocial/gotosocial.go index 4c6846ff9..6a66602f0 100644 --- a/internal/gotosocial/gotosocial.go +++ b/internal/gotosocial/gotosocial.go @@ -72,6 +72,5 @@ func (gts *gotosocial) Stop(ctx context.Context) error { if err := gts.apiRouter.Stop(ctx); err != nil { return err } - - return gts.db.Stop(ctx) + return gts.db.Close() }