diff --git a/internal/subscriptions/domainperms.go b/internal/subscriptions/domainperms.go index b94f284bf..c9f569f94 100644 --- a/internal/subscriptions/domainperms.go +++ b/internal/subscriptions/domainperms.go @@ -19,6 +19,7 @@ package subscriptions import ( "bufio" + "cmp" "context" "encoding/csv" "encoding/json" @@ -869,10 +870,13 @@ func (s *Subscriptions) adoptPerm( perm.SetCreatedByAccount(permSub.CreatedByAccount) // Set new metadata on the perm. - perm.SetObfuscate(obfuscate) perm.SetPrivateComment(privateComment) perm.SetPublicComment(publicComment) + // Avoid trying to blat nil into the db directly by + // defaulting to false if not set on wanted perm. + perm.SetObfuscate(cmp.Or(obfuscate, util.Ptr(false))) + // Update the perm in the db. var err error switch p := perm.(type) { diff --git a/internal/subscriptions/subscriptions_test.go b/internal/subscriptions/subscriptions_test.go index d86d98691..133db4b7c 100644 --- a/internal/subscriptions/subscriptions_test.go +++ b/internal/subscriptions/subscriptions_test.go @@ -24,6 +24,7 @@ import ( "time" "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/subscriptions" @@ -814,6 +815,141 @@ func (suite *SubscriptionsTestSuite) TestAdoption() { suite.Equal(testSubscription.ID, existingBlock3.SubscriptionID) } +func (suite *SubscriptionsTestSuite) TestDomainAllowsAndBlocks() { + var ( + ctx = context.Background() + testStructs = testrig.SetupTestStructs(rMediaPath, rTemplatePath) + testAccount = suite.testAccounts["admin_account"] + subscriptions = subscriptions.New( + testStructs.State, + testStructs.TransportController, + testStructs.TypeConverter, + ) + + // Create a subscription for a CSV list of goodies. + // This one adopts orphans. + testAllowSubscription = >smodel.DomainPermissionSubscription{ + ID: "01JGE681TQSBPAV59GZXPKE62H", + Priority: 255, + Title: "goodies!", + PermissionType: gtsmodel.DomainPermissionAllow, + AsDraft: util.Ptr(false), + AdoptOrphans: util.Ptr(true), + CreatedByAccountID: testAccount.ID, + CreatedByAccount: testAccount, + URI: "https://lists.example.org/goodies", + ContentType: gtsmodel.DomainPermSubContentTypePlain, + } + + existingAllow = >smodel.DomainAllow{ + ID: "01JHX2V5WN250TKB6FQ1M3QE1H", + Domain: "people.we.like.com", + CreatedByAccount: testAccount, + CreatedByAccountID: testAccount.ID, + } + + testBlockSubscription = >smodel.DomainPermissionSubscription{ + ID: "01JPMVY19TKZND838Z7Y6S4EG8", + Priority: 255, + Title: "baddies!", + PermissionType: gtsmodel.DomainPermissionBlock, + AsDraft: util.Ptr(false), + AdoptOrphans: util.Ptr(false), + CreatedByAccountID: testAccount.ID, + CreatedByAccount: testAccount, + URI: "https://lists.example.org/baddies.csv", + ContentType: gtsmodel.DomainPermSubContentTypeCSV, + } + ) + defer testrig.TearDownTestStructs(testStructs) + + // Store test subscriptions. + if err := testStructs.State.DB.PutDomainPermissionSubscription( + ctx, testAllowSubscription, + ); err != nil { + suite.FailNow(err.Error()) + } + if err := testStructs.State.DB.PutDomainPermissionSubscription( + ctx, testBlockSubscription, + ); err != nil { + suite.FailNow(err.Error()) + } + + // Store existing allow. + if err := testStructs.State.DB.CreateDomainAllow(ctx, existingAllow); err != nil { + suite.FailNow(err.Error()) + } + + // Put the instance in allowlist mode. + config.SetInstanceFederationMode("allowlist") + + // Fetch + process subscribed perms in order. + var order [2]gtsmodel.DomainPermissionType + if config.GetInstanceFederationMode() == config.InstanceFederationModeBlocklist { + order = [2]gtsmodel.DomainPermissionType{ + gtsmodel.DomainPermissionAllow, + gtsmodel.DomainPermissionBlock, + } + } else { + order = [2]gtsmodel.DomainPermissionType{ + gtsmodel.DomainPermissionBlock, + gtsmodel.DomainPermissionAllow, + } + } + for _, permType := range order { + subscriptions.ProcessDomainPermissionSubscriptions(ctx, permType) + } + + // We should now have allows for each + // domain on the subscribed allow list. + for _, domain := range []string{ + "people.we.like.com", + "goodeggs.org", + "allowthesefolks.church", + } { + var ( + perm gtsmodel.DomainPermission + err error + ) + if !testrig.WaitFor(func() bool { + perm, err = testStructs.State.DB.GetDomainAllow(ctx, domain) + return err == nil + }) { + suite.FailNowf("", "timed out waiting for domain %s", domain) + } + + suite.Equal(testAllowSubscription.ID, perm.GetSubscriptionID()) + } + + // And blocks for for each domain + // on the subscribed block list. + for _, domain := range []string{ + "bumfaces.net", + "peepee.poopoo", + "nothanks.com", + } { + var ( + perm gtsmodel.DomainPermission + err error + ) + if !testrig.WaitFor(func() bool { + perm, err = testStructs.State.DB.GetDomainBlock(ctx, domain) + return err == nil + }) { + suite.FailNowf("", "timed out waiting for domain %s", domain) + } + + suite.Equal(testBlockSubscription.ID, perm.GetSubscriptionID()) + } + + var err error + existingAllow, err = testStructs.State.DB.GetDomainAllow(ctx, "people.we.like.com") + if err != nil { + suite.FailNow(err.Error()) + } + suite.Equal(existingAllow.SubscriptionID, testAllowSubscription.ID) +} + func TestSubscriptionTestSuite(t *testing.T) { suite.Run(t, new(SubscriptionsTestSuite)) } diff --git a/testrig/transportcontroller.go b/testrig/transportcontroller.go index a6b0dd801..bbcb3901d 100644 --- a/testrig/transportcontroller.go +++ b/testrig/transportcontroller.go @@ -640,6 +640,10 @@ nothanks.com` } ]` jsonRespETag = "\"don't modify me daddy\"" + allowsResp = `people.we.like.com +goodeggs.org +allowthesefolks.church` + allowsRespETag = "\"never change\"" ) switch req.URL.String() { @@ -720,6 +724,36 @@ nothanks.com` } responseContentLength = len(responseBytes) + case "https://lists.example.org/goodies.csv": + extraHeaders = map[string]string{ + "Last-Modified": lastModified, + "ETag": allowsRespETag, + } + if req.Header.Get("If-None-Match") == allowsRespETag { + // Cached. + responseCode = http.StatusNotModified + } else { + responseBytes = []byte(allowsResp) + responseContentType = textCSV + responseCode = http.StatusOK + } + responseContentLength = len(responseBytes) + + case "https://lists.example.org/goodies": + extraHeaders = map[string]string{ + "Last-Modified": lastModified, + "ETag": allowsRespETag, + } + if req.Header.Get("If-None-Match") == allowsRespETag { + // Cached. + responseCode = http.StatusNotModified + } else { + responseBytes = []byte(allowsResp) + responseContentType = textPlain + responseCode = http.StatusOK + } + responseContentLength = len(responseBytes) + default: responseCode = http.StatusNotFound responseBytes = []byte(`{"error":"not found"}`)