Merging T705-oauth into T710-oauth-slack. T705,T710

This commit is contained in:
Nick Gerakines 2019-12-27 13:40:11 -05:00
parent 4266154749
commit 13121cb266
7 changed files with 144 additions and 93 deletions

View File

@ -56,6 +56,25 @@ type (
Port int `ini:"port"` Port int `ini:"port"`
} }
OAuthCfg struct {
Enabled bool `ini:"enable"`
// write.as
WriteAsProviderAuthLocation string `ini:"wa_auth_location"`
WriteAsProviderTokenLocation string `ini:"wa_token_location"`
WriteAsProviderInspectLocation string `ini:"wa_inspect_location"`
WriteAsClientCallbackLocation string `ini:"wa_callback_location"`
WriteAsClientID string `ini:"wa_client_id"`
WriteAsClientSecret string `ini:"wa_client_secret"`
WriteAsAuthLocation string
// slack
SlackClientID string `ini:"slack_client_id"`
SlackClientSecret string `ini:"slack_client_secret"`
SlackTeamID string `init:"slack_team_id"`
SlackAuthLocation string
}
// AppCfg holds values that affect how the application functions // AppCfg holds values that affect how the application functions
AppCfg struct { AppCfg struct {
SiteName string `ini:"site_name"` SiteName string `ini:"site_name"`
@ -92,17 +111,10 @@ type (
LocalTimeline bool `ini:"local_timeline"` LocalTimeline bool `ini:"local_timeline"`
UserInvites string `ini:"user_invites"` UserInvites string `ini:"user_invites"`
// OAuth
EnableOAuth bool `ini:"enable_oauth"`
OAuthProviderAuthLocation string `ini:"oauth_auth_location"`
OAuthProviderTokenLocation string `ini:"oauth_token_location"`
OAuthProviderInspectLocation string `ini:"oauth_inspect_location"`
OAuthClientCallbackLocation string `ini:"oauth_callback_location"`
OAuthClientID string `ini:"oauth_client_id"`
OAuthClientSecret string `ini:"oauth_client_secret"`
// Defaults // Defaults
DefaultVisibility string `ini:"default_visibility"` DefaultVisibility string `ini:"default_visibility"`
OAuth OAuthCfg `ini:"oauth"`
} }
// Config holds the complete configuration for running a writefreely instance // Config holds the complete configuration for running a writefreely instance

View File

@ -125,10 +125,10 @@ type writestore interface {
GetUserLastPostTime(id int64) (*time.Time, error) GetUserLastPostTime(id int64) (*time.Time, error)
GetCollectionLastPostTime(id int64) (*time.Time, error) GetCollectionLastPostTime(id int64) (*time.Time, error)
GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) GetIDForRemoteUser(context.Context, int64) (int64, error)
RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error RecordRemoteUserID(context.Context, int64, int64) error
ValidateOAuthState(ctx context.Context, state string) error ValidateOAuthState(context.Context, string, string, string) error
GenerateOAuthState(ctx context.Context) (string, error) GenerateOAuthState(context.Context, string, string) (string, error)
DatabaseInitialized() bool DatabaseInitialized() bool
} }
@ -138,6 +138,8 @@ type datastore struct {
driverName string driverName string
} }
var _ writestore = &datastore{}
func (db *datastore) now() string { func (db *datastore) now() string {
if db.driverName == driverSQLite { if db.driverName == driverSQLite {
return "strftime('%Y-%m-%d %H:%M:%S','now')" return "strftime('%Y-%m-%d %H:%M:%S','now')"
@ -2459,17 +2461,17 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
return &t, nil return &t, nil
} }
func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) { func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) {
state := store.Generate62RandomString(24) state := store.Generate62RandomString(24)
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, used, created_at) VALUES (?, FALSE, NOW())", state) _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, NOW())", state, provider, clientID)
if err != nil { if err != nil {
return "", fmt.Errorf("unable to record oauth client state: %w", err) return "", fmt.Errorf("unable to record oauth client state: %w", err)
} }
return state, nil return state, nil
} }
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) error { func (db *datastore) ValidateOAuthState(ctx context.Context, state, provider, clientID string) error {
res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ?", state) res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ? AND provider = ? AND client_id = ?", state, provider, clientID)
if err != nil { if err != nil {
return err return err
} }

View File

@ -18,7 +18,7 @@ func TestOAuthDatastore(t *testing.T) {
driverName: "", driverName: "",
} }
state, err := ds.GenerateOAuthState(ctx) state, err := ds.GenerateOAuthState(ctx, "", "")
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, state, 24) assert.Len(t, state, 24)

1
go.mod
View File

