Feature complete on MVP slack auth integration. T710

This commit is contained in:
Nick Gerakines 2019-12-28 15:15:47 -05:00
parent 13121cb266
commit 462f87919a
14 changed files with 664 additions and 300 deletions

View File

@ -56,23 +56,18 @@ type (
Port int `ini:"port"` Port int `ini:"port"`
} }
OAuthCfg struct { WriteAsOauthCfg struct {
Enabled bool `ini:"enable"` ClientID string `ini:"client_id"`
ClientSecret string `ini:"client_secret"`
AuthLocation string `ini:"auth_location"`
TokenLocation string `ini:"token_location"`
InspectLocation string `ini:"inspect_location"`
}
// write.as SlackOauthCfg struct {
WriteAsProviderAuthLocation string `ini:"wa_auth_location"` ClientID string `ini:"client_id"`
WriteAsProviderTokenLocation string `ini:"wa_token_location"` ClientSecret string `ini:"client_secret"`
WriteAsProviderInspectLocation string `ini:"wa_inspect_location"` TeamID string `ini:"team_id"`
WriteAsClientCallbackLocation string `ini:"wa_callback_location"`
WriteAsClientID string `ini:"wa_client_id"`
WriteAsClientSecret string `ini:"wa_client_secret"`
WriteAsAuthLocation string
// slack
SlackClientID string `ini:"slack_client_id"`
SlackClientSecret string `ini:"slack_client_secret"`
SlackTeamID string `init:"slack_team_id"`
SlackAuthLocation string
} }
// AppCfg holds values that affect how the application functions // AppCfg holds values that affect how the application functions
@ -113,8 +108,6 @@ type (
// Defaults // Defaults
DefaultVisibility string `ini:"default_visibility"` DefaultVisibility string `ini:"default_visibility"`
OAuth OAuthCfg `ini:"oauth"`
} }
// Config holds the complete configuration for running a writefreely instance // Config holds the complete configuration for running a writefreely instance
@ -122,6 +115,8 @@ type (
Server ServerCfg `ini:"server"` Server ServerCfg `ini:"server"`
Database DatabaseCfg `ini:"database"` Database DatabaseCfg `ini:"database"`
App AppCfg `ini:"app"` App AppCfg `ini:"app"`
SlackOauth SlackOauthCfg `ini:"oauth.slack"`
WriteAsOauth WriteAsOauthCfg `ini:"oauth.writeas"`
} }
) )

View File

@ -14,6 +14,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
wf_db "github.com/writeas/writefreely/db"
"net/http" "net/http"
"strings" "strings"
"time" "time"
@ -125,9 +126,9 @@ type writestore interface {
GetUserLastPostTime(id int64) (*time.Time, error) GetUserLastPostTime(id int64) (*time.Time, error)
GetCollectionLastPostTime(id int64) (*time.Time, error) GetCollectionLastPostTime(id int64) (*time.Time, error)
GetIDForRemoteUser(context.Context, int64) (int64, error) GetIDForRemoteUser(context.Context, string) (int64, error)
RecordRemoteUserID(context.Context, int64, int64) error RecordRemoteUserID(context.Context, int64, string) error
ValidateOAuthState(context.Context, string, string, string) error ValidateOAuthState(context.Context, string) (string, string, error)
GenerateOAuthState(context.Context, string, string) (string, error) GenerateOAuthState(context.Context, string, string) (string, error)
DatabaseInitialized() bool DatabaseInitialized() bool
@ -2470,8 +2471,16 @@ func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID
return state, nil return state, nil
} }
func (db *datastore) ValidateOAuthState(ctx context.Context, state, provider, clientID string) error { func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) {
res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ? AND provider = ? AND client_id = ?", state, provider, clientID) var provider string
var clientID string
err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
err := tx.QueryRow("SELECT provider, client_id FROM oauth_client_state WHERE state = ? AND used = FALSE", state).Scan(&provider, &clientID)
if err != nil {
return err
}
res, err := tx.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ?", state)
if err != nil { if err != nil {
return err return err
} }
@ -2483,9 +2492,14 @@ func (db *datastore) ValidateOAuthState(ctx context.Context, state, provider, cl
return fmt.Errorf("state not found") return fmt.Errorf("state not found")
} }
return nil return nil
})
if err != nil {
return "", "", nil
}
return provider, clientID, nil
} }
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error { func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID string) error {
var err error var err error
if db.driverName == driverSQLite { if db.driverName == driverSQLite {
_, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO users_oauth (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID) _, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO users_oauth (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID)
@ -2499,7 +2513,7 @@ func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remote
} }
// GetIDForRemoteUser returns a user ID associated with a remote user ID. // GetIDForRemoteUser returns a user ID associated with a remote user ID.
func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) { func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID string) (int64, error) {
var userID int64 = -1 var userID int64 = -1
err := db. err := db.
QueryRowContext(ctx, "SELECT user_id FROM users_oauth WHERE remote_user_id = ?", remoteUserID). QueryRowContext(ctx, "SELECT user_id FROM users_oauth WHERE remote_user_id = ?", remoteUserID).

View File

@ -18,19 +18,19 @@ func TestOAuthDatastore(t *testing.T) {
driverName: "", driverName: "",
} }
state, err := ds.GenerateOAuthState(ctx, "", "") state, err := ds.GenerateOAuthState(ctx, "test", "development")
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, state, 24) assert.Len(t, state, 24)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = false", state) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = false", state)
err = ds.ValidateOAuthState(ctx, state) _, _, err = ds.ValidateOAuthState(ctx, state)
assert.NoError(t, err) assert.NoError(t, err)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = true", state) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = true", state)
var localUserID int64 = 99 var localUserID int64 = 99
var remoteUserID int64 = 100 var remoteUserID = "100"
err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID) err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID)
assert.NoError(t, err) assert.NoError(t, err)

