[chore] migrate oauth2 -> codeberg (#3857)

This commit is contained in:
tobi
2025-03-02 16:42:51 +01:00
committed by GitHub
parent 49c12636c6
commit 8488ac9286
65 changed files with 1677 additions and 1221 deletions

View File

@@ -0,0 +1,33 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof
coverage.txt
# OSX
*.DS_Store
*.db
*.swp
/example/client/client
/example/server/server

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2016 Lyric
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,3 @@
# Golang OAuth 2.0 Server
Forked from [go-oauth2](https://github.com/go-oauth2/oauth2).

View File

@@ -0,0 +1,76 @@
package oauth2
import (
"crypto/sha256"
"encoding/base64"
"strings"
)
// ResponseType the type of authorization request
type ResponseType string
// define the type of authorization request
const (
Code ResponseType = "code"
Token ResponseType = "token"
)
func (rt ResponseType) String() string {
return string(rt)
}
// GrantType authorization model
type GrantType string
// define authorization model
const (
AuthorizationCode GrantType = "authorization_code"
PasswordCredentials GrantType = "password"
ClientCredentials GrantType = "client_credentials"
Refreshing GrantType = "refresh_token"
Implicit GrantType = "__implicit"
)
func (gt GrantType) String() string {
if gt == AuthorizationCode ||
gt == PasswordCredentials ||
gt == ClientCredentials ||
gt == Refreshing {
return string(gt)
}
return ""
}
// CodeChallengeMethod PCKE method
type CodeChallengeMethod string
const (
// CodeChallengePlain PCKE Method
CodeChallengePlain CodeChallengeMethod = "plain"
// CodeChallengeS256 PCKE Method
CodeChallengeS256 CodeChallengeMethod = "S256"
)
func (ccm CodeChallengeMethod) String() string {
if ccm == CodeChallengePlain ||
ccm == CodeChallengeS256 {
return string(ccm)
}
return ""
}
// Validate code challenge
func (ccm CodeChallengeMethod) Validate(cc, ver string) bool {
switch ccm {
case CodeChallengePlain:
return cc == ver
case CodeChallengeS256:
s256 := sha256.Sum256([]byte(ver))
// trim padding
a := strings.TrimRight(base64.URLEncoding.EncodeToString(s256[:]), "=")
b := strings.TrimRight(cc, "=")
return a == b
default:
return false
}
}

View File

@@ -0,0 +1,24 @@
// OAuth 2.0 server library for the Go programming language
//
// package main
// import (
// "net/http"
// "codeberg.org/superseriousbusiness/oauth2/v4/manage"
// "codeberg.org/superseriousbusiness/oauth2/v4/server"
// "codeberg.org/superseriousbusiness/oauth2/v4/store"
// )
// func main() {
// manager := manage.NewDefaultManager()
// manager.MustTokenStorage(store.NewMemoryTokenStore())
// manager.MapClientStorage(store.NewTestClientStore())
// srv := server.NewDefaultServer(manager)
// http.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) {
// srv.HandleAuthorizeRequest(w, r)
// })
// http.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
// srv.HandleTokenRequest(w, r)
// })
// http.ListenAndServe(":9096", nil)
// }
package oauth2

View File

@@ -0,0 +1,19 @@
package errors
import "errors"
// New returns an error that formats as the given text.
var New = errors.New
// known errors
var (
ErrInvalidRedirectURI = errors.New("invalid redirect uri")
ErrInvalidAuthorizeCode = errors.New("invalid authorize code")
ErrInvalidAccessToken = errors.New("invalid access token")
ErrInvalidRefreshToken = errors.New("invalid refresh token")
ErrExpiredAccessToken = errors.New("expired access token")
ErrExpiredRefreshToken = errors.New("expired refresh token")
ErrMissingCodeVerifier = errors.New("missing code verifier")
ErrMissingCodeChallenge = errors.New("missing code challenge")
ErrInvalidCodeChallenge = errors.New("invalid code challenge")
)

View File

@@ -0,0 +1,84 @@
package errors
import (
"errors"
"net/http"
)
// Response error response
type Response struct {
Error error
ErrorCode int
Description string
URI string
StatusCode int
Header http.Header
}
// NewResponse create the response pointer
func NewResponse(err error, statusCode int) *Response {
return &Response{
Error: err,
StatusCode: statusCode,
}
}
// SetHeader sets the header entries associated with key to
// the single element value.
func (r *Response) SetHeader(key, value string) {
if r.Header == nil {
r.Header = make(http.Header)
}
r.Header.Set(key, value)
}
// https://tools.ietf.org/html/rfc6749#section-5.2
var (
ErrInvalidRequest = errors.New("invalid_request")
ErrUnauthorizedClient = errors.New("unauthorized_client")
ErrAccessDenied = errors.New("access_denied")
ErrUnsupportedResponseType = errors.New("unsupported_response_type")
ErrInvalidScope = errors.New("invalid_scope")
ErrServerError = errors.New("server_error")
ErrTemporarilyUnavailable = errors.New("temporarily_unavailable")
ErrInvalidClient = errors.New("invalid_client")
ErrInvalidGrant = errors.New("invalid_grant")
ErrUnsupportedGrantType = errors.New("unsupported_grant_type")
ErrCodeChallengeRquired = errors.New("invalid_request")
ErrUnsupportedCodeChallengeMethod = errors.New("invalid_request")
ErrInvalidCodeChallengeLen = errors.New("invalid_request")
)
// Descriptions error description
var Descriptions = map[error]string{
ErrInvalidRequest: "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed",
ErrUnauthorizedClient: "The client is not authorized to request an authorization code using this method",
ErrAccessDenied: "The resource owner or authorization server denied the request",
ErrUnsupportedResponseType: "The authorization server does not support obtaining an authorization code using this method",
ErrInvalidScope: "The requested scope is invalid, unknown, or malformed",
ErrServerError: "The authorization server encountered an unexpected condition that prevented it from fulfilling the request",
ErrTemporarilyUnavailable: "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server",
ErrInvalidClient: "Client authentication failed",
ErrInvalidGrant: "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client",
ErrUnsupportedGrantType: "The authorization grant type is not supported by the authorization server",
ErrCodeChallengeRquired: "PKCE is required. code_challenge is missing",
ErrUnsupportedCodeChallengeMethod: "Selected code_challenge_method not supported",
ErrInvalidCodeChallengeLen: "Code challenge length must be between 43 and 128 charachters long",
}
// StatusCodes response error HTTP status code
var StatusCodes = map[error]int{
ErrInvalidRequest: 400,
ErrUnauthorizedClient: 401,
ErrAccessDenied: 403,
ErrUnsupportedResponseType: 401,
ErrInvalidScope: 400,
ErrServerError: 500,
ErrTemporarilyUnavailable: 503,
ErrInvalidClient: 401,
ErrInvalidGrant: 401,
ErrUnsupportedGrantType: 401,
ErrCodeChallengeRquired: 400,
ErrUnsupportedCodeChallengeMethod: 400,
ErrInvalidCodeChallengeLen: 400,
}

View File

@@ -0,0 +1,28 @@
package oauth2
import (
"context"
"net/http"
"time"
)
type (
// GenerateBasic provide the basis of the generated token data
GenerateBasic struct {
Client ClientInfo
UserID string
CreateAt time.Time
TokenInfo TokenInfo
Request *http.Request
}
// AuthorizeGenerate generate the authorization code interface
AuthorizeGenerate interface {
Token(ctx context.Context, data *GenerateBasic) (code string, err error)
}
// AccessGenerate generate the access and refresh tokens interface
AccessGenerate interface {
Token(ctx context.Context, data *GenerateBasic, isGenRefresh bool) (access, refresh string, err error)
}
)

View File

@@ -0,0 +1,38 @@
package generates
import (
"bytes"
"context"
"encoding/base64"
"strconv"
"strings"
"codeberg.org/superseriousbusiness/oauth2/v4"
"github.com/google/uuid"
)
// NewAccessGenerate create to generate the access token instance
func NewAccessGenerate() *AccessGenerate {
return &AccessGenerate{}
}
// AccessGenerate generate the access token
type AccessGenerate struct {
}
// Token based on the UUID generated token
func (ag *AccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) {
buf := bytes.NewBufferString(data.Client.GetID())
buf.WriteString(data.UserID)
buf.WriteString(strconv.FormatInt(data.CreateAt.UnixNano(), 10))
access := base64.URLEncoding.EncodeToString([]byte(uuid.NewMD5(uuid.Must(uuid.NewRandom()), buf.Bytes()).String()))
access = strings.ToUpper(strings.TrimRight(access, "="))
refresh := ""
if isGenRefresh {
refresh = base64.URLEncoding.EncodeToString([]byte(uuid.NewSHA1(uuid.Must(uuid.NewRandom()), buf.Bytes()).String()))
refresh = strings.ToUpper(strings.TrimRight(refresh, "="))
}
return access, refresh, nil
}

View File

@@ -0,0 +1,30 @@
package generates
import (
"bytes"
"context"
"encoding/base64"
"strings"
"codeberg.org/superseriousbusiness/oauth2/v4"
"github.com/google/uuid"
)
// NewAuthorizeGenerate create to generate the authorize code instance
func NewAuthorizeGenerate() *AuthorizeGenerate {
return &AuthorizeGenerate{}
}
// AuthorizeGenerate generate the authorize code
type AuthorizeGenerate struct{}
// Token based on the UUID generated token
func (ag *AuthorizeGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic) (string, error) {
buf := bytes.NewBufferString(data.Client.GetID())
buf.WriteString(data.UserID)
token := uuid.NewMD5(uuid.Must(uuid.NewRandom()), buf.Bytes())
code := base64.URLEncoding.EncodeToString([]byte(token.String()))
code = strings.ToUpper(strings.TrimRight(code, "="))
return code, nil
}

View File

@@ -0,0 +1,104 @@
package generates
import (
"context"
"encoding/base64"
"strings"
"time"
"codeberg.org/superseriousbusiness/oauth2/v4"
"codeberg.org/superseriousbusiness/oauth2/v4/errors"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
)
// JWTAccessClaims jwt claims
type JWTAccessClaims struct {
jwt.StandardClaims
}
// Valid claims verification
func (a *JWTAccessClaims) Valid() error {
if time.Unix(a.ExpiresAt, 0).Before(time.Now()) {
return errors.ErrInvalidAccessToken
}
return nil
}
// NewJWTAccessGenerate create to generate the jwt access token instance
func NewJWTAccessGenerate(kid string, key []byte, method jwt.SigningMethod) *JWTAccessGenerate {
return &JWTAccessGenerate{
SignedKeyID: kid,
SignedKey: key,
SignedMethod: method,
}
}
// JWTAccessGenerate generate the jwt access token
type JWTAccessGenerate struct {
SignedKeyID string
SignedKey []byte
SignedMethod jwt.SigningMethod
}
// Token based on the UUID generated token
func (a *JWTAccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) {
claims := &JWTAccessClaims{
StandardClaims: jwt.StandardClaims{
Audience: data.Client.GetID(),
Subject: data.UserID,
ExpiresAt: data.TokenInfo.GetAccessCreateAt().Add(data.TokenInfo.GetAccessExpiresIn()).Unix(),
},
}
token := jwt.NewWithClaims(a.SignedMethod, claims)
if a.SignedKeyID != "" {
token.Header["kid"] = a.SignedKeyID
}
var key interface{}
if a.isEs() {
v, err := jwt.ParseECPrivateKeyFromPEM(a.SignedKey)
if err != nil {
return "", "", err
}
key = v
} else if a.isRsOrPS() {
v, err := jwt.ParseRSAPrivateKeyFromPEM(a.SignedKey)
if err != nil {
return "", "", err
}
key = v
} else if a.isHs() {
key = a.SignedKey
} else {
return "", "", errors.New("unsupported sign method")
}
access, err := token.SignedString(key)
if err != nil {
return "", "", err
}
refresh := ""
if isGenRefresh {
t := uuid.NewSHA1(uuid.Must(uuid.NewRandom()), []byte(access)).String()
refresh = base64.URLEncoding.EncodeToString([]byte(t))
refresh = strings.ToUpper(strings.TrimRight(refresh, "="))
}
return access, refresh, nil
}
func (a *JWTAccessGenerate) isEs() bool {
return strings.HasPrefix(a.SignedMethod.Alg(), "ES")
}
func (a *JWTAccessGenerate) isRsOrPS() bool {
isRs := strings.HasPrefix(a.SignedMethod.Alg(), "RS")
isPs := strings.HasPrefix(a.SignedMethod.Alg(), "PS")
return isRs || isPs
}
func (a *JWTAccessGenerate) isHs() bool {
return strings.HasPrefix(a.SignedMethod.Alg(), "HS")
}

View File

@@ -0,0 +1,12 @@
#!/usr/bin/env bash
set -e
echo "" > coverage.txt
for d in $(go list ./... | grep -v vendor); do
go test -race -coverprofile=profile.out -covermode=atomic "$d"
if [ -f profile.out ]; then
cat profile.out >> coverage.txt
rm profile.out
fi
done

View File

@@ -0,0 +1,50 @@
package oauth2
import (
"context"
"net/http"
"time"
)
// TokenGenerateRequest provide to generate the token request parameters
type TokenGenerateRequest struct {
ClientID string
ClientSecret string
UserID string
RedirectURI string
Scope string
Code string
CodeChallenge string
CodeChallengeMethod CodeChallengeMethod
Refresh string
CodeVerifier string
AccessTokenExp time.Duration
Request *http.Request
}
// Manager authorization management interface
type Manager interface {
// get the client information
GetClient(ctx context.Context, clientID string) (cli ClientInfo, err error)
// generate the authorization token(code)
GenerateAuthToken(ctx context.Context, rt ResponseType, tgr *TokenGenerateRequest) (authToken TokenInfo, err error)
// generate the access token
GenerateAccessToken(ctx context.Context, rt GrantType, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error)
// refreshing an access token
RefreshAccessToken(ctx context.Context, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error)
// use the access token to delete the token information
RemoveAccessToken(ctx context.Context, access string) (err error)
// use the refresh token to delete the token information
RemoveRefreshToken(ctx context.Context, refresh string) (err error)
// according to the access token for corresponding token information
LoadAccessToken(ctx context.Context, access string) (ti TokenInfo, err error)
// according to the refresh token for corresponding token information
LoadRefreshToken(ctx context.Context, refresh string) (ti TokenInfo, err error)
}

View File

@@ -0,0 +1,39 @@
package manage
import "time"
// Config authorization configuration parameters
type Config struct {
// access token expiration time, 0 means it doesn't expire
AccessTokenExp time.Duration
// refresh token expiration time, 0 means it doesn't expire
RefreshTokenExp time.Duration
// whether to generate the refreshing token
IsGenerateRefresh bool
}
// RefreshingConfig refreshing token config
type RefreshingConfig struct {
// access token expiration time, 0 means it doesn't expire
AccessTokenExp time.Duration
// refresh token expiration time, 0 means it doesn't expire
RefreshTokenExp time.Duration
// whether to generate the refreshing token
IsGenerateRefresh bool
// whether to reset the refreshing create time
IsResetRefreshTime bool
// whether to remove access token
IsRemoveAccess bool
// whether to remove refreshing token
IsRemoveRefreshing bool
}
// default configs
var (
DefaultCodeExp = time.Minute * 10
DefaultAuthorizeCodeTokenCfg = &Config{AccessTokenExp: time.Hour * 2, RefreshTokenExp: time.Hour * 24 * 3, IsGenerateRefresh: true}
DefaultImplicitTokenCfg = &Config{AccessTokenExp: time.Hour * 1}
DefaultPasswordTokenCfg = &Config{AccessTokenExp: time.Hour * 2, RefreshTokenExp: time.Hour * 24 * 7, IsGenerateRefresh: true}
DefaultClientTokenCfg = &Config{AccessTokenExp: time.Hour * 2}
DefaultRefreshTokenCfg = &RefreshingConfig{IsGenerateRefresh: true, IsRemoveAccess: true, IsRemoveRefreshing: true}
)

View File

@@ -0,0 +1,504 @@
package manage
import (
"context"
"time"
"codeberg.org/superseriousbusiness/oauth2/v4"
"codeberg.org/superseriousbusiness/oauth2/v4/errors"
"codeberg.org/superseriousbusiness/oauth2/v4/generates"
"codeberg.org/superseriousbusiness/oauth2/v4/models"
)
// NewDefaultManager create to default authorization management instance
func NewDefaultManager() *Manager {
m := NewManager()
// default implementation
m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate())
m.MapAccessGenerate(generates.NewAccessGenerate())
return m
}
// NewManager create to authorization management instance
func NewManager() *Manager {
return &Manager{
gtcfg: make(map[oauth2.GrantType]*Config),
validateURI: DefaultValidateURI,
}
}
// Manager provide authorization management
type Manager struct {
codeExp time.Duration
gtcfg map[oauth2.GrantType]*Config
rcfg *RefreshingConfig
validateURI ValidateURIHandler
authorizeGenerate oauth2.AuthorizeGenerate
accessGenerate oauth2.AccessGenerate
tokenStore oauth2.TokenStore
clientStore oauth2.ClientStore
}
// get grant type config
func (m *Manager) grantConfig(gt oauth2.GrantType) *Config {
if c, ok := m.gtcfg[gt]; ok && c != nil {
return c
}
switch gt {
case oauth2.AuthorizationCode:
return DefaultAuthorizeCodeTokenCfg
case oauth2.Implicit:
return DefaultImplicitTokenCfg
case oauth2.PasswordCredentials:
return DefaultPasswordTokenCfg
case oauth2.ClientCredentials:
return DefaultClientTokenCfg
}
return &Config{}
}
// SetAuthorizeCodeExp set the authorization code expiration time
func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) {
m.codeExp = exp
}
// SetAuthorizeCodeTokenCfg set the authorization code grant token config
func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) {
m.gtcfg[oauth2.AuthorizationCode] = cfg
}
// SetImplicitTokenCfg set the implicit grant token config
func (m *Manager) SetImplicitTokenCfg(cfg *Config) {
m.gtcfg[oauth2.Implicit] = cfg
}
// SetPasswordTokenCfg set the password grant token config
func (m *Manager) SetPasswordTokenCfg(cfg *Config) {
m.gtcfg[oauth2.PasswordCredentials] = cfg
}
// SetClientTokenCfg set the client grant token config
func (m *Manager) SetClientTokenCfg(cfg *Config) {
m.gtcfg[oauth2.ClientCredentials] = cfg
}
// SetRefreshTokenCfg set the refreshing token config
func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) {
m.rcfg = cfg
}
// SetValidateURIHandler set the validates that RedirectURI is contained in baseURI
func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) {
m.validateURI = handler
}
// MapAuthorizeGenerate mapping the authorize code generate interface
func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) {
m.authorizeGenerate = gen
}
// MapAccessGenerate mapping the access token generate interface
func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) {
m.accessGenerate = gen
}
// MapClientStorage mapping the client store interface
func (m *Manager) MapClientStorage(stor oauth2.ClientStore) {
m.clientStore = stor
}
// MustClientStorage mandatory mapping the client store interface
func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) {
if err != nil {
panic(err.Error())
}
m.clientStore = stor
}
// MapTokenStorage mapping the token store interface
func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) {
m.tokenStore = stor
}
// MustTokenStorage mandatory mapping the token store interface
func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) {
if err != nil {
panic(err)
}
m.tokenStore = stor
}
// GetClient get the client information
func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) {
cli, err = m.clientStore.GetByID(ctx, clientID)
if err != nil {
return
} else if cli == nil {
err = errors.ErrInvalidClient
}
return
}
// GenerateAuthToken generate the authorization token(code)
func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
cli, err := m.GetClient(ctx, tgr.ClientID)
if err != nil {
return nil, err
} else if tgr.RedirectURI != "" {
if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil {
return nil, err
}
}
ti := models.NewToken()
ti.SetClientID(tgr.ClientID)
ti.SetUserID(tgr.UserID)
ti.SetRedirectURI(tgr.RedirectURI)
ti.SetScope(tgr.Scope)
createAt := time.Now()
td := &oauth2.GenerateBasic{
Client: cli,
UserID: tgr.UserID,
CreateAt: createAt,
TokenInfo: ti,
Request: tgr.Request,
}
switch rt {
case oauth2.Code:
codeExp := m.codeExp
if codeExp == 0 {
codeExp = DefaultCodeExp
}
ti.SetCodeCreateAt(createAt)
ti.SetCodeExpiresIn(codeExp)
if exp := tgr.AccessTokenExp; exp > 0 {
ti.SetAccessExpiresIn(exp)
}
if tgr.CodeChallenge != "" {
ti.SetCodeChallenge(tgr.CodeChallenge)
ti.SetCodeChallengeMethod(tgr.CodeChallengeMethod)
}
tv, err := m.authorizeGenerate.Token(ctx, td)
if err != nil {
return nil, err
}
ti.SetCode(tv)
case oauth2.Token:
// set access token expires
icfg := m.grantConfig(oauth2.Implicit)
aexp := icfg.AccessTokenExp
if exp := tgr.AccessTokenExp; exp > 0 {
aexp = exp
}
ti.SetAccessCreateAt(createAt)
ti.SetAccessExpiresIn(aexp)
if icfg.IsGenerateRefresh {
ti.SetRefreshCreateAt(createAt)
ti.SetRefreshExpiresIn(icfg.RefreshTokenExp)
}
tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh)
if err != nil {
return nil, err
}
ti.SetAccess(tv)
if rv != "" {
ti.SetRefresh(rv)
}
}
err = m.tokenStore.Create(ctx, ti)
if err != nil {
return nil, err
}
return ti, nil
}
// get authorization code data
func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
ti, err := m.tokenStore.GetByCode(ctx, code)
if err != nil {
return nil, err
} else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) {
err = errors.ErrInvalidAuthorizeCode
return nil, errors.ErrInvalidAuthorizeCode
}
return ti, nil
}
// delete authorization code data
func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error {
return m.tokenStore.RemoveByCode(ctx, code)
}
// get and delete authorization code data
func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
code := tgr.Code
ti, err := m.getAuthorizationCode(ctx, code)
if err != nil {
return nil, err
} else if ti.GetClientID() != tgr.ClientID {
return nil, errors.ErrInvalidAuthorizeCode
} else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI {
return nil, errors.ErrInvalidAuthorizeCode
}
err = m.delAuthorizationCode(ctx, code)
if err != nil {
return nil, err
}
return ti, nil
}
func (m *Manager) validateCodeChallenge(ti oauth2.TokenInfo, ver string) error {
cc := ti.GetCodeChallenge()
// early return
if cc == "" && ver == "" {
return nil
}
if cc == "" {
return errors.ErrMissingCodeVerifier
}
if ver == "" {
return errors.ErrMissingCodeVerifier
}
ccm := ti.GetCodeChallengeMethod()
if ccm.String() == "" {
ccm = oauth2.CodeChallengePlain
}
if !ccm.Validate(cc, ver) {
return errors.ErrInvalidCodeChallenge
}
return nil
}
// GenerateAccessToken generate the access token
func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
cli, err := m.GetClient(ctx, tgr.ClientID)
if err != nil {
return nil, err
}
if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok {
if !cliPass.VerifyPassword(tgr.ClientSecret) {
return nil, errors.ErrInvalidClient
}
} else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() {
return nil, errors.ErrInvalidClient
}
if tgr.RedirectURI != "" {
if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil {
return nil, err
}
}
if gt == oauth2.AuthorizationCode {
ti, err := m.getAndDelAuthorizationCode(ctx, tgr)
if err != nil {
return nil, err
}
if err := m.validateCodeChallenge(ti, tgr.CodeVerifier); err != nil {
return nil, err
}
tgr.UserID = ti.GetUserID()
tgr.Scope = ti.GetScope()
if exp := ti.GetAccessExpiresIn(); exp > 0 {
tgr.AccessTokenExp = exp
}
}
ti := models.NewToken()
ti.SetClientID(tgr.ClientID)
ti.SetUserID(tgr.UserID)
ti.SetRedirectURI(tgr.RedirectURI)
ti.SetScope(tgr.Scope)
createAt := time.Now()
ti.SetAccessCreateAt(createAt)
// set access token expires
gcfg := m.grantConfig(gt)
aexp := gcfg.AccessTokenExp
if exp := tgr.AccessTokenExp; exp > 0 {
aexp = exp
}
ti.SetAccessExpiresIn(aexp)
if gcfg.IsGenerateRefresh {
ti.SetRefreshCreateAt(createAt)
ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp)
}
td := &oauth2.GenerateBasic{
Client: cli,
UserID: tgr.UserID,
CreateAt: createAt,
TokenInfo: ti,
Request: tgr.Request,
}
av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh)
if err != nil {
return nil, err
}
ti.SetAccess(av)
if rv != "" {
ti.SetRefresh(rv)
}
err = m.tokenStore.Create(ctx, ti)
if err != nil {
return nil, err
}
return ti, nil
}
// RefreshAccessToken refreshing an access token
func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
cli, err := m.GetClient(ctx, tgr.ClientID)
if err != nil {
return nil, err
} else if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok {
if !cliPass.VerifyPassword(tgr.ClientSecret) {
return nil, errors.ErrInvalidClient
}
} else if tgr.ClientSecret != cli.GetSecret() {
return nil, errors.ErrInvalidClient
}
ti, err := m.LoadRefreshToken(ctx, tgr.Refresh)
if err != nil {
return nil, err
} else if ti.GetClientID() != tgr.ClientID {
return nil, errors.ErrInvalidRefreshToken
}
oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh()
td := &oauth2.GenerateBasic{
Client: cli,
UserID: ti.GetUserID(),
CreateAt: time.Now(),
TokenInfo: ti,
Request: tgr.Request,
}
rcfg := DefaultRefreshTokenCfg
if v := m.rcfg; v != nil {
rcfg = v
}
ti.SetAccessCreateAt(td.CreateAt)
if v := rcfg.AccessTokenExp; v > 0 {
ti.SetAccessExpiresIn(v)
}
if v := rcfg.RefreshTokenExp; v > 0 {
ti.SetRefreshExpiresIn(v)
}
if rcfg.IsResetRefreshTime {
ti.SetRefreshCreateAt(td.CreateAt)
}
if scope := tgr.Scope; scope != "" {
ti.SetScope(scope)
}
tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh)
if err != nil {
return nil, err
}
ti.SetAccess(tv)
if rv != "" {
ti.SetRefresh(rv)
}
if err := m.tokenStore.Create(ctx, ti); err != nil {
return nil, err
}
if rcfg.IsRemoveAccess {
// remove the old access token
if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil {
return nil, err
}
}
if rcfg.IsRemoveRefreshing && rv != "" {
// remove the old refresh token
if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil {
return nil, err
}
}
if rv == "" {
ti.SetRefresh("")
ti.SetRefreshCreateAt(time.Now())
ti.SetRefreshExpiresIn(0)
}
return ti, nil
}
// RemoveAccessToken use the access token to delete the token information
func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error {
if access == "" {
return errors.ErrInvalidAccessToken
}
return m.tokenStore.RemoveByAccess(ctx, access)
}
// RemoveRefreshToken use the refresh token to delete the token information
func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error {
if refresh == "" {
return errors.ErrInvalidAccessToken
}
return m.tokenStore.RemoveByRefresh(ctx, refresh)
}
// LoadAccessToken according to the access token for corresponding token information
func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) {
if access == "" {
return nil, errors.ErrInvalidAccessToken
}
ct := time.Now()
ti, err := m.tokenStore.GetByAccess(ctx, access)
if err != nil {
return nil, err
} else if ti == nil || ti.GetAccess() != access {
return nil, errors.ErrInvalidAccessToken
} else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 &&
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
return nil, errors.ErrExpiredRefreshToken
} else if ti.GetAccessExpiresIn() != 0 &&
ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) {
return nil, errors.ErrExpiredAccessToken
}
return ti, nil
}
// LoadRefreshToken according to the refresh token for corresponding token information
func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
if refresh == "" {
return nil, errors.ErrInvalidRefreshToken
}
ti, err := m.tokenStore.GetByRefresh(ctx, refresh)
if err != nil {
return nil, err
} else if ti == nil || ti.GetRefresh() != refresh {
return nil, errors.ErrInvalidRefreshToken
} else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) {
return nil, errors.ErrExpiredRefreshToken
}
return ti, nil
}

