Merged T710-oauth-slack into oauth-wrapper.

This commit is contained in:
Nick Gerakines 2020-01-02 16:19:26 -05:00
commit 2aea9560bc
21 changed files with 860 additions and 240 deletions

View File

@ -47,6 +47,12 @@ build-arm7: deps
fi fi
xgo --targets=linux/arm-7, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely xgo --targets=linux/arm-7, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely
build-arm64: deps
@hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \
$(GOGET) -u github.com/karalabe/xgo; \
fi
xgo --targets=linux/arm64, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely
build-docker : build-docker :
$(DOCKERCMD) build -t $(IMAGE_NAME):latest -t $(IMAGE_NAME):$(GITREV) . $(DOCKERCMD) build -t $(IMAGE_NAME):latest -t $(IMAGE_NAME):$(GITREV) .
@ -83,6 +89,10 @@ release : clean ui assets
mv build/$(BINARY_NAME)-linux-arm-7 $(BUILDPATH)/$(BINARY_NAME) mv build/$(BINARY_NAME)-linux-arm-7 $(BUILDPATH)/$(BINARY_NAME)
tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_arm7.tar.gz -C build $(BINARY_NAME) tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_arm7.tar.gz -C build $(BINARY_NAME)
rm $(BUILDPATH)/$(BINARY_NAME) rm $(BUILDPATH)/$(BINARY_NAME)
$(MAKE) build-arm64
mv build/$(BINARY_NAME)-linux-arm64 $(BUILDPATH)/$(BINARY_NAME)
tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_arm64.tar.gz -C build $(BINARY_NAME)
rm $(BUILDPATH)/$(BINARY_NAME)
$(MAKE) build-darwin $(MAKE) build-darwin
mv build/$(BINARY_NAME)-darwin-10.6-amd64 $(BUILDPATH)/$(BINARY_NAME) mv build/$(BINARY_NAME)-darwin-10.6-amd64 $(BUILDPATH)/$(BINARY_NAME)
tar -cvzf $(BINARY_NAME)_$(GITREV)_macos_amd64.tar.gz -C build $(BINARY_NAME) tar -cvzf $(BINARY_NAME)_$(GITREV)_macos_amd64.tar.gz -C build $(BINARY_NAME)

View File

@ -56,6 +56,20 @@ type (
Port int `ini:"port"` Port int `ini:"port"`
} }
WriteAsOauthCfg struct {
ClientID string `ini:"client_id"`
ClientSecret string `ini:"client_secret"`
AuthLocation string `ini:"auth_location"`
TokenLocation string `ini:"token_location"`
InspectLocation string `ini:"inspect_location"`
}
SlackOauthCfg struct {
ClientID string `ini:"client_id"`
ClientSecret string `ini:"client_secret"`
TeamID string `ini:"team_id"`
}
// AppCfg holds values that affect how the application functions // AppCfg holds values that affect how the application functions
AppCfg struct { AppCfg struct {
SiteName string `ini:"site_name"` SiteName string `ini:"site_name"`
@ -92,24 +106,17 @@ type (
LocalTimeline bool `ini:"local_timeline"` LocalTimeline bool `ini:"local_timeline"`
UserInvites string `ini:"user_invites"` UserInvites string `ini:"user_invites"`
// OAuth
EnableOAuth bool `ini:"enable_oauth"`
OAuthProviderAuthLocation string `ini:"oauth_auth_location"`
OAuthProviderTokenLocation string `ini:"oauth_token_location"`
OAuthProviderInspectLocation string `ini:"oauth_inspect_location"`
OAuthClientCallbackLocation string `ini:"oauth_callback_location"`
OAuthClientID string `ini:"oauth_client_id"`
OAuthClientSecret string `ini:"oauth_client_secret"`
// Defaults // Defaults
DefaultVisibility string `ini:"default_visibility"` DefaultVisibility string `ini:"default_visibility"`
} }
// Config holds the complete configuration for running a writefreely instance // Config holds the complete configuration for running a writefreely instance
Config struct { Config struct {
Server ServerCfg `ini:"server"` Server ServerCfg `ini:"server"`
Database DatabaseCfg `ini:"database"` Database DatabaseCfg `ini:"database"`
App AppCfg `ini:"app"` App AppCfg `ini:"app"`
SlackOauth SlackOauthCfg `ini:"oauth.slack"`
WriteAsOauth WriteAsOauthCfg `ini:"oauth.writeas"`
} }
) )

View File

@ -11,7 +11,9 @@
package config package config
import ( import (
"net/http"
"strings" "strings"
"time"
) )
// FriendlyHost returns the app's Host sans any schema // FriendlyHost returns the app's Host sans any schema
@ -25,3 +27,16 @@ func (ac AppCfg) CanCreateBlogs(currentlyUsed uint64) bool {
} }
return int(currentlyUsed) < ac.MaxBlogs return int(currentlyUsed) < ac.MaxBlogs
} }
// OrDefaultString returns input or a default value if input is empty.
func OrDefaultString(input, defaultValue string) string {
if len(input) == 0 {
return defaultValue
}
return input
}
// DefaultHTTPClient returns a sane default HTTP client.
func DefaultHTTPClient() *http.Client {
return &http.Client{Timeout: 10 * time.Second}
}

View File