40
db/alter.go Normal file
View File

@ -0,0 +1,40 @@
package db
import (
"fmt"
"strings"
)
type AlterTableSqlBuilder struct {
Dialect DialectType
Name string
Changes []string
}
func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder {
if colVal, err := col.String(); err == nil {
b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal))
}
return b
}
func (b *AlterTableSqlBuilder) ToSQL() (string, error) {
var str strings.Builder
str.WriteString("ALTER TABLE ")
str.WriteString(b.Name)
str.WriteString(" ")
if len(b.Changes) == 0 {
return "", fmt.Errorf("no changes provide for table: %s", b.Name)
}
changeCount := len(b.Changes)
for i, thing := range b.Changes {
str.WriteString(thing)
if i < changeCount-1 {
str.WriteString(", ")
}
}
return str.String(), nil
}

56
db/alter_test.go Normal file
View File

@ -0,0 +1,56 @@
package db
import "testing"
func TestAlterTableSqlBuilder_ToSQL(t *testing.T) {
type fields struct {
Dialect DialectType
Name string
Changes []string
}
tests := []struct {
name string
builder *AlterTableSqlBuilder
want string
wantErr bool
}{
{
name: "MySQL add int",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)),
want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL",
wantErr: false,
},
{
name: "MySQL add string",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})),
want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL",
wantErr: false,
},
{
name: "MySQL add int and string",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)).
AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})),
want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.builder.ToSQL()
if (err != nil) != tt.wantErr {
t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ToSQL() got = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -5,7 +5,6 @@ import (
"strings" "strings"
) )
type DialectType int
type ColumnType int type ColumnType int
type OptionalInt struct { type OptionalInt struct {
@ -41,11 +40,6 @@ type CreateTableSqlBuilder struct {
Constraints []string Constraints []string
} }
const (
DialectSQLite DialectType = iota
DialectMySQL DialectType = iota
)
const ( const (
ColumnTypeBool ColumnType = iota ColumnTypeBool ColumnType = iota
ColumnTypeSmallInt ColumnType = iota ColumnTypeSmallInt ColumnType = iota
@ -61,28 +55,6 @@ var _ SQLBuilder = &CreateTableSqlBuilder{}
var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0} var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0}
var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""} var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""}
func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column {
switch d {
case DialectSQLite:
return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size}
case DialectMySQL:
return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) Table(name string) *CreateTableSqlBuilder {
switch d {
case DialectSQLite:
return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) { func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) {
if dialect != DialectMySQL && dialect != DialectSQLite { if dialect != DialectMySQL && dialect != DialectSQLite {
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
@ -269,3 +241,4 @@ func (b *CreateTableSqlBuilder) ToSQL() (string, error) {
return str.String(), nil return str.String(), nil
} }

43
db/dialect.go Normal file
View File

@ -0,0 +1,43 @@
package db
import "fmt"
type DialectType int
const (
DialectSQLite DialectType = iota
DialectMySQL DialectType = iota
)
func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column {
switch d {
case DialectSQLite:
return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size}
case DialectMySQL:
return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) Table(name string) *CreateTableSqlBuilder {
switch d {
case DialectSQLite:
return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder {
switch d {
case DialectSQLite:
return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}

View File

@ -60,6 +60,7 @@ var migrations = []Migration{
New("support dynamic instance pages", supportInstancePages), // V1 -> V2 (v0.9.0) New("support dynamic instance pages", supportInstancePages), // V1 -> V2 (v0.9.0)
New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0) New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0)
New("support oauth", oauth), // V3 -> V4 New("support oauth", oauth), // V3 -> V4
New("support slack oauth", oauth_slack), // V4 -> v5
} }
// CurrentVer returns the current migration version the application is on // CurrentVer returns the current migration version the application is on

41
migrations/v5.go Normal file
View File

@ -0,0 +1,41 @@
package migrations
import (
"context"
"database/sql"
wf_db "github.com/writeas/writefreely/db"
)
func oauth_slack(db *datastore) error {
dialect := wf_db.DialectMySQL
if db.driverName == driverSQLite {
dialect = wf_db.DialectSQLite
}
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_state").
AddColumn(dialect.
Column(
"provider",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 24,})).
AddColumn(dialect.
Column(
"client_id",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 128,})),
}
for _, builder := range builders {
query, err := builder.ToSQL()
if err != nil {
return err
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return err
}
}
return nil
})
}