View File

@@ -0,0 +1,30 @@
package manage
import (
"net/url"
"strings"
"codeberg.org/superseriousbusiness/oauth2/v4/errors"
)
type (
// ValidateURIHandler validates that redirectURI is contained in baseURI
ValidateURIHandler func(baseURI, redirectURI string) error
)
// DefaultValidateURI validates that redirectURI is contained in baseURI
func DefaultValidateURI(baseURI string, redirectURI string) error {
base, err := url.Parse(baseURI)
if err != nil {
return err
}
redirect, err := url.Parse(redirectURI)
if err != nil {
return err
}
if !strings.HasSuffix(redirect.Host, base.Host) {
return errors.ErrInvalidRedirectURI
}
return nil
}

View File

@@ -0,0 +1,59 @@
package oauth2
import (
"time"
)
type (
// ClientInfo the client information model interface
ClientInfo interface {
GetID() string
GetSecret() string
GetDomain() string
GetUserID() string
}
// ClientPasswordVerifier the password handler interface
ClientPasswordVerifier interface {
VerifyPassword(string) bool
}
// TokenInfo the token information model interface
TokenInfo interface {
New() TokenInfo
GetClientID() string
SetClientID(string)
GetUserID() string
SetUserID(string)
GetRedirectURI() string
SetRedirectURI(string)
GetScope() string
SetScope(string)
GetCode() string
SetCode(string)
GetCodeCreateAt() time.Time
SetCodeCreateAt(time.Time)
GetCodeExpiresIn() time.Duration
SetCodeExpiresIn(time.Duration)
GetCodeChallenge() string
SetCodeChallenge(string)
GetCodeChallengeMethod() CodeChallengeMethod
SetCodeChallengeMethod(CodeChallengeMethod)
GetAccess() string
SetAccess(string)
GetAccessCreateAt() time.Time
SetAccessCreateAt(time.Time)
GetAccessExpiresIn() time.Duration
SetAccessExpiresIn(time.Duration)
GetRefresh() string
SetRefresh(string)
GetRefreshCreateAt() time.Time
SetRefreshCreateAt(time.Time)
GetRefreshExpiresIn() time.Duration
SetRefreshExpiresIn(time.Duration)
}
)