@ -19,6 +19,7 @@ require (
github.com/guregu/null v3.4.0+incompatible github.com/guregu/null v3.4.0+incompatible
github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2 github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2
github.com/jtolds/gls v4.2.1+incompatible // indirect github.com/jtolds/gls v4.2.1+incompatible // indirect
github.com/kr/pretty v0.1.0
github.com/kylemcc/twitter-text-go v0.0.0-20180726194232-7f582f6736ec github.com/kylemcc/twitter-text-go v0.0.0-20180726194232-7f582f6736ec
github.com/lunixbochs/vtclean v1.0.0 // indirect github.com/lunixbochs/vtclean v1.0.0 // indirect
github.com/manifoldco/promptui v0.3.2 github.com/manifoldco/promptui v0.3.2

1
go.sum
View File

@ -64,6 +64,7 @@ github.com/jtolds/gls v4.2.1+incompatible h1:fSuqC+Gmlu6l/ZYAoZzx2pyucC8Xza35fpR
github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a h1:FaWFmfWdAUKbSCtOU2QjDaorUexogfaMgbipgYATUMU= github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a h1:FaWFmfWdAUKbSCtOU2QjDaorUexogfaMgbipgYATUMU=
github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a/go.mod h1:UJSiEoRfvx3hP73CvoARgeLjaIOjybY9vj8PUPPFGeU= github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a/go.mod h1:UJSiEoRfvx3hP73CvoARgeLjaIOjybY9vj8PUPPFGeU=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=

181
oauth.go
View File

@ -2,14 +2,17 @@ package writefreely
import ( import (
"context" "context"
"encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gorilla/mux"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/guregu/null/zero" "github.com/guregu/null/zero"
"github.com/writeas/nerds/store" "github.com/writeas/nerds/store"
"github.com/writeas/web-core/auth" "github.com/writeas/web-core/auth"
"github.com/writeas/web-core/log" "github.com/writeas/web-core/log"
"github.com/writeas/writefreely/config" "github.com/writeas/writefreely/config"
"hash/fnv"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -55,11 +58,12 @@ type OAuthDatastoreProvider interface {
// OAuthDatastore provides a minimal interface of data store methods used in // OAuthDatastore provides a minimal interface of data store methods used in
// oauth functionality. // oauth functionality.
type OAuthDatastore interface { type OAuthDatastore interface {
GenerateOAuthState(context.Context) (string, error)
ValidateOAuthState(context.Context, string) error
GetIDForRemoteUser(context.Context, int64) (int64, error) GetIDForRemoteUser(context.Context, int64) (int64, error)
CreateUser(*config.Config, *User, string) error
RecordRemoteUserID(context.Context, int64, int64) error RecordRemoteUserID(context.Context, int64, int64) error
ValidateOAuthState(context.Context, string, string, string) error
GenerateOAuthState(context.Context, string, string) (string, error)
CreateUser(*config.Config, *User, string) error
GetUserForAuthByID(int64) (*User, error) GetUserForAuthByID(int64) (*User, error)
} }
@ -75,8 +79,8 @@ type oauthHandler struct {
} }
// buildAuthURL returns a URL used to initiate authentication. // buildAuthURL returns a URL used to initiate authentication.
func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation, callbackURL string) (string, error) { func buildAuthURL(db OAuthDatastore, ctx context.Context, provider, clientID, authLocation, callbackURL string) (string, error) {
state, err := db.GenerateOAuthState(ctx) state, err := db.GenerateOAuthState(ctx, provider, clientID)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -95,9 +99,8 @@ func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation
return u.String(), nil return u.String(), nil
} }
// app *App, w http.ResponseWriter, r *http.Request func (h oauthHandler) viewOauthInitWriteAs(w http.ResponseWriter, r *http.Request) {
func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) { location, err := buildAuthURL(h.DB, r.Context(), "write.as", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation)
location, err := buildAuthURL(h.DB, r.Context(), h.Config.App.OAuthClientID, h.Config.App.OAuthProviderAuthLocation, h.Config.App.OAuthClientCallbackLocation)
if err != nil { if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url") failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
return return
@ -105,94 +108,128 @@ func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, location, http.StatusTemporaryRedirect) http.Redirect(w, r, location, http.StatusTemporaryRedirect)
} }
func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) { func (h oauthHandler) viewOauthInitSlack(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() location, err := buildAuthURL(h.DB, r.Context(), "slack", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation)
code := r.FormValue("code")
state := r.FormValue("state")
err := h.DB.ValidateOAuthState(ctx, state)
if err != nil { if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
return return
} }
http.Redirect(w, r, location, http.StatusTemporaryRedirect)
}
tokenResponse, err := h.exchangeOauthCode(ctx, code) func (h oauthHandler) configureRoutes(r *mux.Router) {
if err != nil { if h.Config.App.OAuth.Enabled {
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) if h.Config.App.OAuth.WriteAsClientID != "" {
return callbackHash := oauthProviderHash("write.as", h.Config.App.OAuth.WriteAsClientID)
} log.InfoLog.Println("write.as oauth callback URL", "/oauth/callback/"+callbackHash)
r.HandleFunc("/oauth/write.as", h.viewOauthInitWriteAs).Methods("GET")
// Now that we have the access token, let's use it real quick to make sur r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("write.as", h.Config.App.OAuth.WriteAsClientID)).Methods("GET")
// it really really works.
tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
fmt.Println("local user id", localUserID)
if localUserID == -1 {
// We don't have, nor do we want, the password from the origin, so we
//create a random string. If the user needs to set a password, they
//can do so through the settings page or through the password reset
//flow.
randPass := store.Generate62RandomString(14)
hashedPass, err := auth.HashPass([]byte(randPass))
if err != nil {
log.ErrorLog.Println(err)
failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash")
return
} }
newUser := &User{ if h.Config.App.OAuth.SlackClientID != "" {
Username: tokenInfo.Username, callbackHash := oauthProviderHash("slack", h.Config.App.OAuth.SlackClientID)
HashedPass: hashedPass, log.InfoLog.Println("slack oauth callback URL", "/oauth/callback/"+callbackHash)
HasPass: true, r.HandleFunc("/oauth/slack", h.viewOauthInitSlack).Methods("GET")
Email: zero.NewString("", tokenInfo.Email != ""), r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("slack", h.Config.App.OAuth.SlackClientID)).Methods("GET")
Created: time.Now().Truncate(time.Second).UTC(),
} }
}
err = h.DB.CreateUser(h.Config, newUser, newUser.Username) }
func oauthProviderHash(provider, clientID string) string {
hasher := fnv.New32()
return hex.EncodeToString(hasher.Sum([]byte(provider + clientID)))
}
func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
code := r.FormValue("code")
state := r.FormValue("state")
err := h.DB.ValidateOAuthState(ctx, state, provider, clientID)
if err != nil { if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return return
} }
err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID) tokenResponse, err := h.exchangeOauthCode(ctx, code)
if err != nil { if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return return
} }
if err := loginOrFail(h.Store, w, r, newUser); err != nil { // Now that we have the access token, let's use it real quick to make sur
// it really really works.
tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
fmt.Println("local user id", localUserID)
if localUserID == -1 {
// We don't have, nor do we want, the password from the origin, so we
//create a random string. If the user needs to set a password, they
//can do so through the settings page or through the password reset
//flow.
randPass := store.Generate62RandomString(14)
hashedPass, err := auth.HashPass([]byte(randPass))
if err != nil {
log.ErrorLog.Println(err)
failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash")
return
}
newUser := &User{
Username: tokenInfo.Username,
HashedPass: hashedPass,
HasPass: true,
Email: zero.NewString("", tokenInfo.Email != ""),
Created: time.Now().Truncate(time.Second).UTC(),
}
err = h.DB.CreateUser(h.Config, newUser, newUser.Username)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
if err := loginOrFail(h.Store, w, r, newUser); err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
}
return
}
user, err := h.DB.GetUserForAuthByID(localUserID)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
if err = loginOrFail(h.Store, w, r, user); err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) failOAuthRequest(w, http.StatusInternalServerError, err.Error())
} }
return
}
user, err := h.DB.GetUserForAuthByID(localUserID)
if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return
}
if err = loginOrFail(h.Store, w, r, user); err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error())
} }
} }
func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{} form := url.Values{}
form.Add("grant_type", "authorization_code") form.Add("grant_type", "authorization_code")
form.Add("redirect_uri", h.Config.App.OAuthClientCallbackLocation) form.Add("redirect_uri", h.Config.App.OAuth.WriteAsClientCallbackLocation)
form.Add("code", code) form.Add("code", code)
req, err := http.NewRequest("POST", h.Config.App.OAuthProviderTokenLocation, strings.NewReader(form.Encode())) req, err := http.NewRequest("POST", h.Config.App.OAuth.WriteAsProviderTokenLocation, strings.NewReader(form.Encode()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -200,7 +237,7 @@ func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*Toke
req.Header.Set("User-Agent", "writefreely") req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json") req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(h.Config.App.OAuthClientID, h.Config.App.OAuthClientSecret) req.SetBasicAuth(h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsClientSecret)
resp, err := h.HttpClient.Do(req) resp, err := h.HttpClient.Do(req)
if err != nil { if err != nil {
@ -224,7 +261,7 @@ func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*Toke
} }
func (h oauthHandler) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { func (h oauthHandler) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", h.Config.App.OAuthProviderInspectLocation, nil) req, err := http.NewRequest("GET", h.Config.App.OAuth.WriteAsProviderInspectLocation, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -86,9 +86,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
DB: apper.App().DB(), DB: apper.App().DB(),
Store: apper.App().SessionStore(), Store: apper.App().SessionStore(),
} }
oauthHandler.configureRoutes(write)
write.HandleFunc("/oauth/write.as", oauthHandler.viewOauthInit).Methods("GET")
write.HandleFunc("/oauth/callback", oauthHandler.viewOauthCallback).Methods("GET")
// Handle logged in user sections // Handle logged in user sections
me := write.PathPrefix("/me").Subrouter() me := write.PathPrefix("/me").Subrouter()