@ -14,6 +14,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
wf_db "github.com/writeas/writefreely/db"
"net/http" "net/http"
"strings" "strings"
"time" "time"
@ -125,10 +126,10 @@ type writestore interface {
GetUserLastPostTime(id int64) (*time.Time, error) GetUserLastPostTime(id int64) (*time.Time, error)
GetCollectionLastPostTime(id int64) (*time.Time, error) GetCollectionLastPostTime(id int64) (*time.Time, error)
GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error RecordRemoteUserID(context.Context, int64, string, string, string, string) error
ValidateOAuthState(ctx context.Context, state string) error ValidateOAuthState(context.Context, string) (string, string, error)
GenerateOAuthState(ctx context.Context) (string, error) GenerateOAuthState(context.Context, string, string) (string, error)
DatabaseInitialized() bool DatabaseInitialized() bool
} }
@ -138,6 +139,8 @@ type datastore struct {
driverName string driverName string
} }
var _ writestore = &datastore{}
func (db *datastore) now() string { func (db *datastore) now() string {
if db.driverName == driverSQLite { if db.driverName == driverSQLite {
return "strftime('%Y-%m-%d %H:%M:%S','now')" return "strftime('%Y-%m-%d %H:%M:%S','now')"
@ -2459,48 +2462,61 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
return &t, nil return &t, nil
} }
func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) { func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) {
state := store.Generate62RandomString(24) state := store.Generate62RandomString(24)
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, used, created_at) VALUES (?, FALSE, NOW())", state) _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, NOW())", state, provider, clientID)
if err != nil { if err != nil {
return "", fmt.Errorf("unable to record oauth client state: %w", err) return "", fmt.Errorf("unable to record oauth client state: %w", err)
} }
return state, nil return state, nil
} }
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) error { func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) {
res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ?", state) var provider string
var clientID string
err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
err := tx.QueryRow("SELECT provider, client_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).Scan(&provider, &clientID)
if err != nil {
return err
}
res, err := tx.ExecContext(ctx, "UPDATE oauth_client_states SET used = TRUE WHERE state = ?", state)
if err != nil {
return err
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return err
}
if rowsAffected != 1 {
return fmt.Errorf("state not found")
}
return nil
})
if err != nil { if err != nil {
return err return "", "", nil
} }
rowsAffected, err := res.RowsAffected() return provider, clientID, nil
if err != nil {
return err
}
if rowsAffected != 1 {
return fmt.Errorf("state not found")
}
return nil
} }
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error { func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
var err error var err error
if db.driverName == driverSQLite { if db.driverName == driverSQLite {
_, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO users_oauth (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID) _, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO oauth_users (user_id, remote_user_id, provider, client_id, access_token) VALUES (?, ?, ?, ?, ?)", localUserID, remoteUserID, provider, clientID, accessToken)
} else { } else {
_, err = db.ExecContext(ctx, "INSERT INTO users_oauth (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id")+" user_id = ?", localUserID, remoteUserID, localUserID) _, err = db.ExecContext(ctx, "INSERT INTO oauth_users (user_id, remote_user_id, provider, client_id, access_token) VALUES (?, ?, ?, ?, ?) "+db.upsert("user")+" access_token = ?", localUserID, remoteUserID, provider, clientID, accessToken, accessToken)
} }
if err != nil { if err != nil {
log.Error("Unable to INSERT users_oauth for '%d': %v", localUserID, err) log.Error("Unable to INSERT oauth_users for '%d': %v", localUserID, err)
} }
return err return err
} }
// GetIDForRemoteUser returns a user ID associated with a remote user ID. // GetIDForRemoteUser returns a user ID associated with a remote user ID.
func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) { func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
var userID int64 = -1 var userID int64 = -1
err := db. err := db.
QueryRowContext(ctx, "SELECT user_id FROM users_oauth WHERE remote_user_id = ?", remoteUserID). QueryRowContext(ctx, "SELECT user_id FROM oauth_users WHERE remote_user_id = ? AND provider = ? AND client_id = ?", remoteUserID, provider, clientID).
Scan(&userID) Scan(&userID)
// Not finding a record is OK. // Not finding a record is OK.
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {

View File

@ -18,25 +18,32 @@ func TestOAuthDatastore(t *testing.T) {
driverName: "", driverName: "",
} }
state, err := ds.GenerateOAuthState(ctx) state, err := ds.GenerateOAuthState(ctx, "test", "development")
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, state, 24) assert.Len(t, state, 24)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = false", state) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state)
err = ds.ValidateOAuthState(ctx, state) _, _, err = ds.ValidateOAuthState(ctx, state)
assert.NoError(t, err) assert.NoError(t, err)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = true", state) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state)
var localUserID int64 = 99 var localUserID int64 = 99
var remoteUserID int64 = 100 var remoteUserID = "100"
err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID) err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID, "test", "test", "access_token_a")
assert.NoError(t, err) assert.NoError(t, err)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `users_oauth` WHERE `user_id` = ? AND `remote_user_id` = ?", localUserID, remoteUserID) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_a'", localUserID, remoteUserID)
foundUserID, err := ds.GetIDForRemoteUser(ctx, remoteUserID) err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID, "test", "test", "access_token_b")
assert.NoError(t, err)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_b'", localUserID, remoteUserID)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users`")
foundUserID, err := ds.GetIDForRemoteUser(ctx, remoteUserID, "test", "test")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, localUserID, foundUserID) assert.Equal(t, localUserID, foundUserID)
}) })

52
db/alter.go Normal file
View File

@ -0,0 +1,52 @@
package db
import (
"fmt"
"strings"
)
type AlterTableSqlBuilder struct {
Dialect DialectType
Name string
Changes []string
}
func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder {
if colVal, err := col.String(); err == nil {
b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal))
}
return b
}
func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder {
if colVal, err := col.String(); err == nil {
b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal))
}
return b
}
func (b *AlterTableSqlBuilder) AddUniqueConstraint(name string, columns ...string) *AlterTableSqlBuilder {
b.Changes = append(b.Changes, fmt.Sprintf("ADD CONSTRAINT %s UNIQUE (%s)", name, strings.Join(columns, ", ")))
return b
}
func (b *AlterTableSqlBuilder) ToSQL() (string, error) {
var str strings.Builder
str.WriteString("ALTER TABLE ")
str.WriteString(b.Name)
str.WriteString(" ")
if len(b.Changes) == 0 {
return "", fmt.Errorf("no changes provide for table: %s", b.Name)
}
changeCount := len(b.Changes)
for i, thing := range b.Changes {
str.WriteString(thing)
if i < changeCount-1 {
str.WriteString(", ")
}
}
return str.String(), nil
}