View File

@@ -0,0 +1,46 @@
package models
// Client client model
type Client interface {
GetID() string
GetSecret() string
GetDomain() string
GetUserID() string
}
func New(id string, secret string, domain string, userID string) Client {
return &simpleClient{
id: id,
secret: secret,
domain: domain,
userID: userID,
}
}
// simpleClient is a very simple client model that satisfies the Client interface
type simpleClient struct {
id string
secret string
domain string
userID string
}
// GetID client id
func (c *simpleClient) GetID() string {
return c.id
}
// GetSecret client secret
func (c *simpleClient) GetSecret() string {
return c.secret
}
// GetDomain client domain
func (c *simpleClient) GetDomain() string {
return c.domain
}
// GetUserID user id
func (c *simpleClient) GetUserID() string {
return c.userID
}

View File

@@ -0,0 +1,186 @@
package models
import (
"time"
"codeberg.org/superseriousbusiness/oauth2/v4"
)
// NewToken create to token model instance
func NewToken() *Token {
return &Token{}
}
// Token token model
type Token struct {
ClientID string `bson:"ClientID"`
UserID string `bson:"UserID"`
RedirectURI string `bson:"RedirectURI"`
Scope string `bson:"Scope"`
Code string `bson:"Code"`
CodeChallenge string `bson:"CodeChallenge"`
CodeChallengeMethod string `bson:"CodeChallengeMethod"`
CodeCreateAt time.Time `bson:"CodeCreateAt"`
CodeExpiresIn time.Duration `bson:"CodeExpiresIn"`
Access string `bson:"Access"`
AccessCreateAt time.Time `bson:"AccessCreateAt"`
AccessExpiresIn time.Duration `bson:"AccessExpiresIn"`
Refresh string `bson:"Refresh"`
RefreshCreateAt time.Time `bson:"RefreshCreateAt"`
RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"`
}
// New create to token model instance
func (t *Token) New() oauth2.TokenInfo {
return NewToken()
}
// GetClientID the client id
func (t *Token) GetClientID() string {
return t.ClientID
}
// SetClientID the client id
func (t *Token) SetClientID(clientID string) {
t.ClientID = clientID
}
// GetUserID the user id
func (t *Token) GetUserID() string {
return t.UserID
}
// SetUserID the user id
func (t *Token) SetUserID(userID string) {
t.UserID = userID
}
// GetRedirectURI redirect URI
func (t *Token) GetRedirectURI() string {
return t.RedirectURI
}
// SetRedirectURI redirect URI
func (t *Token) SetRedirectURI(redirectURI string) {
t.RedirectURI = redirectURI
}
// GetScope get scope of authorization
func (t *Token) GetScope() string {
return t.Scope
}
// SetScope get scope of authorization
func (t *Token) SetScope(scope string) {
t.Scope = scope
}
// GetCode authorization code
func (t *Token) GetCode() string {
return t.Code
}
// SetCode authorization code
func (t *Token) SetCode(code string) {
t.Code = code
}
// GetCodeCreateAt create Time
func (t *Token) GetCodeCreateAt() time.Time {
return t.CodeCreateAt
}
// SetCodeCreateAt create Time
func (t *Token) SetCodeCreateAt(createAt time.Time) {
t.CodeCreateAt = createAt
}
// GetCodeExpiresIn the lifetime in seconds of the authorization code
func (t *Token) GetCodeExpiresIn() time.Duration {
return t.CodeExpiresIn
}
// SetCodeExpiresIn the lifetime in seconds of the authorization code
func (t *Token) SetCodeExpiresIn(exp time.Duration) {
t.CodeExpiresIn = exp
}
// GetCodeChallenge challenge code
func (t *Token) GetCodeChallenge() string {
return t.CodeChallenge
}
// SetCodeChallenge challenge code
func (t *Token) SetCodeChallenge(code string) {
t.CodeChallenge = code
}
// GetCodeChallengeMethod challenge method
func (t *Token) GetCodeChallengeMethod() oauth2.CodeChallengeMethod {
return oauth2.CodeChallengeMethod(t.CodeChallengeMethod)
}
// SetCodeChallengeMethod challenge method
func (t *Token) SetCodeChallengeMethod(method oauth2.CodeChallengeMethod) {
t.CodeChallengeMethod = string(method)
}
// GetAccess access Token
func (t *Token) GetAccess() string {
return t.Access
}
// SetAccess access Token
func (t *Token) SetAccess(access string) {
t.Access = access
}
// GetAccessCreateAt create Time
func (t *Token) GetAccessCreateAt() time.Time {
return t.AccessCreateAt
}
// SetAccessCreateAt create Time
func (t *Token) SetAccessCreateAt(createAt time.Time) {
t.AccessCreateAt = createAt
}
// GetAccessExpiresIn the lifetime in seconds of the access token
func (t *Token) GetAccessExpiresIn() time.Duration {
return t.AccessExpiresIn
}
// SetAccessExpiresIn the lifetime in seconds of the access token
func (t *Token) SetAccessExpiresIn(exp time.Duration) {
t.AccessExpiresIn = exp
}
// GetRefresh refresh Token
func (t *Token) GetRefresh() string {
return t.Refresh
}
// SetRefresh refresh Token
func (t *Token) SetRefresh(refresh string) {
t.Refresh = refresh
}
// GetRefreshCreateAt create Time
func (t *Token) GetRefreshCreateAt() time.Time {
return t.RefreshCreateAt
}
// SetRefreshCreateAt create Time
func (t *Token) SetRefreshCreateAt(createAt time.Time) {
t.RefreshCreateAt = createAt
}
// GetRefreshExpiresIn the lifetime in seconds of the refresh token
func (t *Token) GetRefreshExpiresIn() time.Duration {
return t.RefreshExpiresIn
}
// SetRefreshExpiresIn the lifetime in seconds of the refresh token
func (t *Token) SetRefreshExpiresIn(exp time.Duration) {
t.RefreshExpiresIn = exp
}

