mirror of
				https://github.com/superseriousbusiness/gotosocial
				synced 2025-06-05 21:59:39 +02:00 
			
		
		
		
	[performance] cached oauth database types (#2838)
* update token + client code to use struct caches * add code comments * slight tweak to default mem ratios * fix envparsing * add appropriate invalidate hooks * update the tokenstore sweeping function to rely on caches * update to use PutClient() * add ClientID to list of token struct indices
This commit is contained in:
		| @@ -98,8 +98,8 @@ var Start action.GTSAction = func(ctx context.Context) error { | ||||
| 	testrig.StandardStorageSetup(state.Storage, "./testrig/media") | ||||
|  | ||||
| 	// Initialize workers. | ||||
| 	state.Workers.Start() | ||||
| 	defer state.Workers.Stop() | ||||
| 	testrig.StartNoopWorkers(&state) | ||||
| 	defer testrig.StopWorkers(&state) | ||||
|  | ||||
| 	// build backend handlers | ||||
| 	transportController := testrig.NewTestTransportController(&state, testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) { | ||||
|   | ||||
| @@ -49,7 +49,7 @@ func (m *Module) TokenPOSTHandler(c *gin.Context) { | ||||
|  | ||||
| 	form := &tokenRequestForm{} | ||||
| 	if err := c.ShouldBind(form); err != nil { | ||||
| 		apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), err.Error())) | ||||
| 		apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.ErrInvalidRequest, err.Error())) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| @@ -98,7 +98,7 @@ func (m *Module) TokenPOSTHandler(c *gin.Context) { | ||||
| 	} | ||||
|  | ||||
| 	if len(help) != 0 { | ||||
| 		apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.InvalidRequest(), help...)) | ||||
| 		apiutil.OAuthErrorHandler(c, gtserror.NewErrorBadRequest(oauth.ErrInvalidRequest, help...)) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
|   | ||||
							
								
								
									
										4
									
								
								internal/cache/cache.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								internal/cache/cache.go
									
									
									
									
										vendored
									
									
								
							| @@ -59,6 +59,7 @@ func (c *Caches) Init() { | ||||
| 	c.initBlock() | ||||
| 	c.initBlockIDs() | ||||
| 	c.initBoostOfIDs() | ||||
| 	c.initClient() | ||||
| 	c.initDomainAllow() | ||||
| 	c.initDomainBlock() | ||||
| 	c.initEmoji() | ||||
| @@ -85,9 +86,10 @@ func (c *Caches) Init() { | ||||
| 	c.initReport() | ||||
| 	c.initStatus() | ||||
| 	c.initStatusFave() | ||||
| 	c.initStatusFaveIDs() | ||||
| 	c.initTag() | ||||
| 	c.initThreadMute() | ||||
| 	c.initStatusFaveIDs() | ||||
| 	c.initToken() | ||||
| 	c.initTombstone() | ||||
| 	c.initUser() | ||||
| 	c.initWebfinger() | ||||
|   | ||||
							
								
								
									
										70
									
								
								internal/cache/db.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										70
									
								
								internal/cache/db.go
									
									
									
									
										vendored
									
									
								
							| @@ -58,6 +58,9 @@ type GTSCaches struct { | ||||
| 	// BoostOfIDs provides access to the boost of IDs list database cache. | ||||
| 	BoostOfIDs SliceCache[string] | ||||
|  | ||||
| 	// Client provides access to the gtsmodel Client database cache. | ||||
| 	Client StructCache[*gtsmodel.Client] | ||||
|  | ||||
| 	// DomainAllow provides access to the domain allow database cache. | ||||
| 	DomainAllow *domain.Cache | ||||
|  | ||||
| @@ -150,6 +153,9 @@ type GTSCaches struct { | ||||
| 	// Tag provides access to the gtsmodel Tag database cache. | ||||
| 	Tag StructCache[*gtsmodel.Tag] | ||||
|  | ||||
| 	// Token provides access to the gtsmodel Token database cache. | ||||
| 	Token StructCache[*gtsmodel.Token] | ||||
|  | ||||
| 	// Tombstone provides access to the gtsmodel Tombstone database cache. | ||||
| 	Tombstone StructCache[*gtsmodel.Tombstone] | ||||
|  | ||||
| @@ -309,9 +315,10 @@ func (c *Caches) initApplication() { | ||||
| 			{Fields: "ID"}, | ||||
| 			{Fields: "ClientID"}, | ||||
| 		}, | ||||
| 		MaxSize:   cap, | ||||
| 		IgnoreErr: ignoreErrors, | ||||
| 		Copy:      copyF, | ||||
| 		MaxSize:    cap, | ||||
| 		IgnoreErr:  ignoreErrors, | ||||
| 		Copy:       copyF, | ||||
| 		Invalidate: c.OnInvalidateApplication, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| @@ -374,6 +381,32 @@ func (c *Caches) initBoostOfIDs() { | ||||
| 	c.GTS.BoostOfIDs.Init(0, cap) | ||||
| } | ||||
|  | ||||
| func (c *Caches) initClient() { | ||||
| 	// Calculate maximum cache size. | ||||
| 	cap := calculateResultCacheMax( | ||||
| 		sizeofClient(), // model in-mem size. | ||||
| 		config.GetCacheClientMemRatio(), | ||||
| 	) | ||||
|  | ||||
| 	log.Infof(nil, "cache size = %d", cap) | ||||
|  | ||||
| 	copyF := func(c1 *gtsmodel.Client) *gtsmodel.Client { | ||||
| 		c2 := new(gtsmodel.Client) | ||||
| 		*c2 = *c1 | ||||
| 		return c2 | ||||
| 	} | ||||
|  | ||||
| 	c.GTS.Client.Init(structr.CacheConfig[*gtsmodel.Client]{ | ||||
| 		Indices: []structr.IndexConfig{ | ||||
| 			{Fields: "ID"}, | ||||
| 		}, | ||||
| 		MaxSize:    cap, | ||||
| 		IgnoreErr:  ignoreErrors, | ||||
| 		Copy:       copyF, | ||||
| 		Invalidate: c.OnInvalidateClient, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (c *Caches) initDomainAllow() { | ||||
| 	c.GTS.DomainAllow = new(domain.Cache) | ||||
| } | ||||
| @@ -1135,7 +1168,7 @@ func (c *Caches) initTag() { | ||||
|  | ||||
| func (c *Caches) initThreadMute() { | ||||
| 	cap := calculateResultCacheMax( | ||||
| 		sizeOfThreadMute(), // model in-mem size. | ||||
| 		sizeofThreadMute(), // model in-mem size. | ||||
| 		config.GetCacheThreadMuteMemRatio(), | ||||
| 	) | ||||
|  | ||||
| @@ -1160,6 +1193,35 @@ func (c *Caches) initThreadMute() { | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (c *Caches) initToken() { | ||||
| 	// Calculate maximum cache size. | ||||
| 	cap := calculateResultCacheMax( | ||||
| 		sizeofToken(), // model in-mem size. | ||||
| 		config.GetCacheTokenMemRatio(), | ||||
| 	) | ||||
|  | ||||
| 	log.Infof(nil, "cache size = %d", cap) | ||||
|  | ||||
| 	copyF := func(t1 *gtsmodel.Token) *gtsmodel.Token { | ||||
| 		t2 := new(gtsmodel.Token) | ||||
| 		*t2 = *t1 | ||||
| 		return t2 | ||||
| 	} | ||||
|  | ||||
| 	c.GTS.Token.Init(structr.CacheConfig[*gtsmodel.Token]{ | ||||
| 		Indices: []structr.IndexConfig{ | ||||
| 			{Fields: "ID"}, | ||||
| 			{Fields: "Code"}, | ||||
| 			{Fields: "Access"}, | ||||
| 			{Fields: "Refresh"}, | ||||
| 			{Fields: "ClientID", Multiple: true}, | ||||
| 		}, | ||||
| 		MaxSize:   cap, | ||||
| 		IgnoreErr: ignoreErrors, | ||||
| 		Copy:      copyF, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (c *Caches) initTombstone() { | ||||
| 	// Calculate maximum cache size. | ||||
| 	cap := calculateResultCacheMax( | ||||
|   | ||||
							
								
								
									
										10
									
								
								internal/cache/invalidate.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								internal/cache/invalidate.go
									
									
									
									
										vendored
									
									
								
							| @@ -60,6 +60,11 @@ func (c *Caches) OnInvalidateAccount(account *gtsmodel.Account) { | ||||
| 	c.GTS.Move.Invalidate("TargetURI", account.URI) | ||||
| } | ||||
|  | ||||
| func (c *Caches) OnInvalidateApplication(app *gtsmodel.Application) { | ||||
| 	// Invalidate cached client of this application. | ||||
| 	c.GTS.Client.Invalidate("ID", app.ClientID) | ||||
| } | ||||
|  | ||||
| func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) { | ||||
| 	// Invalidate block origin account ID cached visibility. | ||||
| 	c.Visibility.Invalidate("ItemID", block.AccountID) | ||||
| @@ -73,6 +78,11 @@ func (c *Caches) OnInvalidateBlock(block *gtsmodel.Block) { | ||||
| 	c.GTS.BlockIDs.Invalidate(block.AccountID) | ||||
| } | ||||
|  | ||||
| func (c *Caches) OnInvalidateClient(client *gtsmodel.Client) { | ||||
| 	// Invalidate any tokens under this client. | ||||
| 	c.GTS.Token.Invalidate("ClientID", client.ID) | ||||
| } | ||||
|  | ||||
| func (c *Caches) OnInvalidateEmojiCategory(category *gtsmodel.EmojiCategory) { | ||||
| 	// Invalidate any emoji in this category. | ||||
| 	c.GTS.Emoji.Invalidate("CategoryID", category.ID) | ||||
|   | ||||
							
								
								
									
										38
									
								
								internal/cache/size.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										38
									
								
								internal/cache/size.go
									
									
									
									
										vendored
									
									
								
							| @@ -176,6 +176,7 @@ func totalOfRatios() float64 { | ||||
| 		config.GetCacheBlockMemRatio() + | ||||
| 		config.GetCacheBlockIDsMemRatio() + | ||||
| 		config.GetCacheBoostOfIDsMemRatio() + | ||||
| 		config.GetCacheClientMemRatio() + | ||||
| 		config.GetCacheEmojiMemRatio() + | ||||
| 		config.GetCacheEmojiCategoryMemRatio() + | ||||
| 		config.GetCacheFollowMemRatio() + | ||||
| @@ -198,6 +199,7 @@ func totalOfRatios() float64 { | ||||
| 		config.GetCacheStatusFaveIDsMemRatio() + | ||||
| 		config.GetCacheTagMemRatio() + | ||||
| 		config.GetCacheThreadMuteMemRatio() + | ||||
| 		config.GetCacheTokenMemRatio() + | ||||
| 		config.GetCacheTombstoneMemRatio() + | ||||
| 		config.GetCacheUserMemRatio() + | ||||
| 		config.GetCacheWebfingerMemRatio() + | ||||
| @@ -287,6 +289,17 @@ func sizeofBlock() uintptr { | ||||
| 	})) | ||||
| } | ||||
|  | ||||
| func sizeofClient() uintptr { | ||||
| 	return uintptr(size.Of(>smodel.Client{ | ||||
| 		ID:        exampleID, | ||||
| 		CreatedAt: exampleTime, | ||||
| 		UpdatedAt: exampleTime, | ||||
| 		Secret:    exampleID, | ||||
| 		Domain:    exampleURI, | ||||
| 		UserID:    exampleID, | ||||
| 	})) | ||||
| } | ||||
|  | ||||
| func sizeofEmoji() uintptr { | ||||
| 	return uintptr(size.Of(>smodel.Emoji{ | ||||
| 		ID:                     exampleID, | ||||
| @@ -591,7 +604,7 @@ func sizeofTag() uintptr { | ||||
| 	})) | ||||
| } | ||||
|  | ||||
| func sizeOfThreadMute() uintptr { | ||||
| func sizeofThreadMute() uintptr { | ||||
| 	return uintptr(size.Of(>smodel.ThreadMute{ | ||||
| 		ID:        exampleID, | ||||
| 		CreatedAt: exampleTime, | ||||
| @@ -601,6 +614,29 @@ func sizeOfThreadMute() uintptr { | ||||
| 	})) | ||||
| } | ||||
|  | ||||
| func sizeofToken() uintptr { | ||||
| 	return uintptr(size.Of(>smodel.Token{ | ||||
| 		ID:                  exampleID, | ||||
| 		CreatedAt:           exampleTime, | ||||
| 		UpdatedAt:           exampleTime, | ||||
| 		ClientID:            exampleID, | ||||
| 		UserID:              exampleID, | ||||
| 		RedirectURI:         exampleURI, | ||||
| 		Scope:               "r:w", | ||||
| 		Code:                "", // TODO | ||||
| 		CodeChallenge:       "", // TODO | ||||
| 		CodeChallengeMethod: "", // TODO | ||||
| 		CodeCreateAt:        exampleTime, | ||||
| 		CodeExpiresAt:       exampleTime, | ||||
| 		Access:              exampleID + exampleID, | ||||
| 		AccessCreateAt:      exampleTime, | ||||
| 		AccessExpiresAt:     exampleTime, | ||||
| 		Refresh:             "", // TODO: clients don't really support this very well yet | ||||
| 		RefreshCreateAt:     exampleTime, | ||||
| 		RefreshExpiresAt:    exampleTime, | ||||
| 	})) | ||||
| } | ||||
|  | ||||
| func sizeofTombstone() uintptr { | ||||
| 	return uintptr(size.Of(>smodel.Tombstone{ | ||||
| 		ID:        exampleID, | ||||
|   | ||||
| @@ -199,6 +199,7 @@ type CacheConfiguration struct { | ||||
| 	BlockMemRatio            float64       `name:"block-mem-ratio"` | ||||
| 	BlockIDsMemRatio         float64       `name:"block-mem-ratio"` | ||||
| 	BoostOfIDsMemRatio       float64       `name:"boost-of-ids-mem-ratio"` | ||||
| 	ClientMemRatio           float64       `name:"client-mem-ratio"` | ||||
| 	EmojiMemRatio            float64       `name:"emoji-mem-ratio"` | ||||
| 	EmojiCategoryMemRatio    float64       `name:"emoji-category-mem-ratio"` | ||||
| 	FilterMemRatio           float64       `name:"filter-mem-ratio"` | ||||
| @@ -226,6 +227,7 @@ type CacheConfiguration struct { | ||||
| 	StatusFaveIDsMemRatio    float64       `name:"status-fave-ids-mem-ratio"` | ||||
| 	TagMemRatio              float64       `name:"tag-mem-ratio"` | ||||
| 	ThreadMuteMemRatio       float64       `name:"thread-mute-mem-ratio"` | ||||
| 	TokenMemRatio            float64       `name:"token-mem-ratio"` | ||||
| 	TombstoneMemRatio        float64       `name:"tombstone-mem-ratio"` | ||||
| 	UserMemRatio             float64       `name:"user-mem-ratio"` | ||||
| 	WebfingerMemRatio        float64       `name:"webfinger-mem-ratio"` | ||||
|   | ||||
| @@ -163,6 +163,7 @@ var Defaults = Configuration{ | ||||
| 		BlockMemRatio:            2, | ||||
| 		BlockIDsMemRatio:         3, | ||||
| 		BoostOfIDsMemRatio:       3, | ||||
| 		ClientMemRatio:           0.1, | ||||
| 		EmojiMemRatio:            3, | ||||
| 		EmojiCategoryMemRatio:    0.1, | ||||
| 		FilterMemRatio:           0.5, | ||||
| @@ -190,6 +191,7 @@ var Defaults = Configuration{ | ||||
| 		StatusFaveIDsMemRatio:    3, | ||||
| 		TagMemRatio:              2, | ||||
| 		ThreadMuteMemRatio:       0.2, | ||||
| 		TokenMemRatio:            0.75, | ||||
| 		TombstoneMemRatio:        0.5, | ||||
| 		UserMemRatio:             0.25, | ||||
| 		WebfingerMemRatio:        0.1, | ||||
|   | ||||
| @@ -2925,6 +2925,31 @@ func GetCacheBoostOfIDsMemRatio() float64 { return global.GetCacheBoostOfIDsMemR | ||||
| // SetCacheBoostOfIDsMemRatio safely sets the value for global configuration 'Cache.BoostOfIDsMemRatio' field | ||||
| func SetCacheBoostOfIDsMemRatio(v float64) { global.SetCacheBoostOfIDsMemRatio(v) } | ||||
|  | ||||
| // GetCacheClientMemRatio safely fetches the Configuration value for state's 'Cache.ClientMemRatio' field | ||||
| func (st *ConfigState) GetCacheClientMemRatio() (v float64) { | ||||
| 	st.mutex.RLock() | ||||
| 	v = st.config.Cache.ClientMemRatio | ||||
| 	st.mutex.RUnlock() | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // SetCacheClientMemRatio safely sets the Configuration value for state's 'Cache.ClientMemRatio' field | ||||
| func (st *ConfigState) SetCacheClientMemRatio(v float64) { | ||||
| 	st.mutex.Lock() | ||||
| 	defer st.mutex.Unlock() | ||||
| 	st.config.Cache.ClientMemRatio = v | ||||
| 	st.reloadToViper() | ||||
| } | ||||
|  | ||||
| // CacheClientMemRatioFlag returns the flag name for the 'Cache.ClientMemRatio' field | ||||
| func CacheClientMemRatioFlag() string { return "cache-client-mem-ratio" } | ||||
|  | ||||
| // GetCacheClientMemRatio safely fetches the value for global configuration 'Cache.ClientMemRatio' field | ||||
| func GetCacheClientMemRatio() float64 { return global.GetCacheClientMemRatio() } | ||||
|  | ||||
| // SetCacheClientMemRatio safely sets the value for global configuration 'Cache.ClientMemRatio' field | ||||
| func SetCacheClientMemRatio(v float64) { global.SetCacheClientMemRatio(v) } | ||||
|  | ||||
| // GetCacheEmojiMemRatio safely fetches the Configuration value for state's 'Cache.EmojiMemRatio' field | ||||
| func (st *ConfigState) GetCacheEmojiMemRatio() (v float64) { | ||||
| 	st.mutex.RLock() | ||||
| @@ -3600,6 +3625,31 @@ func GetCacheThreadMuteMemRatio() float64 { return global.GetCacheThreadMuteMemR | ||||
| // SetCacheThreadMuteMemRatio safely sets the value for global configuration 'Cache.ThreadMuteMemRatio' field | ||||
| func SetCacheThreadMuteMemRatio(v float64) { global.SetCacheThreadMuteMemRatio(v) } | ||||
|  | ||||
| // GetCacheTokenMemRatio safely fetches the Configuration value for state's 'Cache.TokenMemRatio' field | ||||
| func (st *ConfigState) GetCacheTokenMemRatio() (v float64) { | ||||
| 	st.mutex.RLock() | ||||
| 	v = st.config.Cache.TokenMemRatio | ||||
| 	st.mutex.RUnlock() | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // SetCacheTokenMemRatio safely sets the Configuration value for state's 'Cache.TokenMemRatio' field | ||||
| func (st *ConfigState) SetCacheTokenMemRatio(v float64) { | ||||
| 	st.mutex.Lock() | ||||
| 	defer st.mutex.Unlock() | ||||
| 	st.config.Cache.TokenMemRatio = v | ||||
| 	st.reloadToViper() | ||||
| } | ||||
|  | ||||
| // CacheTokenMemRatioFlag returns the flag name for the 'Cache.TokenMemRatio' field | ||||
| func CacheTokenMemRatioFlag() string { return "cache-token-mem-ratio" } | ||||
|  | ||||
| // GetCacheTokenMemRatio safely fetches the value for global configuration 'Cache.TokenMemRatio' field | ||||
| func GetCacheTokenMemRatio() float64 { return global.GetCacheTokenMemRatio() } | ||||
|  | ||||
| // SetCacheTokenMemRatio safely sets the value for global configuration 'Cache.TokenMemRatio' field | ||||
| func SetCacheTokenMemRatio(v float64) { global.SetCacheTokenMemRatio(v) } | ||||
|  | ||||
| // GetCacheTombstoneMemRatio safely fetches the Configuration value for state's 'Cache.TombstoneMemRatio' field | ||||
| func (st *ConfigState) GetCacheTombstoneMemRatio() (v float64) { | ||||
| 	st.mutex.RLock() | ||||
|   | ||||
| @@ -35,4 +35,40 @@ type Application interface { | ||||
|  | ||||
| 	// DeleteApplicationByClientID deletes the application with corresponding client_id value from the database. | ||||
| 	DeleteApplicationByClientID(ctx context.Context, clientID string) error | ||||
|  | ||||
| 	// GetClientByID ... | ||||
| 	GetClientByID(ctx context.Context, id string) (*gtsmodel.Client, error) | ||||
|  | ||||
| 	// PutClient ... | ||||
| 	PutClient(ctx context.Context, client *gtsmodel.Client) error | ||||
|  | ||||
| 	// DeleteClientByID ... | ||||
| 	DeleteClientByID(ctx context.Context, id string) error | ||||
|  | ||||
| 	// GetAllTokens ... | ||||
| 	GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) | ||||
|  | ||||
| 	// GetTokenByCode ... | ||||
| 	GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) | ||||
|  | ||||
| 	// GetTokenByAccess ... | ||||
| 	GetTokenByAccess(ctx context.Context, access string) (*gtsmodel.Token, error) | ||||
|  | ||||
| 	// GetTokenByRefresh ... | ||||
| 	GetTokenByRefresh(ctx context.Context, refresh string) (*gtsmodel.Token, error) | ||||
|  | ||||
| 	// PutToken ... | ||||
| 	PutToken(ctx context.Context, token *gtsmodel.Token) error | ||||
|  | ||||
| 	// DeleteTokenByID ... | ||||
| 	DeleteTokenByID(ctx context.Context, id string) error | ||||
|  | ||||
| 	// DeleteTokenByCode ... | ||||
| 	DeleteTokenByCode(ctx context.Context, code string) error | ||||
|  | ||||
| 	// DeleteTokenByAccess ... | ||||
| 	DeleteTokenByAccess(ctx context.Context, access string) error | ||||
|  | ||||
| 	// DeleteTokenByRefresh ... | ||||
| 	DeleteTokenByRefresh(ctx context.Context, refresh string) error | ||||
| } | ||||
|   | ||||
| @@ -397,7 +397,7 @@ func (a *adminDB) CreateInstanceApplication(ctx context.Context) error { | ||||
| 	} | ||||
|  | ||||
| 	// Store it. | ||||
| 	return a.state.DB.Put(ctx, oc) | ||||
| 	return a.state.DB.PutClient(ctx, oc) | ||||
| } | ||||
|  | ||||
| func (a *adminDB) GetInstanceApplication(ctx context.Context) (*gtsmodel.Application, error) { | ||||
|   | ||||
| @@ -22,6 +22,7 @@ import ( | ||||
|  | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/state" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/util" | ||||
| 	"github.com/uptrace/bun" | ||||
| ) | ||||
|  | ||||
| @@ -95,3 +96,181 @@ func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientI | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *applicationDB) GetClientByID(ctx context.Context, id string) (*gtsmodel.Client, error) { | ||||
| 	return a.state.Caches.GTS.Client.LoadOne("ID", func() (*gtsmodel.Client, error) { | ||||
| 		var client gtsmodel.Client | ||||
|  | ||||
| 		if err := a.db.NewSelect(). | ||||
| 			Model(&client). | ||||
| 			Where("? = ?", bun.Ident("id"), id). | ||||
| 			Scan(ctx); err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		return &client, nil | ||||
| 	}, id) | ||||
| } | ||||
|  | ||||
| func (a *applicationDB) PutClient(ctx context.Context, client *gtsmodel.Client) error { | ||||
| 	return a.state.Caches.GTS.Client.Store(client, func() error { | ||||
| 		_, err := a.db.NewInsert().Model(client).Exec(ctx) | ||||
| 		return err | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (a *applicationDB) DeleteClientByID(ctx context.Context, id string) error { | ||||
| 	_, err := a.db.NewDelete(). | ||||
| 		Table("clients"). | ||||
| 		Where("? = ?", bun.Ident("id"), id). | ||||
| 		Exec(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	a.state.Caches.GTS.Client.Invalidate("ID", id) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *applicationDB) GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) { | ||||
| 	var tokenIDs []string | ||||
|  | ||||
| 	// Select ALL token IDs. | ||||
| 	if err := a.db.NewSelect(). | ||||
| 		Table("tokens"). | ||||
| 		Column("id"). | ||||
| 		Scan(ctx, &tokenIDs); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// Load all input token IDs via cache loader callback. | ||||
| 	tokens, err := a.state.Caches.GTS.Token.LoadIDs("ID", | ||||
| 		tokenIDs, | ||||
| 		func(uncached []string) ([]*gtsmodel.Token, error) { | ||||
| 			// Preallocate expected length of uncached tokens. | ||||
| 			tokens := make([]*gtsmodel.Token, 0, len(uncached)) | ||||
|  | ||||
| 			// Perform database query scanning | ||||
| 			// the remaining (uncached) token IDs. | ||||
| 			if err := a.db.NewSelect(). | ||||
| 				Model(tokens). | ||||
| 				Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). | ||||
| 				Scan(ctx); err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
|  | ||||
| 			return tokens, nil | ||||
| 		}, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// Reoroder the tokens by their | ||||
| 	// IDs to ensure in correct order. | ||||
| 	getID := func(t *gtsmodel.Token) string { return t.ID } | ||||
| 	util.OrderBy(tokens, tokenIDs, getID) | ||||
|  | ||||
| 	return tokens, nil | ||||
| } | ||||
|  | ||||
| func (a *applicationDB) GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) { | ||||
| 	return a.getTokenBy( | ||||
| 		"Code", | ||||
| 		func(t *gtsmodel.Token) error { | ||||
| 			return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("code"), code).Scan(ctx) | ||||
| 		}, | ||||
| 		code, | ||||
| 	) | ||||
| } | ||||
|  | ||||
| func (a *applicationDB) GetTokenByAccess(ctx context.Context, access string) (*gtsmodel.Token, error) { | ||||
| 	return a.getTokenBy( | ||||
| 		"Access", | ||||
| 		func(t *gtsmodel.Token) error { | ||||
| 			return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("access"), access).Scan(ctx) | ||||
| 		}, | ||||
| 		access, | ||||
| 	) | ||||
| } | ||||
|  | ||||
| func (a *applicationDB) GetTokenByRefresh(ctx context.Context, refresh string) (*gtsmodel.Token, error) { | ||||
| 	return a.getTokenBy( | ||||
| 		"Refresh", | ||||
| 		func(t *gtsmodel.Token) error { | ||||
| 			return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("refresh"), refresh).Scan(ctx) | ||||
| 		}, | ||||
| 		refresh, | ||||
| 	) | ||||
| } | ||||
|  | ||||
| func (a *applicationDB) getTokenBy(lookup string, dbQuery func(*gtsmodel.Token) error, keyParts ...any) (*gtsmodel.Token, error) { | ||||
| 	return a.state.Caches.GTS.Token.LoadOne(lookup, func() (*gtsmodel.Token, error) { | ||||
| 		var token gtsmodel.Token | ||||
|  | ||||
| 		if err := dbQuery(&token); err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		return &token, nil | ||||
| 	}, keyParts...) | ||||
| } | ||||
|  | ||||
| func (a *applicationDB) PutToken(ctx context.Context, token *gtsmodel.Token) error { | ||||
| 	return a.state.Caches.GTS.Token.Store(token, func() error { | ||||
| 		_, err := a.db.NewInsert().Model(token).Exec(ctx) | ||||
| 		return err | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (a *applicationDB) DeleteTokenByID(ctx context.Context, id string) error { | ||||
| 	_, err := a.db.NewDelete(). | ||||
| 		Table("tokens"). | ||||
| 		Where("? = ?", bun.Ident("id"), id). | ||||
| 		Exec(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	a.state.Caches.GTS.Token.Invalidate("ID", id) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *applicationDB) DeleteTokenByCode(ctx context.Context, code string) error { | ||||
| 	_, err := a.db.NewDelete(). | ||||
| 		Table("tokens"). | ||||
| 		Where("? = ?", bun.Ident("code"), code). | ||||
| 		Exec(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	a.state.Caches.GTS.Token.Invalidate("Code", code) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *applicationDB) DeleteTokenByAccess(ctx context.Context, access string) error { | ||||
| 	_, err := a.db.NewDelete(). | ||||
| 		Table("tokens"). | ||||
| 		Where("? = ?", bun.Ident("access"), access). | ||||
| 		Exec(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	a.state.Caches.GTS.Token.Invalidate("Access", access) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *applicationDB) DeleteTokenByRefresh(ctx context.Context, refresh string) error { | ||||
| 	_, err := a.db.NewDelete(). | ||||
| 		Table("tokens"). | ||||
| 		Where("? = ?", bun.Ident("refresh"), refresh). | ||||
| 		Exec(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	a.state.Caches.GTS.Token.Invalidate("Refresh", refresh) | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -27,11 +27,11 @@ import ( | ||||
| ) | ||||
|  | ||||
| type clientStore struct { | ||||
| 	db db.Basic | ||||
| 	db db.DB | ||||
| } | ||||
|  | ||||
| // NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend. | ||||
| func NewClientStore(db db.Basic) oauth2.ClientStore { | ||||
| func NewClientStore(db db.DB) oauth2.ClientStore { | ||||
| 	pts := &clientStore{ | ||||
| 		db: db, | ||||
| 	} | ||||
| @@ -39,26 +39,27 @@ func NewClientStore(db db.Basic) oauth2.ClientStore { | ||||
| } | ||||
|  | ||||
| func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) { | ||||
| 	poc := >smodel.Client{} | ||||
| 	if err := cs.db.GetByID(ctx, clientID, poc); err != nil { | ||||
| 	client, err := cs.db.GetClientByID(ctx, clientID) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil | ||||
| 	return models.New( | ||||
| 		client.ID, | ||||
| 		client.Secret, | ||||
| 		client.Domain, | ||||
| 		client.UserID, | ||||
| 	), nil | ||||
| } | ||||
|  | ||||
| func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error { | ||||
| 	poc := >smodel.Client{ | ||||
| 	return cs.db.PutClient(ctx, >smodel.Client{ | ||||
| 		ID:     cli.GetID(), | ||||
| 		Secret: cli.GetSecret(), | ||||
| 		Domain: cli.GetDomain(), | ||||
| 		UserID: cli.GetUserID(), | ||||
| 	} | ||||
| 	return cs.db.Put(ctx, poc) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (cs *clientStore) Delete(ctx context.Context, id string) error { | ||||
| 	poc := >smodel.Client{ | ||||
| 		ID: id, | ||||
| 	} | ||||
| 	return cs.db.DeleteByID(ctx, id, poc) | ||||
| 	return cs.db.DeleteClientByID(ctx, id) | ||||
| } | ||||
|   | ||||
| @@ -19,7 +19,5 @@ package oauth | ||||
|  | ||||
| import "github.com/superseriousbusiness/oauth2/v4/errors" | ||||
|  | ||||
| // InvalidRequest returns an oauth spec compliant 'invalid_request' error. | ||||
| func InvalidRequest() error { | ||||
| 	return errors.New("invalid_request") | ||||
| } | ||||
| // ErrInvalidRequest is an oauth spec compliant 'invalid_request' error. | ||||
| var ErrInvalidRequest = errors.New("invalid_request") | ||||
|   | ||||
| @@ -75,7 +75,7 @@ type s struct { | ||||
| } | ||||
|  | ||||
| // New returns a new oauth server that implements the Server interface | ||||
| func New(ctx context.Context, database db.Basic) Server { | ||||
| func New(ctx context.Context, database db.DB) Server { | ||||
| 	ts := newTokenStore(ctx, database) | ||||
| 	cs := NewClientStore(database) | ||||
|  | ||||
|   | ||||
| @@ -20,7 +20,6 @@ package oauth | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/db" | ||||
| @@ -34,14 +33,14 @@ import ( | ||||
| // tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend. | ||||
| type tokenStore struct { | ||||
| 	oauth2.TokenStore | ||||
| 	db db.Basic | ||||
| 	db db.DB | ||||
| } | ||||
|  | ||||
| // newTokenStore returns a token store that satisfies the oauth2.TokenStore interface. | ||||
| // | ||||
| // In order to allow tokens to 'expire', it will also set off a goroutine that iterates through | ||||
| // the tokens in the DB once per minute and deletes any that have expired. | ||||
| func newTokenStore(ctx context.Context, db db.Basic) oauth2.TokenStore { | ||||
| func newTokenStore(ctx context.Context, db db.DB) oauth2.TokenStore { | ||||
| 	ts := &tokenStore{ | ||||
| 		db: db, | ||||
| 	} | ||||
| @@ -69,19 +68,19 @@ func newTokenStore(ctx context.Context, db db.Basic) oauth2.TokenStore { | ||||
| func (ts *tokenStore) sweep(ctx context.Context) error { | ||||
| 	// select *all* tokens from the db | ||||
| 	// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way. | ||||
| 	tokens := new([]*gtsmodel.Token) | ||||
| 	if err := ts.db.GetAll(ctx, tokens); err != nil { | ||||
| 	tokens, err := ts.db.GetAllTokens(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// iterate through and remove expired tokens | ||||
| 	now := time.Now() | ||||
| 	for _, dbt := range *tokens { | ||||
| 	for _, dbt := range tokens { | ||||
| 		// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So: | ||||
| 		// we only want to check if a token expired before now if the expiry time is *not zero*; | ||||
| 		// ie., if it's been explicity set. | ||||
| 		if !dbt.CodeExpiresAt.IsZero() && dbt.CodeExpiresAt.Before(now) || !dbt.RefreshExpiresAt.IsZero() && dbt.RefreshExpiresAt.Before(now) || !dbt.AccessExpiresAt.IsZero() && dbt.AccessExpiresAt.Before(now) { | ||||
| 			if err := ts.db.DeleteByID(ctx, dbt.ID, dbt); err != nil { | ||||
| 			if err := ts.db.DeleteTokenByID(ctx, dbt.ID); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| @@ -107,67 +106,49 @@ func (ts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { | ||||
| 		dbt.ID = dbtID | ||||
| 	} | ||||
|  | ||||
| 	if err := ts.db.Put(ctx, dbt); err != nil { | ||||
| 		return fmt.Errorf("error in tokenstore create: %s", err) | ||||
| 	} | ||||
| 	return nil | ||||
| 	return ts.db.PutToken(ctx, dbt) | ||||
| } | ||||
|  | ||||
| // RemoveByCode deletes a token from the DB based on the Code field | ||||
| func (ts *tokenStore) RemoveByCode(ctx context.Context, code string) error { | ||||
| 	return ts.db.DeleteWhere(ctx, []db.Where{{Key: "code", Value: code}}, >smodel.Token{}) | ||||
| 	return ts.db.DeleteTokenByCode(ctx, code) | ||||
| } | ||||
|  | ||||
| // RemoveByAccess deletes a token from the DB based on the Access field | ||||
| func (ts *tokenStore) RemoveByAccess(ctx context.Context, access string) error { | ||||
| 	return ts.db.DeleteWhere(ctx, []db.Where{{Key: "access", Value: access}}, >smodel.Token{}) | ||||
| 	return ts.db.DeleteTokenByAccess(ctx, access) | ||||
| } | ||||
|  | ||||
| // RemoveByRefresh deletes a token from the DB based on the Refresh field | ||||
| func (ts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error { | ||||
| 	return ts.db.DeleteWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, >smodel.Token{}) | ||||
| 	return ts.db.DeleteTokenByRefresh(ctx, refresh) | ||||
| } | ||||
|  | ||||
| // GetByCode selects a token from the DB based on the Code field | ||||
| func (ts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { | ||||
| 	if code == "" { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	dbt := >smodel.Token{ | ||||
| 		Code: code, | ||||
| 	} | ||||
| 	if err := ts.db.GetWhere(ctx, []db.Where{{Key: "code", Value: code}}, dbt); err != nil { | ||||
| 	token, err := ts.db.GetTokenByCode(ctx, code) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return DBTokenToToken(dbt), nil | ||||
| 	return DBTokenToToken(token), nil | ||||
| } | ||||
|  | ||||
| // GetByAccess selects a token from the DB based on the Access field | ||||
| func (ts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { | ||||
| 	if access == "" { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	dbt := >smodel.Token{ | ||||
| 		Access: access, | ||||
| 	} | ||||
| 	if err := ts.db.GetWhere(ctx, []db.Where{{Key: "access", Value: access}}, dbt); err != nil { | ||||
| 	token, err := ts.db.GetTokenByAccess(ctx, access) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return DBTokenToToken(dbt), nil | ||||
| 	return DBTokenToToken(token), nil | ||||
| } | ||||
|  | ||||
| // GetByRefresh selects a token from the DB based on the Refresh field | ||||
| func (ts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { | ||||
| 	if refresh == "" { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	dbt := >smodel.Token{ | ||||
| 		Refresh: refresh, | ||||
| 	} | ||||
| 	if err := ts.db.GetWhere(ctx, []db.Where{{Key: "refresh", Value: refresh}}, dbt); err != nil { | ||||
| 	token, err := ts.db.GetTokenByRefresh(ctx, refresh) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return DBTokenToToken(dbt), nil | ||||
| 	return DBTokenToToken(token), nil | ||||
| } | ||||
|  | ||||
| /* | ||||
|   | ||||
| @@ -75,7 +75,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api | ||||
| 	} | ||||
|  | ||||
| 	// chuck it in the db | ||||
| 	if err := p.state.DB.Put(ctx, oc); err != nil { | ||||
| 	if err := p.state.DB.PutClient(ctx, oc); err != nil { | ||||
| 		return nil, gtserror.NewErrorInternalError(err) | ||||
| 	} | ||||
|  | ||||
|   | ||||
| @@ -29,6 +29,7 @@ EXPECT=$(cat << "EOF" | ||||
|         "application-mem-ratio": 0.1, | ||||
|         "block-mem-ratio": 3, | ||||
|         "boost-of-ids-mem-ratio": 3, | ||||
|         "client-mem-ratio": 0.1, | ||||
|         "emoji-category-mem-ratio": 0.1, | ||||
|         "emoji-mem-ratio": 3, | ||||
|         "filter-keyword-mem-ratio": 0.5, | ||||
| @@ -57,6 +58,7 @@ EXPECT=$(cat << "EOF" | ||||
|         "status-mem-ratio": 5, | ||||
|         "tag-mem-ratio": 2, | ||||
|         "thread-mute-mem-ratio": 0.2, | ||||
|         "token-mem-ratio": 0.75, | ||||
|         "tombstone-mem-ratio": 0.5, | ||||
|         "user-mem-ratio": 0.25, | ||||
|         "visibility-mem-ratio": 2, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user