186
oauth.go
View File

@ -2,7 +2,6 @@ package writefreely
import ( import (
"context" "context"
"encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -10,14 +9,10 @@ import (
"github.com/guregu/null/zero" "github.com/guregu/null/zero"
"github.com/writeas/nerds/store" "github.com/writeas/nerds/store"
"github.com/writeas/web-core/auth" "github.com/writeas/web-core/auth"
"github.com/writeas/web-core/log"
"github.com/writeas/writefreely/config" "github.com/writeas/writefreely/config"
"hash/fnv"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"strings"
"time" "time"
) )
@ -33,7 +28,7 @@ type TokenResponse struct {
// InspectResponse contains data returned when an access token is inspected. // InspectResponse contains data returned when an access token is inspected.
type InspectResponse struct { type InspectResponse struct {
ClientID string `json:"client_id"` ClientID string `json:"client_id"`
UserID int64 `json:"user_id"` UserID string `json:"user_id"`
ExpiresAt time.Time `json:"expires_at"` ExpiresAt time.Time `json:"expires_at"`
Username string `json:"username"` Username string `json:"username"`
Email string `json:"email"` Email string `json:"email"`
@ -58,9 +53,9 @@ type OAuthDatastoreProvider interface {
// OAuthDatastore provides a minimal interface of data store methods used in // OAuthDatastore provides a minimal interface of data store methods used in
// oauth functionality. // oauth functionality.
type OAuthDatastore interface { type OAuthDatastore interface {
GetIDForRemoteUser(context.Context, int64) (int64, error) GetIDForRemoteUser(context.Context, string) (int64, error)
RecordRemoteUserID(context.Context, int64, int64) error RecordRemoteUserID(context.Context, int64, string) error
ValidateOAuthState(context.Context, string, string, string) error ValidateOAuthState(context.Context, string) (string, string, error)
GenerateOAuthState(context.Context, string, string) (string, error) GenerateOAuthState(context.Context, string, string) (string, error)
CreateUser(*config.Config, *User, string) error CreateUser(*config.Config, *User, string) error
@ -71,36 +66,28 @@ type HttpClient interface {
Do(req *http.Request) (*http.Response, error) Do(req *http.Request) (*http.Response, error)
} }
type oauthClient interface {
GetProvider() string
GetClientID() string
buildLoginURL(state string) (string, error)
exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error)
inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error)
}
type oauthHandler struct { type oauthHandler struct {
Config *config.Config Config *config.Config
DB OAuthDatastore DB OAuthDatastore
Store sessions.Store Store sessions.Store
HttpClient HttpClient oauthClient oauthClient
} }
// buildAuthURL returns a URL used to initiate authentication. func (h oauthHandler) viewOauthInit(w http.ResponseWriter, r *http.Request) {
func buildAuthURL(db OAuthDatastore, ctx context.Context, provider, clientID, authLocation, callbackURL string) (string, error) { ctx := r.Context()
state, err := db.GenerateOAuthState(ctx, provider, clientID) state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID())
if err != nil { if err != nil {
return "", err failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
} }
location, err := h.oauthClient.buildLoginURL(state)
u, err := url.Parse(authLocation)
if err != nil {
return "", err
}
q := u.Query()
q.Set("client_id", clientID)
q.Set("redirect_uri", callbackURL)
q.Set("response_type", "code")
q.Set("state", state)
u.RawQuery = q.Encode()
return u.String(), nil
}
func (h oauthHandler) viewOauthInitWriteAs(w http.ResponseWriter, r *http.Request) {
location, err := buildAuthURL(h.DB, r.Context(), "write.as", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation)
if err != nil { if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url") failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url")
return return
@ -108,52 +95,58 @@ func (h oauthHandler) viewOauthInitWriteAs(w http.ResponseWriter, r *http.Reques
http.Redirect(w, r, location, http.StatusTemporaryRedirect) http.Redirect(w, r, location, http.StatusTemporaryRedirect)
} }
func (h oauthHandler) viewOauthInitSlack(w http.ResponseWriter, r *http.Request) { func configureSlackOauth(r *mux.Router, app *App) {
location, err := buildAuthURL(h.DB, r.Context(), "slack", h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsProviderAuthLocation, h.Config.App.OAuth.WriteAsClientCallbackLocation) if app.Config().SlackOauth.ClientID != "" {
if err != nil { oauthClient := slackOauthClient{
failOAuthRequest(w, http.StatusInternalServerError, "could not prepare oauth redirect url") ClientID: app.Config().SlackOauth.ClientID,
return ClientSecret: app.Config().SlackOauth.ClientSecret,
TeamID: app.Config().SlackOauth.TeamID,
CallbackLocation: app.Config().App.Host + "/oauth/callback",
HttpClient: &http.Client{Timeout: 10 * time.Second},
}
configureOauthRoutes(r, app, oauthClient)
} }
http.Redirect(w, r, location, http.StatusTemporaryRedirect)
} }
func (h oauthHandler) configureRoutes(r *mux.Router) { func configureWriteAsOauth(r *mux.Router, app *App) {
if h.Config.App.OAuth.Enabled { if app.Config().WriteAsOauth.ClientID != "" {
if h.Config.App.OAuth.WriteAsClientID != "" { oauthClient := writeAsOauthClient{
callbackHash := oauthProviderHash("write.as", h.Config.App.OAuth.WriteAsClientID) ClientID: app.Config().WriteAsOauth.ClientID,
log.InfoLog.Println("write.as oauth callback URL", "/oauth/callback/"+callbackHash) ClientSecret: app.Config().WriteAsOauth.ClientSecret,
r.HandleFunc("/oauth/write.as", h.viewOauthInitWriteAs).Methods("GET") ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("write.as", h.Config.App.OAuth.WriteAsClientID)).Methods("GET") InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
HttpClient: &http.Client{Timeout: 10 * time.Second},
CallbackLocation: app.Config().App.Host + "/oauth/callback",
} }
if h.Config.App.OAuth.SlackClientID != "" { configureOauthRoutes(r, app, oauthClient)
callbackHash := oauthProviderHash("slack", h.Config.App.OAuth.SlackClientID)
log.InfoLog.Println("slack oauth callback URL", "/oauth/callback/"+callbackHash)
r.HandleFunc("/oauth/slack", h.viewOauthInitSlack).Methods("GET")
r.HandleFunc("/oauth/callback/"+callbackHash, h.viewOauthCallback("slack", h.Config.App.OAuth.SlackClientID)).Methods("GET")
} }
}
} }
func oauthProviderHash(provider, clientID string) string { func configureOauthRoutes(r *mux.Router, app *App, oauthClient oauthClient) {
hasher := fnv.New32() handler := &oauthHandler{
return hex.EncodeToString(hasher.Sum([]byte(provider + clientID))) Config: app.Config(),
DB: app.DB(),
Store: app.SessionStore(),
oauthClient: oauthClient,
}
r.HandleFunc("/oauth/"+oauthClient.GetProvider(), handler.viewOauthInit).Methods("GET")
r.HandleFunc("/oauth/callback", handler.viewOauthCallback).Methods("GET")
} }
func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerFunc { func (h oauthHandler) viewOauthCallback(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
code := r.FormValue("code") code := r.FormValue("code")
state := r.FormValue("state") state := r.FormValue("state")
err := h.DB.ValidateOAuthState(ctx, state, provider, clientID) _, _, err := h.DB.ValidateOAuthState(ctx, state)
if err != nil { if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return return
} }
tokenResponse, err := h.exchangeOauthCode(ctx, code) tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code)
if err != nil { if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return return
@ -161,7 +154,7 @@ func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerF
// Now that we have the access token, let's use it real quick to make sur // Now that we have the access token, let's use it real quick to make sur
// it really really works. // it really really works.
tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
if err != nil { if err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) failOAuthRequest(w, http.StatusInternalServerError, err.Error())
return return
@ -173,8 +166,6 @@ func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerF
return return
} }
fmt.Println("local user id", localUserID)
if localUserID == -1 { if localUserID == -1 {
// We don't have, nor do we want, the password from the origin, so we // We don't have, nor do we want, the password from the origin, so we
//create a random string. If the user needs to set a password, they //create a random string. If the user needs to set a password, they
@ -183,7 +174,6 @@ func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerF
randPass := store.Generate62RandomString(14) randPass := store.Generate62RandomString(14)
hashedPass, err := auth.HashPass([]byte(randPass)) hashedPass, err := auth.HashPass([]byte(randPass))
if err != nil { if err != nil {
log.ErrorLog.Println(err)
failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash") failOAuthRequest(w, http.StatusInternalServerError, "unable to create password hash")
return return
} }
@ -191,7 +181,7 @@ func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerF
Username: tokenInfo.Username, Username: tokenInfo.Username,
HashedPass: hashedPass, HashedPass: hashedPass,
HasPass: true, HasPass: true,
Email: zero.NewString("", tokenInfo.Email != ""), Email: zero.NewString(tokenInfo.Email, tokenInfo.Email != ""),
Created: time.Now().Truncate(time.Second).UTC(), Created: time.Now().Truncate(time.Second).UTC(),
} }
@ -221,74 +211,18 @@ func (h oauthHandler) viewOauthCallback(provider, clientID string) http.HandlerF
if err = loginOrFail(h.Store, w, r, user); err != nil { if err = loginOrFail(h.Store, w, r, user); err != nil {
failOAuthRequest(w, http.StatusInternalServerError, err.Error()) failOAuthRequest(w, http.StatusInternalServerError, err.Error())
} }
}
} }
func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error {
form := url.Values{} lr := io.LimitReader(body, int64(n+1))
form.Add("grant_type", "authorization_code") data, err := ioutil.ReadAll(lr)
form.Add("redirect_uri", h.Config.App.OAuth.WriteAsClientCallbackLocation)
form.Add("code", code)
req, err := http.NewRequest("POST", h.Config.App.OAuth.WriteAsProviderTokenLocation, strings.NewReader(form.Encode()))
if err != nil { if err != nil {
return nil, err return err
} }
req.WithContext(ctx) if len(data) == n+1 {
req.Header.Set("User-Agent", "writefreely") return fmt.Errorf("content larger than max read allowance: %d", n)
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(h.Config.App.OAuth.WriteAsClientID, h.Config.App.OAuth.WriteAsClientSecret)
resp, err := h.HttpClient.Do(req)
if err != nil {
return nil, err
} }
return json.Unmarshal(data, thing)
// Nick: I like using limited readers to reduce the risk of an endpoint
// being broken or compromised.
lr := io.LimitReader(resp.Body, tokenRequestMaxLen)
body, err := ioutil.ReadAll(lr)
if err != nil {
return nil, err
}
var tokenResponse TokenResponse
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return nil, err
}
return &tokenResponse, nil
}
func (h oauthHandler) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", h.Config.App.OAuth.WriteAsProviderInspectLocation, nil)
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := h.HttpClient.Do(req)
if err != nil {
return nil, err
}
// Nick: I like using limited readers to reduce the risk of an endpoint
// being broken or compromised.
lr := io.LimitReader(resp.Body, infoRequestMaxLen)
body, err := ioutil.ReadAll(lr)
if err != nil {
return nil, err
}
var inspectResponse InspectResponse
err = json.Unmarshal(body, &inspectResponse)
if err != nil {
return nil, err
}
return &inspectResponse, nil
} }
func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error { func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error {

148
oauth_slack.go Normal file
View File

@ -0,0 +1,148 @@
package writefreely
import (
"context"
"github.com/writeas/slug"
"net/http"
"net/url"
"strings"
)
type slackOauthClient struct {
ClientID string
ClientSecret string
TeamID string
CallbackLocation string
HttpClient HttpClient
}
type slackExchangeResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TeamName string `json:"team_name"`
TeamID string `json:"team_id"`
}
type slackIdentity struct {
Name string `json:"name"`
ID string `json:"id"`
Email string `json:"email"`
}
type slackTeam struct {
Name string `json:"name"`
ID string `json:"id"`
}
type slackUserIdentityResponse struct {
OK bool `json:"ok"`
User slackIdentity `json:"user"`
Team slackTeam `json:"team"`
Error string `json:"error"`
}
const (
slackAuthLocation = "https://slack.com/oauth/authorize"
slackExchangeLocation = "https://slack.com/api/oauth.access"
slackIdentityLocation = "https://slack.com/api/users.identity"
)
var _ oauthClient = slackOauthClient{}
func (c slackOauthClient) GetProvider() string {
return "slack"
}
func (c slackOauthClient) GetClientID() string {
return c.ClientID
}
func (c slackOauthClient) buildLoginURL(state string) (string, error) {
u, err := url.Parse(slackAuthLocation)
if err != nil {
return "", err
}
q := u.Query()
q.Set("client_id", c.ClientID)
q.Set("scope", "identity.basic identity.email identity.team")
q.Set("redirect_uri", c.CallbackLocation)
q.Set("state", state)
// If this param is not set, the user can select which team they
// authenticate through and then we'd have to match the configured team
// against the profile get. That is extra work in the post-auth phase
// that we don't want to do.
q.Set("team", c.TeamID)
// The Slack OAuth docs don't explicitly list this one, but it is part of
// the spec, so we include it anyway.
q.Set("response_type", "code")
u.RawQuery = q.Encode()
return u.String(), nil
}
func (c slackOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{}
// The oauth.access documentation doesn't explicitly mention this
// parameter, but it is part of the spec, so we include it anyway.
// https://api.slack.com/methods/oauth.access
form.Add("grant_type", "authorization_code")
form.Add("redirect_uri", c.CallbackLocation)
form.Add("code", code)
req, err := http.NewRequest("POST", slackExchangeLocation, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(c.ClientID, c.ClientSecret)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
var tokenResponse slackExchangeResponse
if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil {
return nil, err
}
return tokenResponse.TokenResponse(), nil
}
func (c slackOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", slackIdentityLocation, nil)
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
var inspectResponse slackUserIdentityResponse
if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil {
return nil, err
}
return inspectResponse.InspectResponse(), nil
}
func (resp slackUserIdentityResponse) InspectResponse() *InspectResponse {
return &InspectResponse{
UserID: resp.User.ID,
Username: slug.Make(resp.User.Name),
Email: resp.User.Email,
}
}
func (resp slackExchangeResponse) TokenResponse() *TokenResponse {
return &TokenResponse{
AccessToken: resp.AccessToken,
}
}

View File

@ -21,14 +21,16 @@ type MockOAuthDatastoreProvider struct {
} }
type MockOAuthDatastore struct { type MockOAuthDatastore struct {
DoGenerateOAuthState func(ctx context.Context) (string, error) DoGenerateOAuthState func(context.Context, string, string) (string, error)
DoValidateOAuthState func(context.Context, string) error DoValidateOAuthState func(context.Context, string) (string, string, error)
DoGetIDForRemoteUser func(context.Context, int64) (int64, error) DoGetIDForRemoteUser func(context.Context, string) (int64, error)
DoCreateUser func(*config.Config, *User, string) error DoCreateUser func(*config.Config, *User, string) error
DoRecordRemoteUserID func(context.Context, int64, int64) error DoRecordRemoteUserID func(context.Context, int64, string) error
DoGetUserForAuthByID func(int64) (*User, error) DoGetUserForAuthByID func(int64) (*User, error)
} }
var _ OAuthDatastore = &MockOAuthDatastore{}
type StringReadCloser struct { type StringReadCloser struct {
*strings.Reader *strings.Reader
} }
@ -68,24 +70,29 @@ func (m *MockOAuthDatastoreProvider) Config() *config.Config {
} }
cfg := config.New() cfg := config.New()
cfg.UseSQLite(true) cfg.UseSQLite(true)
cfg.App.EnableOAuth = true cfg.WriteAsOauth = config.WriteAsOauthCfg{
cfg.App.OAuthProviderAuthLocation = "https://write.as/oauth/login" ClientID: "development",
cfg.App.OAuthProviderTokenLocation = "https://write.as/oauth/token" ClientSecret: "development",
cfg.App.OAuthProviderInspectLocation = "https://write.as/oauth/inspect" AuthLocation: "https://write.as/oauth/login",
cfg.App.OAuthClientCallbackLocation = "http://localhost/oauth/callback" TokenLocation: "https://write.as/oauth/token",
cfg.App.OAuthClientID = "development" InspectLocation: "https://write.as/oauth/inspect",
cfg.App.OAuthClientSecret = "development" }
cfg.SlackOauth = config.SlackOauthCfg{
ClientID: "development",
ClientSecret: "development",
TeamID: "development",
}
return cfg return cfg
} }
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) error { func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) {
if m.DoValidateOAuthState != nil { if m.DoValidateOAuthState != nil {
return m.DoValidateOAuthState(ctx, state) return m.DoValidateOAuthState(ctx, state)
} }
return nil return "", "", nil
} }
func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) { func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID string) (int64, error) {
if m.DoGetIDForRemoteUser != nil { if m.DoGetIDForRemoteUser != nil {
return m.DoGetIDForRemoteUser(ctx, remoteUserID) return m.DoGetIDForRemoteUser(ctx, remoteUserID)
} }
@ -100,7 +107,7 @@ func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username st
return nil return nil
} }
func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID int64) error { func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID string) error {
if m.DoRecordRemoteUserID != nil { if m.DoRecordRemoteUserID != nil {
return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID) return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID)
} }
@ -117,9 +124,9 @@ func (m *MockOAuthDatastore) GetUserForAuthByID(userID int64) (*User, error) {
return user, nil return user, nil
} }
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context) (string, error) { func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string) (string, error) {
if m.DoGenerateOAuthState != nil { if m.DoGenerateOAuthState != nil {
return m.DoGenerateOAuthState(ctx) return m.DoGenerateOAuthState(ctx, provider, clientID)
} }
return store.Generate62RandomString(14), nil return store.Generate62RandomString(14), nil
} }
@ -132,6 +139,15 @@ func TestViewOauthInit(t *testing.T) {
Config: app.Config(), Config: app.Config(),
DB: app.DB(), DB: app.DB(),
Store: app.SessionStore(), Store: app.SessionStore(),
oauthClient:writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
CallbackLocation: "http://localhost/oauth/callback",
HttpClient: nil,
},
} }
req, err := http.NewRequest("GET", "/oauth/client", nil) req, err := http.NewRequest("GET", "/oauth/client", nil)
assert.NoError(t, err) assert.NoError(t, err)
@ -151,7 +167,7 @@ func TestViewOauthInit(t *testing.T) {
app := &MockOAuthDatastoreProvider{ app := &MockOAuthDatastoreProvider{
DoDB: func() OAuthDatastore { DoDB: func() OAuthDatastore {
return &MockOAuthDatastore{ return &MockOAuthDatastore{
DoGenerateOAuthState: func(ctx context.Context) (string, error) { DoGenerateOAuthState: func(ctx context.Context, provider, clientID string) (string, error) {
return "", fmt.Errorf("pretend unable to write state error") return "", fmt.Errorf("pretend unable to write state error")
}, },
} }
@ -161,6 +177,15 @@ func TestViewOauthInit(t *testing.T) {
Config: app.Config(), Config: app.Config(),
DB: app.DB(), DB: app.DB(),
Store: app.SessionStore(), Store: app.SessionStore(),
oauthClient:writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
CallbackLocation: "http://localhost/oauth/callback",
HttpClient: nil,
},
} }
req, err := http.NewRequest("GET", "/oauth/client", nil) req, err := http.NewRequest("GET", "/oauth/client", nil)
assert.NoError(t, err) assert.NoError(t, err)
@ -179,6 +204,13 @@ func TestViewOauthCallback(t *testing.T) {
Config: app.Config(), Config: app.Config(),
DB: app.DB(), DB: app.DB(),
Store: app.SessionStore(), Store: app.SessionStore(),
oauthClient: writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
CallbackLocation: "http://localhost/oauth/callback",
HttpClient: &MockHTTPClient{ HttpClient: &MockHTTPClient{
DoDo: func(req *http.Request) (*http.Response, error) { DoDo: func(req *http.Request) (*http.Response, error) {
switch req.URL.String() { switch req.URL.String() {
@ -190,7 +222,7 @@ func TestViewOauthCallback(t *testing.T) {
case "https://write.as/oauth/inspect": case "https://write.as/oauth/inspect":
return &http.Response{ return &http.Response{
StatusCode: 200, StatusCode: 200,
Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": 1, "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)}, Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)},
}, nil }, nil
} }
@ -199,6 +231,7 @@ func TestViewOauthCallback(t *testing.T) {
}, nil }, nil
}, },
}, },
},
} }
req, err := http.NewRequest("GET", "/oauth/callback", nil) req, err := http.NewRequest("GET", "/oauth/callback", nil)
assert.NoError(t, err) assert.NoError(t, err)
@ -206,6 +239,5 @@ func TestViewOauthCallback(t *testing.T) {
h.viewOauthCallback(rr, req) h.viewOauthCallback(rr, req)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, http.StatusTemporaryRedirect, rr.Code) assert.Equal(t, http.StatusTemporaryRedirect, rr.Code)
}) })
} }