View File

@@ -0,0 +1,50 @@
package server
import (
"net/http"
"time"
"codeberg.org/superseriousbusiness/oauth2/v4"
)
// Config configuration parameters
type Config struct {
TokenType string // token type
AllowGetAccessRequest bool // to allow GET requests for the token
AllowedResponseTypes []oauth2.ResponseType // allow the authorization type
AllowedGrantTypes []oauth2.GrantType // allow the grant type
AllowedCodeChallengeMethods []oauth2.CodeChallengeMethod
ForcePKCE bool
}
// NewConfig create to configuration instance
func NewConfig() *Config {
return &Config{
TokenType: "Bearer",
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code, oauth2.Token},
AllowedGrantTypes: []oauth2.GrantType{
oauth2.AuthorizationCode,
oauth2.PasswordCredentials,
oauth2.ClientCredentials,
oauth2.Refreshing,
},
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{
oauth2.CodeChallengePlain,
oauth2.CodeChallengeS256,
},
}
}
// AuthorizeRequest authorization request
type AuthorizeRequest struct {
ResponseType oauth2.ResponseType
ClientID string
Scope string
RedirectURI string
State string
UserID string
CodeChallenge string
CodeChallengeMethod oauth2.CodeChallengeMethod
AccessTokenExp time.Duration
Request *http.Request
}