56
db/alter_test.go Normal file
View File

@ -0,0 +1,56 @@
package db
import "testing"
func TestAlterTableSqlBuilder_ToSQL(t *testing.T) {
type fields struct {
Dialect DialectType
Name string
Changes []string
}
tests := []struct {
name string
builder *AlterTableSqlBuilder
want string
wantErr bool
}{
{
name: "MySQL add int",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)),
want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL",
wantErr: false,
},
{
name: "MySQL add string",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})),
want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL",
wantErr: false,
},
{
name: "MySQL add int and string",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)).
AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})),
want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.builder.ToSQL()
if (err != nil) != tt.wantErr {
t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ToSQL() got = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -5,7 +5,6 @@ import (
"strings" "strings"
) )
type DialectType int
type ColumnType int type ColumnType int
type OptionalInt struct { type OptionalInt struct {
@ -41,11 +40,6 @@ type CreateTableSqlBuilder struct {
Constraints []string Constraints []string
} }
const (
DialectSQLite DialectType = iota
DialectMySQL DialectType = iota
)
const ( const (
ColumnTypeBool ColumnType = iota ColumnTypeBool ColumnType = iota
ColumnTypeSmallInt ColumnType = iota ColumnTypeSmallInt ColumnType = iota
@ -61,28 +55,6 @@ var _ SQLBuilder = &CreateTableSqlBuilder{}
var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0} var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0}
var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""} var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""}
func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column {
switch d {
case DialectSQLite:
return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size}
case DialectMySQL:
return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) Table(name string) *CreateTableSqlBuilder {
switch d {
case DialectSQLite:
return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) { func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) {
if dialect != DialectMySQL && dialect != DialectSQLite { if dialect != DialectMySQL && dialect != DialectSQLite {
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
@ -269,3 +241,4 @@ func (b *CreateTableSqlBuilder) ToSQL() (string, error) {
return str.String(), nil return str.String(), nil
} }

76
db/dialect.go Normal file
View File

@ -0,0 +1,76 @@
package db
import "fmt"
type DialectType int
const (
DialectSQLite DialectType = iota
DialectMySQL DialectType = iota
)
func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column {
switch d {
case DialectSQLite:
return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size}
case DialectMySQL:
return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) Table(name string) *CreateTableSqlBuilder {
switch d {
case DialectSQLite:
return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder {
switch d {
case DialectSQLite:
return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) CreateUniqueIndex(name, table string, columns ...string) *CreateIndexSqlBuilder {
switch d {
case DialectSQLite:
return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: true, Columns: columns}
case DialectMySQL:
return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: true, Columns: columns}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) CreateIndex(name, table string, columns ...string) *CreateIndexSqlBuilder {
switch d {
case DialectSQLite:
return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: false, Columns: columns}
case DialectMySQL:
return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: false, Columns: columns}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) DropIndex(name, table string) *DropIndexSqlBuilder {
switch d {
case DialectSQLite:
return &DropIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table}
case DialectMySQL:
return &DropIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}

53
db/index.go Normal file
View File

@ -0,0 +1,53 @@
package db
import (
"fmt"
"strings"
)
type CreateIndexSqlBuilder struct {
Dialect DialectType
Name string
Table string
Unique bool
Columns []string
}
type DropIndexSqlBuilder struct {
Dialect DialectType
Name string
Table string
}
func (b *CreateIndexSqlBuilder) ToSQL() (string, error) {
var str strings.Builder
str.WriteString("CREATE ")
if b.Unique {
str.WriteString("UNIQUE ")
}
str.WriteString("INDEX ")
str.WriteString(b.Name)
str.WriteString(" on ")
str.WriteString(b.Table)
if len(b.Columns) == 0 {
return "", fmt.Errorf("columns provided for this index: %s", b.Name)
}
str.WriteString(" (")
columnCount := len(b.Columns)
for i, thing := range b.Columns {
str.WriteString(thing)
if i < columnCount-1 {
str.WriteString(", ")
}
}
str.WriteString(")")
return str.String(), nil
}
func (b *DropIndexSqlBuilder) ToSQL() (string, error) {
return fmt.Sprintf("DROP INDEX %s on %s", b.Name, b.Table), nil
}

9
db/raw.go Normal file
View File

@ -0,0 +1,9 @@
package db
type RawSqlBuilder struct {
Query string
}
func (b *RawSqlBuilder) ToSQL() (string, error) {
return b.Query, nil
}

1
go.mod
View File

@ -19,6 +19,7 @@ require (
github.com/guregu/null v3.4.0+incompatible github.com/guregu/null v3.4.0+incompatible
github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2 github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2
github.com/jtolds/gls v4.2.1+incompatible // indirect github.com/jtolds/gls v4.2.1+incompatible // indirect
github.com/kr/pretty v0.1.0
github.com/kylemcc/twitter-text-go v0.0.0-20180726194232-7f582f6736ec github.com/kylemcc/twitter-text-go v0.0.0-20180726194232-7f582f6736ec
github.com/lunixbochs/vtclean v1.0.0 // indirect github.com/lunixbochs/vtclean v1.0.0 // indirect
github.com/manifoldco/promptui v0.3.2 github.com/manifoldco/promptui v0.3.2

1
go.sum
View File

@ -64,6 +64,7 @@ github.com/jtolds/gls v4.2.1+incompatible h1:fSuqC+Gmlu6l/ZYAoZzx2pyucC8Xza35fpR
github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a h1:FaWFmfWdAUKbSCtOU2QjDaorUexogfaMgbipgYATUMU= github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a h1:FaWFmfWdAUKbSCtOU2QjDaorUexogfaMgbipgYATUMU=
github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a/go.mod h1:UJSiEoRfvx3hP73CvoARgeLjaIOjybY9vj8PUPPFGeU= github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a/go.mod h1:UJSiEoRfvx3hP73CvoARgeLjaIOjybY9vj8PUPPFGeU=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=

View File

@ -59,7 +59,8 @@ var migrations = []Migration{
New("support user invites", supportUserInvites), // -> V1 (v0.8.0) New("support user invites", supportUserInvites), // -> V1 (v0.8.0)
New("support dynamic instance pages", supportInstancePages), // V1 -> V2 (v0.9.0) New("support dynamic instance pages", supportInstancePages), // V1 -> V2 (v0.9.0)
New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0) New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0)
New("support oauth", oauth), // V3 -> V4 New("support oauth", oauth), // V3 -> V4
New("support slack oauth", oauthSlack), // V4 -> v5
} }
// CurrentVer returns the current migration version the application is on // CurrentVer returns the current migration version the application is on

View File

@ -14,7 +14,7 @@ func oauth(db *datastore) error {
} }
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
createTableUsersOauth, err := dialect. createTableUsersOauth, err := dialect.
Table("users_oauth"). Table("oauth_users").
SetIfNotExists(true). SetIfNotExists(true).
Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)).
Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)).
@ -25,7 +25,7 @@ func oauth(db *datastore) error {
return err return err
} }
createTableOauthClientState, err := dialect. createTableOauthClientState, err := dialect.
Table("oauth_client_state"). Table("oauth_client_states").
SetIfNotExists(true). SetIfNotExists(true).
Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})). Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})).
Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)). Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)).