91
oauth_writeas.go Normal file
View File

@ -0,0 +1,91 @@
package writefreely
import (
"context"
"net/http"
"net/url"
"strings"
)
type writeAsOauthClient struct {
ClientID string
ClientSecret string
AuthLocation string
ExchangeLocation string
InspectLocation string
CallbackLocation string
HttpClient HttpClient
}
var _ oauthClient = writeAsOauthClient{}
func (c writeAsOauthClient) GetProvider() string {
return "write.as"
}
func (c writeAsOauthClient) GetClientID() string {
return c.ClientID
}
func (c writeAsOauthClient) buildLoginURL(state string) (string, error) {
u, err := url.Parse(c.AuthLocation)
if err != nil {
return "", err
}
q := u.Query()
q.Set("client_id", c.ClientID)
q.Set("redirect_uri", c.CallbackLocation)
q.Set("response_type", "code")
q.Set("state", state)
u.RawQuery = q.Encode()
return u.String(), nil
}
func (c writeAsOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{}
form.Add("grant_type", "authorization_code")
form.Add("redirect_uri", c.CallbackLocation)
form.Add("code", code)
req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(c.ClientID, c.ClientSecret)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
var tokenResponse TokenResponse
if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil {
return nil, err
}
return &tokenResponse, nil
}
func (c writeAsOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", c.InspectLocation, nil)
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
var inspectResponse InspectResponse
if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil {
return nil, err
}
return &inspectResponse, nil
}