View File

@@ -0,0 +1,66 @@
package server
import (
"net/http"
"time"
"codeberg.org/superseriousbusiness/oauth2/v4"
"codeberg.org/superseriousbusiness/oauth2/v4/errors"
)
type (
// ClientInfoHandler get client info from request
ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error)
// ClientAuthorizedHandler check the client allows to use this authorization grant type
ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error)
// ClientScopeHandler check the client allows to use scope
ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error)
// UserAuthorizationHandler get user id from request authorization
UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error)
// PasswordAuthorizationHandler get user id from username and password
PasswordAuthorizationHandler func(username, password string) (userID string, err error)
// RefreshingScopeHandler check the scope of the refreshing token
RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error)
// RefreshingValidationHandler check if refresh_token is still valid. eg no revocation or other
RefreshingValidationHandler func(ti oauth2.TokenInfo) (allowed bool, err error)
// ResponseErrorHandler response error handing
ResponseErrorHandler func(re *errors.Response)
// InternalErrorHandler internal error handing
InternalErrorHandler func(err error) (re *errors.Response)
// AuthorizeScopeHandler set the authorized scope
AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error)
// AccessTokenExpHandler set expiration date for the access token
AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error)
// ExtensionFieldsHandler in response to the access token with the extension of the field
ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{})
)
// ClientFormHandler get client data from form
func ClientFormHandler(r *http.Request) (string, string, error) {
clientID := r.Form.Get("client_id")
if clientID == "" {
return "", "", errors.ErrInvalidClient
}
clientSecret := r.Form.Get("client_secret")
return clientID, clientSecret, nil
}
// ClientBasicHandler get client data from basic authorization
func ClientBasicHandler(r *http.Request) (string, string, error) {
username, password, ok := r.BasicAuth()
if !ok {
return "", "", errors.ErrInvalidClient
}
return username, password, nil
}