67
migrations/v5.go Normal file
View File

@ -0,0 +1,67 @@
package migrations
import (
"context"
"database/sql"
wf_db "github.com/writeas/writefreely/db"
)
func oauthSlack(db *datastore) error {
dialect := wf_db.DialectMySQL
if db.driverName == driverSQLite {
dialect = wf_db.DialectSQLite
}
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_states").
AddColumn(dialect.
Column(
"provider",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 24,})).
AddColumn(dialect.
Column(
"client_id",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 128,})),
dialect.
AlterTable("oauth_users").
ChangeColumn("remote_user_id",
dialect.
Column(
"remote_user_id",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 128,})).
AddColumn(dialect.
Column(
"provider",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 24,})).
AddColumn(dialect.
Column(
"client_id",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 128,})).
AddColumn(dialect.
Column(
"access_token",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 512,})),
dialect.DropIndex("remote_user_id", "oauth_users"),
dialect.DropIndex("user_id", "oauth_users"),
dialect.CreateUniqueIndex("oauth_users", "oauth_users", "user_id", "provider", "client_id"),
}
for _, builder := range builders {
query, err := builder.ToSQL()
if err != nil {
return err
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return err
}
}
return nil
})
}

206
oauth.go
View File

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gorilla/mux"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/guregu/null/zero" "github.com/guregu/null/zero"
"github.com/writeas/impart" "github.com/writeas/impart"
@ -14,8 +15,6 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"strings"
"time" "time"
) )
@ -31,11 +30,13 @@ type TokenResponse struct {
// InspectResponse contains data returned when an access token is inspected. // InspectResponse contains data returned when an access token is inspected.
type InspectResponse struct { type InspectResponse struct {
ClientID string `json:"client_id"` ClientID string `json:"client_id"`
UserID int64 `json:"user_id"` UserID string `json:"user_id"`
ExpiresAt time.Time `json:"expires_at"` ExpiresAt time.Time `json:"expires_at"`
Username string `json:"username"` Username string `json:"username"`
Email string `json:"email"` DisplayName string `json:"-"`
Email string `json:"email"`
Error string `json:"error"`
} }
// tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token // tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token
@ -57,11 +58,12 @@ type OAuthDatastoreProvider interface {
// OAuthDatastore provides a minimal interface of data store methods used in // OAuthDatastore provides a minimal interface of data store methods used in
// oauth functionality. // oauth functionality.
type OAuthDatastore interface { type OAuthDatastore interface {
GenerateOAuthState(context.Context) (string, error) GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
ValidateOAuthState(context.Context, string) error RecordRemoteUserID(context.Context, int64, string, string, string, string) error
GetIDForRemoteUser(context.Context, int64) (int64, error) ValidateOAuthState(context.Context, string) (string, string, error)
GenerateOAuthState(context.Context, string, string) (string, error)
CreateUser(*config.Config, *User, string) error CreateUser(*config.Config, *User, string) error
RecordRemoteUserID(context.Context, int64, int64) error
GetUserForAuthByID(int64) (*User, error) GetUserForAuthByID(int64) (*User, error)
} }
@ -69,56 +71,89 @@ type HttpClient interface {
Do(req *http.Request) (*http.Response, error) Do(req *http.Request) (*http.Response, error)
} }
type oauthClient interface {
GetProvider() string
GetClientID() string
buildLoginURL(state string) (string, error)
exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error)
inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error)
}
type oauthHandler struct { type oauthHandler struct {
Config *config.Config Config *config.Config
DB OAuthDatastore DB OAuthDatastore
Store sessions.Store Store sessions.Store
HttpClient HttpClient oauthClient oauthClient
} }
// buildAuthURL returns a URL used to initiate authentication.
func buildAuthURL(db OAuthDatastore, ctx context.Context, clientID, authLocation, callbackURL string) (string, error) {
state, err := db.GenerateOAuthState(ctx)
if err != nil {
return "", err
}
u, err := url.Parse(authLocation)
if err != nil {
return "", err
}
q := u.Query()
q.Set("client_id", clientID)
q.Set("redirect_uri", callbackURL)
q.Set("response_type", "code")
q.Set("state", state)
u.RawQuery = q.Encode()
return u.String(), nil
}
// app *App, w http.ResponseWriter, r *http.Request
func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error { func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error {
location, err := buildAuthURL(h.DB, r.Context(), h.Config.App.OAuthClientID, h.Config.App.OAuthProviderAuthLocation, h.Config.App.OAuthClientCallbackLocation) ctx := r.Context()
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID())
if err != nil {
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
}
location, err := h.oauthClient.buildLoginURL(state)
if err != nil { if err != nil {
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
} }
return impart.HTTPError{http.StatusTemporaryRedirect, location} return impart.HTTPError{http.StatusTemporaryRedirect, location}
} }
func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) {
if app.Config().SlackOauth.ClientID != "" {
oauthClient := slackOauthClient{
ClientID: app.Config().SlackOauth.ClientID,
ClientSecret: app.Config().SlackOauth.ClientSecret,
TeamID: app.Config().SlackOauth.TeamID,
CallbackLocation: app.Config().App.Host + "/oauth/callback",
HttpClient: config.DefaultHTTPClient(),
}
configureOauthRoutes(parentHandler, r, app, oauthClient)
}
}
func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) {
if app.Config().WriteAsOauth.ClientID != "" {
oauthClient := writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation),
InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation),
AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation),
HttpClient: config.DefaultHTTPClient(),
CallbackLocation: app.Config().App.Host + "/oauth/callback",
}
if oauthClient.ExchangeLocation == "" {
}
configureOauthRoutes(parentHandler, r, app, oauthClient)
}
}
func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient) {
handler := &oauthHandler{
Config: app.Config(),
DB: app.DB(),
Store: app.SessionStore(),
oauthClient: oauthClient,
}
r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET")
r.HandleFunc("/oauth/callback", parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET")
}
func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error { func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error {
ctx := r.Context() ctx := r.Context()
code := r.FormValue("code") code := r.FormValue("code")
state := r.FormValue("state") state := r.FormValue("state")
err := h.DB.ValidateOAuthState(ctx, state) provider, clientID, err := h.DB.ValidateOAuthState(ctx, state)
if err != nil { if err != nil {
log.Error("Unable to ValidateOAuthState: %s", err) log.Error("Unable to ValidateOAuthState: %s", err)
return impart.HTTPError{http.StatusInternalServerError, err.Error()} return impart.HTTPError{http.StatusInternalServerError, err.Error()}
} }
tokenResponse, err := h.exchangeOauthCode(ctx, code) tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code)
if err != nil { if err != nil {
log.Error("Unable to exchangeOauthCode: %s", err) log.Error("Unable to exchangeOauthCode: %s", err)
return impart.HTTPError{http.StatusInternalServerError, err.Error()} return impart.HTTPError{http.StatusInternalServerError, err.Error()}
@ -126,20 +161,18 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
// Now that we have the access token, let's use it real quick to make sur // Now that we have the access token, let's use it real quick to make sur
// it really really works. // it really really works.
tokenInfo, err := h.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
if err != nil { if err != nil {
log.Error("Unable to inspectOauthAccessToken: %s", err) log.Error("Unable to inspectOauthAccessToken: %s", err)
return impart.HTTPError{http.StatusInternalServerError, err.Error()} return impart.HTTPError{http.StatusInternalServerError, err.Error()}
} }
localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID) localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID)
if err != nil { if err != nil {
log.Error("Unable to GetIDForRemoteUser: %s", err) log.Error("Unable to GetIDForRemoteUser: %s", err)
return impart.HTTPError{http.StatusInternalServerError, err.Error()} return impart.HTTPError{http.StatusInternalServerError, err.Error()}
} }
fmt.Println("local user id", localUserID)
if localUserID == -1 { if localUserID == -1 {
// We don't have, nor do we want, the password from the origin, so we // We don't have, nor do we want, the password from the origin, so we
//create a random string. If the user needs to set a password, they //create a random string. If the user needs to set a password, they
@ -148,23 +181,26 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
randPass := store.Generate62RandomString(14) randPass := store.Generate62RandomString(14)
hashedPass, err := auth.HashPass([]byte(randPass)) hashedPass, err := auth.HashPass([]byte(randPass))
if err != nil { if err != nil {
log.ErrorLog.Println(err)
return impart.HTTPError{http.StatusInternalServerError, "unable to create password hash"} return impart.HTTPError{http.StatusInternalServerError, "unable to create password hash"}
} }
newUser := &User{ newUser := &User{
Username: tokenInfo.Username, Username: tokenInfo.Username,
HashedPass: hashedPass, HashedPass: hashedPass,
HasPass: true, HasPass: true,
Email: zero.NewString("", tokenInfo.Email != ""), Email: zero.NewString(tokenInfo.Email, tokenInfo.Email != ""),
Created: time.Now().Truncate(time.Second).UTC(), Created: time.Now().Truncate(time.Second).UTC(),
} }
displayName := tokenInfo.DisplayName
if len(displayName) == 0 {
displayName = tokenInfo.Username
}
err = h.DB.CreateUser(h.Config, newUser, newUser.Username) err = h.DB.CreateUser(h.Config, newUser, displayName)
if err != nil { if err != nil {
return impart.HTTPError{http.StatusInternalServerError, err.Error()} return impart.HTTPError{http.StatusInternalServerError, err.Error()}
} }
err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID) err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken)
if err != nil { if err != nil {
return impart.HTTPError{http.StatusInternalServerError, err.Error()} return impart.HTTPError{http.StatusInternalServerError, err.Error()}
} }
@ -185,76 +221,16 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
return nil return nil
} }
func (h oauthHandler) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error {
form := url.Values{} lr := io.LimitReader(body, int64(n+1))
form.Add("grant_type", "authorization_code") data, err := ioutil.ReadAll(lr)
form.Add("redirect_uri", h.Config.App.OAuthClientCallbackLocation)
form.Add("code", code)
req, err := http.NewRequest("POST", h.Config.App.OAuthProviderTokenLocation, strings.NewReader(form.Encode()))
if err != nil { if err != nil {
return nil, err return err
} }
req.WithContext(ctx) if len(data) == n+1 {
req.Header.Set("User-Agent", "writefreely") return fmt.Errorf("content larger than max read allowance: %d", n)
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(h.Config.App.OAuthClientID, h.Config.App.OAuthClientSecret)
resp, err := h.HttpClient.Do(req)
if err != nil {
return nil, err
} }
return json.Unmarshal(data, thing)
// Nick: I like using limited readers to reduce the risk of an endpoint
// being broken or compromised.
lr := io.LimitReader(resp.Body, tokenRequestMaxLen)
body, err := ioutil.ReadAll(lr)
if err != nil {
return nil, err
}
var tokenResponse TokenResponse
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return nil, err
}
// Check the response for an error message, and return it if there is one.
if tokenResponse.Error != "" {
return nil, fmt.Errorf(tokenResponse.Error)
}
return &tokenResponse, nil
}
func (h oauthHandler) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", h.Config.App.OAuthProviderInspectLocation, nil)
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := h.HttpClient.Do(req)
if err != nil {
return nil, err
}
// Nick: I like using limited readers to reduce the risk of an endpoint
// being broken or compromised.
lr := io.LimitReader(resp.Body, infoRequestMaxLen)
body, err := ioutil.ReadAll(lr)
if err != nil {
return nil, err
}
var inspectResponse InspectResponse
err = json.Unmarshal(body, &inspectResponse)
if err != nil {
return nil, err
}
return &inspectResponse, nil
} }
func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error { func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error {

164
oauth_slack.go Normal file
View File

@ -0,0 +1,164 @@
package writefreely
import (
"context"
"errors"
"github.com/writeas/slug"
"net/http"
"net/url"
"strings"
)
type slackOauthClient struct {
ClientID string
ClientSecret string
TeamID string
CallbackLocation string
HttpClient HttpClient
}
type slackExchangeResponse struct {
OK bool `json:"ok"`
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TeamName string `json:"team_name"`
TeamID string `json:"team_id"`
Error string `json:"error"`
}
type slackIdentity struct {
Name string `json:"name"`
ID string `json:"id"`
Email string `json:"email"`
}
type slackTeam struct {
Name string `json:"name"`
ID string `json:"id"`
}
type slackUserIdentityResponse struct {
OK bool `json:"ok"`
User slackIdentity `json:"user"`
Team slackTeam `json:"team"`
Error string `json:"error"`
}
const (
slackAuthLocation = "https://slack.com/oauth/authorize"
slackExchangeLocation = "https://slack.com/api/oauth.access"
slackIdentityLocation = "https://slack.com/api/users.identity"
)
var _ oauthClient = slackOauthClient{}
func (c slackOauthClient) GetProvider() string {
return "slack"
}
func (c slackOauthClient) GetClientID() string {
return c.ClientID
}
func (c slackOauthClient) buildLoginURL(state string) (string, error) {
u, err := url.Parse(slackAuthLocation)
if err != nil {
return "", err
}
q := u.Query()
q.Set("client_id", c.ClientID)
q.Set("scope", "identity.basic identity.email identity.team")
q.Set("redirect_uri", c.CallbackLocation)
q.Set("state", state)
// If this param is not set, the user can select which team they
// authenticate through and then we'd have to match the configured team
// against the profile get. That is extra work in the post-auth phase
// that we don't want to do.
q.Set("team", c.TeamID)
// The Slack OAuth docs don't explicitly list this one, but it is part of
// the spec, so we include it anyway.
q.Set("response_type", "code")
u.RawQuery = q.Encode()
return u.String(), nil
}
func (c slackOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{}
// The oauth.access documentation doesn't explicitly mention this
// parameter, but it is part of the spec, so we include it anyway.
// https://api.slack.com/methods/oauth.access
form.Add("grant_type", "authorization_code")
form.Add("redirect_uri", c.CallbackLocation)
form.Add("code", code)
req, err := http.NewRequest("POST", slackExchangeLocation, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(c.ClientID, c.ClientSecret)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, errors.New("unable to exchange code for access token")
}
var tokenResponse slackExchangeResponse
if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil {
return nil, err
}
if !tokenResponse.OK {
return nil, errors.New(tokenResponse.Error)
}
return tokenResponse.TokenResponse(), nil
}
func (c slackOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", slackIdentityLocation, nil)
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, errors.New("unable to inspect access token")
}
var inspectResponse slackUserIdentityResponse
if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil {
return nil, err
}
if !inspectResponse.OK {
return nil, errors.New(inspectResponse.Error)
}
return inspectResponse.InspectResponse(), nil
}
func (resp slackUserIdentityResponse) InspectResponse() *InspectResponse {
return &InspectResponse{
UserID: resp.User.ID,
Username: slug.Make(resp.User.Name),
DisplayName: resp.User.Name,
Email: resp.User.Email,
}
}
func (resp slackExchangeResponse) TokenResponse() *TokenResponse {
return &TokenResponse{
AccessToken: resp.AccessToken,
}
}

View File

@ -21,14 +21,16 @@ type MockOAuthDatastoreProvider struct {
} }
type MockOAuthDatastore struct { type MockOAuthDatastore struct {
DoGenerateOAuthState func(ctx context.Context) (string, error) DoGenerateOAuthState func(context.Context, string, string) (string, error)
DoValidateOAuthState func(context.Context, string) error DoValidateOAuthState func(context.Context, string) (string, string, error)
DoGetIDForRemoteUser func(context.Context, int64) (int64, error) DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error)
DoCreateUser func(*config.Config, *User, string) error DoCreateUser func(*config.Config, *User, string) error
DoRecordRemoteUserID func(context.Context, int64, int64) error DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error
DoGetUserForAuthByID func(int64) (*User, error) DoGetUserForAuthByID func(int64) (*User, error)
} }
var _ OAuthDatastore = &MockOAuthDatastore{}
type StringReadCloser struct { type StringReadCloser struct {
*strings.Reader *strings.Reader
} }
@ -68,26 +70,31 @@ func (m *MockOAuthDatastoreProvider) Config() *config.Config {
} }
cfg := config.New() cfg := config.New()
cfg.UseSQLite(true) cfg.UseSQLite(true)
cfg.App.EnableOAuth = true cfg.WriteAsOauth = config.WriteAsOauthCfg{
cfg.App.OAuthProviderAuthLocation = "https://write.as/oauth/login" ClientID: "development",
cfg.App.OAuthProviderTokenLocation = "https://write.as/oauth/token" ClientSecret: "development",
cfg.App.OAuthProviderInspectLocation = "https://write.as/oauth/inspect" AuthLocation: "https://write.as/oauth/login",
cfg.App.OAuthClientCallbackLocation = "http://localhost/oauth/callback" TokenLocation: "https://write.as/oauth/token",
cfg.App.OAuthClientID = "development" InspectLocation: "https://write.as/oauth/inspect",
cfg.App.OAuthClientSecret = "development" }
cfg.SlackOauth = config.SlackOauthCfg{
ClientID: "development",
ClientSecret: "development",
TeamID: "development",
}
return cfg return cfg
} }
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) error { func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) {
if m.DoValidateOAuthState != nil { if m.DoValidateOAuthState != nil {
return m.DoValidateOAuthState(ctx, state) return m.DoValidateOAuthState(ctx, state)
} }
return nil return "", "", nil
} }
func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) { func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
if m.DoGetIDForRemoteUser != nil { if m.DoGetIDForRemoteUser != nil {
return m.DoGetIDForRemoteUser(ctx, remoteUserID) return m.DoGetIDForRemoteUser(ctx, remoteUserID, provider, clientID)
} }
return -1, nil return -1, nil
} }
@ -100,9 +107,9 @@ func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username st
return nil return nil
} }
func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID int64) error { func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
if m.DoRecordRemoteUserID != nil { if m.DoRecordRemoteUserID != nil {
return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID) return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID, provider, clientID, accessToken)
} }
return nil return nil
} }
@ -117,9 +124,9 @@ func (m *MockOAuthDatastore) GetUserForAuthByID(userID int64) (*User, error) {
return user, nil return user, nil
} }
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context) (string, error) { func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string) (string, error) {
if m.DoGenerateOAuthState != nil { if m.DoGenerateOAuthState != nil {
return m.DoGenerateOAuthState(ctx) return m.DoGenerateOAuthState(ctx, provider, clientID)
} }
return store.Generate62RandomString(14), nil return store.Generate62RandomString(14), nil
} }
@ -132,11 +139,20 @@ func TestViewOauthInit(t *testing.T) {
Config: app.Config(), Config: app.Config(),
DB: app.DB(), DB: app.DB(),
Store: app.SessionStore(), Store: app.SessionStore(),
oauthClient:writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
CallbackLocation: "http://localhost/oauth/callback",
HttpClient: nil,
},
} }
req, err := http.NewRequest("GET", "/oauth/client", nil) req, err := http.NewRequest("GET", "/oauth/client", nil)
assert.NoError(t, err) assert.NoError(t, err)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
h.viewOauthInit(rr, req) h.viewOauthInit(nil, rr, req)
assert.Equal(t, http.StatusTemporaryRedirect, rr.Code) assert.Equal(t, http.StatusTemporaryRedirect, rr.Code)
locURI, err := url.Parse(rr.Header().Get("Location")) locURI, err := url.Parse(rr.Header().Get("Location"))
assert.NoError(t, err) assert.NoError(t, err)
@ -151,7 +167,7 @@ func TestViewOauthInit(t *testing.T) {
app := &MockOAuthDatastoreProvider{ app := &MockOAuthDatastoreProvider{
DoDB: func() OAuthDatastore { DoDB: func() OAuthDatastore {
return &MockOAuthDatastore{ return &MockOAuthDatastore{
DoGenerateOAuthState: func(ctx context.Context) (string, error) { DoGenerateOAuthState: func(ctx context.Context, provider, clientID string) (string, error) {
return "", fmt.Errorf("pretend unable to write state error") return "", fmt.Errorf("pretend unable to write state error")
}, },
} }
@ -161,11 +177,20 @@ func TestViewOauthInit(t *testing.T) {
Config: app.Config(), Config: app.Config(),
DB: app.DB(), DB: app.DB(),
Store: app.SessionStore(), Store: app.SessionStore(),
oauthClient:writeAsOauthClient{
ClientID: app.Config().WriteAsOauth.ClientID,
ClientSecret: app.Config().WriteAsOauth.ClientSecret,
ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
InspectLocation: app.Config().WriteAsOauth.InspectLocation,
AuthLocation: app.Config().WriteAsOauth.AuthLocation,
CallbackLocation: "http://localhost/oauth/callback",
HttpClient: nil,
},
} }
req, err := http.NewRequest("GET", "/oauth/client", nil) req, err := http.NewRequest("GET", "/oauth/client", nil)
assert.NoError(t, err) assert.NoError(t, err)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
h.viewOauthInit(rr, req) h.viewOauthInit(nil, rr, req)
assert.Equal(t, http.StatusInternalServerError, rr.Code) assert.Equal(t, http.StatusInternalServerError, rr.Code)
expected := `{"error":"could not prepare oauth redirect url"}` + "\n" expected := `{"error":"could not prepare oauth redirect url"}` + "\n"
assert.Equal(t, expected, rr.Body.String()) assert.Equal(t, expected, rr.Body.String())
@ -179,33 +204,40 @@ func TestViewOauthCallback(t *testing.T) {
Config: app.Config(), Config: app.Config(),
DB: app.DB(), DB: app.DB(),
Store: app.SessionStore(), Store: app.SessionStore(),
HttpClient: &MockHTTPClient{ oauthClient: writeAsOauthClient{
DoDo: func(req *http.Request) (*http.Response, error) { ClientID: app.Config().WriteAsOauth.ClientID,
switch req.URL.String() { ClientSecret: app.Config().WriteAsOauth.ClientSecret,
case "https://write.as/oauth/token": ExchangeLocation: app.Config().WriteAsOauth.TokenLocation,
return &http.Response{ InspectLocation: app.Config().WriteAsOauth.InspectLocation,
StatusCode: 200, AuthLocation: app.Config().WriteAsOauth.AuthLocation,
Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)}, CallbackLocation: "http://localhost/oauth/callback",
}, nil HttpClient: &MockHTTPClient{
case "https://write.as/oauth/inspect": DoDo: func(req *http.Request) (*http.Response, error) {
return &http.Response{ switch req.URL.String() {
StatusCode: 200, case "https://write.as/oauth/token":
Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": 1, "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)}, return &http.Response{
}, nil StatusCode: 200,
} Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)},
}, nil
case "https://write.as/oauth/inspect":
return &http.Response{
StatusCode: 200,
Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)},
}, nil
}
return &http.Response{ return &http.Response{
StatusCode: http.StatusNotFound, StatusCode: http.StatusNotFound,
}, nil }, nil
},
}, },
}, },
} }
req, err := http.NewRequest("GET", "/oauth/callback", nil) req, err := http.NewRequest("GET", "/oauth/callback", nil)
assert.NoError(t, err) assert.NoError(t, err)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
h.viewOauthCallback(rr, req) h.viewOauthCallback(nil, rr, req)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, http.StatusTemporaryRedirect, rr.Code) assert.Equal(t, http.StatusTemporaryRedirect, rr.Code)
}) })
} }

