diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 913d6eca7..c1f419d22 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -102,6 +102,9 @@ func (c *Caches) setuphooks() { // Invalidate follow request target account ID cached visibility. c.Visibility.Invalidate("ItemID", followReq.TargetAccountID) c.Visibility.Invalidate("RequesterID", followReq.TargetAccountID) + + // Invalidate any cached follow corresponding to this request. + c.GTS.Follow().Invalidate("AccountID.TargetAccountID", followReq.AccountID, followReq.TargetAccountID) }) c.GTS.Status().SetInvalidateCallback(func(status *gtsmodel.Status) { diff --git a/internal/db/bundb/relationship_follow_req.go b/internal/db/bundb/relationship_follow_req.go index 11200338d..ae398bf3b 100644 --- a/internal/db/bundb/relationship_follow_req.go +++ b/internal/db/bundb/relationship_follow_req.go @@ -204,7 +204,8 @@ func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountI return nil, r.conn.ProcessError(err) } - // Invalidate follow request from cache lookups. + // Invalidate follow request from cache lookups; this will + // invalidate the follow as well via the invalidate hook. r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID) // Delete original follow request notification @@ -225,12 +226,8 @@ func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountI } // Delete original follow request. - if _, err := r.conn. - NewDelete(). - Table("follow_requests"). - Where("? = ?", bun.Ident("id"), followReq.ID). - Exec(ctx); err != nil { - return r.conn.ProcessError(err) + if err := r.DeleteFollowRequestByID(ctx, followReq.ID); err != nil { + return err } // Delete original follow request notification diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go index 00583d175..9e5a71d60 100644 --- a/internal/db/bundb/relationship_test.go +++ b/internal/db/bundb/relationship_test.go @@ -568,6 +568,14 @@ func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() { account := suite.testAccounts["admin_account"] targetAccount := suite.testAccounts["local_account_2"] + // Fetch relationship before follow request. + relationship, err := suite.db.GetRelationship(ctx, account.ID, targetAccount.ID) + if err != nil { + suite.FailNow(err.Error()) + } + suite.False(relationship.Following) + suite.False(relationship.Requested) + followRequest := >smodel.FollowRequest{ ID: "01GEF753FWHCHRDWR0QEHBXM8W", URI: "http://localhost:8080/weeeeeeeeeeeeeeeee", @@ -575,10 +583,18 @@ func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() { TargetAccountID: targetAccount.ID, } - if err := suite.db.Put(ctx, followRequest); err != nil { + if err := suite.db.PutFollowRequest(ctx, followRequest); err != nil { suite.FailNow(err.Error()) } + // Fetch relationship while follow requested. + relationship, err = suite.db.GetRelationship(ctx, account.ID, targetAccount.ID) + if err != nil { + suite.FailNow(err.Error()) + } + suite.False(relationship.Following) + suite.True(relationship.Requested) + followRequestNotification := >smodel.Notification{ ID: "01GV8MY1Q9KX2ZSWN4FAQ3V1PB", OriginAccountID: account.ID, @@ -586,7 +602,7 @@ func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() { NotificationType: gtsmodel.NotificationFollowRequest, } - if err := suite.db.Put(ctx, followRequestNotification); err != nil { + if err := suite.db.PutNotification(ctx, followRequestNotification); err != nil { suite.FailNow(err.Error()) } @@ -599,6 +615,14 @@ func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() { notification, err := suite.db.GetNotificationByID(ctx, followRequestNotification.ID) suite.ErrorIs(err, db.ErrNoEntries) suite.Nil(notification) + + // Fetch relationship while followed. + relationship, err = suite.db.GetRelationship(ctx, account.ID, targetAccount.ID) + if err != nil { + suite.FailNow(err.Error()) + } + suite.True(relationship.Following) + suite.False(relationship.Requested) } func (suite *RelationshipTestSuite) TestAcceptFollowRequestNoNotification() {