From fe310f5230690740e2dd6612f214e4659e9ecebf Mon Sep 17 00:00:00 2001 From: tsmethurst Date: Sun, 28 Mar 2021 18:46:52 +0200 Subject: [PATCH] new shortcut --- internal/db/db.go | 5 +++++ internal/db/mock_DB.go | 28 +++++++++++++++++++++------- internal/db/pg.go | 32 +++++++++++++++++++++++++++++++- 3 files changed, 57 insertions(+), 8 deletions(-) diff --git a/internal/db/db.go b/internal/db/db.go index f6c298b97..755aaee64 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -109,6 +109,11 @@ type DB interface { // In case of no entries, a 'no entries' error will be returned GetAccountByUserID(userID string, account *model.Account) error + // GetFollowRequestsForAccountID is a shortcut for the common action of fetching a list of follow requests targeting the given account ID. + // The given slice 'followRequests' will be set to the result of the query, whatever it is. + // In case of no entries, a 'no entries' error will be returned + GetFollowRequestsForAccountID(accountID string, followRequests *[]model.FollowRequest) error + // GetFollowingByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is following. // The given slice 'following' will be set to the result of the query, whatever it is. // In case of no entries, a 'no entries' error will be returned diff --git a/internal/db/mock_DB.go b/internal/db/mock_DB.go index 8bf52e3b3..d4c25bb79 100644 --- a/internal/db/mock_DB.go +++ b/internal/db/mock_DB.go @@ -157,6 +157,20 @@ func (_m *MockDB) GetByID(id string, i interface{}) error { return r0 } +// GetFollowRequestsForAccountID provides a mock function with given fields: accountID, followRequests +func (_m *MockDB) GetFollowRequestsForAccountID(accountID string, followRequests *[]model.FollowRequest) error { + ret := _m.Called(accountID, followRequests) + + var r0 error + if rf, ok := ret.Get(0).(func(string, *[]model.FollowRequest) error); ok { + r0 = rf(accountID, followRequests) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // GetFollowersByAccountID provides a mock function with given fields: accountID, followers func (_m *MockDB) GetFollowersByAccountID(accountID string, followers *[]model.Follow) error { ret := _m.Called(accountID, followers) @@ -283,13 +297,13 @@ func (_m *MockDB) IsUsernameAvailable(username string) error { return r0 } -// NewSignup provides a mock function with given fields: username, reason, requireApproval, email, password, signUpIP, locale -func (_m *MockDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string) (*model.User, error) { - ret := _m.Called(username, reason, requireApproval, email, password, signUpIP, locale) +// NewSignup provides a mock function with given fields: username, reason, requireApproval, email, password, signUpIP, locale, appID +func (_m *MockDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string) (*model.User, error) { + ret := _m.Called(username, reason, requireApproval, email, password, signUpIP, locale, appID) var r0 *model.User - if rf, ok := ret.Get(0).(func(string, string, bool, string, string, net.IP, string) *model.User); ok { - r0 = rf(username, reason, requireApproval, email, password, signUpIP, locale) + if rf, ok := ret.Get(0).(func(string, string, bool, string, string, net.IP, string, string) *model.User); ok { + r0 = rf(username, reason, requireApproval, email, password, signUpIP, locale, appID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*model.User) @@ -297,8 +311,8 @@ func (_m *MockDB) NewSignup(username string, reason string, requireApproval bool } var r1 error - if rf, ok := ret.Get(1).(func(string, string, bool, string, string, net.IP, string) error); ok { - r1 = rf(username, reason, requireApproval, email, password, signUpIP, locale) + if rf, ok := ret.Get(1).(func(string, string, bool, string, string, net.IP, string, string) error); ok { + r1 = rf(username, reason, requireApproval, email, password, signUpIP, locale, appID) } else { r1 = ret.Error(1) } diff --git a/internal/db/pg.go b/internal/db/pg.go index 679011401..588bc6e27 100644 --- a/internal/db/pg.go +++ b/internal/db/pg.go @@ -317,6 +317,16 @@ func (ps *postgresService) GetAccountByUserID(userID string, account *model.Acco return nil } +func (ps *postgresService) GetFollowRequestsForAccountID(accountID string, followRequests *[]model.FollowRequest) error { + if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil { + if err == pg.ErrNoRows { + return ErrNoEntries{} + } + return err + } + return nil +} + func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]model.Follow) error { if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil { if err == pg.ErrNoRows { @@ -523,6 +533,26 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype lastStatusAt = lastStatus.CreatedAt.Format(time.RFC3339) } + fr := []model.FollowRequest{} + if err := ps.GetFollowRequestsForAccountID(a.ID, &fr); err != nil { + if _, ok := err.(ErrNoEntries); !ok { + return nil, fmt.Errorf("error getting follow requests: %s", err) + } + } + var frc int + if fr != nil { + frc = len(fr) + } + + source := &mastotypes.Source{ + Privacy: a.Privacy, + Sensitive: a.Sensitive, + Language: a.Language, + Note: a.Note, + Fields: fields, + FollowRequestsCount: frc, + } + return &mastotypes.Account{ ID: a.ID, Username: a.Username, @@ -541,7 +571,7 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype FollowingCount: followingCount, StatusesCount: statusesCount, LastStatusAt: lastStatusAt, - Source: nil, + Source: source, Emojis: nil, Fields: fields, }, nil