110
oauth_writeas.go Normal file
View File

@ -0,0 +1,110 @@
package writefreely
import (
"context"
"errors"
"net/http"
"net/url"
"strings"
)
type writeAsOauthClient struct {
ClientID string
ClientSecret string
AuthLocation string
ExchangeLocation string
InspectLocation string
CallbackLocation string
HttpClient HttpClient
}
var _ oauthClient = writeAsOauthClient{}
const (
writeAsAuthLocation = "https://write.as/oauth/login"
writeAsExchangeLocation = "https://write.as/oauth/token"
writeAsIdentityLocation = "https://write.as/oauth/inspect"
)
func (c writeAsOauthClient) GetProvider() string {
return "write.as"
}
func (c writeAsOauthClient) GetClientID() string {
return c.ClientID
}
func (c writeAsOauthClient) buildLoginURL(state string) (string, error) {
u, err := url.Parse(c.AuthLocation)
if err != nil {
return "", err
}
q := u.Query()
q.Set("client_id", c.ClientID)
q.Set("redirect_uri", c.CallbackLocation)
q.Set("response_type", "code")
q.Set("state", state)
u.RawQuery = q.Encode()
return u.String(), nil
}
func (c writeAsOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) {
form := url.Values{}
form.Add("grant_type", "authorization_code")
form.Add("redirect_uri", c.CallbackLocation)
form.Add("code", code)
req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth(c.ClientID, c.ClientSecret)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, errors.New("unable to exchange code for access token")
}
var tokenResponse TokenResponse
if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil {
return nil, err
}
if tokenResponse.Error != "" {
return nil, errors.New(tokenResponse.Error)
}
return &tokenResponse, nil
}
func (c writeAsOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) {
req, err := http.NewRequest("GET", c.InspectLocation, nil)
if err != nil {
return nil, err
}
req.WithContext(ctx)
req.Header.Set("User-Agent", "writefreely")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, errors.New("unable to inspect access token")
}
var inspectResponse InspectResponse
if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil {
return nil, err
}
if inspectResponse.Error != "" {
return nil, errors.New(inspectResponse.Error)
}
return &inspectResponse, nil
}