View File

@@ -0,0 +1,589 @@
package server
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"codeberg.org/superseriousbusiness/oauth2/v4"
"codeberg.org/superseriousbusiness/oauth2/v4/errors"
)
// NewDefaultServer create a default authorization server
func NewDefaultServer(manager oauth2.Manager) *Server {
return NewServer(NewConfig(), manager)
}
// NewServer create authorization server
func NewServer(cfg *Config, manager oauth2.Manager) *Server {
srv := &Server{
Config: cfg,
Manager: manager,
}
// default handler
srv.ClientInfoHandler = ClientBasicHandler
srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) {
return "", errors.ErrAccessDenied
}
srv.PasswordAuthorizationHandler = func(username, password string) (string, error) {
return "", errors.ErrAccessDenied
}
return srv
}
// Server Provide authorization server
type Server struct {
Config *Config
Manager oauth2.Manager
ClientInfoHandler ClientInfoHandler
ClientAuthorizedHandler ClientAuthorizedHandler
ClientScopeHandler ClientScopeHandler
UserAuthorizationHandler UserAuthorizationHandler
PasswordAuthorizationHandler PasswordAuthorizationHandler
RefreshingValidationHandler RefreshingValidationHandler
RefreshingScopeHandler RefreshingScopeHandler
ResponseErrorHandler ResponseErrorHandler
InternalErrorHandler InternalErrorHandler
ExtensionFieldsHandler ExtensionFieldsHandler
AccessTokenExpHandler AccessTokenExpHandler
AuthorizeScopeHandler AuthorizeScopeHandler
}
func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error {
if req == nil {
return err
}
data, _, _ := s.GetErrorData(err)
return s.redirect(w, req, data)
}
func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error {
uri, err := s.GetRedirectURI(req, data)
if err != nil {
return err
}
w.Header().Set("Location", uri)
w.WriteHeader(302)
return nil
}
func (s *Server) tokenError(w http.ResponseWriter, err error) error {
data, statusCode, header := s.GetErrorData(err)
return s.token(w, data, header, statusCode)
}
func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error {
w.Header().Set("Content-Type", "application/json;charset=UTF-8")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
for key := range header {
w.Header().Set(key, header.Get(key))
}
status := http.StatusOK
if len(statusCode) > 0 && statusCode[0] > 0 {
status = statusCode[0]
}
w.WriteHeader(status)
return json.NewEncoder(w).Encode(data)
}
// GetRedirectURI get redirect uri
func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) {
u, err := url.Parse(req.RedirectURI)
if err != nil {
return "", err
}
q := u.Query()
if req.State != "" {
q.Set("state", req.State)
}
for k, v := range data {
q.Set(k, fmt.Sprint(v))
}
switch req.ResponseType {
case oauth2.Code:
u.RawQuery = q.Encode()
case oauth2.Token:
u.RawQuery = ""
fragment, err := url.QueryUnescape(q.Encode())
if err != nil {
return "", err
}
u.Fragment = fragment
}
return u.String(), nil
}
// CheckResponseType check allows response type
func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool {
for _, art := range s.Config.AllowedResponseTypes {
if art == rt {
return true
}
}
return false
}
// CheckCodeChallengeMethod checks for allowed code challenge method
func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool {
for _, c := range s.Config.AllowedCodeChallengeMethods {
if c == ccm {
return true
}
}
return false
}
// ValidationAuthorizeRequest the authorization request validation
func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) {
redirectURI := r.FormValue("redirect_uri")
clientID := r.FormValue("client_id")
if !(r.Method == "GET" || r.Method == "POST") ||
clientID == "" {
return nil, errors.ErrInvalidRequest
}
resType := oauth2.ResponseType(r.FormValue("response_type"))
if resType.String() == "" {
return nil, errors.ErrUnsupportedResponseType
} else if allowed := s.CheckResponseType(resType); !allowed {
return nil, errors.ErrUnauthorizedClient
}
cc := r.FormValue("code_challenge")
if cc == "" && s.Config.ForcePKCE {
return nil, errors.ErrCodeChallengeRquired
}
if cc != "" && (len(cc) < 43 || len(cc) > 128) {
return nil, errors.ErrInvalidCodeChallengeLen
}
ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method"))
// set default
if ccm == "" {
ccm = oauth2.CodeChallengePlain
}
if ccm.String() != "" && !s.CheckCodeChallengeMethod(ccm) {
return nil, errors.ErrUnsupportedCodeChallengeMethod
}
req := &AuthorizeRequest{
RedirectURI: redirectURI,
ResponseType: resType,
ClientID: clientID,
State: r.FormValue("state"),
Scope: r.FormValue("scope"),
Request: r,
CodeChallenge: cc,
CodeChallengeMethod: ccm,
}
return req, nil
}
// GetAuthorizeToken get authorization token(code)
func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) {
// check the client allows the grant type
if fn := s.ClientAuthorizedHandler; fn != nil {
gt := oauth2.AuthorizationCode
if req.ResponseType == oauth2.Token {
gt = oauth2.Implicit
}
allowed, err := fn(req.ClientID, gt)
if err != nil {
return nil, err
} else if !allowed {
return nil, errors.ErrUnauthorizedClient
}
}
tgr := &oauth2.TokenGenerateRequest{
ClientID: req.ClientID,
UserID: req.UserID,
RedirectURI: req.RedirectURI,
Scope: req.Scope,
AccessTokenExp: req.AccessTokenExp,
Request: req.Request,
}
// check the client allows the authorized scope
if fn := s.ClientScopeHandler; fn != nil {
allowed, err := fn(tgr)
if err != nil {
return nil, err
} else if !allowed {
return nil, errors.ErrInvalidScope
}
}
tgr.CodeChallenge = req.CodeChallenge
tgr.CodeChallengeMethod = req.CodeChallengeMethod
return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr)
}
// GetAuthorizeData get authorization response data
func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} {
if rt == oauth2.Code {
return map[string]interface{}{
"code": ti.GetCode(),
}
}
return s.GetTokenData(ti)
}
// HandleAuthorizeRequest the authorization request handling
func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
req, err := s.ValidationAuthorizeRequest(r)
if err != nil {
return s.redirectError(w, req, err)
}
// user authorization
userID, err := s.UserAuthorizationHandler(w, r)
if err != nil {
return s.redirectError(w, req, err)
} else if userID == "" {
return nil
}
req.UserID = userID
// specify the scope of authorization
if fn := s.AuthorizeScopeHandler; fn != nil {
scope, err := fn(w, r)
if err != nil {
return err
} else if scope != "" {
req.Scope = scope
}
}
// specify the expiration time of access token
if fn := s.AccessTokenExpHandler; fn != nil {
exp, err := fn(w, r)
if err != nil {
return err
}
req.AccessTokenExp = exp
}
ti, err := s.GetAuthorizeToken(ctx, req)
if err != nil {
return s.redirectError(w, req, err)
}
// If the redirect URI is empty, the default domain provided by the client is used.
if req.RedirectURI == "" {
client, err := s.Manager.GetClient(ctx, req.ClientID)
if err != nil {
return err
}
req.RedirectURI = client.GetDomain()
}
return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti))
}
// ValidationTokenRequest the token request validation
func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) {
if v := r.Method; !(v == "POST" ||
(s.Config.AllowGetAccessRequest && v == "GET")) {
return "", nil, errors.ErrInvalidRequest
}
gt := oauth2.GrantType(r.FormValue("grant_type"))
if gt.String() == "" {
return "", nil, errors.ErrUnsupportedGrantType
}
if !s.CheckGrantType(gt) {
return "", nil, errors.ErrUnsupportedGrantType
}
clientID, clientSecret, err := s.ClientInfoHandler(r)
if err != nil {
return "", nil, err
}
tgr := &oauth2.TokenGenerateRequest{
ClientID: clientID,
ClientSecret: clientSecret,
Request: r,
}
switch gt {
case oauth2.AuthorizationCode:
tgr.RedirectURI = r.FormValue("redirect_uri")
tgr.Code = r.FormValue("code")
if tgr.RedirectURI == "" ||
tgr.Code == "" {
return "", nil, errors.ErrInvalidRequest
}
tgr.CodeVerifier = r.FormValue("code_verifier")
if s.Config.ForcePKCE && tgr.CodeVerifier == "" {
return "", nil, errors.ErrInvalidRequest
}
case oauth2.PasswordCredentials:
tgr.Scope = r.FormValue("scope")
username, password := r.FormValue("username"), r.FormValue("password")
if username == "" || password == "" {
return "", nil, errors.ErrInvalidRequest
}
userID, err := s.PasswordAuthorizationHandler(username, password)
if err != nil {
return "", nil, err
} else if userID == "" {
return "", nil, errors.ErrInvalidGrant
}
tgr.UserID = userID
case oauth2.ClientCredentials:
tgr.Scope = r.FormValue("scope")
tgr.RedirectURI = r.FormValue("redirect_uri")
case oauth2.Refreshing:
tgr.Refresh = r.FormValue("refresh_token")
tgr.Scope = r.FormValue("scope")
if tgr.Refresh == "" {
return "", nil, errors.ErrInvalidRequest
}
}
return gt, tgr, nil
}
// CheckGrantType check allows grant type
func (s *Server) CheckGrantType(gt oauth2.GrantType) bool {
for _, agt := range s.Config.AllowedGrantTypes {
if agt == gt {
return true
}
}
return false
}
// GetAccessToken access token
func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo,
error) {
if allowed := s.CheckGrantType(gt); !allowed {
return nil, errors.ErrUnauthorizedClient
}
if fn := s.ClientAuthorizedHandler; fn != nil {
allowed, err := fn(tgr.ClientID, gt)
if err != nil {
return nil, err
} else if !allowed {
return nil, errors.ErrUnauthorizedClient
}
}
switch gt {
case oauth2.AuthorizationCode:
ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr)
if err != nil {
switch err {
case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge, errors.ErrMissingCodeChallenge:
return nil, errors.ErrInvalidGrant
case errors.ErrInvalidClient:
return nil, errors.ErrInvalidClient
default:
return nil, err
}
}
return ti, nil
case oauth2.PasswordCredentials, oauth2.ClientCredentials:
if fn := s.ClientScopeHandler; fn != nil {
allowed, err := fn(tgr)
if err != nil {
return nil, err
} else if !allowed {
return nil, errors.ErrInvalidScope
}
}
return s.Manager.GenerateAccessToken(ctx, gt, tgr)
case oauth2.Refreshing:
// check scope
if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil {
rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh)
if err != nil {
if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
return nil, errors.ErrInvalidGrant
}
return nil, err
}
allowed, err := scopeFn(tgr, rti.GetScope())
if err != nil {
return nil, err
} else if !allowed {
return nil, errors.ErrInvalidScope
}
}
if validationFn := s.RefreshingValidationHandler; validationFn != nil {
rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh)
if err != nil {
if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
return nil, errors.ErrInvalidGrant
}
return nil, err
}
allowed, err := validationFn(rti)
if err != nil {
return nil, err
} else if !allowed {
return nil, errors.ErrInvalidScope
}
}
ti, err := s.Manager.RefreshAccessToken(ctx, tgr)
if err != nil {
if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
return nil, errors.ErrInvalidGrant
}
return nil, err
}
return ti, nil
}
return nil, errors.ErrUnsupportedGrantType
}
// GetTokenData token data
func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} {
data := map[string]interface{}{
"access_token": ti.GetAccess(),
"token_type": s.Config.TokenType,
"expires_in": int64(ti.GetAccessExpiresIn() / time.Second),
}
if scope := ti.GetScope(); scope != "" {
data["scope"] = scope
}
if refresh := ti.GetRefresh(); refresh != "" {
data["refresh_token"] = refresh
}
if fn := s.ExtensionFieldsHandler; fn != nil {
ext := fn(ti)
for k, v := range ext {
if _, ok := data[k]; ok {
continue
}
data[k] = v
}
}
return data
}
// HandleTokenRequest token request handling
func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
gt, tgr, err := s.ValidationTokenRequest(r)
if err != nil {
return s.tokenError(w, err)
}
ti, err := s.GetAccessToken(ctx, gt, tgr)
if err != nil {
return s.tokenError(w, err)
}
return s.token(w, s.GetTokenData(ti), nil)
}
// GetErrorData get error response data
func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) {
var re errors.Response
if v, ok := errors.Descriptions[err]; ok {
re.Error = err
re.Description = v
re.StatusCode = errors.StatusCodes[err]
} else {
if fn := s.InternalErrorHandler; fn != nil {
if v := fn(err); v != nil {
re = *v
}
}
if re.Error == nil {
re.Error = errors.ErrServerError
re.Description = errors.Descriptions[errors.ErrServerError]
re.StatusCode = errors.StatusCodes[errors.ErrServerError]
}
}
if fn := s.ResponseErrorHandler; fn != nil {
fn(&re)
}
data := make(map[string]interface{})
if err := re.Error; err != nil {
data["error"] = err.Error()
}
if v := re.ErrorCode; v != 0 {
data["error_code"] = v
}
if v := re.Description; v != "" {
data["error_description"] = v
}
if v := re.URI; v != "" {
data["error_uri"] = v
}
statusCode := http.StatusInternalServerError
if v := re.StatusCode; v > 0 {
statusCode = v
}
return data, statusCode, re.Header
}
// BearerAuth parse bearer token
func (s *Server) BearerAuth(r *http.Request) (string, bool) {
auth := r.Header.Get("Authorization")
prefix := "Bearer "
token := ""
if auth != "" && strings.HasPrefix(auth, prefix) {
token = auth[len(prefix):]
} else {
token = r.FormValue("access_token")
}
return token, token != ""
}
// ValidationBearerToken validation the bearer tokens
// https://tools.ietf.org/html/rfc6750
func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
ctx := r.Context()
accessToken, ok := s.BearerAuth(r)
if !ok {
return nil, errors.ErrInvalidAccessToken
}
return s.Manager.LoadAccessToken(ctx, accessToken)
}

