From e8a20f587c0b0129bc68f5c6092c54f2b4c3519a Mon Sep 17 00:00:00 2001 From: tobi <31960611+tsmethurst@users.noreply.github.com> Date: Wed, 2 Aug 2023 17:21:46 +0200 Subject: [PATCH] [bugfix] Rework MultiError to wrap + unwrap errors properly (#2057) * rework multierror a bit * test multierror --- .../action/admin/media/prune/common.go | 6 +- .../api/activitypub/users/inboxpost_test.go | 7 +- .../api/client/accounts/accountupdate_test.go | 6 +- internal/api/client/accounts/lists_test.go | 7 +- internal/api/client/accounts/search_test.go | 7 +- .../api/client/admin/reportresolve_test.go | 7 +- internal/api/client/admin/reportsget_test.go | 7 +- .../api/client/lists/listaccounts_test.go | 7 +- .../api/client/reports/reportcreate_test.go | 7 +- internal/api/client/reports/reportget_test.go | 7 +- internal/api/client/search/searchget_test.go | 7 +- .../api/client/statuses/statuspin_test.go | 15 ++-- .../api/client/statuses/statusunpin_test.go | 15 ++-- internal/cleaner/cleaner.go | 10 ++- internal/db/bundb/account.go | 15 ++-- internal/db/bundb/instance.go | 12 ++- internal/db/bundb/list.go | 12 ++- internal/db/bundb/relationship_follow.go | 12 ++- internal/db/bundb/status.go | 29 ++++---- internal/db/bundb/statusfave.go | 14 ++-- internal/gtserror/multi.go | 55 ++++++++++---- internal/gtserror/multi_test.go | 64 ++++++++++++++++ internal/processing/fromcommon.go | 73 ++++++++++--------- internal/timeline/manager.go | 16 ++-- 24 files changed, 263 insertions(+), 154 deletions(-) create mode 100644 internal/gtserror/multi_test.go diff --git a/cmd/gotosocial/action/admin/media/prune/common.go b/cmd/gotosocial/action/admin/media/prune/common.go index 0db1a3462..ad721675e 100644 --- a/cmd/gotosocial/action/admin/media/prune/common.go +++ b/cmd/gotosocial/action/admin/media/prune/common.go @@ -75,14 +75,14 @@ func setupPrune(ctx context.Context) (*prune, error) { } func (p *prune) shutdown(ctx context.Context) error { - var errs gtserror.MultiError + errs := gtserror.NewMultiError(2) if err := p.storage.Close(); err != nil { - errs.Appendf("error closing storage backend: %v", err) + errs.Appendf("error closing storage backend: %w", err) } if err := p.dbService.Stop(ctx); err != nil { - errs.Appendf("error stopping database: %v", err) + errs.Appendf("error stopping database: %w", err) } p.state.Workers.Stop() diff --git a/internal/api/activitypub/users/inboxpost_test.go b/internal/api/activitypub/users/inboxpost_test.go index c5027f342..d26dae513 100644 --- a/internal/api/activitypub/users/inboxpost_test.go +++ b/internal/api/activitypub/users/inboxpost_test.go @@ -22,7 +22,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "io" "net/http" "net/http/httptest" @@ -105,16 +104,16 @@ func (suite *InboxPostTestSuite) inboxPost( suite.FailNow(err.Error()) } - errs := gtserror.MultiError{} + errs := gtserror.NewMultiError(2) // Check expected code + body. if resultCode := recorder.Code; expectedHTTPStatus != resultCode { - errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode)) + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) } // If we got an expected body, return early. if expectedBody != "" && string(b) != expectedBody { - errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b))) + errs.Appendf("expected %s got %s", expectedBody, string(b)) } if err := errs.Combine(); err != nil { diff --git a/internal/api/client/accounts/accountupdate_test.go b/internal/api/client/accounts/accountupdate_test.go index 01d12ab27..835989037 100644 --- a/internal/api/client/accounts/accountupdate_test.go +++ b/internal/api/client/accounts/accountupdate_test.go @@ -90,16 +90,16 @@ func (suite *AccountUpdateTestSuite) updateAccount( return nil, err } - errs := gtserror.MultiError{} + errs := gtserror.NewMultiError(2) // Check expected code + body. if resultCode := recorder.Code; expectedHTTPStatus != resultCode { - errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode)) + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) } // If we got an expected body, return early. if expectedBody != "" && string(b) != expectedBody { - errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b))) + errs.Appendf("expected %s got %s", expectedBody, string(b)) } if err := errs.Combine(); err != nil { diff --git a/internal/api/client/accounts/lists_test.go b/internal/api/client/accounts/lists_test.go index 6984d6ef8..637babc35 100644 --- a/internal/api/client/accounts/lists_test.go +++ b/internal/api/client/accounts/lists_test.go @@ -19,7 +19,6 @@ package accounts_test import ( "encoding/json" - "fmt" "io" "net/http" "net/http/httptest" @@ -63,16 +62,16 @@ func (suite *ListsTestSuite) getLists(targetAccountID string, expectedHTTPStatus suite.FailNow(err.Error()) } - errs := gtserror.MultiError{} + errs := gtserror.NewMultiError(2) // Check expected code + body. if resultCode := recorder.Code; expectedHTTPStatus != resultCode { - errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode)) + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) } // If we got an expected body, return early. if expectedBody != "" && string(b) != expectedBody { - errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b))) + errs.Appendf("expected %s got %s", expectedBody, string(b)) } if err := errs.Combine(); err != nil { diff --git a/internal/api/client/accounts/search_test.go b/internal/api/client/accounts/search_test.go index 7d778f090..119900331 100644 --- a/internal/api/client/accounts/search_test.go +++ b/internal/api/client/accounts/search_test.go @@ -19,7 +19,6 @@ package accounts_test import ( "encoding/json" - "fmt" "io" "net/http" "net/http/httptest" @@ -99,16 +98,16 @@ func (suite *AccountSearchTestSuite) getSearch( suite.FailNow(err.Error()) } - errs := gtserror.MultiError{} + errs := gtserror.NewMultiError(2) // Check expected code + body. if resultCode := recorder.Code; expectedHTTPStatus != resultCode { - errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode)) + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) } // If we got an expected body, return early. if expectedBody != "" && string(b) != expectedBody { - errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b))) + errs.Appendf("expected %s got %s", expectedBody, string(b)) } if err := errs.Combine(); err != nil { diff --git a/internal/api/client/admin/reportresolve_test.go b/internal/api/client/admin/reportresolve_test.go index 691ba1f38..754dbb443 100644 --- a/internal/api/client/admin/reportresolve_test.go +++ b/internal/api/client/admin/reportresolve_test.go @@ -19,7 +19,6 @@ package admin_test import ( "encoding/json" - "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -84,16 +83,16 @@ func (suite *ReportResolveTestSuite) resolveReport( return nil, err } - errs := gtserror.MultiError{} + errs := gtserror.NewMultiError(2) if resultCode := recorder.Code; expectedHTTPStatus != resultCode { - errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode)) + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) } // if we got an expected body, return early if expectedBody != "" { if string(b) != expectedBody { - errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b))) + errs.Appendf("expected %s got %s", expectedBody, string(b)) } return nil, errs.Combine() } diff --git a/internal/api/client/admin/reportsget_test.go b/internal/api/client/admin/reportsget_test.go index fae21dc07..4f29aa872 100644 --- a/internal/api/client/admin/reportsget_test.go +++ b/internal/api/client/admin/reportsget_test.go @@ -19,7 +19,6 @@ package admin_test import ( "encoding/json" - "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -101,16 +100,16 @@ func (suite *ReportsGetTestSuite) getReports( return nil, "", err } - errs := gtserror.MultiError{} + errs := gtserror.NewMultiError(2) if resultCode := recorder.Code; expectedHTTPStatus != resultCode { - errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode)) + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) } // if we got an expected body, return early if expectedBody != "" { if string(b) != expectedBody { - errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b))) + errs.Appendf("expected %s got %s", expectedBody, string(b)) } return nil, "", errs.Combine() } diff --git a/internal/api/client/lists/listaccounts_test.go b/internal/api/client/lists/listaccounts_test.go index 64e9ef768..bbd187f7d 100644 --- a/internal/api/client/lists/listaccounts_test.go +++ b/internal/api/client/lists/listaccounts_test.go @@ -19,7 +19,6 @@ package lists_test import ( "encoding/json" - "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -103,17 +102,17 @@ func (suite *ListAccountsTestSuite) getListAccounts( return nil, "", err } - errs := gtserror.MultiError{} + errs := gtserror.NewMultiError(2) // check code + body if resultCode := recorder.Code; expectedHTTPStatus != resultCode { - errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode)) + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) } // if we got an expected body, return early if expectedBody != "" { if string(b) != expectedBody { - errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b))) + errs.Appendf("expected %s got %s", expectedBody, string(b)) } return nil, "", errs.Combine() } diff --git a/internal/api/client/reports/reportcreate_test.go b/internal/api/client/reports/reportcreate_test.go index 672a7a63b..e17695cb9 100644 --- a/internal/api/client/reports/reportcreate_test.go +++ b/internal/api/client/reports/reportcreate_test.go @@ -19,7 +19,6 @@ package reports_test import ( "encoding/json" - "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -77,17 +76,17 @@ func (suite *ReportCreateTestSuite) createReport(expectedHTTPStatus int, expecte return nil, err } - errs := gtserror.MultiError{} + errs := gtserror.NewMultiError(2) // check code + body if resultCode := recorder.Code; expectedHTTPStatus != resultCode { - errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode)) + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) } // if we got an expected body, return early if expectedBody != "" { if string(b) != expectedBody { - errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b))) + errs.Appendf("expected %s got %s", expectedBody, string(b)) } return nil, errs.Combine() } diff --git a/internal/api/client/reports/reportget_test.go b/internal/api/client/reports/reportget_test.go index b3de34053..e29836b6a 100644 --- a/internal/api/client/reports/reportget_test.go +++ b/internal/api/client/reports/reportget_test.go @@ -19,7 +19,6 @@ package reports_test import ( "encoding/json" - "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -64,17 +63,17 @@ func (suite *ReportGetTestSuite) getReport(expectedHTTPStatus int, expectedBody return nil, err } - errs := gtserror.MultiError{} + errs := gtserror.NewMultiError(2) // check code + body if resultCode := recorder.Code; expectedHTTPStatus != resultCode { - errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode)) + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) } // if we got an expected body, return early if expectedBody != "" { if string(b) != expectedBody { - errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b))) + errs.Appendf("expected %s got %s", expectedBody, string(b)) } return nil, errs.Combine() } diff --git a/internal/api/client/search/searchget_test.go b/internal/api/client/search/searchget_test.go index e811dd329..2a6911430 100644 --- a/internal/api/client/search/searchget_test.go +++ b/internal/api/client/search/searchget_test.go @@ -22,7 +22,6 @@ import ( "crypto/rand" "crypto/rsa" "encoding/json" - "fmt" "io" "net/http" "net/http/httptest" @@ -122,16 +121,16 @@ func (suite *SearchGetTestSuite) getSearch( suite.FailNow(err.Error()) } - errs := gtserror.MultiError{} + errs := gtserror.NewMultiError(2) // Check expected code + body. if resultCode := recorder.Code; expectedHTTPStatus != resultCode { - errs = append(errs, fmt.Sprintf("expected %d got %d: %v", expectedHTTPStatus, resultCode, ctx.Errors.JSON())) + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) } // If we got an expected body, return early. if expectedBody != "" && string(b) != expectedBody { - errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b))) + errs.Appendf("expected %s got %s", expectedBody, string(b)) } if err := errs.Combine(); err != nil { diff --git a/internal/api/client/statuses/statuspin_test.go b/internal/api/client/statuses/statuspin_test.go index 66ad1a2ee..c7ed6e95d 100644 --- a/internal/api/client/statuses/statuspin_test.go +++ b/internal/api/client/statuses/statuspin_test.go @@ -20,7 +20,6 @@ package statuses_test import ( "context" "encoding/json" - "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -74,20 +73,20 @@ func (suite *StatusPinTestSuite) createPin( return nil, err } - errs := gtserror.MultiError{} + errs := gtserror.NewMultiError(2) - // check code + body + // Check expected code + body. if resultCode := recorder.Code; expectedHTTPStatus != resultCode { - errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode)) + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) } - // if we got an expected body, return early + // If we got an expected body, return early. if expectedBody != "" && string(b) != expectedBody { - errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b))) + errs.Appendf("expected %s got %s", expectedBody, string(b)) } - if len(errs) > 0 { - return nil, errs.Combine() + if err := errs.Combine(); err != nil { + suite.FailNow("", "%v (body %s)", err, string(b)) } resp := &apimodel.Status{} diff --git a/internal/api/client/statuses/statusunpin_test.go b/internal/api/client/statuses/statusunpin_test.go index bc68a2ca3..9f6602b98 100644 --- a/internal/api/client/statuses/statusunpin_test.go +++ b/internal/api/client/statuses/statusunpin_test.go @@ -19,7 +19,6 @@ package statuses_test import ( "encoding/json" - "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -68,20 +67,20 @@ func (suite *StatusUnpinTestSuite) createUnpin( return nil, err } - errs := gtserror.MultiError{} + errs := gtserror.NewMultiError(2) - // check code + body + // Check expected code + body. if resultCode := recorder.Code; expectedHTTPStatus != resultCode { - errs = append(errs, fmt.Sprintf("expected %d got %d", expectedHTTPStatus, resultCode)) + errs.Appendf("expected %d got %d", expectedHTTPStatus, resultCode) } - // if we got an expected body, return early + // If we got an expected body, return early. if expectedBody != "" && string(b) != expectedBody { - errs = append(errs, fmt.Sprintf("expected %s got %s", expectedBody, string(b))) + errs.Appendf("expected %s got %s", expectedBody, string(b)) } - if len(errs) > 0 { - return nil, errs.Combine() + if err := errs.Combine(); err != nil { + suite.FailNow("", "%v (body %s)", err, string(b)) } resp := &apimodel.Status{} diff --git a/internal/cleaner/cleaner.go b/internal/cleaner/cleaner.go index 70497c10e..31766bae6 100644 --- a/internal/cleaner/cleaner.go +++ b/internal/cleaner/cleaner.go @@ -83,19 +83,23 @@ func (c *Cleaner) removeFiles(ctx context.Context, files ...string) (int, error) return len(files), nil } - var errs gtserror.MultiError + var ( + errs gtserror.MultiError + errCount int + ) for _, path := range files { // Remove each provided storage path. log.Debugf(ctx, "removing file: %s", path) err := c.state.Storage.Delete(ctx, path) if err != nil && !errors.Is(err, storage.ErrNotFound) { - errs.Appendf("error removing %s: %v", path, err) + errs.Appendf("error removing %s: %w", path, err) + errCount++ } } // Calculate no. files removed. - diff := len(files) - len(errs) + diff := len(files) - errCount // Wrap the combined error slice. if err := errs.Combine(); err != nil { diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 6a47418b7..83b3c13f5 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -20,7 +20,6 @@ package bundb import ( "context" "errors" - "fmt" "strings" "time" @@ -255,7 +254,7 @@ func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func( func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Account) error { var ( err error - errs = make(gtserror.MultiError, 0, 3) + errs = gtserror.NewMultiError(3) ) if account.AvatarMediaAttachment == nil && account.AvatarMediaAttachmentID != "" { @@ -265,7 +264,7 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou account.AvatarMediaAttachmentID, ) if err != nil { - errs.Append(fmt.Errorf("error populating account avatar: %w", err)) + errs.Appendf("error populating account avatar: %w", err) } } @@ -276,7 +275,7 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou account.HeaderMediaAttachmentID, ) if err != nil { - errs.Append(fmt.Errorf("error populating account header: %w", err)) + errs.Appendf("error populating account header: %w", err) } } @@ -287,11 +286,15 @@ func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Accou account.EmojiIDs, ) if err != nil { - errs.Append(fmt.Errorf("error populating account emojis: %w", err)) + errs.Appendf("error populating account emojis: %w", err) } } - return errs.Combine() + if err := errs.Combine(); err != nil { + return gtserror.Newf("%w", err) + } + + return nil } func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) error { diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go index 48332c731..6657072fd 100644 --- a/internal/db/bundb/instance.go +++ b/internal/db/bundb/instance.go @@ -173,7 +173,7 @@ func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery fun func (i *instanceDB) populateInstance(ctx context.Context, instance *gtsmodel.Instance) error { var ( err error - errs = make(gtserror.MultiError, 0, 2) + errs = gtserror.NewMultiError(2) ) if instance.DomainBlockID != "" && instance.DomainBlock == nil { @@ -183,7 +183,7 @@ func (i *instanceDB) populateInstance(ctx context.Context, instance *gtsmodel.In instance.Domain, ) if err != nil { - errs.Append(gtserror.Newf("error populating instance domain block: %w", err)) + errs.Appendf("error populating instance domain block: %w", err) } } @@ -194,11 +194,15 @@ func (i *instanceDB) populateInstance(ctx context.Context, instance *gtsmodel.In instance.ContactAccountID, ) if err != nil { - errs.Append(gtserror.Newf("error populating instance contact account: %w", err)) + errs.Appendf("error populating instance contact account: %w", err) } } - return errs.Combine() + if err := errs.Combine(); err != nil { + return gtserror.Newf("%w", err) + } + + return nil } func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instance) error { diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go index 70faf837a..ad970f539 100644 --- a/internal/db/bundb/list.go +++ b/internal/db/bundb/list.go @@ -117,7 +117,7 @@ func (l *listDB) GetListsForAccountID(ctx context.Context, accountID string) ([] func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { var ( err error - errs = make(gtserror.MultiError, 0, 2) + errs = gtserror.NewMultiError(2) ) if list.Account == nil { @@ -127,7 +127,7 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { list.AccountID, ) if err != nil { - errs.Append(fmt.Errorf("error populating list account: %w", err)) + errs.Appendf("error populating list account: %w", err) } } @@ -139,11 +139,15 @@ func (l *listDB) PopulateList(ctx context.Context, list *gtsmodel.List) error { "", "", "", 0, ) if err != nil { - errs.Append(fmt.Errorf("error populating list entries: %w", err)) + errs.Appendf("error populating list entries: %w", err) } } - return errs.Combine() + if err := errs.Combine(); err != nil { + return gtserror.Newf("%w", err) + } + + return nil } func (l *listDB) PutList(ctx context.Context, list *gtsmodel.List) error { diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go index 3b0597612..e22ed30de 100644 --- a/internal/db/bundb/relationship_follow.go +++ b/internal/db/bundb/relationship_follow.go @@ -160,7 +160,7 @@ func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery f func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Follow) error { var ( err error - errs = make(gtserror.MultiError, 0, 2) + errs = gtserror.NewMultiError(2) ) if follow.Account == nil { @@ -170,7 +170,7 @@ func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Fo follow.AccountID, ) if err != nil { - errs.Append(fmt.Errorf("error populating follow account: %w", err)) + errs.Appendf("error populating follow account: %w", err) } } @@ -181,11 +181,15 @@ func (r *relationshipDB) PopulateFollow(ctx context.Context, follow *gtsmodel.Fo follow.TargetAccountID, ) if err != nil { - errs.Append(fmt.Errorf("error populating follow target account: %w", err)) + errs.Appendf("error populating follow target account: %w", err) } } - return errs.Combine() + if err := errs.Combine(); err != nil { + return gtserror.Newf("%w", err) + } + + return nil } func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error { diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index c34074dd6..25b773dfa 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -22,7 +22,6 @@ import ( "context" "database/sql" "errors" - "fmt" "time" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -129,7 +128,7 @@ func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*g func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) error { var ( err error - errs = make(gtserror.MultiError, 0, 9) + errs = gtserror.NewMultiError(9) ) if status.Account == nil { @@ -139,7 +138,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) status.AccountID, ) if err != nil { - errs.Append(fmt.Errorf("error populating status author: %w", err)) + errs.Appendf("error populating status author: %w", err) } } @@ -150,7 +149,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) status.InReplyToID, ) if err != nil { - errs.Append(fmt.Errorf("error populating status parent: %w", err)) + errs.Appendf("error populating status parent: %w", err) } } @@ -162,7 +161,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) status.InReplyToID, ) if err != nil { - errs.Append(fmt.Errorf("error populating status parent: %w", err)) + errs.Appendf("error populating status parent: %w", err) } } @@ -173,7 +172,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) status.InReplyToAccountID, ) if err != nil { - errs.Append(fmt.Errorf("error populating status parent author: %w", err)) + errs.Appendf("error populating status parent author: %w", err) } } } @@ -186,7 +185,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) status.BoostOfID, ) if err != nil { - errs.Append(fmt.Errorf("error populating status boost: %w", err)) + errs.Appendf("error populating status boost: %w", err) } } @@ -197,7 +196,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) status.BoostOfAccountID, ) if err != nil { - errs.Append(fmt.Errorf("error populating status boost author: %w", err)) + errs.Appendf("error populating status boost author: %w", err) } } } @@ -209,7 +208,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) status.AttachmentIDs, ) if err != nil { - errs.Append(fmt.Errorf("error populating status attachments: %w", err)) + errs.Appendf("error populating status attachments: %w", err) } } @@ -220,7 +219,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) status.TagIDs, ) if err != nil { - errs.Append(fmt.Errorf("error populating status tags: %w", err)) + errs.Appendf("error populating status tags: %w", err) } } @@ -231,7 +230,7 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) status.MentionIDs, ) if err != nil { - errs.Append(fmt.Errorf("error populating status mentions: %w", err)) + errs.Appendf("error populating status mentions: %w", err) } } @@ -242,11 +241,15 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) status.EmojiIDs, ) if err != nil { - errs.Append(fmt.Errorf("error populating status emojis: %w", err)) + errs.Appendf("error populating status emojis: %w", err) } } - return errs.Combine() + if err := errs.Combine(); err != nil { + return gtserror.Newf("%w", err) + } + + return nil } func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error { diff --git a/internal/db/bundb/statusfave.go b/internal/db/bundb/statusfave.go index a8d1cd0d1..7aff543fd 100644 --- a/internal/db/bundb/statusfave.go +++ b/internal/db/bundb/statusfave.go @@ -149,7 +149,7 @@ func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID str func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) error { var ( err error - errs = make(gtserror.MultiError, 0, 3) + errs = gtserror.NewMultiError(3) ) if statusFave.Account == nil { @@ -159,7 +159,7 @@ func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmo statusFave.AccountID, ) if err != nil { - errs.Append(fmt.Errorf("error populating status fave author: %w", err)) + errs.Appendf("error populating status fave author: %w", err) } } @@ -170,7 +170,7 @@ func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmo statusFave.TargetAccountID, ) if err != nil { - errs.Append(fmt.Errorf("error populating status fave target account: %w", err)) + errs.Appendf("error populating status fave target account: %w", err) } } @@ -181,11 +181,15 @@ func (s *statusFaveDB) PopulateStatusFave(ctx context.Context, statusFave *gtsmo statusFave.StatusID, ) if err != nil { - errs.Append(fmt.Errorf("error populating status fave status: %w", err)) + errs.Appendf("error populating status fave status: %w", err) } } - return errs.Combine() + if err := errs.Combine(); err != nil { + return gtserror.Newf("%w", err) + } + + return nil } func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) error { diff --git a/internal/gtserror/multi.go b/internal/gtserror/multi.go index 371eb2f63..1c533b285 100644 --- a/internal/gtserror/multi.go +++ b/internal/gtserror/multi.go @@ -20,25 +20,48 @@ package gtserror import ( "errors" "fmt" - "strings" ) -// MultiError allows encapsulating multiple errors under a singular instance, -// which is useful when you only want to log on errors, not return early / bubble up. -type MultiError []string - -func (e *MultiError) Append(err error) { - *e = append(*e, err.Error()) +// MultiError allows encapsulating multiple +// errors under a singular instance, which +// is useful when you only want to log on +// errors, not return early / bubble up. +type MultiError struct { + e []error } -func (e *MultiError) Appendf(format string, args ...any) { - *e = append(*e, fmt.Sprintf(format, args...)) -} - -// Combine converts this multiError to a singular error instance, returning nil if empty. -func (e MultiError) Combine() error { - if len(e) == 0 { - return nil +// NewMultiError returns a *MultiError with +// the capacity of its underlying error slice +// set to the provided value. +// +// This capacity can be exceeded if necessary, +// but it saves a teeny tiny bit of memory if +// callers set it correctly. +// +// If you don't know in advance what the capacity +// must be, just use new(MultiError) instead. +func NewMultiError(capacity int) *MultiError { + return &MultiError{ + e: make([]error, 0, capacity), } - return errors.New(`"` + strings.Join(e, `","`) + `"`) +} + +// Append the given error to the MultiError. +func (m *MultiError) Append(err error) { + m.e = append(m.e, err) +} + +// Append the given format string to the MultiError. +// +// It is valid to use %w in the format string +// to wrap any other errors. +func (m *MultiError) Appendf(format string, args ...any) { + m.e = append(m.e, fmt.Errorf(format, args...)) +} + +// Combine the MultiError into a single error. +// +// Unwrap will work on the returned error as expected. +func (m MultiError) Combine() error { + return errors.Join(m.e...) } diff --git a/internal/gtserror/multi_test.go b/internal/gtserror/multi_test.go new file mode 100644 index 000000000..9c16c1a53 --- /dev/null +++ b/internal/gtserror/multi_test.go @@ -0,0 +1,64 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package gtserror + +import ( + "errors" + "testing" + + "github.com/superseriousbusiness/gotosocial/internal/db" +) + +func TestMultiError(t *testing.T) { + errs := MultiError{ + e: []error{ + db.ErrNoEntries, + errors.New("oopsie woopsie we did a fucky wucky etc"), + }, + } + errs.Appendf("appended + wrapped error: %w", db.ErrAlreadyExists) + + err := errs.Combine() + + if !errors.Is(err, db.ErrNoEntries) { + t.Error("should be db.ErrNoEntries") + } + + if !errors.Is(err, db.ErrAlreadyExists) { + t.Error("should be db.ErrAlreadyExists") + } + + if errors.Is(err, db.ErrBusyTimeout) { + t.Error("should not be db.ErrBusyTimeout") + } + + errString := err.Error() + expected := `sql: no rows in result set +oopsie woopsie we did a fucky wucky etc +appended + wrapped error: already exists` + if errString != expected { + t.Errorf("errString '%s' should be '%s'", errString, expected) + } +} + +func TestMultiErrorEmpty(t *testing.T) { + err := new(MultiError).Combine() + if err != nil { + t.Errorf("should be nil") + } +} diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go index 5889da4f7..030ff506c 100644 --- a/internal/processing/fromcommon.go +++ b/internal/processing/fromcommon.go @@ -20,7 +20,6 @@ package processing import ( "context" "errors" - "fmt" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" @@ -42,13 +41,13 @@ import ( func (p *Processor) timelineAndNotifyStatus(ctx context.Context, status *gtsmodel.Status) error { // Ensure status fully populated; including account, mentions, etc. if err := p.state.DB.PopulateStatus(ctx, status); err != nil { - return fmt.Errorf("timelineAndNotifyStatus: error populating status with id %s: %w", status.ID, err) + return gtserror.Newf("error populating status with id %s: %w", status.ID, err) } // Get local followers of the account that posted the status. follows, err := p.state.DB.GetAccountLocalFollowers(ctx, status.AccountID) if err != nil { - return fmt.Errorf("timelineAndNotifyStatus: error getting local followers for account id %s: %w", status.AccountID, err) + return gtserror.Newf("error getting local followers for account id %s: %w", status.AccountID, err) } // If the poster is also local, add a fake entry for them @@ -66,12 +65,12 @@ func (p *Processor) timelineAndNotifyStatus(ctx context.Context, status *gtsmode // This will also handle notifying any followers with notify // set to true on their follow. if err := p.timelineAndNotifyStatusForFollowers(ctx, status, follows); err != nil { - return fmt.Errorf("timelineAndNotifyStatus: error timelining status %s for followers: %w", status.ID, err) + return gtserror.Newf("error timelining status %s for followers: %w", status.ID, err) } // Notify each local account that's mentioned by this status. if err := p.notifyStatusMentions(ctx, status); err != nil { - return fmt.Errorf("timelineAndNotifyStatus: error notifying status mentions for status %s: %w", status.ID, err) + return gtserror.Newf("error notifying status mentions for status %s: %w", status.ID, err) } return nil @@ -79,7 +78,7 @@ func (p *Processor) timelineAndNotifyStatus(ctx context.Context, status *gtsmode func (p *Processor) timelineAndNotifyStatusForFollowers(ctx context.Context, status *gtsmodel.Status, follows []*gtsmodel.Follow) error { var ( - errs = make(gtserror.MultiError, 0, len(follows)) + errs = gtserror.NewMultiError(len(follows)) boost = status.BoostOfID != "" reply = status.InReplyToURI != "" ) @@ -100,7 +99,7 @@ func (p *Processor) timelineAndNotifyStatusForFollowers(ctx context.Context, sta follow.ID, ) if err != nil && !errors.Is(err, db.ErrNoEntries) { - errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error list timelining status: %w", err)) + errs.Appendf("error list timelining status: %w", err) continue } @@ -113,7 +112,7 @@ func (p *Processor) timelineAndNotifyStatusForFollowers(ctx context.Context, sta status, stream.TimelineList+":"+listEntry.ListID, // key streamType to this specific list ); err != nil { - errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error list timelining status: %w", err)) + errs.Appendf("error list timelining status: %w", err) continue } } @@ -128,7 +127,7 @@ func (p *Processor) timelineAndNotifyStatusForFollowers(ctx context.Context, sta status, stream.TimelineHome, ); err != nil { - errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error home timelining status: %w", err)) + errs.Appendf("error home timelining status: %w", err) continue } else if !timelined { // Status wasn't added to home tomeline, @@ -162,11 +161,15 @@ func (p *Processor) timelineAndNotifyStatusForFollowers(ctx context.Context, sta status.AccountID, status.ID, ); err != nil { - errs.Append(fmt.Errorf("timelineAndNotifyStatusForFollowers: error notifying account %s about new status: %w", follow.AccountID, err)) + errs.Appendf("error notifying account %s about new status: %w", follow.AccountID, err) } } - return errs.Combine() + if err := errs.Combine(); err != nil { + return gtserror.Newf("%w", err) + } + + return nil } // timelineStatus uses the provided ingest function to put the given @@ -185,7 +188,7 @@ func (p *Processor) timelineStatus( // Make sure the status is timelineable. // This works for both home and list timelines. if timelineable, err := p.filter.StatusHomeTimelineable(ctx, account, status); err != nil { - err = fmt.Errorf("timelineStatusForAccount: error getting timelineability for status for timeline with id %s: %w", account.ID, err) + err = gtserror.Newf("error getting timelineability for status for timeline with id %s: %w", account.ID, err) return false, err } else if !timelineable { // Nothing to do. @@ -194,7 +197,7 @@ func (p *Processor) timelineStatus( // Ingest status into given timeline using provided function. if inserted, err := ingest(ctx, timelineID, status); err != nil { - err = fmt.Errorf("timelineStatusForAccount: error ingesting status %s: %w", status.ID, err) + err = gtserror.Newf("error ingesting status %s: %w", status.ID, err) return false, err } else if !inserted { // Nothing more to do. @@ -204,12 +207,12 @@ func (p *Processor) timelineStatus( // The status was inserted so stream it to the user. apiStatus, err := p.tc.StatusToAPIStatus(ctx, status, account) if err != nil { - err = fmt.Errorf("timelineStatusForAccount: error converting status %s to frontend representation: %w", status.ID, err) + err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err) return true, err } if err := p.stream.Update(apiStatus, account, []string{streamType}); err != nil { - err = fmt.Errorf("timelineStatusForAccount: error streaming update for status %s: %w", status.ID, err) + err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err) return true, err } @@ -217,7 +220,7 @@ func (p *Processor) timelineStatus( } func (p *Processor) notifyStatusMentions(ctx context.Context, status *gtsmodel.Status) error { - errs := make(gtserror.MultiError, 0, len(status.Mentions)) + errs := gtserror.NewMultiError(len(status.Mentions)) for _, m := range status.Mentions { if err := p.notify( @@ -231,7 +234,11 @@ func (p *Processor) notifyStatusMentions(ctx context.Context, status *gtsmodel.S } } - return errs.Combine() + if err := errs.Combine(); err != nil { + return gtserror.Newf("%w", err) + } + + return nil } func (p *Processor) notifyFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest) error { @@ -255,13 +262,13 @@ func (p *Processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, t ) if err != nil && !errors.Is(err, db.ErrNoEntries) { // Proper error while checking. - return fmt.Errorf("notifyFollow: db error checking for previous follow request notification: %w", err) + return gtserror.Newf("db error checking for previous follow request notification: %w", err) } if prevNotif != nil { // Previous notification existed, delete. if err := p.state.DB.DeleteNotificationByID(ctx, prevNotif.ID); err != nil { - return fmt.Errorf("notifyFollow: db error removing previous follow request notification %s: %w", prevNotif.ID, err) + return gtserror.Newf("db error removing previous follow request notification %s: %w", prevNotif.ID, err) } } @@ -319,7 +326,7 @@ func (p *Processor) notify( ) error { targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID) if err != nil { - return fmt.Errorf("notify: error getting target account %s: %w", targetAccountID, err) + return gtserror.Newf("error getting target account %s: %w", targetAccountID, err) } if !targetAccount.IsLocal() { @@ -340,7 +347,7 @@ func (p *Processor) notify( return nil } else if !errors.Is(err, db.ErrNoEntries) { // Real error. - return fmt.Errorf("notify: error checking existence of notification: %w", err) + return gtserror.Newf("error checking existence of notification: %w", err) } // Notification doesn't yet exist, so @@ -354,17 +361,17 @@ func (p *Processor) notify( } if err := p.state.DB.PutNotification(ctx, notif); err != nil { - return fmt.Errorf("notify: error putting notification in database: %w", err) + return gtserror.Newf("error putting notification in database: %w", err) } // Stream notification to the user. apiNotif, err := p.tc.NotificationToAPINotification(ctx, notif) if err != nil { - return fmt.Errorf("notify: error converting notification to api representation: %w", err) + return gtserror.Newf("error converting notification to api representation: %w", err) } if err := p.stream.Notify(apiNotif, targetAccount); err != nil { - return fmt.Errorf("notify: error streaming notification to account: %w", err) + return gtserror.Newf("error streaming notification to account: %w", err) } return nil @@ -479,7 +486,7 @@ func (p *Processor) invalidateStatusFromTimelines(ctx context.Context, statusID func (p *Processor) emailReport(ctx context.Context, report *gtsmodel.Report) error { instance, err := p.state.DB.GetInstance(ctx, config.GetHost()) if err != nil { - return fmt.Errorf("emailReport: error getting instance: %w", err) + return gtserror.Newf("error getting instance: %w", err) } toAddresses, err := p.state.DB.GetInstanceModeratorAddresses(ctx) @@ -488,20 +495,20 @@ func (p *Processor) emailReport(ctx context.Context, report *gtsmodel.Report) er // No registered moderator addresses. return nil } - return fmt.Errorf("emailReport: error getting instance moderator addresses: %w", err) + return gtserror.Newf("error getting instance moderator addresses: %w", err) } if report.Account == nil { report.Account, err = p.state.DB.GetAccountByID(ctx, report.AccountID) if err != nil { - return fmt.Errorf("emailReport: error getting report account: %w", err) + return gtserror.Newf("error getting report account: %w", err) } } if report.TargetAccount == nil { report.TargetAccount, err = p.state.DB.GetAccountByID(ctx, report.TargetAccountID) if err != nil { - return fmt.Errorf("emailReport: error getting report target account: %w", err) + return gtserror.Newf("error getting report target account: %w", err) } } @@ -514,7 +521,7 @@ func (p *Processor) emailReport(ctx context.Context, report *gtsmodel.Report) er } if err := p.emailSender.SendNewReportEmail(toAddresses, reportData); err != nil { - return fmt.Errorf("emailReport: error emailing instance moderators: %w", err) + return gtserror.Newf("error emailing instance moderators: %w", err) } return nil @@ -523,7 +530,7 @@ func (p *Processor) emailReport(ctx context.Context, report *gtsmodel.Report) er func (p *Processor) emailReportClosed(ctx context.Context, report *gtsmodel.Report) error { user, err := p.state.DB.GetUserByAccountID(ctx, report.Account.ID) if err != nil { - return fmt.Errorf("emailReportClosed: db error getting user: %w", err) + return gtserror.Newf("db error getting user: %w", err) } if user.ConfirmedAt.IsZero() || !*user.Approved || *user.Disabled || user.Email == "" { @@ -537,20 +544,20 @@ func (p *Processor) emailReportClosed(ctx context.Context, report *gtsmodel.Repo instance, err := p.state.DB.GetInstance(ctx, config.GetHost()) if err != nil { - return fmt.Errorf("emailReportClosed: db error getting instance: %w", err) + return gtserror.Newf("db error getting instance: %w", err) } if report.Account == nil { report.Account, err = p.state.DB.GetAccountByID(ctx, report.AccountID) if err != nil { - return fmt.Errorf("emailReportClosed: error getting report account: %w", err) + return gtserror.Newf("error getting report account: %w", err) } } if report.TargetAccount == nil { report.TargetAccount, err = p.state.DB.GetAccountByID(ctx, report.TargetAccountID) if err != nil { - return fmt.Errorf("emailReportClosed: error getting report target account: %w", err) + return gtserror.Newf("error getting report target account: %w", err) } } diff --git a/internal/timeline/manager.go b/internal/timeline/manager.go index a701756bb..23b769c62 100644 --- a/internal/timeline/manager.go +++ b/internal/timeline/manager.go @@ -190,18 +190,18 @@ func (m *manager) GetOldestIndexedID(ctx context.Context, timelineID string) str } func (m *manager) WipeItemFromAllTimelines(ctx context.Context, itemID string) error { - errors := gtserror.MultiError{} + errs := new(gtserror.MultiError) m.timelines.Range(func(_ any, v any) bool { if _, err := v.(Timeline).Remove(ctx, itemID); err != nil { - errors.Append(err) + errs.Append(err) } return true // always continue range }) - if len(errors) > 0 { - return gtserror.Newf("error(s) wiping status %s: %w", itemID, errors.Combine()) + if err := errs.Combine(); err != nil { + return gtserror.Newf("error(s) wiping status %s: %w", itemID, errs.Combine()) } return nil @@ -213,21 +213,21 @@ func (m *manager) WipeItemsFromAccountID(ctx context.Context, timelineID string, } func (m *manager) UnprepareItemFromAllTimelines(ctx context.Context, itemID string) error { - errors := gtserror.MultiError{} + errs := new(gtserror.MultiError) // Work through all timelines held by this // manager, and call Unprepare for each. m.timelines.Range(func(_ any, v any) bool { // nolint:forcetypeassert if err := v.(Timeline).Unprepare(ctx, itemID); err != nil { - errors.Append(err) + errs.Append(err) } return true // always continue range }) - if len(errors) > 0 { - return gtserror.Newf("error(s) unpreparing status %s: %w", itemID, errors.Combine()) + if err := errs.Combine(); err != nil { + return gtserror.Newf("error(s) unpreparing status %s: %w", itemID, errs.Combine()) } return nil