View File

@ -70,6 +70,9 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
write.HandleFunc(nodeinfo.NodeInfoPath, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfoDiscover))) write.HandleFunc(nodeinfo.NodeInfoPath, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfoDiscover)))
write.HandleFunc(niCfg.InfoURL, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfo))) write.HandleFunc(niCfg.InfoURL, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfo)))
configureSlackOauth(write, apper.App())
configureWriteAsOauth(write, apper.App())
// Set up dyamic page handlers // Set up dyamic page handlers
// Handle auth // Handle auth
auth := write.PathPrefix("/api/auth/").Subrouter() auth := write.PathPrefix("/api/auth/").Subrouter()
@ -80,14 +83,6 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
auth.HandleFunc("/read", handler.WebErrors(handleWebCollectionUnlock, UserLevelNone)).Methods("POST") auth.HandleFunc("/read", handler.WebErrors(handleWebCollectionUnlock, UserLevelNone)).Methods("POST")
auth.HandleFunc("/me", handler.All(handleAPILogout)).Methods("DELETE") auth.HandleFunc("/me", handler.All(handleAPILogout)).Methods("DELETE")
oauthHandler := oauthHandler{
HttpClient: &http.Client{},
Config: apper.App().Config(),
DB: apper.App().DB(),
Store: apper.App().SessionStore(),
}
oauthHandler.configureRoutes(write)
// Handle logged in user sections // Handle logged in user sections
me := write.PathPrefix("/me").Subrouter() me := write.PathPrefix("/me").Subrouter()
me.HandleFunc("/", handler.Redirect("/me", UserLevelUser)) me.HandleFunc("/", handler.Redirect("/me", UserLevelUser))
@ -189,6 +184,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
} }
write.HandleFunc(draftEditPrefix+"/{post}", handler.Web(handleViewPost, UserLevelOptional)) write.HandleFunc(draftEditPrefix+"/{post}", handler.Web(handleViewPost, UserLevelOptional))
write.HandleFunc("/", handler.Web(handleViewHome, UserLevelOptional)) write.HandleFunc("/", handler.Web(handleViewHome, UserLevelOptional))
return r return r
} }