145 lines
4.2 KiB
Go
145 lines
4.2 KiB
Go
/*
|
|
* Copyright © 2020-2021 Musing Studio LLC and respective authors.
|
|
*
|
|
* This file is part of WriteFreely.
|
|
*
|
|
* WriteFreely is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU Affero General Public License, included
|
|
* in the LICENSE file in this source code package.
|
|
*/
|
|
|
|
package writefreely
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/writeas/web-core/log"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
)
|
|
|
|
type genericOauthClient struct {
|
|
ClientID string
|
|
ClientSecret string
|
|
AuthLocation string
|
|
ExchangeLocation string
|
|
InspectLocation string
|
|
CallbackLocation string
|
|
Scope string
|
|
MapUserID string
|
|
MapUsername string
|
|
MapDisplayName string
|
|
MapEmail string
|
|
HttpClient HttpClient
|
|
}
|
|
|
|
var _ oauthClient = genericOauthClient{}
|
|
|
|
const (
|
|
genericOauthDisplayName = "OAuth"
|
|
)
|
|
|
|
func (c genericOauthClient) GetProvider() string {
|
|
return "generic"
|
|
}
|
|
|
|
func (c genericOauthClient) GetClientID() string {
|
|
return c.ClientID
|
|
}
|
|
|
|
func (c genericOauthClient) GetCallbackLocation() string {
|
|
return c.CallbackLocation
|
|
}
|
|
|
|
func (c genericOauthClient) buildLoginURL(state string) (string, error) {
|
|
u, err := url.Parse(c.AuthLocation)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
q := u.Query()
|
|
q.Set("client_id", c.ClientID)
|
|
q.Set("redirect_uri", c.CallbackLocation)
|
|
q.Set("response_type", "code")
|
|
q.Set("state", state)
|
|
q.Set("scope", c.Scope)
|
|
u.RawQuery = q.Encode()
|
|
return u.String(), nil
|
|
}
|
|
|
|
func (c genericOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
|
|
form := url.Values{}
|
|
form.Add("client_id", c.ClientID)
|
|
form.Add("client_secret", c.ClientSecret)
|
|
form.Add("grant_type", "authorization_code")
|
|
form.Add("redirect_uri", c.CallbackLocation)
|
|
form.Add("scope", c.Scope)
|
|
form.Add("code", code)
|
|
req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode()))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.WithContext(ctx)
|
|
req.Header.Set("User-Agent", ServerUserAgent(""))
|
|
req.Header.Set("Accept", "application/json")
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.SetBasicAuth(c.ClientID, c.ClientSecret)
|
|
|
|
resp, err := c.HttpClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, errors.New("unable to exchange code for access token")
|
|
}
|
|
|
|
var tokenResponse TokenResponse
|
|
if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil {
|
|
return nil, err
|
|
}
|
|
if tokenResponse.Error != "" {
|
|
return nil, errors.New(tokenResponse.Error)
|
|
}
|
|
return &tokenResponse, nil
|
|
}
|
|
|
|
func (c genericOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
|
|
req, err := http.NewRequest("GET", c.InspectLocation, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.WithContext(ctx)
|
|
req.Header.Set("User-Agent", ServerUserAgent(""))
|
|
req.Header.Set("Accept", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
|
|
resp, err := c.HttpClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, errors.New("unable to inspect access token")
|
|
}
|
|
|
|
// since we don't know what the JSON from the server will look like, we create a
|
|
// generic interface and then map manually to values set in the config
|
|
var genericInterface map[string]interface{}
|
|
if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &genericInterface); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// map each relevant field in inspectResponse to the mapped field from the config
|
|
var inspectResponse InspectResponse
|
|
inspectResponse.UserID, _ = genericInterface[c.MapUserID].(string)
|
|
if inspectResponse.UserID == "" {
|
|
log.Error("[CONFIGURATION ERROR] Generic OAuth provider returned empty UserID value (`%s`).\n Do you need to configure a different `map_user_id` value for this provider?", c.MapUserID)
|
|
return nil, fmt.Errorf("no UserID (`%s`) value returned", c.MapUserID)
|
|
}
|
|
inspectResponse.Username, _ = genericInterface[c.MapUsername].(string)
|
|
inspectResponse.DisplayName, _ = genericInterface[c.MapDisplayName].(string)
|
|
inspectResponse.Email, _ = genericInterface[c.MapEmail].(string)
|
|
|
|
return &inspectResponse, nil
|
|
}
|