View File

@ -70,6 +70,9 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
write.HandleFunc(nodeinfo.NodeInfoPath, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfoDiscover))) write.HandleFunc(nodeinfo.NodeInfoPath, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfoDiscover)))
write.HandleFunc(niCfg.InfoURL, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfo))) write.HandleFunc(niCfg.InfoURL, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfo)))
configureSlackOauth(handler, write, apper.App())
configureWriteAsOauth(handler, write, apper.App())
// Set up dyamic page handlers // Set up dyamic page handlers
// Handle auth // Handle auth
auth := write.PathPrefix("/api/auth/").Subrouter() auth := write.PathPrefix("/api/auth/").Subrouter()
@ -80,16 +83,6 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
auth.HandleFunc("/read", handler.WebErrors(handleWebCollectionUnlock, UserLevelNone)).Methods("POST") auth.HandleFunc("/read", handler.WebErrors(handleWebCollectionUnlock, UserLevelNone)).Methods("POST")
auth.HandleFunc("/me", handler.All(handleAPILogout)).Methods("DELETE") auth.HandleFunc("/me", handler.All(handleAPILogout)).Methods("DELETE")
oauthHandler := oauthHandler{
HttpClient: &http.Client{},
Config: apper.App().Config(),
DB: apper.App().DB(),
Store: apper.App().SessionStore(),
}
write.HandleFunc("/oauth/write.as", handler.OAuth(oauthHandler.viewOauthInit)).Methods("GET")
write.HandleFunc("/oauth/callback", handler.OAuth(oauthHandler.viewOauthCallback)).Methods("GET")
// Handle logged in user sections // Handle logged in user sections
me := write.PathPrefix("/me").Subrouter() me := write.PathPrefix("/me").Subrouter()
me.HandleFunc("/", handler.Redirect("/me", UserLevelUser)) me.HandleFunc("/", handler.Redirect("/me", UserLevelUser))
@ -191,6 +184,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
} }
write.HandleFunc(draftEditPrefix+"/{post}", handler.Web(handleViewPost, UserLevelOptional)) write.HandleFunc(draftEditPrefix+"/{post}", handler.Web(handleViewPost, UserLevelOptional))
write.HandleFunc("/", handler.Web(handleViewHome, UserLevelOptional)) write.HandleFunc("/", handler.Web(handleViewHome, UserLevelOptional))
return r return r
} }