mirror of
https://github.com/superseriousbusiness/gotosocial
synced 2025-06-05 21:59:39 +02:00
[chore] move client/federator workerpools to Workers{} (#1575)
* replace concurrency worker pools with base models in State.Workers, update code and tests accordingly * improve code comment * change back testrig default log level * un-comment-out TestAnnounceTwice() and fix --------- Signed-off-by: kim <grufwub@gmail.com> Reviewed-by: tobi
This commit is contained in:
@@ -19,13 +19,11 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/text"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/visibility"
|
||||
@@ -35,35 +33,32 @@ import (
|
||||
//
|
||||
// It also contains logic for actions towards accounts such as following, blocking, seeing follows, etc.
|
||||
type Processor struct {
|
||||
state *state.State
|
||||
tc typeutils.TypeConverter
|
||||
mediaManager media.Manager
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
|
||||
oauthServer oauth.Server
|
||||
filter visibility.Filter
|
||||
formatter text.Formatter
|
||||
db db.DB
|
||||
federator federation.Federator
|
||||
parseMention gtsmodel.ParseMentionFunc
|
||||
}
|
||||
|
||||
// New returns a new account processor.
|
||||
func New(
|
||||
db db.DB,
|
||||
state *state.State,
|
||||
tc typeutils.TypeConverter,
|
||||
mediaManager media.Manager,
|
||||
oauthServer oauth.Server,
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI],
|
||||
federator federation.Federator,
|
||||
parseMention gtsmodel.ParseMentionFunc,
|
||||
) Processor {
|
||||
return Processor{
|
||||
state: state,
|
||||
tc: tc,
|
||||
mediaManager: mediaManager,
|
||||
clientWorker: clientWorker,
|
||||
oauthServer: oauthServer,
|
||||
filter: visibility.NewFilter(db),
|
||||
formatter: text.NewFormatter(db),
|
||||
db: db,
|
||||
filter: visibility.NewFilter(state.DB),
|
||||
formatter: text.NewFormatter(state.DB),
|
||||
federator: federator,
|
||||
parseMention: parseMention,
|
||||
}
|
||||
|
@@ -22,7 +22,6 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
@@ -32,6 +31,7 @@ import (
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing/account"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/storage"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
@@ -44,6 +44,7 @@ type AccountStandardTestSuite struct {
|
||||
db db.DB
|
||||
tc typeutils.TypeConverter
|
||||
storage *storage.Driver
|
||||
state state.State
|
||||
mediaManager media.Manager
|
||||
oauthServer oauth.Server
|
||||
fromClientAPIChan chan messages.FromClientAPI
|
||||
@@ -76,30 +77,30 @@ func (suite *AccountStandardTestSuite) SetupSuite() {
|
||||
}
|
||||
|
||||
func (suite *AccountStandardTestSuite) SetupTest() {
|
||||
suite.state.Caches.Init()
|
||||
testrig.StartWorkers(&suite.state)
|
||||
|
||||
testrig.InitTestLog()
|
||||
testrig.InitTestConfig()
|
||||
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
clientWorker.SetProcessor(func(_ context.Context, msg messages.FromClientAPI) error {
|
||||
suite.fromClientAPIChan <- msg
|
||||
return nil
|
||||
})
|
||||
|
||||
_ = fedWorker.Start()
|
||||
_ = clientWorker.Start()
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.db = testrig.NewTestDB(&suite.state)
|
||||
suite.state.DB = suite.db
|
||||
suite.tc = testrig.NewTestTypeConverter(suite.db)
|
||||
suite.storage = testrig.NewInMemoryStorage()
|
||||
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
|
||||
suite.state.Storage = suite.storage
|
||||
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
|
||||
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
|
||||
|
||||
suite.fromClientAPIChan = make(chan messages.FromClientAPI, 100)
|
||||
suite.transportController = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker)
|
||||
suite.federator = testrig.NewTestFederator(suite.db, suite.transportController, suite.storage, suite.mediaManager, fedWorker)
|
||||
suite.state.Workers.EnqueueClientAPI = func(ctx context.Context, msg messages.FromClientAPI) {
|
||||
suite.fromClientAPIChan <- msg
|
||||
}
|
||||
|
||||
suite.transportController = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media"))
|
||||
suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager)
|
||||
suite.sentEmails = make(map[string]string)
|
||||
suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails)
|
||||
suite.accountProcessor = account.New(suite.db, suite.tc, suite.mediaManager, suite.oauthServer, clientWorker, suite.federator, processing.GetParseMentionFunc(suite.db, suite.federator))
|
||||
suite.accountProcessor = account.New(&suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, processing.GetParseMentionFunc(suite.db, suite.federator))
|
||||
testrig.StandardDBSetup(suite.db, nil)
|
||||
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")
|
||||
}
|
||||
@@ -107,4 +108,5 @@ func (suite *AccountStandardTestSuite) SetupTest() {
|
||||
func (suite *AccountStandardTestSuite) TearDownTest() {
|
||||
testrig.StandardDBTeardown(suite.db)
|
||||
testrig.StandardStorageTeardown(suite.storage)
|
||||
testrig.StopWorkers(&suite.state)
|
||||
}
|
||||
|
@@ -36,13 +36,13 @@ import (
|
||||
// BlockCreate handles the creation of a block from requestingAccount to targetAccountID, either remote or local.
|
||||
func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
|
||||
// make sure the target account actually exists in our db
|
||||
targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID)
|
||||
targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err))
|
||||
}
|
||||
|
||||
// if requestingAccount already blocks target account, we don't need to do anything
|
||||
if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, false); err != nil {
|
||||
if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, false); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error checking existence of block: %s", err))
|
||||
} else if blocked {
|
||||
return p.RelationshipGet(ctx, requestingAccount, targetAccountID)
|
||||
@@ -64,18 +64,18 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
|
||||
block.URI = uris.GenerateURIForBlock(requestingAccount.Username, newBlockID)
|
||||
|
||||
// whack it in the database
|
||||
if err := p.db.PutBlock(ctx, block); err != nil {
|
||||
if err := p.state.DB.PutBlock(ctx, block); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error creating block in db: %s", err))
|
||||
}
|
||||
|
||||
// clear any follows or follow requests from the blocked account to the target account -- this is a simple delete
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{
|
||||
{Key: "account_id", Value: targetAccountID},
|
||||
{Key: "target_account_id", Value: requestingAccount.ID},
|
||||
}, >smodel.Follow{}); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow in db: %s", err))
|
||||
}
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{
|
||||
{Key: "account_id", Value: targetAccountID},
|
||||
{Key: "target_account_id", Value: requestingAccount.ID},
|
||||
}, >smodel.FollowRequest{}); err != nil {
|
||||
@@ -89,12 +89,12 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
|
||||
var frChanged bool
|
||||
var frURI string
|
||||
fr := >smodel.FollowRequest{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{
|
||||
{Key: "account_id", Value: requestingAccount.ID},
|
||||
{Key: "target_account_id", Value: targetAccountID},
|
||||
}, fr); err == nil {
|
||||
frURI = fr.URI
|
||||
if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil {
|
||||
if err := p.state.DB.DeleteByID(ctx, fr.ID, fr); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow request from db: %s", err))
|
||||
}
|
||||
frChanged = true
|
||||
@@ -104,12 +104,12 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
|
||||
var fChanged bool
|
||||
var fURI string
|
||||
f := >smodel.Follow{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{
|
||||
{Key: "account_id", Value: requestingAccount.ID},
|
||||
{Key: "target_account_id", Value: targetAccountID},
|
||||
}, f); err == nil {
|
||||
fURI = f.URI
|
||||
if err := p.db.DeleteByID(ctx, f.ID, f); err != nil {
|
||||
if err := p.state.DB.DeleteByID(ctx, f.ID, f); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow from db: %s", err))
|
||||
}
|
||||
fChanged = true
|
||||
@@ -117,7 +117,7 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
|
||||
|
||||
// follow request status changed so send the UNDO activity to the channel for async processing
|
||||
if frChanged {
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActivityFollow,
|
||||
APActivityType: ap.ActivityUndo,
|
||||
GTSModel: >smodel.Follow{
|
||||
@@ -132,7 +132,7 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
|
||||
|
||||
// follow status changed so send the UNDO activity to the channel for async processing
|
||||
if fChanged {
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActivityFollow,
|
||||
APActivityType: ap.ActivityUndo,
|
||||
GTSModel: >smodel.Follow{
|
||||
@@ -146,7 +146,7 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
|
||||
}
|
||||
|
||||
// handle the rest of the block process asynchronously
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActivityBlock,
|
||||
APActivityType: ap.ActivityCreate,
|
||||
GTSModel: block,
|
||||
@@ -160,23 +160,23 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel
|
||||
// BlockRemove handles the removal of a block from requestingAccount to targetAccountID, either remote or local.
|
||||
func (p *Processor) BlockRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
|
||||
// make sure the target account actually exists in our db
|
||||
targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID)
|
||||
targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err))
|
||||
}
|
||||
|
||||
// check if a block exists, and remove it if it does
|
||||
block, err := p.db.GetBlock(ctx, requestingAccount.ID, targetAccountID)
|
||||
block, err := p.state.DB.GetBlock(ctx, requestingAccount.ID, targetAccountID)
|
||||
if err == nil {
|
||||
// we got a block, remove it
|
||||
block.Account = requestingAccount
|
||||
block.TargetAccount = targetAccount
|
||||
if err := p.db.DeleteBlockByID(ctx, block.ID); err != nil {
|
||||
if err := p.state.DB.DeleteBlockByID(ctx, block.ID); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockRemove: error removing block from db: %s", err))
|
||||
}
|
||||
|
||||
// send the UNDO activity to the client worker for async processing
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActivityBlock,
|
||||
APActivityType: ap.ActivityUndo,
|
||||
GTSModel: block,
|
||||
|
@@ -34,7 +34,7 @@ import (
|
||||
// BookmarksGet returns a pageable response of statuses that are bookmarked by requestingAccount.
|
||||
// Paging for this response is done based on bookmark ID rather than status ID.
|
||||
func (p *Processor) BookmarksGet(ctx context.Context, requestingAccount *gtsmodel.Account, limit int, maxID string, minID string) (*apimodel.PageableResponse, gtserror.WithCode) {
|
||||
bookmarks, err := p.db.GetBookmarks(ctx, requestingAccount.ID, limit, maxID, minID)
|
||||
bookmarks, err := p.state.DB.GetBookmarks(ctx, requestingAccount.ID, limit, maxID, minID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
@@ -47,7 +47,7 @@ func (p *Processor) BookmarksGet(ctx context.Context, requestingAccount *gtsmode
|
||||
)
|
||||
|
||||
for _, bookmark := range bookmarks {
|
||||
status, err := p.db.GetStatusByID(ctx, bookmark.StatusID)
|
||||
status, err := p.state.DB.GetStatusByID(ctx, bookmark.StatusID)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNoEntries) {
|
||||
// We just don't have the status for some reason.
|
||||
|
@@ -35,7 +35,7 @@ import (
|
||||
|
||||
// Create processes the given form for creating a new account, returning an oauth token for that account if successful.
|
||||
func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, gtserror.WithCode) {
|
||||
emailAvailable, err := p.db.IsEmailAvailable(ctx, form.Email)
|
||||
emailAvailable, err := p.state.DB.IsEmailAvailable(ctx, form.Email)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorBadRequest(err)
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf
|
||||
return nil, gtserror.NewErrorConflict(fmt.Errorf("email address %s is not available", form.Email))
|
||||
}
|
||||
|
||||
usernameAvailable, err := p.db.IsUsernameAvailable(ctx, form.Username)
|
||||
usernameAvailable, err := p.state.DB.IsUsernameAvailable(ctx, form.Username)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorBadRequest(err)
|
||||
}
|
||||
@@ -61,7 +61,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf
|
||||
}
|
||||
|
||||
log.Trace(ctx, "creating new username and account")
|
||||
user, err := p.db.NewSignup(ctx, form.Username, text.SanitizePlaintext(reason), approvalRequired, form.Email, form.Password, form.IP, form.Locale, application.ID, false, "", false)
|
||||
user, err := p.state.DB.NewSignup(ctx, form.Username, text.SanitizePlaintext(reason), approvalRequired, form.Email, form.Password, form.IP, form.Locale, application.ID, false, "", false)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error creating new signup in the database: %s", err))
|
||||
}
|
||||
@@ -73,7 +73,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf
|
||||
}
|
||||
|
||||
if user.Account == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, user.AccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, user.AccountID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting new account from the database: %s", err))
|
||||
}
|
||||
@@ -82,7 +82,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf
|
||||
|
||||
// there are side effects for creating a new account (sending confirmation emails etc)
|
||||
// so pass a message to the processor so that it can do it asynchronously
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ObjectProfile,
|
||||
APActivityType: ap.ActivityCreate,
|
||||
GTSModel: user.Account,
|
||||
|
@@ -54,22 +54,22 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
|
||||
if account.Domain == "" {
|
||||
// see if we can get a user for this account
|
||||
var err error
|
||||
if user, err = p.db.GetUserByAccountID(ctx, account.ID); err == nil {
|
||||
if user, err = p.state.DB.GetUserByAccountID(ctx, account.ID); err == nil {
|
||||
// we got one! select all tokens with the user's ID
|
||||
tokens := []*gtsmodel.Token{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err == nil {
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err == nil {
|
||||
// we have some tokens to delete
|
||||
for _, t := range tokens {
|
||||
// delete client(s) associated with this token
|
||||
if err := p.db.DeleteByID(ctx, t.ClientID, >smodel.Client{}); err != nil {
|
||||
if err := p.state.DB.DeleteByID(ctx, t.ClientID, >smodel.Client{}); err != nil {
|
||||
l.Errorf("error deleting oauth client: %s", err)
|
||||
}
|
||||
// delete application(s) associated with this token
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, >smodel.Application{}); err != nil {
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, >smodel.Application{}); err != nil {
|
||||
l.Errorf("error deleting application: %s", err)
|
||||
}
|
||||
// delete the token itself
|
||||
if err := p.db.DeleteByID(ctx, t.ID, t); err != nil {
|
||||
if err := p.state.DB.DeleteByID(ctx, t.ID, t); err != nil {
|
||||
l.Errorf("error deleting oauth token: %s", err)
|
||||
}
|
||||
}
|
||||
@@ -80,12 +80,12 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
|
||||
// 2. Delete account's blocks
|
||||
l.Trace("deleting account blocks")
|
||||
// first delete any blocks that this account created
|
||||
if err := p.db.DeleteBlocksByOriginAccountID(ctx, account.ID); err != nil {
|
||||
if err := p.state.DB.DeleteBlocksByOriginAccountID(ctx, account.ID); err != nil {
|
||||
l.Errorf("error deleting blocks created by account: %s", err)
|
||||
}
|
||||
|
||||
// now delete any blocks that target this account
|
||||
if err := p.db.DeleteBlocksByTargetAccountID(ctx, account.ID); err != nil {
|
||||
if err := p.state.DB.DeleteBlocksByTargetAccountID(ctx, account.ID); err != nil {
|
||||
l.Errorf("error deleting blocks targeting account: %s", err)
|
||||
}
|
||||
|
||||
@@ -96,12 +96,12 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
|
||||
// TODO: federate these if necessary
|
||||
l.Trace("deleting account follow requests")
|
||||
// first delete any follow requests that this account created
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
|
||||
l.Errorf("error deleting follow requests created by account: %s", err)
|
||||
}
|
||||
|
||||
// now delete any follow requests that target this account
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil {
|
||||
l.Errorf("error deleting follow requests targeting account: %s", err)
|
||||
}
|
||||
|
||||
@@ -109,12 +109,12 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
|
||||
// TODO: federate these if necessary
|
||||
l.Trace("deleting account follows")
|
||||
// first delete any follows that this account created
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
|
||||
l.Errorf("error deleting follows created by account: %s", err)
|
||||
}
|
||||
|
||||
// now delete any follows that target this account
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil {
|
||||
l.Errorf("error deleting follows targeting account: %s", err)
|
||||
}
|
||||
|
||||
@@ -129,7 +129,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
|
||||
|
||||
for {
|
||||
// Fetch next block of account statuses from database
|
||||
statuses, err := p.db.GetAccountStatuses(ctx, account.ID, 20, false, false, maxID, "", false, false)
|
||||
statuses, err := p.state.DB.GetAccountStatuses(ctx, account.ID, 20, false, false, maxID, "", false, false)
|
||||
if err != nil {
|
||||
if !errors.Is(err, db.ErrNoEntries) {
|
||||
// an actual error has occurred
|
||||
@@ -149,7 +149,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
|
||||
l.Tracef("queue client API status delete: %s", status.ID)
|
||||
|
||||
// pass the status delete through the client api channel for processing
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ObjectNote,
|
||||
APActivityType: ap.ActivityDelete,
|
||||
GTSModel: status,
|
||||
@@ -158,7 +158,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
|
||||
})
|
||||
|
||||
// Look for any boosts of this status in DB
|
||||
boosts, err := p.db.GetStatusReblogs(ctx, status)
|
||||
boosts, err := p.state.DB.GetStatusReblogs(ctx, status)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
l.Errorf("error fetching status reblogs for %q: %v", status.ID, err)
|
||||
continue
|
||||
@@ -167,7 +167,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
|
||||
for _, boost := range boosts {
|
||||
if boost.Account == nil {
|
||||
// Fetch the relevant account for this status boost
|
||||
boostAcc, err := p.db.GetAccountByID(ctx, boost.AccountID)
|
||||
boostAcc, err := p.state.DB.GetAccountByID(ctx, boost.AccountID)
|
||||
if err != nil {
|
||||
l.Errorf("error fetching boosted status account for %q: %v", boost.AccountID, err)
|
||||
continue
|
||||
@@ -180,7 +180,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
|
||||
l.Tracef("queue client API boost delete: %s", status.ID)
|
||||
|
||||
// pass the boost delete through the client api channel for processing
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActivityAnnounce,
|
||||
APActivityType: ap.ActivityUndo,
|
||||
GTSModel: status,
|
||||
@@ -197,31 +197,31 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
|
||||
// 10. Delete account's notifications
|
||||
l.Trace("deleting account notifications")
|
||||
// first notifications created by account
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil {
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil {
|
||||
l.Errorf("error deleting notifications created by account: %s", err)
|
||||
}
|
||||
|
||||
// now notifications targeting account
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil {
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil {
|
||||
l.Errorf("error deleting notifications targeting account: %s", err)
|
||||
}
|
||||
|
||||
// 11. Delete account's bookmarks
|
||||
l.Trace("deleting account bookmarks")
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil {
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil {
|
||||
l.Errorf("error deleting bookmarks created by account: %s", err)
|
||||
}
|
||||
|
||||
// 12. Delete account's faves
|
||||
// TODO: federate these if necessary
|
||||
l.Trace("deleting account faves")
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil {
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil {
|
||||
l.Errorf("error deleting faves created by account: %s", err)
|
||||
}
|
||||
|
||||
// 13. Delete account's mutes
|
||||
l.Trace("deleting account mutes")
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil {
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil {
|
||||
l.Errorf("error deleting status mutes created by account: %s", err)
|
||||
}
|
||||
|
||||
@@ -234,7 +234,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
|
||||
// 16. Delete account's user
|
||||
if user != nil {
|
||||
l.Trace("deleting account user")
|
||||
if err := p.db.DeleteUserByID(ctx, user.ID); err != nil {
|
||||
if err := p.state.DB.DeleteUserByID(ctx, user.ID); err != nil {
|
||||
return gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
}
|
||||
@@ -261,7 +261,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
|
||||
account.Discoverable = &discoverable
|
||||
account.SuspendedAt = time.Now()
|
||||
account.SuspensionOrigin = origin
|
||||
err := p.db.UpdateAccount(ctx, account)
|
||||
err := p.state.DB.UpdateAccount(ctx, account)
|
||||
if err != nil {
|
||||
return gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
@@ -281,7 +281,7 @@ func (p *Processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account,
|
||||
|
||||
if form.DeleteOriginID == account.ID {
|
||||
// the account owner themself has requested deletion via the API, get their user from the db
|
||||
user, err := p.db.GetUserByAccountID(ctx, account.ID)
|
||||
user, err := p.state.DB.GetUserByAccountID(ctx, account.ID)
|
||||
if err != nil {
|
||||
return gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
@@ -301,7 +301,7 @@ func (p *Processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account,
|
||||
} else {
|
||||
// the delete has been requested by some other account, grab it;
|
||||
// if we've reached this point we know it has permission already
|
||||
requestingAccount, err := p.db.GetAccountByID(ctx, form.DeleteOriginID)
|
||||
requestingAccount, err := p.state.DB.GetAccountByID(ctx, form.DeleteOriginID)
|
||||
if err != nil {
|
||||
return gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
@@ -310,7 +310,7 @@ func (p *Processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account,
|
||||
}
|
||||
|
||||
// put the delete in the processor queue to handle the rest of it asynchronously
|
||||
p.clientWorker.Queue(fromClientAPIMessage)
|
||||
p.state.Workers.EnqueueClientAPI(ctx, fromClientAPIMessage)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -35,14 +35,14 @@ import (
|
||||
// FollowCreate handles a follow request to an account, either remote or local.
|
||||
func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) {
|
||||
// if there's a block between the accounts we shouldn't create the request ofc
|
||||
if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, form.ID, true); err != nil {
|
||||
if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, form.ID, true); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
} else if blocked {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
|
||||
}
|
||||
|
||||
// make sure the target account actually exists in our db
|
||||
targetAcct, err := p.db.GetAccountByID(ctx, form.ID)
|
||||
targetAcct, err := p.state.DB.GetAccountByID(ctx, form.ID)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err))
|
||||
@@ -51,7 +51,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
|
||||
}
|
||||
|
||||
// check if a follow exists already
|
||||
if follows, err := p.db.IsFollowing(ctx, requestingAccount, targetAcct); err != nil {
|
||||
if follows, err := p.state.DB.IsFollowing(ctx, requestingAccount, targetAcct); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow in db: %s", err))
|
||||
} else if follows {
|
||||
// already follows so just return the relationship
|
||||
@@ -59,7 +59,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
|
||||
}
|
||||
|
||||
// check if a follow request exists already
|
||||
if followRequested, err := p.db.IsFollowRequested(ctx, requestingAccount, targetAcct); err != nil {
|
||||
if followRequested, err := p.state.DB.IsFollowRequested(ctx, requestingAccount, targetAcct); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow request in db: %s", err))
|
||||
} else if followRequested {
|
||||
// already follow requested so just return the relationship
|
||||
@@ -95,13 +95,13 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
|
||||
}
|
||||
|
||||
// whack it in the database
|
||||
if err := p.db.Put(ctx, fr); err != nil {
|
||||
if err := p.state.DB.Put(ctx, fr); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error creating follow request in db: %s", err))
|
||||
}
|
||||
|
||||
// if it's a local account that's not locked we can just straight up accept the follow request
|
||||
if !*targetAcct.Locked && targetAcct.Domain == "" {
|
||||
if _, err := p.db.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil {
|
||||
if _, err := p.state.DB.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error accepting folow request for local unlocked account: %s", err))
|
||||
}
|
||||
// return the new relationship
|
||||
@@ -109,7 +109,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
|
||||
}
|
||||
|
||||
// otherwise we leave the follow request as it is and we handle the rest of the process asynchronously
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActivityFollow,
|
||||
APActivityType: ap.ActivityCreate,
|
||||
GTSModel: fr,
|
||||
@@ -124,7 +124,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
|
||||
// FollowRemove handles the removal of a follow/follow request to an account, either remote or local.
|
||||
func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
|
||||
// if there's a block between the accounts we shouldn't do anything
|
||||
blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true)
|
||||
blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
@@ -133,7 +133,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
|
||||
}
|
||||
|
||||
// make sure the target account actually exists in our db
|
||||
targetAcct, err := p.db.GetAccountByID(ctx, targetAccountID)
|
||||
targetAcct, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err))
|
||||
@@ -144,12 +144,12 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
|
||||
var frChanged bool
|
||||
var frURI string
|
||||
fr := >smodel.FollowRequest{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{
|
||||
{Key: "account_id", Value: requestingAccount.ID},
|
||||
{Key: "target_account_id", Value: targetAccountID},
|
||||
}, fr); err == nil {
|
||||
frURI = fr.URI
|
||||
if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil {
|
||||
if err := p.state.DB.DeleteByID(ctx, fr.ID, fr); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow request from db: %s", err))
|
||||
}
|
||||
frChanged = true
|
||||
@@ -159,12 +159,12 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
|
||||
var fChanged bool
|
||||
var fURI string
|
||||
f := >smodel.Follow{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{
|
||||
{Key: "account_id", Value: requestingAccount.ID},
|
||||
{Key: "target_account_id", Value: targetAccountID},
|
||||
}, f); err == nil {
|
||||
fURI = f.URI
|
||||
if err := p.db.DeleteByID(ctx, f.ID, f); err != nil {
|
||||
if err := p.state.DB.DeleteByID(ctx, f.ID, f); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow from db: %s", err))
|
||||
}
|
||||
fChanged = true
|
||||
@@ -172,7 +172,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
|
||||
|
||||
// follow request status changed so send the UNDO activity to the channel for async processing
|
||||
if frChanged {
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActivityFollow,
|
||||
APActivityType: ap.ActivityUndo,
|
||||
GTSModel: >smodel.Follow{
|
||||
@@ -187,7 +187,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode
|
||||
|
||||
// follow status changed so send the UNDO activity to the channel for async processing
|
||||
if fChanged {
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActivityFollow,
|
||||
APActivityType: ap.ActivityUndo,
|
||||
GTSModel: >smodel.Follow{
|
||||
|
@@ -33,7 +33,7 @@ import (
|
||||
|
||||
// Get processes the given request for account information.
|
||||
func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, gtserror.WithCode) {
|
||||
targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID)
|
||||
targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(errors.New("account not found"))
|
||||
@@ -46,7 +46,7 @@ func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account
|
||||
|
||||
// GetLocalByUsername processes the given request for account information targeting a local account by username.
|
||||
func (p *Processor) GetLocalByUsername(ctx context.Context, requestingAccount *gtsmodel.Account, username string) (*apimodel.Account, gtserror.WithCode) {
|
||||
targetAccount, err := p.db.GetAccountByUsernameDomain(ctx, username, "")
|
||||
targetAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, username, "")
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(errors.New("account not found"))
|
||||
@@ -59,7 +59,7 @@ func (p *Processor) GetLocalByUsername(ctx context.Context, requestingAccount *g
|
||||
|
||||
// GetCustomCSSForUsername returns custom css for the given local username.
|
||||
func (p *Processor) GetCustomCSSForUsername(ctx context.Context, username string) (string, gtserror.WithCode) {
|
||||
customCSS, err := p.db.GetAccountCustomCSSByUsername(ctx, username)
|
||||
customCSS, err := p.state.DB.GetAccountCustomCSSByUsername(ctx, username)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return "", gtserror.NewErrorNotFound(errors.New("account not found"))
|
||||
@@ -74,7 +74,7 @@ func (p *Processor) getFor(ctx context.Context, requestingAccount *gtsmodel.Acco
|
||||
var blocked bool
|
||||
var err error
|
||||
if requestingAccount != nil {
|
||||
blocked, err = p.db.IsBlocked(ctx, requestingAccount.ID, targetAccount.ID, true)
|
||||
blocked, err = p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccount.ID, true)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking account block: %s", err))
|
||||
}
|
||||
|
@@ -31,14 +31,14 @@ import (
|
||||
|
||||
// FollowersGet fetches a list of the target account's followers.
|
||||
func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
|
||||
if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
|
||||
if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
} else if blocked {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
|
||||
}
|
||||
|
||||
accounts := []apimodel.Account{}
|
||||
follows, err := p.db.GetAccountFollowedBy(ctx, targetAccountID, false)
|
||||
follows, err := p.state.DB.GetAccountFollowedBy(ctx, targetAccountID, false)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return accounts, nil
|
||||
@@ -47,7 +47,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode
|
||||
}
|
||||
|
||||
for _, f := range follows {
|
||||
blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true)
|
||||
blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
@@ -56,7 +56,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode
|
||||
}
|
||||
|
||||
if f.Account == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, f.AccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, f.AccountID)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
continue
|
||||
@@ -77,14 +77,14 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode
|
||||
|
||||
// FollowingGet fetches a list of the accounts that target account is following.
|
||||
func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
|
||||
if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
|
||||
if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
} else if blocked {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
|
||||
}
|
||||
|
||||
accounts := []apimodel.Account{}
|
||||
follows, err := p.db.GetAccountFollows(ctx, targetAccountID)
|
||||
follows, err := p.state.DB.GetAccountFollows(ctx, targetAccountID)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return accounts, nil
|
||||
@@ -93,7 +93,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmode
|
||||
}
|
||||
|
||||
for _, f := range follows {
|
||||
blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true)
|
||||
blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
@@ -102,7 +102,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmode
|
||||
}
|
||||
|
||||
if f.TargetAccount == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, f.TargetAccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, f.TargetAccountID)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
continue
|
||||
@@ -127,7 +127,7 @@ func (p *Processor) RelationshipGet(ctx context.Context, requestingAccount *gtsm
|
||||
return nil, gtserror.NewErrorForbidden(errors.New("not authed"))
|
||||
}
|
||||
|
||||
gtsR, err := p.db.GetRelationship(ctx, requestingAccount.ID, targetAccountID)
|
||||
gtsR, err := p.state.DB.GetRelationship(ctx, requestingAccount.ID, targetAccountID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting relationship: %s", err))
|
||||
}
|
||||
|
@@ -34,7 +34,7 @@ const rssFeedLength = 20
|
||||
|
||||
// GetRSSFeedForUsername returns RSS feed for the given local username.
|
||||
func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string) (func() (string, gtserror.WithCode), time.Time, gtserror.WithCode) {
|
||||
account, err := p.db.GetAccountByUsernameDomain(ctx, username, "")
|
||||
account, err := p.state.DB.GetAccountByUsernameDomain(ctx, username, "")
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, time.Time{}, gtserror.NewErrorNotFound(errors.New("GetRSSFeedForUsername: account not found"))
|
||||
@@ -46,13 +46,13 @@ func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string)
|
||||
return nil, time.Time{}, gtserror.NewErrorNotFound(errors.New("GetRSSFeedForUsername: account RSS feed not enabled"))
|
||||
}
|
||||
|
||||
lastModified, err := p.db.GetAccountLastPosted(ctx, account.ID, true)
|
||||
lastModified, err := p.state.DB.GetAccountLastPosted(ctx, account.ID, true)
|
||||
if err != nil {
|
||||
return nil, time.Time{}, gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error: %s", err))
|
||||
}
|
||||
|
||||
return func() (string, gtserror.WithCode) {
|
||||
statuses, err := p.db.GetAccountWebStatuses(ctx, account.ID, rssFeedLength, "")
|
||||
statuses, err := p.state.DB.GetAccountWebStatuses(ctx, account.ID, rssFeedLength, "")
|
||||
if err != nil && err != db.ErrNoEntries {
|
||||
return "", gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error: %s", err))
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string)
|
||||
var image *feeds.Image
|
||||
if account.AvatarMediaAttachmentID != "" {
|
||||
if account.AvatarMediaAttachment == nil {
|
||||
avatar, err := p.db.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID)
|
||||
avatar, err := p.state.DB.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID)
|
||||
if err != nil {
|
||||
return "", gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error fetching avatar attachment: %s", err))
|
||||
}
|
||||
|
@@ -33,7 +33,7 @@ import (
|
||||
// the account given in authed.
|
||||
func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, pinned bool, mediaOnly bool, publicOnly bool) (*apimodel.PageableResponse, gtserror.WithCode) {
|
||||
if requestingAccount != nil {
|
||||
if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
|
||||
if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
} else if blocked {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
|
||||
@@ -46,10 +46,10 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel
|
||||
)
|
||||
if pinned {
|
||||
// Get *ONLY* pinned statuses.
|
||||
statuses, err = p.db.GetAccountPinnedStatuses(ctx, targetAccountID)
|
||||
statuses, err = p.state.DB.GetAccountPinnedStatuses(ctx, targetAccountID)
|
||||
} else {
|
||||
// Get account statuses which *may* include pinned ones.
|
||||
statuses, err = p.db.GetAccountStatuses(ctx, targetAccountID, limit, excludeReplies, excludeReblogs, maxID, minID, mediaOnly, publicOnly)
|
||||
statuses, err = p.state.DB.GetAccountStatuses(ctx, targetAccountID, limit, excludeReplies, excludeReblogs, maxID, minID, mediaOnly, publicOnly)
|
||||
}
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
@@ -120,7 +120,7 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel
|
||||
// WebStatusesGet fetches a number of statuses (in descending order) from the given account. It selects only
|
||||
// statuses which are suitable for showing on the public web profile of an account.
|
||||
func (p *Processor) WebStatusesGet(ctx context.Context, targetAccountID string, maxID string) (*apimodel.PageableResponse, gtserror.WithCode) {
|
||||
acct, err := p.db.GetAccountByID(ctx, targetAccountID)
|
||||
acct, err := p.state.DB.GetAccountByID(ctx, targetAccountID)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
err := fmt.Errorf("account %s not found in the db, not getting web statuses for it", targetAccountID)
|
||||
@@ -134,7 +134,7 @@ func (p *Processor) WebStatusesGet(ctx context.Context, targetAccountID string,
|
||||
return nil, gtserror.NewErrorNotFound(err)
|
||||
}
|
||||
|
||||
statuses, err := p.db.GetAccountWebStatuses(ctx, targetAccountID, 10, maxID)
|
||||
statuses, err := p.state.DB.GetAccountWebStatuses(ctx, targetAccountID, 10, maxID)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return util.EmptyPageableResponse(), nil
|
||||
|
@@ -165,12 +165,12 @@ func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, form
|
||||
account.EnableRSS = form.EnableRSS
|
||||
}
|
||||
|
||||
err := p.db.UpdateAccount(ctx, account)
|
||||
err := p.state.DB.UpdateAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("could not update account %s: %s", account.ID, err))
|
||||
}
|
||||
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ObjectProfile,
|
||||
APActivityType: ap.ActivityUpdate,
|
||||
GTSModel: account,
|
||||
|
@@ -31,7 +31,7 @@ import (
|
||||
)
|
||||
|
||||
func (p *Processor) AccountAction(ctx context.Context, account *gtsmodel.Account, form *apimodel.AdminAccountActionRequest) gtserror.WithCode {
|
||||
targetAccount, err := p.db.GetAccountByID(ctx, form.TargetAccountID)
|
||||
targetAccount, err := p.state.DB.GetAccountByID(ctx, form.TargetAccountID)
|
||||
if err != nil {
|
||||
return gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
@@ -47,7 +47,7 @@ func (p *Processor) AccountAction(ctx context.Context, account *gtsmodel.Account
|
||||
case string(gtsmodel.AdminActionSuspend):
|
||||
adminAction.Type = gtsmodel.AdminActionSuspend
|
||||
// pass the account delete through the client api channel for processing
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActorPerson,
|
||||
APActivityType: ap.ActivityDelete,
|
||||
OriginAccount: account,
|
||||
@@ -57,7 +57,7 @@ func (p *Processor) AccountAction(ctx context.Context, account *gtsmodel.Account
|
||||
return gtserror.NewErrorBadRequest(fmt.Errorf("admin action type %s is not supported for this endpoint", form.Type))
|
||||
}
|
||||
|
||||
if err := p.db.Put(ctx, adminAction); err != nil {
|
||||
if err := p.state.DB.Put(ctx, adminAction); err != nil {
|
||||
return gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
|
@@ -19,32 +19,25 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/storage"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
)
|
||||
|
||||
type Processor struct {
|
||||
state *state.State
|
||||
tc typeutils.TypeConverter
|
||||
mediaManager media.Manager
|
||||
transportController transport.Controller
|
||||
storage *storage.Driver
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
|
||||
db db.DB
|
||||
}
|
||||
|
||||
// New returns a new admin processor.
|
||||
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller, storage *storage.Driver, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor {
|
||||
func New(state *state.State, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller) Processor {
|
||||
return Processor{
|
||||
state: state,
|
||||
tc: tc,
|
||||
mediaManager: mediaManager,
|
||||
transportController: transportController,
|
||||
storage: storage,
|
||||
clientWorker: clientWorker,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
@@ -28,7 +28,7 @@ func (p *Processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Acc
|
||||
domain = strings.ToLower(domain)
|
||||
|
||||
// first check if we already have a block -- if err == nil we already had a block so we can skip a whole lot of work
|
||||
block, err := p.db.GetDomainBlock(ctx, domain)
|
||||
block, err := p.state.DB.GetDomainBlock(ctx, domain)
|
||||
if err != nil {
|
||||
if !errors.Is(err, db.ErrNoEntries) {
|
||||
// something went wrong in the DB
|
||||
@@ -47,7 +47,7 @@ func (p *Processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Acc
|
||||
}
|
||||
|
||||
// Insert the new block into the database
|
||||
if err := p.db.CreateDomainBlock(ctx, newBlock); err != nil {
|
||||
if err := p.state.DB.CreateDomainBlock(ctx, newBlock); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error putting new domain block %s: %s", domain, err))
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account
|
||||
|
||||
// if we have an instance entry for this domain, update it with the new block ID and clear all fields
|
||||
instance := >smodel.Instance{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain}}, instance); err == nil {
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain}}, instance); err == nil {
|
||||
updatingColumns := []string{
|
||||
"title",
|
||||
"updated_at",
|
||||
@@ -105,15 +105,15 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account
|
||||
instance.ContactAccountUsername = ""
|
||||
instance.ContactAccountID = ""
|
||||
instance.Version = ""
|
||||
if err := p.db.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil {
|
||||
if err := p.state.DB.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil {
|
||||
l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err)
|
||||
}
|
||||
l.Debug("domainBlockProcessSideEffects: instance entry updated")
|
||||
}
|
||||
|
||||
// if we have an instance account for this instance, delete it
|
||||
if instanceAccount, err := p.db.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil {
|
||||
if err := p.db.DeleteAccount(ctx, instanceAccount.ID); err != nil {
|
||||
if instanceAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil {
|
||||
if err := p.state.DB.DeleteAccount(ctx, instanceAccount.ID); err != nil {
|
||||
l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err)
|
||||
}
|
||||
}
|
||||
@@ -125,7 +125,7 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account
|
||||
|
||||
selectAccountsLoop:
|
||||
for {
|
||||
accounts, err := p.db.GetInstanceAccounts(ctx, block.Domain, maxID, limit)
|
||||
accounts, err := p.state.DB.GetInstanceAccounts(ctx, block.Domain, maxID, limit)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
// no accounts left for this instance so we're done
|
||||
@@ -141,7 +141,7 @@ selectAccountsLoop:
|
||||
l.Debugf("putting delete for account %s in the clientAPI channel", a.Username)
|
||||
|
||||
// pass the account delete through the client api channel for processing
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActorPerson,
|
||||
APActivityType: ap.ActivityDelete,
|
||||
GTSModel: block,
|
||||
@@ -195,7 +195,7 @@ func (p *Processor) DomainBlocksImport(ctx context.Context, account *gtsmodel.Ac
|
||||
func (p *Processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) {
|
||||
domainBlocks := []*gtsmodel.DomainBlock{}
|
||||
|
||||
if err := p.db.GetAll(ctx, &domainBlocks); err != nil {
|
||||
if err := p.state.DB.GetAll(ctx, &domainBlocks); err != nil {
|
||||
if !errors.Is(err, db.ErrNoEntries) {
|
||||
// something has gone really wrong
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
@@ -219,7 +219,7 @@ func (p *Processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Accou
|
||||
func (p *Processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) {
|
||||
domainBlock := >smodel.DomainBlock{}
|
||||
|
||||
if err := p.db.GetByID(ctx, id, domainBlock); err != nil {
|
||||
if err := p.state.DB.GetByID(ctx, id, domainBlock); err != nil {
|
||||
if !errors.Is(err, db.ErrNoEntries) {
|
||||
// something has gone really wrong
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
@@ -240,7 +240,7 @@ func (p *Processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Accoun
|
||||
func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) {
|
||||
domainBlock := >smodel.DomainBlock{}
|
||||
|
||||
if err := p.db.GetByID(ctx, id, domainBlock); err != nil {
|
||||
if err := p.state.DB.GetByID(ctx, id, domainBlock); err != nil {
|
||||
if !errors.Is(err, db.ErrNoEntries) {
|
||||
// something has gone really wrong
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
@@ -256,13 +256,13 @@ func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc
|
||||
}
|
||||
|
||||
// Delete the domain block
|
||||
if err := p.db.DeleteDomainBlock(ctx, domainBlock.Domain); err != nil {
|
||||
if err := p.state.DB.DeleteDomainBlock(ctx, domainBlock.Domain); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
// remove the domain block reference from the instance, if we have an entry for it
|
||||
i := >smodel.Instance{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{
|
||||
{Key: "domain", Value: domainBlock.Domain},
|
||||
{Key: "domain_block_id", Value: id},
|
||||
}, i); err == nil {
|
||||
@@ -270,21 +270,21 @@ func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc
|
||||
i.SuspendedAt = time.Time{}
|
||||
i.DomainBlockID = ""
|
||||
i.UpdatedAt = time.Now()
|
||||
if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
|
||||
if err := p.state.DB.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err))
|
||||
}
|
||||
}
|
||||
|
||||
// unsuspend all accounts whose suspension origin was this domain block
|
||||
// 1. remove the 'suspended_at' entry from their accounts
|
||||
if err := p.db.UpdateWhere(ctx, []db.Where{
|
||||
if err := p.state.DB.UpdateWhere(ctx, []db.Where{
|
||||
{Key: "suspension_origin", Value: domainBlock.ID},
|
||||
}, "suspended_at", nil, &[]*gtsmodel.Account{}); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspended_at from accounts: %s", err))
|
||||
}
|
||||
|
||||
// 2. remove the 'suspension_origin' entry from their accounts
|
||||
if err := p.db.UpdateWhere(ctx, []db.Where{
|
||||
if err := p.state.DB.UpdateWhere(ctx, []db.Where{
|
||||
{Key: "suspension_origin", Value: domainBlock.ID},
|
||||
}, "suspension_origin", nil, &[]*gtsmodel.Account{}); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspension_origin from accounts: %s", err))
|
||||
|
@@ -42,7 +42,7 @@ func (p *Processor) EmojiCreate(ctx context.Context, account *gtsmodel.Account,
|
||||
return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin")
|
||||
}
|
||||
|
||||
maybeExisting, err := p.db.GetEmojiByShortcodeDomain(ctx, form.Shortcode, "")
|
||||
maybeExisting, err := p.state.DB.GetEmojiByShortcodeDomain(ctx, form.Shortcode, "")
|
||||
if maybeExisting != nil {
|
||||
return nil, gtserror.NewErrorConflict(fmt.Errorf("emoji with shortcode %s already exists", form.Shortcode), fmt.Sprintf("emoji with shortcode %s already exists", form.Shortcode))
|
||||
}
|
||||
@@ -110,7 +110,7 @@ func (p *Processor) EmojisGet(
|
||||
return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin")
|
||||
}
|
||||
|
||||
emojis, err := p.db.GetEmojis(ctx, domain, includeDisabled, includeEnabled, shortcode, maxShortcodeDomain, minShortcodeDomain, limit)
|
||||
emojis, err := p.state.DB.GetEmojis(ctx, domain, includeDisabled, includeEnabled, shortcode, maxShortcodeDomain, minShortcodeDomain, limit)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
err := fmt.Errorf("EmojisGet: db error: %s", err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
@@ -176,7 +176,7 @@ func (p *Processor) EmojiGet(ctx context.Context, account *gtsmodel.Account, use
|
||||
return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin")
|
||||
}
|
||||
|
||||
emoji, err := p.db.GetEmojiByID(ctx, id)
|
||||
emoji, err := p.state.DB.GetEmojiByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNoEntries) {
|
||||
err = fmt.Errorf("EmojiGet: no emoji with id %s found in the db", id)
|
||||
@@ -197,7 +197,7 @@ func (p *Processor) EmojiGet(ctx context.Context, account *gtsmodel.Account, use
|
||||
|
||||
// EmojiDelete deletes one emoji from the database, with the given id.
|
||||
func (p *Processor) EmojiDelete(ctx context.Context, id string) (*apimodel.AdminEmoji, gtserror.WithCode) {
|
||||
emoji, err := p.db.GetEmojiByID(ctx, id)
|
||||
emoji, err := p.state.DB.GetEmojiByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNoEntries) {
|
||||
err = fmt.Errorf("EmojiDelete: no emoji with id %s found in the db", id)
|
||||
@@ -218,7 +218,7 @@ func (p *Processor) EmojiDelete(ctx context.Context, id string) (*apimodel.Admin
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
if err := p.db.DeleteEmojiByID(ctx, id); err != nil {
|
||||
if err := p.state.DB.DeleteEmojiByID(ctx, id); err != nil {
|
||||
err := fmt.Errorf("EmojiDelete: db error: %s", err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
@@ -228,7 +228,7 @@ func (p *Processor) EmojiDelete(ctx context.Context, id string) (*apimodel.Admin
|
||||
|
||||
// EmojiUpdate updates one emoji with the given id, using the provided form parameters.
|
||||
func (p *Processor) EmojiUpdate(ctx context.Context, id string, form *apimodel.EmojiUpdateRequest) (*apimodel.AdminEmoji, gtserror.WithCode) {
|
||||
emoji, err := p.db.GetEmojiByID(ctx, id)
|
||||
emoji, err := p.state.DB.GetEmojiByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNoEntries) {
|
||||
err = fmt.Errorf("EmojiUpdate: no emoji with id %s found in the db", id)
|
||||
@@ -253,7 +253,7 @@ func (p *Processor) EmojiUpdate(ctx context.Context, id string, form *apimodel.E
|
||||
|
||||
// EmojiCategoriesGet returns all custom emoji categories that exist on this instance.
|
||||
func (p *Processor) EmojiCategoriesGet(ctx context.Context) ([]*apimodel.EmojiCategory, gtserror.WithCode) {
|
||||
categories, err := p.db.GetEmojiCategories(ctx)
|
||||
categories, err := p.state.DB.GetEmojiCategories(ctx)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("EmojiCategoriesGet: db error: %s", err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
@@ -277,7 +277,7 @@ func (p *Processor) EmojiCategoriesGet(ctx context.Context) ([]*apimodel.EmojiCa
|
||||
*/
|
||||
|
||||
func (p *Processor) getOrCreateEmojiCategory(ctx context.Context, name string) (*gtsmodel.EmojiCategory, error) {
|
||||
category, err := p.db.GetEmojiCategoryByName(ctx, name)
|
||||
category, err := p.state.DB.GetEmojiCategoryByName(ctx, name)
|
||||
if err == nil {
|
||||
return category, nil
|
||||
}
|
||||
@@ -299,7 +299,7 @@ func (p *Processor) getOrCreateEmojiCategory(ctx context.Context, name string) (
|
||||
Name: name,
|
||||
}
|
||||
|
||||
if err := p.db.PutEmojiCategory(ctx, category); err != nil {
|
||||
if err := p.state.DB.PutEmojiCategory(ctx, category); err != nil {
|
||||
err = fmt.Errorf("GetOrCreateEmojiCategory: error putting new emoji category in the database: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
@@ -319,7 +319,7 @@ func (p *Processor) emojiUpdateCopy(ctx context.Context, emoji *gtsmodel.Emoji,
|
||||
return nil, gtserror.NewErrorBadRequest(err, err.Error())
|
||||
}
|
||||
|
||||
maybeExisting, err := p.db.GetEmojiByShortcodeDomain(ctx, *shortcode, "")
|
||||
maybeExisting, err := p.state.DB.GetEmojiByShortcodeDomain(ctx, *shortcode, "")
|
||||
if maybeExisting != nil {
|
||||
err := fmt.Errorf("emojiUpdateCopy: emoji %s could not be copied, emoji with shortcode %s already exists on this instance", emoji.ID, *shortcode)
|
||||
return nil, gtserror.NewErrorConflict(err, err.Error())
|
||||
@@ -339,7 +339,7 @@ func (p *Processor) emojiUpdateCopy(ctx context.Context, emoji *gtsmodel.Emoji,
|
||||
newEmojiURI := uris.GenerateURIForEmoji(newEmojiID)
|
||||
|
||||
data := func(ctx context.Context) (reader io.ReadCloser, fileSize int64, err error) {
|
||||
rc, err := p.storage.GetStream(ctx, emoji.ImagePath)
|
||||
rc, err := p.state.Storage.GetStream(ctx, emoji.ImagePath)
|
||||
return rc, int64(emoji.ImageFileSize), err
|
||||
}
|
||||
|
||||
@@ -386,7 +386,7 @@ func (p *Processor) emojiUpdateDisable(ctx context.Context, emoji *gtsmodel.Emoj
|
||||
|
||||
emojiDisabled := true
|
||||
emoji.Disabled = &emojiDisabled
|
||||
updatedEmoji, err := p.db.UpdateEmoji(ctx, emoji, "updated_at", "disabled")
|
||||
updatedEmoji, err := p.state.DB.UpdateEmoji(ctx, emoji, "updated_at", "disabled")
|
||||
if err != nil {
|
||||
err = fmt.Errorf("emojiUpdateDisable: error updating emoji %s: %s", emoji.ID, err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
@@ -443,7 +443,7 @@ func (p *Processor) emojiUpdateModify(ctx context.Context, emoji *gtsmodel.Emoji
|
||||
}
|
||||
|
||||
var err error
|
||||
updatedEmoji, err = p.db.UpdateEmoji(ctx, emoji, columns...)
|
||||
updatedEmoji, err = p.state.DB.UpdateEmoji(ctx, emoji, columns...)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("emojiUpdateModify: error updating emoji %s: %s", emoji.ID, err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
|
@@ -43,7 +43,7 @@ func (p *Processor) ReportsGet(
|
||||
minID string,
|
||||
limit int,
|
||||
) (*apimodel.PageableResponse, gtserror.WithCode) {
|
||||
reports, err := p.db.GetReports(ctx, resolved, accountID, targetAccountID, maxID, sinceID, minID, limit)
|
||||
reports, err := p.state.DB.GetReports(ctx, resolved, accountID, targetAccountID, maxID, sinceID, minID, limit)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return util.EmptyPageableResponse(), nil
|
||||
@@ -95,7 +95,7 @@ func (p *Processor) ReportsGet(
|
||||
|
||||
// ReportGet returns one report, with the given ID.
|
||||
func (p *Processor) ReportGet(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.AdminReport, gtserror.WithCode) {
|
||||
report, err := p.db.GetReportByID(ctx, id)
|
||||
report, err := p.state.DB.GetReportByID(ctx, id)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(err)
|
||||
@@ -113,7 +113,7 @@ func (p *Processor) ReportGet(ctx context.Context, account *gtsmodel.Account, id
|
||||
|
||||
// ReportResolve marks a report with the given id as resolved, and stores the provided actionTakenComment (if not null).
|
||||
func (p *Processor) ReportResolve(ctx context.Context, account *gtsmodel.Account, id string, actionTakenComment *string) (*apimodel.AdminReport, gtserror.WithCode) {
|
||||
report, err := p.db.GetReportByID(ctx, id)
|
||||
report, err := p.state.DB.GetReportByID(ctx, id)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(err)
|
||||
@@ -134,7 +134,7 @@ func (p *Processor) ReportResolve(ctx context.Context, account *gtsmodel.Account
|
||||
columns = append(columns, "action_taken")
|
||||
}
|
||||
|
||||
updatedReport, err := p.db.UpdateReport(ctx, report, columns...)
|
||||
updatedReport, err := p.state.DB.UpdateReport(ctx, report, columns...)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
@@ -62,7 +62,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api
|
||||
}
|
||||
|
||||
// chuck it in the db
|
||||
if err := p.db.Put(ctx, app); err != nil {
|
||||
if err := p.state.DB.Put(ctx, app); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api
|
||||
}
|
||||
|
||||
// chuck it in the db
|
||||
if err := p.db.Put(ctx, oc); err != nil {
|
||||
if err := p.state.DB.Put(ctx, oc); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
|
@@ -31,7 +31,7 @@ import (
|
||||
)
|
||||
|
||||
func (p *Processor) BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) {
|
||||
accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit)
|
||||
accounts, nextMaxID, prevMinID, err := p.state.DB.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are just no entries
|
||||
|
@@ -84,8 +84,8 @@ func (p *Processor) OutboxGet(ctx context.Context, requestedUsername string, pag
|
||||
|
||||
// scenario 2 -- get the requested page
|
||||
// limit pages to 30 entries per page
|
||||
publicStatuses, err := p.db.GetAccountStatuses(ctx, requestedAccount.ID, 30, true, true, maxID, minID, false, true)
|
||||
if err != nil && err != db.ErrNoEntries {
|
||||
publicStatuses, err := p.state.DB.GetAccountStatuses(ctx, requestedAccount.ID, 30, true, true, maxID, minID, false, true)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ func (p *Processor) FeaturedCollectionGet(ctx context.Context, requestedUsername
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
statuses, err := p.db.GetAccountPinnedStatuses(ctx, requestedAccount.ID)
|
||||
statuses, err := p.state.DB.GetAccountPinnedStatuses(ctx, requestedAccount.ID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, db.ErrNoEntries) {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
|
@@ -29,7 +29,7 @@ import (
|
||||
)
|
||||
|
||||
func (p *Processor) authenticate(ctx context.Context, requestedUsername string) (requestedAccount, requestingAccount *gtsmodel.Account, errWithCode gtserror.WithCode) {
|
||||
requestedAccount, err := p.db.GetAccountByUsernameDomain(ctx, requestedUsername, "")
|
||||
requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "")
|
||||
if err != nil {
|
||||
errWithCode = gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
|
||||
return
|
||||
@@ -46,7 +46,7 @@ func (p *Processor) authenticate(ctx context.Context, requestedUsername string)
|
||||
return
|
||||
}
|
||||
|
||||
blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
|
||||
blocked, err := p.state.DB.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
|
||||
if err != nil {
|
||||
errWithCode = gtserror.NewErrorInternalError(err)
|
||||
return
|
||||
|
@@ -32,7 +32,7 @@ func (p *Processor) EmojiGet(ctx context.Context, requestedEmojiID string) (inte
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
requestedEmoji, err := p.db.GetEmojiByID(ctx, requestedEmojiID)
|
||||
requestedEmoji, err := p.state.DB.GetEmojiByID(ctx, requestedEmojiID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting emoji with id %s: %s", requestedEmojiID, err))
|
||||
}
|
||||
|
@@ -19,25 +19,25 @@
|
||||
package fedi
|
||||
|
||||
import (
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/visibility"
|
||||
)
|
||||
|
||||
type Processor struct {
|
||||
db db.DB
|
||||
state *state.State
|
||||
federator federation.Federator
|
||||
tc typeutils.TypeConverter
|
||||
filter visibility.Filter
|
||||
}
|
||||
|
||||
// New returns a new fedi processor.
|
||||
func New(db db.DB, tc typeutils.TypeConverter, federator federation.Federator) Processor {
|
||||
func New(state *state.State, tc typeutils.TypeConverter, federator federation.Federator) Processor {
|
||||
return Processor{
|
||||
db: db,
|
||||
state: state,
|
||||
federator: federator,
|
||||
tc: tc,
|
||||
filter: visibility.NewFilter(db),
|
||||
filter: visibility.NewFilter(state.DB),
|
||||
}
|
||||
}
|
||||
|
@@ -36,7 +36,7 @@ func (p *Processor) StatusGet(ctx context.Context, requestedUsername string, req
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
status, err := p.db.GetStatusByID(ctx, requestedStatusID)
|
||||
status, err := p.state.DB.GetStatusByID(ctx, requestedStatusID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(err)
|
||||
}
|
||||
@@ -74,7 +74,7 @@ func (p *Processor) StatusRepliesGet(ctx context.Context, requestedUsername stri
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
status, err := p.db.GetStatusByID(ctx, requestedStatusID)
|
||||
status, err := p.state.DB.GetStatusByID(ctx, requestedStatusID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(err)
|
||||
}
|
||||
@@ -125,7 +125,7 @@ func (p *Processor) StatusRepliesGet(ctx context.Context, requestedUsername stri
|
||||
default:
|
||||
// scenario 3
|
||||
// get immediate children
|
||||
replies, err := p.db.GetStatusChildren(ctx, status, true, minID)
|
||||
replies, err := p.state.DB.GetStatusChildren(ctx, status, true, minID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
@@ -34,7 +34,7 @@ import (
|
||||
// before returning a JSON serializable interface to the caller.
|
||||
func (p *Processor) UserGet(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) {
|
||||
// Get the instance-local account the request is referring to.
|
||||
requestedAccount, err := p.db.GetAccountByUsernameDomain(ctx, requestedUsername, "")
|
||||
requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "")
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
|
||||
}
|
||||
@@ -63,7 +63,7 @@ func (p *Processor) UserGet(ctx context.Context, requestedUsername string, reque
|
||||
return nil, gtserror.NewErrorUnauthorized(err)
|
||||
}
|
||||
|
||||
blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
|
||||
blocked, err := p.state.DB.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
@@ -64,12 +64,12 @@ func (p *Processor) NodeInfoRelGet(ctx context.Context) (*apimodel.WellKnownResp
|
||||
func (p *Processor) NodeInfoGet(ctx context.Context) (*apimodel.Nodeinfo, gtserror.WithCode) {
|
||||
host := config.GetHost()
|
||||
|
||||
userCount, err := p.db.CountInstanceUsers(ctx, host)
|
||||
userCount, err := p.state.DB.CountInstanceUsers(ctx, host)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
postCount, err := p.db.CountInstanceStatuses(ctx, host)
|
||||
postCount, err := p.state.DB.CountInstanceStatuses(ctx, host)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
@@ -99,7 +99,7 @@ func (p *Processor) NodeInfoGet(ctx context.Context) (*apimodel.Nodeinfo, gtserr
|
||||
// WebfingerGet handles the GET for a webfinger resource. Most commonly, it will be used for returning account lookups.
|
||||
func (p *Processor) WebfingerGet(ctx context.Context, requestedUsername string) (*apimodel.WellKnownResponse, gtserror.WithCode) {
|
||||
// Get the local account the request is referring to.
|
||||
requestedAccount, err := p.db.GetAccountByUsernameDomain(ctx, requestedUsername, "")
|
||||
requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "")
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err))
|
||||
}
|
||||
|
@@ -30,7 +30,7 @@ import (
|
||||
)
|
||||
|
||||
func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) {
|
||||
frs, err := p.db.GetAccountFollowRequests(ctx, auth.Account.ID)
|
||||
frs, err := p.state.DB.GetAccountFollowRequests(ctx, auth.Account.ID)
|
||||
if err != nil {
|
||||
if err != db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
@@ -40,7 +40,7 @@ func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]
|
||||
accts := []apimodel.Account{}
|
||||
for _, fr := range frs {
|
||||
if fr.Account == nil {
|
||||
frAcct, err := p.db.GetAccountByID(ctx, fr.AccountID)
|
||||
frAcct, err := p.state.DB.GetAccountByID(ctx, fr.AccountID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
@@ -57,13 +57,13 @@ func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]
|
||||
}
|
||||
|
||||
func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) {
|
||||
follow, err := p.db.AcceptFollowRequest(ctx, accountID, auth.Account.ID)
|
||||
follow, err := p.state.DB.AcceptFollowRequest(ctx, accountID, auth.Account.ID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(err)
|
||||
}
|
||||
|
||||
if follow.Account == nil {
|
||||
followAccount, err := p.db.GetAccountByID(ctx, follow.AccountID)
|
||||
followAccount, err := p.state.DB.GetAccountByID(ctx, follow.AccountID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
@@ -71,14 +71,14 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a
|
||||
}
|
||||
|
||||
if follow.TargetAccount == nil {
|
||||
followTargetAccount, err := p.db.GetAccountByID(ctx, follow.TargetAccountID)
|
||||
followTargetAccount, err := p.state.DB.GetAccountByID(ctx, follow.TargetAccountID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
follow.TargetAccount = followTargetAccount
|
||||
}
|
||||
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActivityFollow,
|
||||
APActivityType: ap.ActivityAccept,
|
||||
GTSModel: follow,
|
||||
@@ -86,7 +86,7 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a
|
||||
TargetAccount: follow.TargetAccount,
|
||||
})
|
||||
|
||||
gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID)
|
||||
gtsR, err := p.state.DB.GetRelationship(ctx, auth.Account.ID, accountID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
@@ -100,13 +100,13 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a
|
||||
}
|
||||
|
||||
func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) {
|
||||
followRequest, err := p.db.RejectFollowRequest(ctx, accountID, auth.Account.ID)
|
||||
followRequest, err := p.state.DB.RejectFollowRequest(ctx, accountID, auth.Account.ID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(err)
|
||||
}
|
||||
|
||||
if followRequest.Account == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, followRequest.AccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, followRequest.AccountID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
@@ -114,14 +114,14 @@ func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, a
|
||||
}
|
||||
|
||||
if followRequest.TargetAccount == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
followRequest.TargetAccount = a
|
||||
}
|
||||
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActivityFollow,
|
||||
APActivityType: ap.ActivityReject,
|
||||
GTSModel: followRequest,
|
||||
@@ -129,7 +129,7 @@ func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, a
|
||||
TargetAccount: followRequest.TargetAccount,
|
||||
})
|
||||
|
||||
gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID)
|
||||
gtsR, err := p.state.DB.GetRelationship(ctx, auth.Account.ID, accountID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
@@ -143,7 +143,7 @@ func (p *Processor) processCreateAccountFromClientAPI(ctx context.Context, clien
|
||||
}
|
||||
|
||||
// get the user this account belongs to
|
||||
user, err := p.db.GetUserByAccountID(ctx, account.ID)
|
||||
user, err := p.state.DB.GetUserByAccountID(ctx, account.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -293,7 +293,7 @@ func (p *Processor) processUndoAnnounceFromClientAPI(ctx context.Context, client
|
||||
return errors.New("undo was not parseable as *gtsmodel.Status")
|
||||
}
|
||||
|
||||
if err := p.db.DeleteStatusByID(ctx, boost.ID); err != nil {
|
||||
if err := p.state.DB.DeleteStatusByID(ctx, boost.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -422,7 +422,7 @@ func (p *Processor) federateStatus(ctx context.Context, status *gtsmodel.Status)
|
||||
}
|
||||
|
||||
if status.Account == nil {
|
||||
statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID)
|
||||
statusAccount, err := p.state.DB.GetAccountByID(ctx, status.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("federateStatus: error fetching status author account: %s", err)
|
||||
}
|
||||
@@ -455,7 +455,7 @@ func (p *Processor) federateStatus(ctx context.Context, status *gtsmodel.Status)
|
||||
|
||||
func (p *Processor) federateStatusDelete(ctx context.Context, status *gtsmodel.Status) error {
|
||||
if status.Account == nil {
|
||||
statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID)
|
||||
statusAccount, err := p.state.DB.GetAccountByID(ctx, status.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("federateStatusDelete: error fetching status author account: %s", err)
|
||||
}
|
||||
@@ -642,7 +642,7 @@ func (p *Processor) federateUnannounce(ctx context.Context, boost *gtsmodel.Stat
|
||||
|
||||
func (p *Processor) federateAcceptFollowRequest(ctx context.Context, follow *gtsmodel.Follow) error {
|
||||
if follow.Account == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, follow.AccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, follow.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -651,7 +651,7 @@ func (p *Processor) federateAcceptFollowRequest(ctx context.Context, follow *gts
|
||||
originAccount := follow.Account
|
||||
|
||||
if follow.TargetAccount == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, follow.TargetAccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, follow.TargetAccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -715,7 +715,7 @@ func (p *Processor) federateAcceptFollowRequest(ctx context.Context, follow *gts
|
||||
|
||||
func (p *Processor) federateRejectFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest) error {
|
||||
if followRequest.Account == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, followRequest.AccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, followRequest.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -724,7 +724,7 @@ func (p *Processor) federateRejectFollowRequest(ctx context.Context, followReque
|
||||
originAccount := followRequest.Account
|
||||
|
||||
if followRequest.TargetAccount == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -844,7 +844,7 @@ func (p *Processor) federateAccountUpdate(ctx context.Context, updatedAccount *g
|
||||
|
||||
func (p *Processor) federateBlock(ctx context.Context, block *gtsmodel.Block) error {
|
||||
if block.Account == nil {
|
||||
blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID)
|
||||
blockAccount, err := p.state.DB.GetAccountByID(ctx, block.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("federateBlock: error getting block account from database: %s", err)
|
||||
}
|
||||
@@ -852,7 +852,7 @@ func (p *Processor) federateBlock(ctx context.Context, block *gtsmodel.Block) er
|
||||
}
|
||||
|
||||
if block.TargetAccount == nil {
|
||||
blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID)
|
||||
blockTargetAccount, err := p.state.DB.GetAccountByID(ctx, block.TargetAccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("federateBlock: error getting block target account from database: %s", err)
|
||||
}
|
||||
@@ -880,7 +880,7 @@ func (p *Processor) federateBlock(ctx context.Context, block *gtsmodel.Block) er
|
||||
|
||||
func (p *Processor) federateUnblock(ctx context.Context, block *gtsmodel.Block) error {
|
||||
if block.Account == nil {
|
||||
blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID)
|
||||
blockAccount, err := p.state.DB.GetAccountByID(ctx, block.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("federateUnblock: error getting block account from database: %s", err)
|
||||
}
|
||||
@@ -888,7 +888,7 @@ func (p *Processor) federateUnblock(ctx context.Context, block *gtsmodel.Block)
|
||||
}
|
||||
|
||||
if block.TargetAccount == nil {
|
||||
blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID)
|
||||
blockTargetAccount, err := p.state.DB.GetAccountByID(ctx, block.TargetAccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("federateUnblock: error getting block target account from database: %s", err)
|
||||
}
|
||||
@@ -934,7 +934,7 @@ func (p *Processor) federateUnblock(ctx context.Context, block *gtsmodel.Block)
|
||||
|
||||
func (p *Processor) federateReport(ctx context.Context, report *gtsmodel.Report) error {
|
||||
if report.TargetAccount == nil {
|
||||
reportTargetAccount, err := p.db.GetAccountByID(ctx, report.TargetAccountID)
|
||||
reportTargetAccount, err := p.state.DB.GetAccountByID(ctx, report.TargetAccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("federateReport: error getting report target account from database: %w", err)
|
||||
}
|
||||
@@ -942,7 +942,7 @@ func (p *Processor) federateReport(ctx context.Context, report *gtsmodel.Report)
|
||||
}
|
||||
|
||||
if len(report.StatusIDs) > 0 && len(report.Statuses) == 0 {
|
||||
statuses, err := p.db.GetStatuses(ctx, report.StatusIDs)
|
||||
statuses, err := p.state.DB.GetStatuses(ctx, report.StatusIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("federateReport: error getting report statuses from database: %w", err)
|
||||
}
|
||||
@@ -966,7 +966,7 @@ func (p *Processor) federateReport(ctx context.Context, report *gtsmodel.Report)
|
||||
|
||||
// deliver the flag using the outbox of the
|
||||
// instance account to anonymize the report
|
||||
instanceAccount, err := p.db.GetInstanceAccount(ctx, "")
|
||||
instanceAccount, err := p.state.DB.GetInstanceAccount(ctx, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("federateReport: error getting instance account: %w", err)
|
||||
}
|
||||
|
@@ -38,7 +38,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
|
||||
|
||||
if status.Mentions == nil {
|
||||
// there are mentions but they're not fully populated on the status yet so do this
|
||||
menchies, err := p.db.GetMentions(ctx, status.MentionIDs)
|
||||
menchies, err := p.state.DB.GetMentions(ctx, status.MentionIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("notifyStatus: error getting mentions for status %s from the db: %s", status.ID, err)
|
||||
}
|
||||
@@ -49,7 +49,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
|
||||
for _, m := range status.Mentions {
|
||||
// make sure this is a local account, otherwise we don't need to create a notification for it
|
||||
if m.TargetAccount == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, m.TargetAccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, m.TargetAccountID)
|
||||
if err != nil {
|
||||
// we don't have the account or there's been an error
|
||||
return fmt.Errorf("notifyStatus: error getting account with id %s from the db: %s", m.TargetAccountID, err)
|
||||
@@ -62,7 +62,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
|
||||
}
|
||||
|
||||
// make sure a notif doesn't already exist for this mention
|
||||
if err := p.db.GetWhere(ctx, []db.Where{
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{
|
||||
{Key: "notification_type", Value: gtsmodel.NotificationMention},
|
||||
{Key: "target_account_id", Value: m.TargetAccountID},
|
||||
{Key: "origin_account_id", Value: m.OriginAccountID},
|
||||
@@ -87,7 +87,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
|
||||
Status: status,
|
||||
}
|
||||
|
||||
if err := p.db.Put(ctx, notif); err != nil {
|
||||
if err := p.state.DB.Put(ctx, notif); err != nil {
|
||||
return fmt.Errorf("notifyStatus: error putting notification in database: %s", err)
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
|
||||
func (p *Processor) notifyFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest) error {
|
||||
// make sure we have the target account pinned on the follow request
|
||||
if followRequest.TargetAccount == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -129,7 +129,7 @@ func (p *Processor) notifyFollowRequest(ctx context.Context, followRequest *gtsm
|
||||
OriginAccountID: followRequest.AccountID,
|
||||
}
|
||||
|
||||
if err := p.db.Put(ctx, notif); err != nil {
|
||||
if err := p.state.DB.Put(ctx, notif); err != nil {
|
||||
return fmt.Errorf("notifyFollowRequest: error putting notification in database: %s", err)
|
||||
}
|
||||
|
||||
@@ -153,7 +153,7 @@ func (p *Processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, t
|
||||
}
|
||||
|
||||
// first remove the follow request notification
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{
|
||||
{Key: "notification_type", Value: gtsmodel.NotificationFollowRequest},
|
||||
{Key: "target_account_id", Value: follow.TargetAccountID},
|
||||
{Key: "origin_account_id", Value: follow.AccountID},
|
||||
@@ -170,7 +170,7 @@ func (p *Processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, t
|
||||
OriginAccountID: follow.AccountID,
|
||||
OriginAccount: follow.Account,
|
||||
}
|
||||
if err := p.db.Put(ctx, notif); err != nil {
|
||||
if err := p.state.DB.Put(ctx, notif); err != nil {
|
||||
return fmt.Errorf("notifyFollow: error putting notification in database: %s", err)
|
||||
}
|
||||
|
||||
@@ -194,7 +194,7 @@ func (p *Processor) notifyFave(ctx context.Context, fave *gtsmodel.StatusFave) e
|
||||
}
|
||||
|
||||
if fave.TargetAccount == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, fave.TargetAccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, fave.TargetAccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -218,7 +218,7 @@ func (p *Processor) notifyFave(ctx context.Context, fave *gtsmodel.StatusFave) e
|
||||
Status: fave.Status,
|
||||
}
|
||||
|
||||
if err := p.db.Put(ctx, notif); err != nil {
|
||||
if err := p.state.DB.Put(ctx, notif); err != nil {
|
||||
return fmt.Errorf("notifyFave: error putting notification in database: %s", err)
|
||||
}
|
||||
|
||||
@@ -242,7 +242,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
|
||||
}
|
||||
|
||||
if status.BoostOf == nil {
|
||||
boostedStatus, err := p.db.GetStatusByID(ctx, status.BoostOfID)
|
||||
boostedStatus, err := p.state.DB.GetStatusByID(ctx, status.BoostOfID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("notifyAnnounce: error getting status with id %s: %s", status.BoostOfID, err)
|
||||
}
|
||||
@@ -250,7 +250,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
|
||||
}
|
||||
|
||||
if status.BoostOfAccount == nil {
|
||||
boostedAcct, err := p.db.GetAccountByID(ctx, status.BoostOfAccountID)
|
||||
boostedAcct, err := p.state.DB.GetAccountByID(ctx, status.BoostOfAccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("notifyAnnounce: error getting account with id %s: %s", status.BoostOfAccountID, err)
|
||||
}
|
||||
@@ -269,7 +269,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
|
||||
}
|
||||
|
||||
// make sure a notif doesn't already exist for this announce
|
||||
err := p.db.GetWhere(ctx, []db.Where{
|
||||
err := p.state.DB.GetWhere(ctx, []db.Where{
|
||||
{Key: "notification_type", Value: gtsmodel.NotificationReblog},
|
||||
{Key: "target_account_id", Value: status.BoostOfAccountID},
|
||||
{Key: "origin_account_id", Value: status.AccountID},
|
||||
@@ -292,7 +292,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
|
||||
Status: status,
|
||||
}
|
||||
|
||||
if err := p.db.Put(ctx, notif); err != nil {
|
||||
if err := p.state.DB.Put(ctx, notif); err != nil {
|
||||
return fmt.Errorf("notifyAnnounce: error putting notification in database: %s", err)
|
||||
}
|
||||
|
||||
@@ -314,7 +314,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
|
||||
func (p *Processor) timelineStatus(ctx context.Context, status *gtsmodel.Status) error {
|
||||
// make sure the author account is pinned onto the status
|
||||
if status.Account == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, status.AccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, status.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("timelineStatus: error getting author account with id %s: %s", status.AccountID, err)
|
||||
}
|
||||
@@ -322,7 +322,7 @@ func (p *Processor) timelineStatus(ctx context.Context, status *gtsmodel.Status)
|
||||
}
|
||||
|
||||
// get local followers of the account that posted the status
|
||||
follows, err := p.db.GetAccountFollowedBy(ctx, status.AccountID, true)
|
||||
follows, err := p.state.DB.GetAccountFollowedBy(ctx, status.AccountID, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("timelineStatus: error getting followers for account id %s: %s", status.AccountID, err)
|
||||
}
|
||||
@@ -374,7 +374,7 @@ func (p *Processor) timelineStatusForAccount(ctx context.Context, status *gtsmod
|
||||
defer wg.Done()
|
||||
|
||||
// get the timeline owner account
|
||||
timelineAccount, err := p.db.GetAccountByID(ctx, accountID)
|
||||
timelineAccount, err := p.state.DB.GetAccountByID(ctx, accountID)
|
||||
if err != nil {
|
||||
errors <- fmt.Errorf("timelineStatusForAccount: error getting account for timeline with id %s: %s", accountID, err)
|
||||
return
|
||||
@@ -446,28 +446,28 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta
|
||||
|
||||
// delete all mention entries generated by this status
|
||||
for _, m := range statusToDelete.MentionIDs {
|
||||
if err := p.db.DeleteByID(ctx, m, >smodel.Mention{}); err != nil {
|
||||
if err := p.state.DB.DeleteByID(ctx, m, >smodel.Mention{}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// delete all notification entries generated by this status
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil {
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete all bookmarks that point to this status
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil {
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// delete all boosts for this status + remove them from timelines
|
||||
if boosts, err := p.db.GetStatusReblogs(ctx, statusToDelete); err == nil {
|
||||
if boosts, err := p.state.DB.GetStatusReblogs(ctx, statusToDelete); err == nil {
|
||||
for _, b := range boosts {
|
||||
if err := p.deleteStatusFromTimelines(ctx, b); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := p.db.DeleteStatusByID(ctx, b.ID); err != nil {
|
||||
if err := p.state.DB.DeleteStatusByID(ctx, b.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -479,7 +479,7 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta
|
||||
}
|
||||
|
||||
// delete the status itself
|
||||
if err := p.db.DeleteStatusByID(ctx, statusToDelete.ID); err != nil {
|
||||
if err := p.state.DB.DeleteStatusByID(ctx, statusToDelete.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@@ -139,7 +139,7 @@ func (p *Processor) processCreateStatusFromFederator(ctx context.Context, federa
|
||||
|
||||
// make sure the account is pinned
|
||||
if status.Account == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, status.AccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, status.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -185,7 +185,7 @@ func (p *Processor) processCreateFaveFromFederator(ctx context.Context, federato
|
||||
|
||||
// make sure the account is pinned
|
||||
if incomingFave.Account == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, incomingFave.AccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, incomingFave.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -227,7 +227,7 @@ func (p *Processor) processCreateFollowRequestFromFederator(ctx context.Context,
|
||||
|
||||
// make sure the account is pinned
|
||||
if followRequest.Account == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, followRequest.AccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, followRequest.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -254,7 +254,7 @@ func (p *Processor) processCreateFollowRequestFromFederator(ctx context.Context,
|
||||
}
|
||||
|
||||
if followRequest.TargetAccount == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -267,7 +267,7 @@ func (p *Processor) processCreateFollowRequestFromFederator(ctx context.Context,
|
||||
}
|
||||
|
||||
// if the target account isn't locked, we should already accept the follow and notify about the new follower instead
|
||||
follow, err := p.db.AcceptFollowRequest(ctx, followRequest.AccountID, followRequest.TargetAccountID)
|
||||
follow, err := p.state.DB.AcceptFollowRequest(ctx, followRequest.AccountID, followRequest.TargetAccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -288,7 +288,7 @@ func (p *Processor) processCreateAnnounceFromFederator(ctx context.Context, fede
|
||||
|
||||
// make sure the account is pinned
|
||||
if incomingAnnounce.Account == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, incomingAnnounce.AccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, incomingAnnounce.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -324,7 +324,7 @@ func (p *Processor) processCreateAnnounceFromFederator(ctx context.Context, fede
|
||||
}
|
||||
incomingAnnounce.ID = incomingAnnounceID
|
||||
|
||||
if err := p.db.PutStatus(ctx, incomingAnnounce); err != nil {
|
||||
if err := p.state.DB.PutStatus(ctx, incomingAnnounce); err != nil {
|
||||
return fmt.Errorf("error adding dereferenced announce to the db: %s", err)
|
||||
}
|
||||
|
||||
|
@@ -344,7 +344,6 @@ func (suite *FromFederatorTestSuite) TestProcessAccountDelete() {
|
||||
suite.NoError(err)
|
||||
|
||||
// now they are mufos!
|
||||
|
||||
err = suite.processor.ProcessFromFederator(ctx, messages.FromFederator{
|
||||
APObjectType: ap.ObjectProfile,
|
||||
APActivityType: ap.ActivityDelete,
|
||||
|
@@ -35,7 +35,7 @@ import (
|
||||
|
||||
func (p *Processor) getThisInstance(ctx context.Context) (*gtsmodel.Instance, error) {
|
||||
i := >smodel.Instance{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: config.GetHost()}}, i); err != nil {
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: config.GetHost()}}, i); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return i, nil
|
||||
@@ -73,7 +73,7 @@ func (p *Processor) InstancePeersGet(ctx context.Context, includeSuspended bool,
|
||||
domains := []*apimodel.Domain{}
|
||||
|
||||
if includeOpen {
|
||||
instances, err := p.db.GetInstancePeers(ctx, false)
|
||||
instances, err := p.state.DB.GetInstancePeers(ctx, false)
|
||||
if err != nil && err != db.ErrNoEntries {
|
||||
err = fmt.Errorf("error selecting instance peers: %s", err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
@@ -87,7 +87,7 @@ func (p *Processor) InstancePeersGet(ctx context.Context, includeSuspended bool,
|
||||
|
||||
if includeSuspended {
|
||||
domainBlocks := []*gtsmodel.DomainBlock{}
|
||||
if err := p.db.GetAll(ctx, &domainBlocks); err != nil && err != db.ErrNoEntries {
|
||||
if err := p.state.DB.GetAll(ctx, &domainBlocks); err != nil && err != db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
@@ -124,12 +124,12 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
|
||||
// fetch the instance entry from the db for processing
|
||||
i := >smodel.Instance{}
|
||||
host := config.GetHost()
|
||||
if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, i); err != nil {
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, i); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", host, err))
|
||||
}
|
||||
|
||||
// fetch the instance account from the db for processing
|
||||
ia, err := p.db.GetInstanceAccount(ctx, "")
|
||||
ia, err := p.state.DB.GetInstanceAccount(ctx, "")
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance account %s: %s", host, err))
|
||||
}
|
||||
@@ -148,12 +148,12 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
|
||||
// validate & update site contact account if it's set on the form
|
||||
if form.ContactUsername != nil {
|
||||
// make sure the account with the given username exists in the db
|
||||
contactAccount, err := p.db.GetAccountByUsernameDomain(ctx, *form.ContactUsername, "")
|
||||
contactAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, *form.ContactUsername, "")
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.ContactUsername))
|
||||
}
|
||||
// make sure it has a user associated with it
|
||||
contactUser, err := p.db.GetUserByAccountID(ctx, contactAccount.ID)
|
||||
contactUser, err := p.state.DB.GetUserByAccountID(ctx, contactAccount.ID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.ContactUsername))
|
||||
}
|
||||
@@ -233,7 +233,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
|
||||
} else if form.AvatarDescription != nil && ia.AvatarMediaAttachment != nil {
|
||||
// process just the description for the existing avatar
|
||||
ia.AvatarMediaAttachment.Description = *form.AvatarDescription
|
||||
if err := p.db.UpdateByID(ctx, ia.AvatarMediaAttachment, ia.AvatarMediaAttachmentID, "description"); err != nil {
|
||||
if err := p.state.DB.UpdateByID(ctx, ia.AvatarMediaAttachment, ia.AvatarMediaAttachmentID, "description"); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance avatar description: %s", err))
|
||||
}
|
||||
}
|
||||
@@ -252,13 +252,13 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
|
||||
if updateInstanceAccount {
|
||||
// if either avatar or header is updated, we need
|
||||
// to update the instance account that stores them
|
||||
if err := p.db.UpdateAccount(ctx, ia); err != nil {
|
||||
if err := p.state.DB.UpdateAccount(ctx, ia); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance account: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
if len(updatingColumns) != 0 {
|
||||
if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
|
||||
if err := p.state.DB.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", host, err))
|
||||
}
|
||||
}
|
||||
|
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
// Delete deletes the media attachment with the given ID, including all files pertaining to that attachment.
|
||||
func (p *Processor) Delete(ctx context.Context, mediaAttachmentID string) gtserror.WithCode {
|
||||
attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
|
||||
attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
// attachment already gone
|
||||
@@ -27,20 +27,20 @@ func (p *Processor) Delete(ctx context.Context, mediaAttachmentID string) gtserr
|
||||
|
||||
// delete the thumbnail from storage
|
||||
if attachment.Thumbnail.Path != "" {
|
||||
if err := p.storage.Delete(ctx, attachment.Thumbnail.Path); err != nil && !errors.Is(err, storage.ErrNotFound) {
|
||||
if err := p.state.Storage.Delete(ctx, attachment.Thumbnail.Path); err != nil && !errors.Is(err, storage.ErrNotFound) {
|
||||
errs = append(errs, fmt.Sprintf("remove thumbnail at path %s: %s", attachment.Thumbnail.Path, err))
|
||||
}
|
||||
}
|
||||
|
||||
// delete the file from storage
|
||||
if attachment.File.Path != "" {
|
||||
if err := p.storage.Delete(ctx, attachment.File.Path); err != nil && !errors.Is(err, storage.ErrNotFound) {
|
||||
if err := p.state.Storage.Delete(ctx, attachment.File.Path); err != nil && !errors.Is(err, storage.ErrNotFound) {
|
||||
errs = append(errs, fmt.Sprintf("remove file at path %s: %s", attachment.File.Path, err))
|
||||
}
|
||||
}
|
||||
|
||||
// delete the attachment
|
||||
if err := p.db.DeleteByID(ctx, mediaAttachmentID, attachment); err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
if err := p.state.DB.DeleteByID(ctx, mediaAttachmentID, attachment); err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
errs = append(errs, fmt.Sprintf("remove attachment: %s", err))
|
||||
}
|
||||
|
||||
|
@@ -31,7 +31,7 @@ import (
|
||||
// GetCustomEmojis returns a list of all useable local custom emojis stored on this instance.
|
||||
// 'useable' in this context means visible and picker, and not disabled.
|
||||
func (p *Processor) GetCustomEmojis(ctx context.Context) ([]*apimodel.Emoji, gtserror.WithCode) {
|
||||
emojis, err := p.db.GetUseableEmojis(ctx)
|
||||
emojis, err := p.state.DB.GetUseableEmojis(ctx)
|
||||
if err != nil {
|
||||
if err != db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error retrieving custom emojis: %s", err))
|
||||
|
@@ -54,7 +54,7 @@ func (p *Processor) GetFile(ctx context.Context, requestingAccount *gtsmodel.Acc
|
||||
owningAccountID := form.AccountID
|
||||
|
||||
// get the account that owns the media and make sure it's not suspended
|
||||
owningAccount, err := p.db.GetAccountByID(ctx, owningAccountID)
|
||||
owningAccount, err := p.state.DB.GetAccountByID(ctx, owningAccountID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("account with id %s could not be selected from the db: %s", owningAccountID, err))
|
||||
}
|
||||
@@ -64,7 +64,7 @@ func (p *Processor) GetFile(ctx context.Context, requestingAccount *gtsmodel.Acc
|
||||
|
||||
// make sure the requesting account and the media account don't block each other
|
||||
if requestingAccount != nil {
|
||||
blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, owningAccountID, true)
|
||||
blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, owningAccountID, true)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block status could not be established between accounts %s and %s: %s", owningAccountID, requestingAccount.ID, err))
|
||||
}
|
||||
@@ -117,7 +117,7 @@ func parseSize(s string) (media.Size, error) {
|
||||
|
||||
func (p *Processor) getAttachmentContent(ctx context.Context, requestingAccount *gtsmodel.Account, wantedMediaID string, owningAccountID string, mediaSize media.Size) (*apimodel.Content, gtserror.WithCode) {
|
||||
// retrieve attachment from the database and do basic checks on it
|
||||
a, err := p.db.GetAttachmentByID(ctx, wantedMediaID)
|
||||
a, err := p.state.DB.GetAttachmentByID(ctx, wantedMediaID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("attachment %s could not be taken from the db: %s", wantedMediaID, err))
|
||||
}
|
||||
@@ -209,7 +209,7 @@ func (p *Processor) getEmojiContent(ctx context.Context, fileName string, owning
|
||||
// so this is more reliable than using full size url
|
||||
imageStaticURL := uris.GenerateURIForAttachment(owningAccountID, string(media.TypeEmoji), string(media.SizeStatic), fileName, "png")
|
||||
|
||||
e, err := p.db.GetEmojiByStaticURL(ctx, imageStaticURL)
|
||||
e, err := p.state.DB.GetEmojiByStaticURL(ctx, imageStaticURL)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("emoji %s could not be taken from the db: %s", fileName, err))
|
||||
}
|
||||
@@ -237,12 +237,12 @@ func (p *Processor) getEmojiContent(ctx context.Context, fileName string, owning
|
||||
func (p *Processor) retrieveFromStorage(ctx context.Context, storagePath string, content *apimodel.Content) (*apimodel.Content, gtserror.WithCode) {
|
||||
// If running on S3 storage with proxying disabled then
|
||||
// just fetch a pre-signed URL instead of serving the content.
|
||||
if url := p.storage.URL(ctx, storagePath); url != nil {
|
||||
if url := p.state.Storage.URL(ctx, storagePath); url != nil {
|
||||
content.URL = url
|
||||
return content, nil
|
||||
}
|
||||
|
||||
reader, err := p.storage.GetStream(ctx, storagePath)
|
||||
reader, err := p.state.Storage.GetStream(ctx, storagePath)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error retrieving from storage: %s", err))
|
||||
}
|
||||
|
@@ -30,7 +30,7 @@ import (
|
||||
)
|
||||
|
||||
func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
|
||||
attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
|
||||
attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
// attachment doesn't exist
|
||||
|
@@ -19,28 +19,25 @@
|
||||
package media
|
||||
|
||||
import (
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/storage"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
)
|
||||
|
||||
type Processor struct {
|
||||
state *state.State
|
||||
tc typeutils.TypeConverter
|
||||
mediaManager media.Manager
|
||||
transportController transport.Controller
|
||||
storage *storage.Driver
|
||||
db db.DB
|
||||
}
|
||||
|
||||
// New returns a new media processor.
|
||||
func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller, storage *storage.Driver) Processor {
|
||||
func New(state *state.State, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller) Processor {
|
||||
return Processor{
|
||||
state: state,
|
||||
tc: tc,
|
||||
mediaManager: mediaManager,
|
||||
transportController: transportController,
|
||||
storage: storage,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
@@ -20,12 +20,11 @@ package media_test
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
mediaprocessing "github.com/superseriousbusiness/gotosocial/internal/processing/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/storage"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
@@ -38,6 +37,7 @@ type MediaStandardTestSuite struct {
|
||||
db db.DB
|
||||
tc typeutils.TypeConverter
|
||||
storage *storage.Driver
|
||||
state state.State
|
||||
mediaManager media.Manager
|
||||
transportController transport.Controller
|
||||
|
||||
@@ -67,15 +67,19 @@ func (suite *MediaStandardTestSuite) SetupSuite() {
|
||||
}
|
||||
|
||||
func (suite *MediaStandardTestSuite) SetupTest() {
|
||||
suite.state.Caches.Init()
|
||||
|
||||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.db = testrig.NewTestDB(&suite.state)
|
||||
suite.state.DB = suite.db
|
||||
suite.tc = testrig.NewTestTypeConverter(suite.db)
|
||||
suite.storage = testrig.NewInMemoryStorage()
|
||||
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
|
||||
suite.transportController = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, concurrency.NewWorkerPool[messages.FromFederator](-1, -1))
|
||||
suite.mediaProcessor = mediaprocessing.New(suite.db, suite.tc, suite.mediaManager, suite.transportController, suite.storage)
|
||||
suite.state.Storage = suite.storage
|
||||
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
|
||||
suite.transportController = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media"))
|
||||
suite.mediaProcessor = mediaprocessing.New(&suite.state, suite.tc, suite.mediaManager, suite.transportController)
|
||||
testrig.StandardDBSetup(suite.db, nil)
|
||||
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")
|
||||
}
|
||||
|
@@ -33,7 +33,7 @@ import (
|
||||
// Unattach unattaches the media attachment with the given ID from any statuses it was attached to, making it available
|
||||
// for reattachment again.
|
||||
func (p *Processor) Unattach(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) {
|
||||
attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
|
||||
attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db"))
|
||||
@@ -49,7 +49,7 @@ func (p *Processor) Unattach(ctx context.Context, account *gtsmodel.Account, med
|
||||
attachment.UpdatedAt = time.Now()
|
||||
attachment.StatusID = ""
|
||||
|
||||
if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil {
|
||||
if err := p.state.DB.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error updating attachment: %s", err))
|
||||
}
|
||||
|
||||
|
@@ -32,7 +32,7 @@ import (
|
||||
|
||||
// Update updates a media attachment with the given id, using the provided form parameters.
|
||||
func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) {
|
||||
attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID)
|
||||
attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
// attachment doesn't exist
|
||||
@@ -62,7 +62,7 @@ func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, media
|
||||
updatingColumns = append(updatingColumns, "focus_x", "focus_y")
|
||||
}
|
||||
|
||||
if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil {
|
||||
if err := p.state.DB.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error updating media: %s", err))
|
||||
}
|
||||
|
||||
|
@@ -29,7 +29,7 @@ import (
|
||||
)
|
||||
|
||||
func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, excludeTypes []string, limit int, maxID string, sinceID string) (*apimodel.PageableResponse, gtserror.WithCode) {
|
||||
notifs, err := p.db.GetNotifications(ctx, authed.Account.ID, excludeTypes, limit, maxID, sinceID)
|
||||
notifs, err := p.state.DB.GetNotifications(ctx, authed.Account.ID, excludeTypes, limit, maxID, sinceID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
@@ -72,7 +72,7 @@ func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, ex
|
||||
}
|
||||
|
||||
func (p *Processor) NotificationsClear(ctx context.Context, authed *oauth.Auth) gtserror.WithCode {
|
||||
err := p.db.ClearNotifications(ctx, authed.Account.ID)
|
||||
err := p.state.DB.ClearNotifications(ctx, authed.Account.ID)
|
||||
if err != nil {
|
||||
return gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
@@ -19,10 +19,11 @@
|
||||
package processing
|
||||
|
||||
import (
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"context"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
mm "github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
@@ -34,23 +35,19 @@ import (
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing/status"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing/stream"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/storage"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/timeline"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/visibility"
|
||||
)
|
||||
|
||||
type Processor struct {
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
|
||||
fedWorker *concurrency.WorkerPool[messages.FromFederator]
|
||||
|
||||
federator federation.Federator
|
||||
tc typeutils.TypeConverter
|
||||
oauthServer oauth.Server
|
||||
mediaManager mm.Manager
|
||||
storage *storage.Driver
|
||||
statusTimelines timeline.Manager
|
||||
db db.DB
|
||||
state *state.State
|
||||
filter visibility.Filter
|
||||
|
||||
/*
|
||||
@@ -105,76 +102,65 @@ func NewProcessor(
|
||||
federator federation.Federator,
|
||||
oauthServer oauth.Server,
|
||||
mediaManager mm.Manager,
|
||||
storage *storage.Driver,
|
||||
db db.DB,
|
||||
state *state.State,
|
||||
emailSender email.Sender,
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI],
|
||||
fedWorker *concurrency.WorkerPool[messages.FromFederator],
|
||||
) *Processor {
|
||||
parseMentionFunc := GetParseMentionFunc(db, federator)
|
||||
parseMentionFunc := GetParseMentionFunc(state.DB, federator)
|
||||
|
||||
filter := visibility.NewFilter(db)
|
||||
filter := visibility.NewFilter(state.DB)
|
||||
|
||||
return &Processor{
|
||||
clientWorker: clientWorker,
|
||||
fedWorker: fedWorker,
|
||||
|
||||
federator: federator,
|
||||
tc: tc,
|
||||
oauthServer: oauthServer,
|
||||
mediaManager: mediaManager,
|
||||
storage: storage,
|
||||
statusTimelines: timeline.NewManager(StatusGrabFunction(db), StatusFilterFunction(db, filter), StatusPrepareFunction(db, tc), StatusSkipInsertFunction()),
|
||||
db: db,
|
||||
filter: filter,
|
||||
|
||||
// sub processors
|
||||
account: account.New(db, tc, mediaManager, oauthServer, clientWorker, federator, parseMentionFunc),
|
||||
admin: admin.New(db, tc, mediaManager, federator.TransportController(), storage, clientWorker),
|
||||
fedi: fedi.New(db, tc, federator),
|
||||
media: media.New(db, tc, mediaManager, federator.TransportController(), storage),
|
||||
report: report.New(db, tc, clientWorker),
|
||||
status: status.New(db, tc, clientWorker, parseMentionFunc),
|
||||
stream: stream.New(db, oauthServer),
|
||||
user: user.New(db, emailSender),
|
||||
processor := &Processor{
|
||||
federator: federator,
|
||||
tc: tc,
|
||||
oauthServer: oauthServer,
|
||||
mediaManager: mediaManager,
|
||||
statusTimelines: timeline.NewManager(
|
||||
StatusGrabFunction(state.DB),
|
||||
StatusFilterFunction(state.DB, filter),
|
||||
StatusPrepareFunction(state.DB, tc),
|
||||
StatusSkipInsertFunction(),
|
||||
),
|
||||
state: state,
|
||||
filter: filter,
|
||||
}
|
||||
|
||||
// sub processors
|
||||
processor.account = account.New(state, tc, mediaManager, oauthServer, federator, parseMentionFunc)
|
||||
processor.admin = admin.New(state, tc, mediaManager, federator.TransportController())
|
||||
processor.fedi = fedi.New(state, tc, federator)
|
||||
processor.media = media.New(state, tc, mediaManager, federator.TransportController())
|
||||
processor.report = report.New(state, tc)
|
||||
processor.status = status.New(state, tc, parseMentionFunc)
|
||||
processor.stream = stream.New(state, oauthServer)
|
||||
processor.user = user.New(state, emailSender)
|
||||
|
||||
return processor
|
||||
}
|
||||
|
||||
// Start starts the Processor, reading from its channels and passing messages back and forth.
|
||||
func (p *Processor) EnqueueClientAPI(ctx context.Context, msg messages.FromClientAPI) {
|
||||
log.WithContext(ctx).WithField("msg", msg).Trace("enqueuing client API")
|
||||
_ = p.state.Workers.ClientAPI.MustEnqueueCtx(ctx, func(ctx context.Context) {
|
||||
if err := p.ProcessFromClientAPI(ctx, msg); err != nil {
|
||||
log.Errorf(ctx, "error processing client API message: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Processor) EnqueueFederator(ctx context.Context, msg messages.FromFederator) {
|
||||
log.WithContext(ctx).WithField("msg", msg).Trace("enqueuing federator")
|
||||
_ = p.state.Workers.Federator.MustEnqueueCtx(ctx, func(ctx context.Context) {
|
||||
if err := p.ProcessFromFederator(ctx, msg); err != nil {
|
||||
log.Errorf(ctx, "error processing federator message: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Start starts the Processor.
|
||||
func (p *Processor) Start() error {
|
||||
// Setup and start the client API worker pool
|
||||
p.clientWorker.SetProcessor(p.ProcessFromClientAPI)
|
||||
if err := p.clientWorker.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Setup and start the federator worker pool
|
||||
p.fedWorker.SetProcessor(p.ProcessFromFederator)
|
||||
if err := p.fedWorker.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start status timelines
|
||||
if err := p.statusTimelines.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return p.statusTimelines.Start()
|
||||
}
|
||||
|
||||
// Stop stops the processor cleanly, finishing handling any remaining messages before closing down.
|
||||
// Stop stops the processor cleanly.
|
||||
func (p *Processor) Stop() error {
|
||||
if err := p.clientWorker.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := p.fedWorker.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := p.statusTimelines.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return p.statusTimelines.Stop()
|
||||
}
|
||||
|
@@ -20,15 +20,14 @@ package processing_test
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/storage"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
@@ -40,6 +39,7 @@ type ProcessingStandardTestSuite struct {
|
||||
suite.Suite
|
||||
db db.DB
|
||||
storage *storage.Driver
|
||||
state state.State
|
||||
mediaManager media.Manager
|
||||
typeconverter typeutils.TypeConverter
|
||||
httpClient *testrig.MockHTTPClient
|
||||
@@ -86,25 +86,29 @@ func (suite *ProcessingStandardTestSuite) SetupSuite() {
|
||||
}
|
||||
|
||||
func (suite *ProcessingStandardTestSuite) SetupTest() {
|
||||
suite.state.Caches.Init()
|
||||
testrig.StartWorkers(&suite.state)
|
||||
|
||||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.db = testrig.NewTestDB(&suite.state)
|
||||
suite.state.DB = suite.db
|
||||
suite.testActivities = testrig.NewTestActivities(suite.testAccounts)
|
||||
suite.storage = testrig.NewInMemoryStorage()
|
||||
suite.state.Storage = suite.storage
|
||||
suite.typeconverter = testrig.NewTestTypeConverter(suite.db)
|
||||
suite.httpClient = testrig.NewMockHTTPClient(nil, "../../testrig/media")
|
||||
|
||||
clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
suite.transportController = testrig.NewTestTransportController(suite.httpClient, suite.db, fedWorker)
|
||||
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
|
||||
suite.federator = testrig.NewTestFederator(suite.db, suite.transportController, suite.storage, suite.mediaManager, fedWorker)
|
||||
suite.transportController = testrig.NewTestTransportController(&suite.state, suite.httpClient)
|
||||
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
|
||||
suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager)
|
||||
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
|
||||
suite.emailSender = testrig.NewEmailSender("../../web/template/", nil)
|
||||
|
||||
suite.processor = processing.NewProcessor(suite.typeconverter, suite.federator, suite.oauthServer, suite.mediaManager, suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker)
|
||||
suite.processor = processing.NewProcessor(suite.typeconverter, suite.federator, suite.oauthServer, suite.mediaManager, &suite.state, suite.emailSender)
|
||||
suite.state.Workers.EnqueueClientAPI = suite.processor.EnqueueClientAPI
|
||||
suite.state.Workers.EnqueueFederator = suite.processor.EnqueueFederator
|
||||
|
||||
testrig.StandardDBSetup(suite.db, suite.testAccounts)
|
||||
testrig.StandardStorageSetup(suite.storage, "../../testrig/media")
|
||||
@@ -119,4 +123,5 @@ func (suite *ProcessingStandardTestSuite) TearDownTest() {
|
||||
if err := suite.processor.Stop(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
testrig.StopWorkers(&suite.state)
|
||||
}
|
||||
|
@@ -41,7 +41,7 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form
|
||||
}
|
||||
|
||||
// validate + fetch target account
|
||||
targetAccount, err := p.db.GetAccountByID(ctx, form.AccountID)
|
||||
targetAccount, err := p.state.DB.GetAccountByID(ctx, form.AccountID)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNoEntries) {
|
||||
err = fmt.Errorf("account with ID %s does not exist", form.AccountID)
|
||||
@@ -52,7 +52,7 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form
|
||||
}
|
||||
|
||||
// fetch statuses by IDs given in the report form (noop if no statuses given)
|
||||
statuses, err := p.db.GetStatuses(ctx, form.StatusIDs)
|
||||
statuses, err := p.state.DB.GetStatuses(ctx, form.StatusIDs)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("db error fetching report target statuses: %w", err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
@@ -79,11 +79,11 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form
|
||||
Forwarded: &form.Forward,
|
||||
}
|
||||
|
||||
if err := p.db.PutReport(ctx, report); err != nil {
|
||||
if err := p.state.DB.PutReport(ctx, report); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ObjectProfile,
|
||||
APActivityType: ap.ActivityFlag,
|
||||
GTSModel: report,
|
||||
|
@@ -32,7 +32,7 @@ import (
|
||||
|
||||
// Get returns the user view of a moderation report, with the given id.
|
||||
func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.Report, gtserror.WithCode) {
|
||||
report, err := p.db.GetReportByID(ctx, id)
|
||||
report, err := p.state.DB.GetReportByID(ctx, id)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(err)
|
||||
@@ -64,7 +64,7 @@ func (p *Processor) GetMultiple(
|
||||
minID string,
|
||||
limit int,
|
||||
) (*apimodel.PageableResponse, gtserror.WithCode) {
|
||||
reports, err := p.db.GetReports(ctx, resolved, account.ID, targetAccountID, maxID, sinceID, minID, limit)
|
||||
reports, err := p.state.DB.GetReports(ctx, resolved, account.ID, targetAccountID, maxID, sinceID, minID, limit)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return util.EmptyPageableResponse(), nil
|
||||
|
@@ -19,22 +19,18 @@
|
||||
package report
|
||||
|
||||
import (
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
)
|
||||
|
||||
type Processor struct {
|
||||
db db.DB
|
||||
tc typeutils.TypeConverter
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
|
||||
state *state.State
|
||||
tc typeutils.TypeConverter
|
||||
}
|
||||
|
||||
func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor {
|
||||
func New(state *state.State, tc typeutils.TypeConverter) Processor {
|
||||
return Processor{
|
||||
tc: tc,
|
||||
db: db,
|
||||
clientWorker: clientWorker,
|
||||
state: state,
|
||||
tc: tc,
|
||||
}
|
||||
}
|
||||
|
@@ -88,7 +88,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
|
||||
|
||||
if username, domain, err := util.ExtractNamestringParts(maybeNamestring); err == nil {
|
||||
l.Trace("search term is a mention, looking it up...")
|
||||
blocked, err := p.db.IsDomainBlocked(ctx, domain)
|
||||
blocked, err := p.state.DB.IsDomainBlocked(ctx, domain)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking domain block: %w", err))
|
||||
}
|
||||
@@ -120,7 +120,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
|
||||
if uri, err := url.Parse(query); err == nil {
|
||||
if uri.Scheme == "https" || uri.Scheme == "http" {
|
||||
l.Trace("search term is a uri, looking it up...")
|
||||
blocked, err := p.db.IsURIBlocked(ctx, uri)
|
||||
blocked, err := p.state.DB.IsURIBlocked(ctx, uri)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking domain block: %w", err))
|
||||
}
|
||||
@@ -178,7 +178,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
|
||||
*/
|
||||
for _, foundAccount := range foundAccounts {
|
||||
// make sure there's no block in either direction between the account and the requester
|
||||
blocked, err := p.db.IsBlocked(ctx, authed.Account.ID, foundAccount.ID, true)
|
||||
blocked, err := p.state.DB.IsBlocked(ctx, authed.Account.ID, foundAccount.ID, true)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("SearchGet: error checking block between %s and %s: %s", authed.Account.ID, foundAccount.ID, err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
@@ -246,14 +246,14 @@ func (p *Processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth,
|
||||
)
|
||||
|
||||
// Search the database for existing account with ID URI.
|
||||
account, err = p.db.GetAccountByURI(ctx, uriStr)
|
||||
account, err = p.state.DB.GetAccountByURI(ctx, uriStr)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
return nil, fmt.Errorf("searchAccountByURI: error checking database for account %s: %w", uriStr, err)
|
||||
}
|
||||
|
||||
if account == nil {
|
||||
// Else, search the database for existing by ID URL.
|
||||
account, err = p.db.GetAccountByURL(ctx, uriStr)
|
||||
account, err = p.state.DB.GetAccountByURL(ctx, uriStr)
|
||||
if err != nil {
|
||||
if !errors.Is(err, db.ErrNoEntries) {
|
||||
return nil, fmt.Errorf("searchAccountByURI: error checking database for account %s: %w", uriStr, err)
|
||||
@@ -281,7 +281,7 @@ func (p *Processor) searchAccountByUsernameDomain(ctx context.Context, authed *o
|
||||
}
|
||||
|
||||
// Search the database for existing account with USERNAME@DOMAIN
|
||||
account, err := p.db.GetAccountByUsernameDomain(ctx, username, domain)
|
||||
account, err := p.state.DB.GetAccountByUsernameDomain(ctx, username, domain)
|
||||
if err != nil {
|
||||
if !errors.Is(err, db.ErrNoEntries) {
|
||||
return nil, fmt.Errorf("searchAccountByUsernameDomain: error checking database for account %s@%s: %w", username, domain, err)
|
||||
|
@@ -32,7 +32,7 @@ import (
|
||||
|
||||
// BookmarkCreate adds a bookmark for the requestingAccount, targeting the given status (no-op if bookmark already exists).
|
||||
func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
|
||||
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
|
||||
targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
|
||||
}
|
||||
@@ -50,7 +50,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo
|
||||
// first check if the status is already bookmarked, if so we don't need to do anything
|
||||
newBookmark := true
|
||||
gtsBookmark := >smodel.StatusBookmark{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil {
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil {
|
||||
// we already have a bookmark for this status
|
||||
newBookmark = false
|
||||
}
|
||||
@@ -67,7 +67,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo
|
||||
Status: targetStatus,
|
||||
}
|
||||
|
||||
if err := p.db.Put(ctx, gtsBookmark); err != nil {
|
||||
if err := p.state.DB.Put(ctx, gtsBookmark); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error putting bookmark in database: %s", err))
|
||||
}
|
||||
}
|
||||
@@ -83,7 +83,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo
|
||||
|
||||
// BookmarkRemove removes a bookmark for the requesting account, targeting the given status (no-op if bookmark doesn't exist).
|
||||
func (p *Processor) BookmarkRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
|
||||
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
|
||||
targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
|
||||
}
|
||||
@@ -101,13 +101,13 @@ func (p *Processor) BookmarkRemove(ctx context.Context, requestingAccount *gtsmo
|
||||
// first check if the status is actually bookmarked
|
||||
toUnbookmark := false
|
||||
gtsBookmark := >smodel.StatusBookmark{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil {
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil {
|
||||
// we have a bookmark for this status
|
||||
toUnbookmark = true
|
||||
}
|
||||
|
||||
if toUnbookmark {
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err != nil {
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error unfaveing status: %s", err))
|
||||
}
|
||||
}
|
||||
|
@@ -33,7 +33,7 @@ import (
|
||||
|
||||
// BoostCreate processes the boost/reblog of a given status, returning the newly-created boost if all is well.
|
||||
func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
|
||||
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
|
||||
targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
|
||||
}
|
||||
@@ -47,7 +47,7 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel
|
||||
// boost boosts, and it looks absolutely bizarre in the UI
|
||||
if targetStatus.BoostOfID != "" {
|
||||
if targetStatus.BoostOf == nil {
|
||||
b, err := p.db.GetStatusByID(ctx, targetStatus.BoostOfID)
|
||||
b, err := p.state.DB.GetStatusByID(ctx, targetStatus.BoostOfID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("couldn't fetch boosted status %s", targetStatus.BoostOfID))
|
||||
}
|
||||
@@ -74,12 +74,12 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel
|
||||
boostWrapperStatus.BoostOfAccount = targetStatus.Account
|
||||
|
||||
// put the boost in the database
|
||||
if err := p.db.PutStatus(ctx, boostWrapperStatus); err != nil {
|
||||
if err := p.state.DB.PutStatus(ctx, boostWrapperStatus); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
// send it back to the processor for async processing
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActivityAnnounce,
|
||||
APActivityType: ap.ActivityCreate,
|
||||
GTSModel: boostWrapperStatus,
|
||||
@@ -98,7 +98,7 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel
|
||||
|
||||
// BoostRemove processes the unboost/unreblog of a given status, returning the status if all is well.
|
||||
func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
|
||||
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
|
||||
targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
|
||||
}
|
||||
@@ -128,7 +128,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel
|
||||
Value: requestingAccount.ID,
|
||||
},
|
||||
}
|
||||
err = p.db.GetWhere(ctx, where, gtsBoost)
|
||||
err = p.state.DB.GetWhere(ctx, where, gtsBoost)
|
||||
if err == nil {
|
||||
// we have a boost
|
||||
toUnboost = true
|
||||
@@ -151,7 +151,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel
|
||||
gtsBoost.BoostOf.Account = targetStatus.Account
|
||||
|
||||
// send it back to the processor for async processing
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActivityAnnounce,
|
||||
APActivityType: ap.ActivityUndo,
|
||||
GTSModel: gtsBoost,
|
||||
@@ -170,7 +170,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel
|
||||
|
||||
// StatusBoostedBy returns a slice of accounts that have boosted the given status, filtered according to privacy settings.
|
||||
func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) {
|
||||
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
|
||||
targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
|
||||
if err != nil {
|
||||
wrapped := fmt.Errorf("BoostedBy: error fetching status %s: %s", targetStatusID, err)
|
||||
if !errors.Is(err, db.ErrNoEntries) {
|
||||
@@ -181,7 +181,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm
|
||||
|
||||
if boostOfID := targetStatus.BoostOfID; boostOfID != "" {
|
||||
// the target status is a boost wrapper, redirect this request to the status it boosts
|
||||
boostedStatus, err := p.db.GetStatusByID(ctx, boostOfID)
|
||||
boostedStatus, err := p.state.DB.GetStatusByID(ctx, boostOfID)
|
||||
if err != nil {
|
||||
wrapped := fmt.Errorf("BoostedBy: error fetching status %s: %s", boostOfID, err)
|
||||
if !errors.Is(err, db.ErrNoEntries) {
|
||||
@@ -202,7 +202,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm
|
||||
return nil, gtserror.NewErrorNotFound(err)
|
||||
}
|
||||
|
||||
statusReblogs, err := p.db.GetStatusReblogs(ctx, targetStatus)
|
||||
statusReblogs, err := p.state.DB.GetStatusReblogs(ctx, targetStatus)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("BoostedBy: error seeing who boosted status: %s", err)
|
||||
return nil, gtserror.NewErrorNotFound(err)
|
||||
@@ -211,7 +211,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm
|
||||
// filter account IDs so the user doesn't see accounts they blocked or which blocked them
|
||||
accountIDs := make([]string, 0, len(statusReblogs))
|
||||
for _, s := range statusReblogs {
|
||||
blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, s.AccountID, true)
|
||||
blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, s.AccountID, true)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("BoostedBy: error checking blocks: %s", err)
|
||||
return nil, gtserror.NewErrorNotFound(err)
|
||||
@@ -226,7 +226,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm
|
||||
// fetch accounts + create their API representations
|
||||
apiAccounts := make([]*apimodel.Account, 0, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
account, err := p.db.GetAccountByID(ctx, accountID)
|
||||
account, err := p.state.DB.GetAccountByID(ctx, accountID)
|
||||
if err != nil {
|
||||
wrapped := fmt.Errorf("BoostedBy: error fetching account %s: %s", accountID, err)
|
||||
if !errors.Is(err, db.ErrNoEntries) {
|
||||
|
@@ -61,11 +61,11 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, appli
|
||||
Text: form.Status,
|
||||
}
|
||||
|
||||
if errWithCode := processReplyToID(ctx, p.db, form, account.ID, newStatus); errWithCode != nil {
|
||||
if errWithCode := processReplyToID(ctx, p.state.DB, form, account.ID, newStatus); errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
if errWithCode := processMediaIDs(ctx, p.db, form, account.ID, newStatus); errWithCode != nil {
|
||||
if errWithCode := processMediaIDs(ctx, p.state.DB, form, account.ID, newStatus); errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
@@ -77,17 +77,17 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, appli
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
if err := processContent(ctx, p.db, p.formatter, p.parseMention, form, account.ID, newStatus); err != nil {
|
||||
if err := processContent(ctx, p.state.DB, p.formatter, p.parseMention, form, account.ID, newStatus); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
// put the new status in the database
|
||||
if err := p.db.PutStatus(ctx, newStatus); err != nil {
|
||||
if err := p.state.DB.PutStatus(ctx, newStatus); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
// send it back to the processor for async processing
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ObjectNote,
|
||||
APActivityType: ap.ActivityCreate,
|
||||
GTSModel: newStatus,
|
||||
|
@@ -32,7 +32,7 @@ import (
|
||||
|
||||
// Delete processes the delete of a given status, returning the deleted status if the delete goes through.
|
||||
func (p *Processor) Delete(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
|
||||
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
|
||||
targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
|
||||
}
|
||||
@@ -50,7 +50,7 @@ func (p *Processor) Delete(ctx context.Context, requestingAccount *gtsmodel.Acco
|
||||
}
|
||||
|
||||
// send the status back to the processor for async processing
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ObjectNote,
|
||||
APActivityType: ap.ActivityDelete,
|
||||
GTSModel: targetStatus,
|
||||
|
@@ -35,7 +35,7 @@ import (
|
||||
|
||||
// FaveCreate processes the faving of a given status, returning the updated status if the fave goes through.
|
||||
func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
|
||||
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
|
||||
targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
|
||||
}
|
||||
@@ -57,7 +57,7 @@ func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel.
|
||||
// first check if the status is already faved, if so we don't need to do anything
|
||||
newFave := true
|
||||
gtsFave := >smodel.StatusFave{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err == nil {
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err == nil {
|
||||
// we already have a fave for this status
|
||||
newFave = false
|
||||
}
|
||||
@@ -77,12 +77,12 @@ func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel.
|
||||
URI: uris.GenerateURIForLike(requestingAccount.Username, thisFaveID),
|
||||
}
|
||||
|
||||
if err := p.db.Put(ctx, gtsFave); err != nil {
|
||||
if err := p.state.DB.Put(ctx, gtsFave); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error putting fave in database: %s", err))
|
||||
}
|
||||
|
||||
// send it back to the processor for async processing
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActivityLike,
|
||||
APActivityType: ap.ActivityCreate,
|
||||
GTSModel: gtsFave,
|
||||
@@ -102,7 +102,7 @@ func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel.
|
||||
|
||||
// FaveRemove processes the unfaving of a given status, returning the updated status if the fave goes through.
|
||||
func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
|
||||
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
|
||||
targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
|
||||
}
|
||||
@@ -122,7 +122,7 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.
|
||||
var toUnfave bool
|
||||
|
||||
gtsFave := >smodel.StatusFave{}
|
||||
err = p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave)
|
||||
err = p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave)
|
||||
if err == nil {
|
||||
// we have a fave
|
||||
toUnfave = true
|
||||
@@ -138,12 +138,12 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.
|
||||
|
||||
if toUnfave {
|
||||
// we had a fave, so take some action to get rid of it
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err != nil {
|
||||
if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error unfaveing status: %s", err))
|
||||
}
|
||||
|
||||
// send it back to the processor for async processing
|
||||
p.clientWorker.Queue(messages.FromClientAPI{
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActivityLike,
|
||||
APActivityType: ap.ActivityUndo,
|
||||
GTSModel: gtsFave,
|
||||
@@ -162,7 +162,7 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.
|
||||
|
||||
// FavedBy returns a slice of accounts that have liked the given status, filtered according to privacy settings.
|
||||
func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) {
|
||||
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
|
||||
targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
|
||||
}
|
||||
@@ -178,7 +178,7 @@ func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Acc
|
||||
return nil, gtserror.NewErrorNotFound(errors.New("status is not visible"))
|
||||
}
|
||||
|
||||
statusFaves, err := p.db.GetStatusFaves(ctx, targetStatus)
|
||||
statusFaves, err := p.state.DB.GetStatusFaves(ctx, targetStatus)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing who faved status: %s", err))
|
||||
}
|
||||
@@ -186,7 +186,7 @@ func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Acc
|
||||
// filter the list so the user doesn't see accounts they blocked or which blocked them
|
||||
filteredAccounts := []*gtsmodel.Account{}
|
||||
for _, fave := range statusFaves {
|
||||
blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, fave.AccountID, true)
|
||||
blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, fave.AccountID, true)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking blocks: %s", err))
|
||||
}
|
||||
|
@@ -31,7 +31,7 @@ import (
|
||||
|
||||
// Get gets the given status, taking account of privacy settings and blocks etc.
|
||||
func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) {
|
||||
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
|
||||
targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
|
||||
}
|
||||
@@ -57,7 +57,7 @@ func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account
|
||||
|
||||
// ContextGet returns the context (previous and following posts) from the given status ID.
|
||||
func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Context, gtserror.WithCode) {
|
||||
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
|
||||
targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err))
|
||||
}
|
||||
@@ -78,7 +78,7 @@ func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel.
|
||||
Descendants: []apimodel.Status{},
|
||||
}
|
||||
|
||||
parents, err := p.db.GetStatusParents(ctx, targetStatus, false)
|
||||
parents, err := p.state.DB.GetStatusParents(ctx, targetStatus, false)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
@@ -96,7 +96,7 @@ func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel.
|
||||
return context.Ancestors[i].ID < context.Ancestors[j].ID
|
||||
})
|
||||
|
||||
children, err := p.db.GetStatusChildren(ctx, targetStatus, false, "")
|
||||
children, err := p.state.DB.GetStatusChildren(ctx, targetStatus, false, "")
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
@@ -39,7 +39,7 @@ const allowedPinnedCount = 10
|
||||
// - Status is public, unlisted, or followers-only.
|
||||
// - Status is not a boost.
|
||||
func (p *Processor) getPinnableStatus(ctx context.Context, targetStatusID string, requestingAccountID string) (*gtsmodel.Status, gtserror.WithCode) {
|
||||
targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID)
|
||||
targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error fetching status %s: %w", targetStatusID, err)
|
||||
return nil, gtserror.NewErrorNotFound(err)
|
||||
@@ -84,7 +84,7 @@ func (p *Processor) PinCreate(ctx context.Context, requestingAccount *gtsmodel.A
|
||||
return nil, gtserror.NewErrorUnprocessableEntity(err, err.Error())
|
||||
}
|
||||
|
||||
pinnedCount, err := p.db.CountAccountPinned(ctx, requestingAccount.ID)
|
||||
pinnedCount, err := p.state.DB.CountAccountPinned(ctx, requestingAccount.ID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking number of pinned statuses: %w", err))
|
||||
}
|
||||
@@ -95,7 +95,7 @@ func (p *Processor) PinCreate(ctx context.Context, requestingAccount *gtsmodel.A
|
||||
}
|
||||
|
||||
targetStatus.PinnedAt = time.Now()
|
||||
if err := p.db.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil {
|
||||
if err := p.state.DB.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error pinning status: %w", err))
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ func (p *Processor) PinRemove(ctx context.Context, requestingAccount *gtsmodel.A
|
||||
|
||||
if targetStatus.PinnedAt.IsZero() {
|
||||
targetStatus.PinnedAt = time.Time{}
|
||||
if err := p.db.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil {
|
||||
if err := p.state.DB.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error unpinning status: %w", err))
|
||||
}
|
||||
}
|
||||
|
@@ -19,32 +19,28 @@
|
||||
package status
|
||||
|
||||
import (
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/text"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/visibility"
|
||||
)
|
||||
|
||||
type Processor struct {
|
||||
state *state.State
|
||||
tc typeutils.TypeConverter
|
||||
db db.DB
|
||||
filter visibility.Filter
|
||||
formatter text.Formatter
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
|
||||
parseMention gtsmodel.ParseMentionFunc
|
||||
}
|
||||
|
||||
// New returns a new status processor.
|
||||
func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor {
|
||||
func New(state *state.State, tc typeutils.TypeConverter, parseMention gtsmodel.ParseMentionFunc) Processor {
|
||||
return Processor{
|
||||
state: state,
|
||||
tc: tc,
|
||||
db: db,
|
||||
filter: visibility.NewFilter(db),
|
||||
formatter: text.NewFormatter(db),
|
||||
clientWorker: clientWorker,
|
||||
filter: visibility.NewFilter(state.DB),
|
||||
formatter: text.NewFormatter(state.DB),
|
||||
parseMention: parseMention,
|
||||
}
|
||||
}
|
||||
|
@@ -19,17 +19,14 @@
|
||||
package status_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/concurrency"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/federation"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/media"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing/status"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/storage"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/transport"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
|
||||
@@ -42,9 +39,9 @@ type StatusStandardTestSuite struct {
|
||||
typeConverter typeutils.TypeConverter
|
||||
tc transport.Controller
|
||||
storage *storage.Driver
|
||||
state state.State
|
||||
mediaManager media.Manager
|
||||
federator federation.Federator
|
||||
clientWorker *concurrency.WorkerPool[messages.FromClientAPI]
|
||||
|
||||
// standard suite models
|
||||
testTokens map[string]*gtsmodel.Token
|
||||
@@ -74,21 +71,22 @@ func (suite *StatusStandardTestSuite) SetupSuite() {
|
||||
}
|
||||
|
||||
func (suite *StatusStandardTestSuite) SetupTest() {
|
||||
suite.state.Caches.Init()
|
||||
testrig.StartWorkers(&suite.state)
|
||||
|
||||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.db = testrig.NewTestDB(&suite.state)
|
||||
suite.typeConverter = testrig.NewTestTypeConverter(suite.db)
|
||||
suite.clientWorker = concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
|
||||
suite.tc = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker)
|
||||
suite.state.DB = suite.db
|
||||
|
||||
suite.tc = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media"))
|
||||
suite.storage = testrig.NewInMemoryStorage()
|
||||
suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
|
||||
suite.federator = testrig.NewTestFederator(suite.db, suite.tc, suite.storage, suite.mediaManager, fedWorker)
|
||||
suite.status = status.New(suite.db, suite.typeConverter, suite.clientWorker, processing.GetParseMentionFunc(suite.db, suite.federator))
|
||||
suite.clientWorker.SetProcessor(func(ctx context.Context, msg messages.FromClientAPI) error { return nil })
|
||||
suite.NoError(suite.clientWorker.Start())
|
||||
suite.state.Storage = suite.storage
|
||||
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
|
||||
suite.federator = testrig.NewTestFederator(&suite.state, suite.tc, suite.mediaManager)
|
||||
suite.status = status.New(&suite.state, suite.typeConverter, processing.GetParseMentionFunc(suite.db, suite.federator))
|
||||
|
||||
testrig.StandardDBSetup(suite.db, suite.testAccounts)
|
||||
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")
|
||||
@@ -97,4 +95,5 @@ func (suite *StatusStandardTestSuite) SetupTest() {
|
||||
func (suite *StatusStandardTestSuite) TearDownTest() {
|
||||
testrig.StandardDBTeardown(suite.db)
|
||||
testrig.StandardStorageTeardown(suite.storage)
|
||||
testrig.StopWorkers(&suite.state)
|
||||
}
|
||||
|
@@ -173,7 +173,7 @@ func (p *Processor) HomeTimelineGet(ctx context.Context, authed *oauth.Auth, max
|
||||
}
|
||||
|
||||
func (p *Processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.PageableResponse, gtserror.WithCode) {
|
||||
statuses, err := p.db.GetPublicTimeline(ctx, maxID, sinceID, minID, limit, local)
|
||||
statuses, err := p.state.DB.GetPublicTimeline(ctx, maxID, sinceID, minID, limit, local)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are just no entries left
|
||||
@@ -218,7 +218,7 @@ func (p *Processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, m
|
||||
}
|
||||
|
||||
func (p *Processor) FavedTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.PageableResponse, gtserror.WithCode) {
|
||||
statuses, nextMaxID, prevMinID, err := p.db.GetFavedTimeline(ctx, authed.Account.ID, maxID, minID, limit)
|
||||
statuses, nextMaxID, prevMinID, err := p.state.DB.GetFavedTimeline(ctx, authed.Account.ID, maxID, minID, limit)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
// there are just no entries left
|
||||
@@ -255,7 +255,7 @@ func (p *Processor) filterPublicStatuses(ctx context.Context, authed *oauth.Auth
|
||||
apiStatuses := []*apimodel.Status{}
|
||||
for _, s := range statuses {
|
||||
targetAccount := >smodel.Account{}
|
||||
if err := p.db.GetByID(ctx, s.AccountID, targetAccount); err != nil {
|
||||
if err := p.state.DB.GetByID(ctx, s.AccountID, targetAccount); err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
log.Debugf(ctx, "skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
|
||||
continue
|
||||
@@ -288,7 +288,7 @@ func (p *Processor) filterFavedStatuses(ctx context.Context, authed *oauth.Auth,
|
||||
apiStatuses := []*apimodel.Status{}
|
||||
for _, s := range statuses {
|
||||
targetAccount := >smodel.Account{}
|
||||
if err := p.db.GetByID(ctx, s.AccountID, targetAccount); err != nil {
|
||||
if err := p.state.DB.GetByID(ctx, s.AccountID, targetAccount); err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
log.Debugf(ctx, "skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
|
||||
continue
|
||||
|
@@ -41,7 +41,7 @@ func (p *Processor) Authorize(ctx context.Context, accessToken string) (*gtsmode
|
||||
return nil, gtserror.NewErrorUnauthorized(err)
|
||||
}
|
||||
|
||||
user, err := p.db.GetUserByID(ctx, uid)
|
||||
user, err := p.state.DB.GetUserByID(ctx, uid)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
err := fmt.Errorf("no user found for validated uid %s", uid)
|
||||
@@ -50,7 +50,7 @@ func (p *Processor) Authorize(ctx context.Context, accessToken string) (*gtsmode
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
acct, err := p.db.GetAccountByID(ctx, user.AccountID)
|
||||
acct, err := p.state.DB.GetAccountByID(ctx, user.AccountID)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
err := fmt.Errorf("no account found for validated uid %s", uid)
|
||||
|
@@ -22,22 +22,21 @@ import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/stream"
|
||||
)
|
||||
|
||||
type Processor struct {
|
||||
db db.DB
|
||||
state *state.State
|
||||
oauthServer oauth.Server
|
||||
streamMap *sync.Map
|
||||
streamMap sync.Map
|
||||
}
|
||||
|
||||
func New(db db.DB, oauthServer oauth.Server) Processor {
|
||||
func New(state *state.State, oauthServer oauth.Server) Processor {
|
||||
return Processor{
|
||||
db: db,
|
||||
state: state,
|
||||
oauthServer: oauthServer,
|
||||
streamMap: &sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/oauth"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing/stream"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
@@ -33,19 +34,23 @@ type StreamTestSuite struct {
|
||||
testTokens map[string]*gtsmodel.Token
|
||||
db db.DB
|
||||
oauthServer oauth.Server
|
||||
state state.State
|
||||
|
||||
streamProcessor stream.Processor
|
||||
}
|
||||
|
||||
func (suite *StreamTestSuite) SetupTest() {
|
||||
suite.state.Caches.Init()
|
||||
|
||||
testrig.InitTestLog()
|
||||
testrig.InitTestConfig()
|
||||
|
||||
suite.testAccounts = testrig.NewTestAccounts()
|
||||
suite.testTokens = testrig.NewTestTokens()
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.db = testrig.NewTestDB(&suite.state)
|
||||
suite.state.DB = suite.db
|
||||
suite.oauthServer = testrig.NewTestOauthServer(suite.db)
|
||||
suite.streamProcessor = stream.New(suite.db, suite.oauthServer)
|
||||
suite.streamProcessor = stream.New(&suite.state, suite.oauthServer)
|
||||
|
||||
testrig.StandardDBSetup(suite.db, suite.testAccounts)
|
||||
}
|
||||
|
@@ -56,7 +56,7 @@ func (p *Processor) EmailSendConfirmation(ctx context.Context, user *gtsmodel.Us
|
||||
// pull our instance entry from the database so we can greet the user nicely in the email
|
||||
instance := >smodel.Instance{}
|
||||
host := config.GetHost()
|
||||
if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, instance); err != nil {
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, instance); err != nil {
|
||||
return fmt.Errorf("SendConfirmEmail: error getting instance: %s", err)
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func (p *Processor) EmailSendConfirmation(ctx context.Context, user *gtsmodel.Us
|
||||
user.LastEmailedAt = time.Now()
|
||||
user.UpdatedAt = time.Now()
|
||||
|
||||
if err := p.db.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil {
|
||||
if err := p.state.DB.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil {
|
||||
return fmt.Errorf("SendConfirmEmail: error updating user entry after email sent: %s", err)
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ func (p *Processor) EmailConfirm(ctx context.Context, token string) (*gtsmodel.U
|
||||
return nil, gtserror.NewErrorNotFound(errors.New("no token provided"))
|
||||
}
|
||||
|
||||
user, err := p.db.GetUserByConfirmationToken(ctx, token)
|
||||
user, err := p.state.DB.GetUserByConfirmationToken(ctx, token)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(err)
|
||||
@@ -101,7 +101,7 @@ func (p *Processor) EmailConfirm(ctx context.Context, token string) (*gtsmodel.U
|
||||
}
|
||||
|
||||
if user.Account == nil {
|
||||
a, err := p.db.GetAccountByID(ctx, user.AccountID)
|
||||
a, err := p.state.DB.GetAccountByID(ctx, user.AccountID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorNotFound(err)
|
||||
}
|
||||
@@ -129,7 +129,7 @@ func (p *Processor) EmailConfirm(ctx context.Context, token string) (*gtsmodel.U
|
||||
user.ConfirmationToken = ""
|
||||
user.UpdatedAt = time.Now()
|
||||
|
||||
if err := p.db.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil {
|
||||
if err := p.state.DB.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
|
@@ -44,7 +44,7 @@ func (p *Processor) PasswordChange(ctx context.Context, user *gtsmodel.User, old
|
||||
|
||||
user.EncryptedPassword = string(newPasswordHash)
|
||||
|
||||
if err := p.db.UpdateUser(ctx, user, "encrypted_password"); err != nil {
|
||||
if err := p.state.DB.UpdateUser(ctx, user, "encrypted_password"); err != nil {
|
||||
return gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
|
@@ -19,19 +19,19 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
)
|
||||
|
||||
type Processor struct {
|
||||
state *state.State
|
||||
emailSender email.Sender
|
||||
db db.DB
|
||||
}
|
||||
|
||||
// New returns a new user processor
|
||||
func New(db db.DB, emailSender email.Sender) Processor {
|
||||
func New(state *state.State, emailSender email.Sender) Processor {
|
||||
return Processor{
|
||||
state: state,
|
||||
emailSender: emailSender,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/superseriousbusiness/gotosocial/internal/email"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/processing/user"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
@@ -31,6 +32,7 @@ type UserStandardTestSuite struct {
|
||||
suite.Suite
|
||||
emailSender email.Sender
|
||||
db db.DB
|
||||
state state.State
|
||||
|
||||
testUsers map[string]*gtsmodel.User
|
||||
|
||||
@@ -40,15 +42,19 @@ type UserStandardTestSuite struct {
|
||||
}
|
||||
|
||||
func (suite *UserStandardTestSuite) SetupTest() {
|
||||
suite.state.Caches.Init()
|
||||
|
||||
testrig.InitTestConfig()
|
||||
testrig.InitTestLog()
|
||||
|
||||
suite.db = testrig.NewTestDB()
|
||||
suite.db = testrig.NewTestDB(&suite.state)
|
||||
suite.state.DB = suite.db
|
||||
|
||||
suite.sentEmails = make(map[string]string)
|
||||
suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails)
|
||||
suite.testUsers = testrig.NewTestUsers()
|
||||
|
||||
suite.user = user.New(suite.db, suite.emailSender)
|
||||
suite.user = user.New(&suite.state, suite.emailSender)
|
||||
|
||||
testrig.StandardDBSetup(suite.db, nil)
|
||||
}
|
||||
|
Reference in New Issue
Block a user