View File

@@ -0,0 +1,85 @@
package server
import (
"codeberg.org/superseriousbusiness/oauth2/v4"
)
// SetTokenType token type
func (s *Server) SetTokenType(tokenType string) {
s.Config.TokenType = tokenType
}
// SetAllowGetAccessRequest to allow GET requests for the token
func (s *Server) SetAllowGetAccessRequest(allow bool) {
s.Config.AllowGetAccessRequest = allow
}
// SetAllowedResponseType allow the authorization types
func (s *Server) SetAllowedResponseType(types ...oauth2.ResponseType) {
s.Config.AllowedResponseTypes = types
}
// SetAllowedGrantType allow the grant types
func (s *Server) SetAllowedGrantType(types ...oauth2.GrantType) {
s.Config.AllowedGrantTypes = types
}
// SetClientInfoHandler get client info from request
func (s *Server) SetClientInfoHandler(handler ClientInfoHandler) {
s.ClientInfoHandler = handler
}
// SetClientAuthorizedHandler check the client allows to use this authorization grant type
func (s *Server) SetClientAuthorizedHandler(handler ClientAuthorizedHandler) {
s.ClientAuthorizedHandler = handler
}
// SetClientScopeHandler check the client allows to use scope
func (s *Server) SetClientScopeHandler(handler ClientScopeHandler) {
s.ClientScopeHandler = handler
}
// SetUserAuthorizationHandler get user id from request authorization
func (s *Server) SetUserAuthorizationHandler(handler UserAuthorizationHandler) {
s.UserAuthorizationHandler = handler
}
// SetPasswordAuthorizationHandler get user id from username and password
func (s *Server) SetPasswordAuthorizationHandler(handler PasswordAuthorizationHandler) {
s.PasswordAuthorizationHandler = handler
}
// SetRefreshingScopeHandler check the scope of the refreshing token
func (s *Server) SetRefreshingScopeHandler(handler RefreshingScopeHandler) {
s.RefreshingScopeHandler = handler
}
// SetRefreshingValidationHandler check if refresh_token is still valid. eg no revocation or other
func (s *Server) SetRefreshingValidationHandler(handler RefreshingValidationHandler) {
s.RefreshingValidationHandler = handler
}
// SetResponseErrorHandler response error handling
func (s *Server) SetResponseErrorHandler(handler ResponseErrorHandler) {
s.ResponseErrorHandler = handler
}
// SetInternalErrorHandler internal error handling
func (s *Server) SetInternalErrorHandler(handler InternalErrorHandler) {
s.InternalErrorHandler = handler
}
// SetExtensionFieldsHandler in response to the access token with the extension of the field
func (s *Server) SetExtensionFieldsHandler(handler ExtensionFieldsHandler) {
s.ExtensionFieldsHandler = handler
}
// SetAccessTokenExpHandler set expiration date for the access token
func (s *Server) SetAccessTokenExpHandler(handler AccessTokenExpHandler) {
s.AccessTokenExpHandler = handler
}
// SetAuthorizeScopeHandler set scope for the access token
func (s *Server) SetAuthorizeScopeHandler(handler AuthorizeScopeHandler) {
s.AuthorizeScopeHandler = handler
}

View File

@@ -0,0 +1,36 @@
package oauth2
import "context"
type (
// ClientStore the client information storage interface
ClientStore interface {
GetByID(ctx context.Context, id string) (ClientInfo, error)
Set(ctx context.Context, id string, cli ClientInfo) error
Delete(ctx context.Context, id string) error
}
// TokenStore the token information storage interface
TokenStore interface {
// create and store the new token information
Create(ctx context.Context, info TokenInfo) error
// delete the authorization code
RemoveByCode(ctx context.Context, code string) error
// use the access token to delete the token information
RemoveByAccess(ctx context.Context, access string) error
// use the refresh token to delete the token information
RemoveByRefresh(ctx context.Context, refresh string) error
// use the authorization code for token information data
GetByCode(ctx context.Context, code string) (TokenInfo, error)
// use the access token for token information data
GetByAccess(ctx context.Context, access string) (TokenInfo, error)
// use the refresh token for token information data
GetByRefresh(ctx context.Context, refresh string) (TokenInfo, error)
}
)