[feature] Refactor tokens, allow multiple app redirect_uris (#3849)

* [feature] Refactor tokens, allow multiple app redirect_uris

* move + tweak handlers a bit

* return error for unset oauth2.ClientStore funcs

* wrap UpdateToken with cache

* panic handling

* cheeky little time optimization

* unlock on error
This commit is contained in:
tobi
2025-03-03 16:03:36 +01:00
committed by GitHub
parent c80810eae8
commit 1b37944f8b
77 changed files with 963 additions and 594 deletions

View File

@@ -30,7 +30,10 @@ import (
"codeberg.org/superseriousbusiness/oauth2/v4/server"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
const (
@@ -60,7 +63,8 @@ const (
HelpfulAdviceGrant = "If you arrived at this error during a sign in/oauth flow, your client is trying to use an unsupported OAuth grant type. Supported grant types are: authorization_code, client_credentials; please reach out to developer of your client"
)
// Server wraps some oauth2 server functions in an interface, exposing only what is needed
// Server wraps some oauth2 server functions
// in an interface, exposing only what is needed.
type Server interface {
HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode)
HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtserror.WithCode
@@ -69,66 +73,76 @@ type Server interface {
LoadAccessToken(ctx context.Context, access string) (accessToken oauth2.TokenInfo, err error)
}
// s fulfils the Server interface using the underlying oauth2 server
// s fulfils the Server interface
// using the underlying oauth2 server.
type s struct {
server *server.Server
}
// New returns a new oauth server that implements the Server interface
func New(ctx context.Context, database db.DB) Server {
ts := newTokenStore(ctx, database)
cs := NewClientStore(database)
func New(
ctx context.Context,
state *state.State,
validateURIHandler manage.ValidateURIHandler,
clientScopeHandler server.ClientScopeHandler,
authorizeScopeHandler server.AuthorizeScopeHandler,
internalErrorHandler server.InternalErrorHandler,
responseErrorHandler server.ResponseErrorHandler,
userAuthorizationHandler server.UserAuthorizationHandler,
) Server {
ts := newTokenStore(ctx, state)
cs := NewClientStore(state)
// Set up OAuth2 manager.
manager := manage.NewDefaultManager()
manager.SetValidateURIHandler(validateURIHandler)
manager.MapTokenStorage(ts)
manager.MapClientStorage(cs)
manager.SetAuthorizeCodeTokenCfg(&manage.Config{
AccessTokenExp: 0, // access tokens don't expire -- they must be revoked
IsGenerateRefresh: false, // don't use refresh tokens
})
sc := &server.Config{
TokenType: "Bearer",
// Must follow the spec.
AllowGetAccessRequest: false,
// Support only the non-implicit flow.
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
// Allow:
// - Authorization Code (for first & third parties)
// - Client Credentials (for applications)
AllowedGrantTypes: []oauth2.GrantType{
oauth2.AuthorizationCode,
oauth2.ClientCredentials,
manager.SetAuthorizeCodeTokenCfg(
&manage.Config{
// Following the Mastodon API,
// access tokens don't expire.
AccessTokenExp: 0,
// Don't use refresh tokens.
IsGenerateRefresh: false,
},
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{
oauth2.CodeChallengePlain,
oauth2.CodeChallengeS256,
)
// Set up OAuth2 server.
srv := server.NewServer(
&server.Config{
TokenType: "Bearer",
// Must follow the spec.
AllowGetAccessRequest: false,
// Support only the non-implicit flow.
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
// Allow:
// - Authorization Code (for first & third parties)
// - Client Credentials (for applications)
AllowedGrantTypes: []oauth2.GrantType{
oauth2.AuthorizationCode,
oauth2.ClientCredentials,
},
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{
oauth2.CodeChallengePlain,
oauth2.CodeChallengeS256,
},
},
}
srv := server.NewServer(sc, manager)
srv.SetInternalErrorHandler(func(err error) *oautherr.Response {
log.Errorf(nil, "internal oauth error: %s", err)
return nil
})
srv.SetResponseErrorHandler(func(re *oautherr.Response) {
log.Errorf(nil, "internal response error: %s", re.Error)
})
srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
userID := r.FormValue("userid")
if userID == "" {
return "", errors.New("userid was empty")
}
return userID, nil
})
manager,
)
srv.SetAuthorizeScopeHandler(authorizeScopeHandler)
srv.SetClientScopeHandler(clientScopeHandler)
srv.SetInternalErrorHandler(internalErrorHandler)
srv.SetResponseErrorHandler(responseErrorHandler)
srv.SetUserAuthorizationHandler(userAuthorizationHandler)
srv.SetClientInfoHandler(server.ClientFormHandler)
return &s{
server: srv,
}
return &s{srv}
}
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function,
// providing some custom error handling (with more informative messages),
// and a slightly different token serialization format.
func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserror.WithCode) {
ctx := r.Context()
@@ -142,32 +156,43 @@ func (s *s) HandleTokenRequest(r *http.Request) (map[string]interface{}, gtserro
return nil, gtserror.NewErrorBadRequest(err, help, adv)
}
// Get access token + do our own nicer error handling.
ti, err := s.server.GetAccessToken(ctx, gt, tgr)
if err != nil {
help := fmt.Sprintf("could not get access token: %s", err)
switch {
case err == nil:
// No problem.
break
case errors.Is(err, oautherr.ErrInvalidScope):
help := fmt.Sprintf("requested scope %s was not covered by client scope", tgr.Scope)
return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice)
case errors.Is(err, oautherr.ErrInvalidRedirectURI):
help := fmt.Sprintf("requested redirect URI %s was not covered by client redirect URIs", tgr.RedirectURI)
return nil, gtserror.NewErrorForbidden(err, help, HelpfulAdvice)
default:
help := fmt.Sprintf("could not get access token: %v", err)
return nil, gtserror.NewErrorBadRequest(err, help, HelpfulAdvice)
}
// Wrangle data a bit.
data := s.server.GetTokenData(ti)
// Add created_at for Mastodon API compatibility.
data["created_at"] = ti.GetAccessCreateAt().Unix()
// If expires_in is 0 or less, omit it
// from serialization so that clients don't
// interpret the token as already expired.
if expiresInI, ok := data["expires_in"]; ok {
switch expiresIn := expiresInI.(type) {
case int64:
// remove this key from the returned map
// if the value is 0 or less, so that clients
// don't interpret the token as already expired
if expiresIn <= 0 {
delete(data, "expires_in")
}
default:
err := errors.New("expires_in was set on token response, but was not an int64")
return nil, gtserror.NewErrorInternalError(err, HelpfulAdvice)
// This will panic if expiresIn is
// not an int64, which is what we want.
if expiresInI.(int64) <= 0 {
delete(data, "expires_in")
}
}
// add this for mastodon api compatibility
data["created_at"] = ti.GetAccessCreateAt().Unix()
return data, nil
}
@@ -207,7 +232,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
}
req.UserID = userID
// specify the scope of authorization
// Specify the scope of authorization.
if fn := s.server.AuthorizeScopeHandler; fn != nil {
scope, err := fn(w, r)
if err != nil {
@@ -217,7 +242,7 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
}
}
// specify the expiration time of access token
// Specify the expiration time of access token.
if fn := s.server.AccessTokenExpHandler; fn != nil {
exp, err := fn(w, r)
if err != nil {
@@ -231,13 +256,24 @@ func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) gtser
return s.errorOrRedirect(err, w, req)
}
// If the redirect URI is empty, the default domain provided by the client is used.
// If the redirect URI is empty, use the
// first of the client's redirect URIs.
if req.RedirectURI == "" {
client, err := s.server.Manager.GetClient(ctx, req.ClientID)
if err != nil {
if err != nil && !errors.Is(err, db.ErrNoEntries) {
// Real error.
err := gtserror.Newf("db error getting application with client id %s: %w", req.ClientID, err)
return gtserror.NewErrorInternalError(err)
}
if util.IsNil(client) {
// Application just not found.
return gtserror.NewErrorUnauthorized(err, HelpfulAdvice)
}
req.RedirectURI = client.GetDomain()
// This will panic if client is not a
// *gtsmodel.Application, which is what we want.
req.RedirectURI = client.(*gtsmodel.Application).RedirectURIs[0]
}
uri, err := s.server.GetRedirectURI(req, s.server.GetAuthorizeData(req.ResponseType, ti))