diff --git a/Makefile b/Makefile index 757bcfd..782e680 100644 --- a/Makefile +++ b/Makefile @@ -47,6 +47,12 @@ build-arm7: deps fi xgo --targets=linux/arm-7, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely +build-arm64: deps + @hash xgo > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ + $(GOGET) -u github.com/karalabe/xgo; \ + fi + xgo --targets=linux/arm64, -dest build/ $(LDFLAGS) -tags='sqlite' -out writefreely ./cmd/writefreely + build-docker : $(DOCKERCMD) build -t $(IMAGE_NAME):latest -t $(IMAGE_NAME):$(GITREV) . @@ -83,6 +89,10 @@ release : clean ui assets mv build/$(BINARY_NAME)-linux-arm-7 $(BUILDPATH)/$(BINARY_NAME) tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_arm7.tar.gz -C build $(BINARY_NAME) rm $(BUILDPATH)/$(BINARY_NAME) + $(MAKE) build-arm64 + mv build/$(BINARY_NAME)-linux-arm64 $(BUILDPATH)/$(BINARY_NAME) + tar -cvzf $(BINARY_NAME)_$(GITREV)_linux_arm64.tar.gz -C build $(BINARY_NAME) + rm $(BUILDPATH)/$(BINARY_NAME) $(MAKE) build-darwin mv build/$(BINARY_NAME)-darwin-10.6-amd64 $(BUILDPATH)/$(BINARY_NAME) tar -cvzf $(BINARY_NAME)_$(GITREV)_macos_amd64.tar.gz -C build $(BINARY_NAME) diff --git a/account.go b/account.go index c41f24d..6fb8053 100644 --- a/account.go +++ b/account.go @@ -156,17 +156,9 @@ func signupWithRegistration(app *App, signup userRegistration, w http.ResponseWr Username: signup.Alias, HashedPass: hashedPass, HasPass: createdWithPass, - Email: zero.NewString("", signup.Email != ""), + Email: prepareUserEmail(signup.Email, app.keys.EmailKey), Created: time.Now().Truncate(time.Second).UTC(), } - if signup.Email != "" { - encEmail, err := data.Encrypt(app.keys.EmailKey, signup.Email) - if err != nil { - log.Error("Unable to encrypt email: %s\n", err) - } else { - u.Email.String = string(encEmail) - } - } // Create actual user if err := app.db.CreateUser(app.cfg, u, desiredUsername); err != nil { @@ -1097,3 +1089,16 @@ func getTempInfo(app *App, key string, r *http.Request, w http.ResponseWriter) s // Return value return s } + +func prepareUserEmail(input string, emailKey []byte) zero.String { + email := zero.NewString("", input != "") + if len(input) > 0 { + encEmail, err := data.Encrypt(emailKey, input) + if err != nil { + log.Error("Unable to encrypt email: %s\n", err) + } else { + email.String = string(encEmail) + } + } + return email +} diff --git a/admin.go b/admin.go index ebb4225..0a73a11 100644 --- a/admin.go +++ b/admin.go @@ -260,7 +260,7 @@ func handleAdminToggleUserStatus(app *App, u *User, w http.ResponseWriter, r *ht } if err != nil { log.Error("toggle user suspended: %v", err) - return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not toggle user status: %v")} + return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not toggle user status: %v", err)} } return impart.HTTPError{http.StatusFound, fmt.Sprintf("/admin/user/%s#status", username)} } diff --git a/app.go b/app.go index 018ce37..d465a3e 100644 --- a/app.go +++ b/app.go @@ -56,7 +56,7 @@ var ( debugging bool // Software version can be set from git env using -ldflags - softwareVer = "0.11.1" + softwareVer = "0.11.2" // DEPRECATED VARS isSingleUser bool @@ -70,7 +70,7 @@ type App struct { cfg *config.Config cfgFile string keys *key.Keychain - sessionStore *sessions.CookieStore + sessionStore sessions.Store formDecoder *schema.Decoder timeline *localTimeline @@ -101,6 +101,14 @@ func (app *App) SetKeys(k *key.Keychain) { app.keys = k } +func (app *App) SessionStore() sessions.Store { + return app.sessionStore +} + +func (app *App) SetSessionStore(s sessions.Store) { + app.sessionStore = s +} + // Apper is the interface for getting data into and out of a WriteFreely // instance (or "App"). // diff --git a/collections.go b/collections.go index b85f0a4..66ad7a0 100644 --- a/collections.go +++ b/collections.go @@ -648,6 +648,16 @@ func processCollectionPermissions(app *App, cr *collectionReq, u *User, w http.R uname = u.Username } + // TODO: move this to all permission checks? + suspended, err := app.db.IsUserSuspended(c.OwnerID) + if err != nil { + log.Error("process protected collection permissions: %v", err) + return nil, err + } + if suspended { + return nil, ErrCollectionNotFound + } + // See if we've authorized this collection authd := isAuthorizedForCollection(app, c.Alias, r) diff --git a/config/config.go b/config/config.go index 84bae86..996c1df 100644 --- a/config/config.go +++ b/config/config.go @@ -56,6 +56,20 @@ type ( Port int `ini:"port"` } + WriteAsOauthCfg struct { + ClientID string `ini:"client_id"` + ClientSecret string `ini:"client_secret"` + AuthLocation string `ini:"auth_location"` + TokenLocation string `ini:"token_location"` + InspectLocation string `ini:"inspect_location"` + } + + SlackOauthCfg struct { + ClientID string `ini:"client_id"` + ClientSecret string `ini:"client_secret"` + TeamID string `ini:"team_id"` + } + // AppCfg holds values that affect how the application functions AppCfg struct { SiteName string `ini:"site_name"` @@ -98,9 +112,11 @@ type ( // Config holds the complete configuration for running a writefreely instance Config struct { - Server ServerCfg `ini:"server"` - Database DatabaseCfg `ini:"database"` - App AppCfg `ini:"app"` + Server ServerCfg `ini:"server"` + Database DatabaseCfg `ini:"database"` + App AppCfg `ini:"app"` + SlackOauth SlackOauthCfg `ini:"oauth.slack"` + WriteAsOauth WriteAsOauthCfg `ini:"oauth.writeas"` } ) diff --git a/config/funcs.go b/config/funcs.go index a9c82ce..9678df0 100644 --- a/config/funcs.go +++ b/config/funcs.go @@ -11,7 +11,9 @@ package config import ( + "net/http" "strings" + "time" ) // FriendlyHost returns the app's Host sans any schema @@ -25,3 +27,16 @@ func (ac AppCfg) CanCreateBlogs(currentlyUsed uint64) bool { } return int(currentlyUsed) < ac.MaxBlogs } + +// OrDefaultString returns input or a default value if input is empty. +func OrDefaultString(input, defaultValue string) string { + if len(input) == 0 { + return defaultValue + } + return input +} + +// DefaultHTTPClient returns a sane default HTTP client. +func DefaultHTTPClient() *http.Client { + return &http.Client{Timeout: 10 * time.Second} +} diff --git a/database.go b/database.go index d78d888..ef52d84 100644 --- a/database.go +++ b/database.go @@ -11,8 +11,10 @@ package writefreely import ( + "context" "database/sql" "fmt" + wf_db "github.com/writeas/writefreely/db" "net/http" "strings" "time" @@ -124,6 +126,11 @@ type writestore interface { GetUserLastPostTime(id int64) (*time.Time, error) GetCollectionLastPostTime(id int64) (*time.Time, error) + GetIDForRemoteUser(context.Context, string, string, string) (int64, error) + RecordRemoteUserID(context.Context, int64, string, string, string, string) error + ValidateOAuthState(context.Context, string) (string, string, error) + GenerateOAuthState(context.Context, string, string) (string, error) + DatabaseInitialized() bool } @@ -132,6 +139,8 @@ type datastore struct { driverName string } +var _ writestore = &datastore{} + func (db *datastore) now() string { if db.driverName == driverSQLite { return "strftime('%Y-%m-%d %H:%M:%S','now')" @@ -2453,6 +2462,69 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) { return &t, nil } +func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) { + state := store.Generate62RandomString(24) + _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, NOW())", state, provider, clientID) + if err != nil { + return "", fmt.Errorf("unable to record oauth client state: %w", err) + } + return state, nil +} + +func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) { + var provider string + var clientID string + err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + err := tx.QueryRow("SELECT provider, client_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).Scan(&provider, &clientID) + if err != nil { + return err + } + + res, err := tx.ExecContext(ctx, "UPDATE oauth_client_states SET used = TRUE WHERE state = ?", state) + if err != nil { + return err + } + rowsAffected, err := res.RowsAffected() + if err != nil { + return err + } + if rowsAffected != 1 { + return fmt.Errorf("state not found") + } + return nil + }) + if err != nil { + return "", "", nil + } + return provider, clientID, nil +} + +func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error { + var err error + if db.driverName == driverSQLite { + _, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO oauth_users (user_id, remote_user_id, provider, client_id, access_token) VALUES (?, ?, ?, ?, ?)", localUserID, remoteUserID, provider, clientID, accessToken) + } else { + _, err = db.ExecContext(ctx, "INSERT INTO oauth_users (user_id, remote_user_id, provider, client_id, access_token) VALUES (?, ?, ?, ?, ?) "+db.upsert("user")+" access_token = ?", localUserID, remoteUserID, provider, clientID, accessToken, accessToken) + } + if err != nil { + log.Error("Unable to INSERT oauth_users for '%d': %v", localUserID, err) + } + return err +} + +// GetIDForRemoteUser returns a user ID associated with a remote user ID. +func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) { + var userID int64 = -1 + err := db. + QueryRowContext(ctx, "SELECT user_id FROM oauth_users WHERE remote_user_id = ? AND provider = ? AND client_id = ?", remoteUserID, provider, clientID). + Scan(&userID) + // Not finding a record is OK. + if err != nil && err != sql.ErrNoRows { + return -1, err + } + return userID, nil +} + // DatabaseInitialized returns whether or not the current datastore has been // initialized with the correct schema. // Currently, it checks to see if the `users` table exists. diff --git a/database_test.go b/database_test.go new file mode 100644 index 0000000..c4c586a --- /dev/null +++ b/database_test.go @@ -0,0 +1,50 @@ +package writefreely + +import ( + "context" + "database/sql" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestOAuthDatastore(t *testing.T) { + if !runMySQLTests() { + t.Skip("skipping mysql tests") + } + withTestDB(t, func(db *sql.DB) { + ctx := context.Background() + ds := &datastore{ + DB: db, + driverName: "", + } + + state, err := ds.GenerateOAuthState(ctx, "test", "development") + assert.NoError(t, err) + assert.Len(t, state, 24) + + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state) + + _, _, err = ds.ValidateOAuthState(ctx, state) + assert.NoError(t, err) + + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state) + + var localUserID int64 = 99 + var remoteUserID = "100" + err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID, "test", "test", "access_token_a") + assert.NoError(t, err) + + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_a'", localUserID, remoteUserID) + + err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID, "test", "test", "access_token_b") + assert.NoError(t, err) + + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_b'", localUserID, remoteUserID) + + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users`") + + foundUserID, err := ds.GetIDForRemoteUser(ctx, remoteUserID, "test", "test") + assert.NoError(t, err) + assert.Equal(t, localUserID, foundUserID) + }) +} diff --git a/db/alter.go b/db/alter.go new file mode 100644 index 0000000..0a4ffdd --- /dev/null +++ b/db/alter.go @@ -0,0 +1,52 @@ +package db + +import ( + "fmt" + "strings" +) + +type AlterTableSqlBuilder struct { + Dialect DialectType + Name string + Changes []string +} + +func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder { + if colVal, err := col.String(); err == nil { + b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal)) + } + return b +} + +func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder { + if colVal, err := col.String(); err == nil { + b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal)) + } + return b +} + +func (b *AlterTableSqlBuilder) AddUniqueConstraint(name string, columns ...string) *AlterTableSqlBuilder { + b.Changes = append(b.Changes, fmt.Sprintf("ADD CONSTRAINT %s UNIQUE (%s)", name, strings.Join(columns, ", "))) + return b +} + +func (b *AlterTableSqlBuilder) ToSQL() (string, error) { + var str strings.Builder + + str.WriteString("ALTER TABLE ") + str.WriteString(b.Name) + str.WriteString(" ") + + if len(b.Changes) == 0 { + return "", fmt.Errorf("no changes provide for table: %s", b.Name) + } + changeCount := len(b.Changes) + for i, thing := range b.Changes { + str.WriteString(thing) + if i < changeCount-1 { + str.WriteString(", ") + } + } + + return str.String(), nil +} diff --git a/db/alter_test.go b/db/alter_test.go new file mode 100644 index 0000000..4bd58ac --- /dev/null +++ b/db/alter_test.go @@ -0,0 +1,56 @@ +package db + +import "testing" + +func TestAlterTableSqlBuilder_ToSQL(t *testing.T) { + type fields struct { + Dialect DialectType + Name string + Changes []string + } + tests := []struct { + name string + builder *AlterTableSqlBuilder + want string + wantErr bool + }{ + { + name: "MySQL add int", + builder: DialectMySQL. + AlterTable("the_table"). + AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)), + want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL", + wantErr: false, + }, + { + name: "MySQL add string", + builder: DialectMySQL. + AlterTable("the_table"). + AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})), + want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL", + wantErr: false, + }, + + { + name: "MySQL add int and string", + builder: DialectMySQL. + AlterTable("the_table"). + AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)). + AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})), + want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.builder.ToSQL() + if (err != nil) != tt.wantErr { + t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ToSQL() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/db/create.go b/db/create.go new file mode 100644 index 0000000..c384778 --- /dev/null +++ b/db/create.go @@ -0,0 +1,244 @@ +package db + +import ( + "fmt" + "strings" +) + +type ColumnType int + +type OptionalInt struct { + Set bool + Value int +} + +type OptionalString struct { + Set bool + Value string +} + +type SQLBuilder interface { + ToSQL() (string, error) +} + +type Column struct { + Dialect DialectType + Name string + Nullable bool + Default OptionalString + Type ColumnType + Size OptionalInt + PrimaryKey bool +} + +type CreateTableSqlBuilder struct { + Dialect DialectType + Name string + IfNotExists bool + ColumnOrder []string + Columns map[string]*Column + Constraints []string +} + +const ( + ColumnTypeBool ColumnType = iota + ColumnTypeSmallInt ColumnType = iota + ColumnTypeInteger ColumnType = iota + ColumnTypeChar ColumnType = iota + ColumnTypeVarChar ColumnType = iota + ColumnTypeText ColumnType = iota + ColumnTypeDateTime ColumnType = iota +) + +var _ SQLBuilder = &CreateTableSqlBuilder{} + +var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0} +var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""} + +func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) { + if dialect != DialectMySQL && dialect != DialectSQLite { + return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) + } + switch d { + case ColumnTypeSmallInt: + { + if dialect == DialectSQLite { + return "INTEGER", nil + } + mod := "" + if size.Set { + mod = fmt.Sprintf("(%d)", size.Value) + } + return "SMALLINT" + mod, nil + } + case ColumnTypeInteger: + { + if dialect == DialectSQLite { + return "INTEGER", nil + } + mod := "" + if size.Set { + mod = fmt.Sprintf("(%d)", size.Value) + } + return "INT" + mod, nil + } + case ColumnTypeChar: + { + if dialect == DialectSQLite { + return "TEXT", nil + } + mod := "" + if size.Set { + mod = fmt.Sprintf("(%d)", size.Value) + } + return "CHAR" + mod, nil + } + case ColumnTypeVarChar: + { + if dialect == DialectSQLite { + return "TEXT", nil + } + mod := "" + if size.Set { + mod = fmt.Sprintf("(%d)", size.Value) + } + return "VARCHAR" + mod, nil + } + case ColumnTypeBool: + { + if dialect == DialectSQLite { + return "INTEGER", nil + } + return "TINYINT(1)", nil + } + case ColumnTypeDateTime: + return "DATETIME", nil + case ColumnTypeText: + return "TEXT", nil + } + return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) +} + +func (c *Column) SetName(name string) *Column { + c.Name = name + return c +} + +func (c *Column) SetNullable(nullable bool) *Column { + c.Nullable = nullable + return c +} + +func (c *Column) SetPrimaryKey(pk bool) *Column { + c.PrimaryKey = pk + return c +} + +func (c *Column) SetDefault(value string) *Column { + c.Default = OptionalString{Set: true, Value: value} + return c +} + +func (c *Column) SetType(t ColumnType) *Column { + c.Type = t + return c +} + +func (c *Column) SetSize(size int) *Column { + c.Size = OptionalInt{Set: true, Value: size} + return c +} + +func (c *Column) String() (string, error) { + var str strings.Builder + + str.WriteString(c.Name) + + str.WriteString(" ") + typeStr, err := c.Type.Format(c.Dialect, c.Size) + if err != nil { + return "", err + } + + str.WriteString(typeStr) + + if !c.Nullable { + str.WriteString(" NOT NULL") + } + + if c.Default.Set { + str.WriteString(" DEFAULT ") + str.WriteString(c.Default.Value) + } + + if c.PrimaryKey { + str.WriteString(" PRIMARY KEY") + } + + return str.String(), nil +} + +func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder { + if b.Columns == nil { + b.Columns = make(map[string]*Column) + } + b.Columns[column.Name] = column + b.ColumnOrder = append(b.ColumnOrder, column.Name) + return b +} + +func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder { + for _, column := range columns { + if _, ok := b.Columns[column]; !ok { + // This fails silently. + return b + } + } + b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ","))) + return b +} + +func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder { + b.IfNotExists = ine + return b +} + +func (b *CreateTableSqlBuilder) ToSQL() (string, error) { + var str strings.Builder + + str.WriteString("CREATE TABLE ") + if b.IfNotExists { + str.WriteString("IF NOT EXISTS ") + } + str.WriteString(b.Name) + + var things []string + for _, columnName := range b.ColumnOrder { + column, ok := b.Columns[columnName] + if !ok { + return "", fmt.Errorf("column not found: %s", columnName) + } + columnStr, err := column.String() + if err != nil { + return "", err + } + things = append(things, columnStr) + } + for _, constraint := range b.Constraints { + things = append(things, constraint) + } + + if thingLen := len(things); thingLen > 0 { + str.WriteString(" ( ") + for i, thing := range things { + str.WriteString(thing) + if i < thingLen-1 { + str.WriteString(", ") + } + } + str.WriteString(" )") + } + + return str.String(), nil +} + diff --git a/db/create_test.go b/db/create_test.go new file mode 100644 index 0000000..369d5c1 --- /dev/null +++ b/db/create_test.go @@ -0,0 +1,146 @@ +package db + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestDialect_Column(t *testing.T) { + c1 := DialectSQLite.Column("foo", ColumnTypeBool, UnsetSize) + assert.Equal(t, DialectSQLite, c1.Dialect) + c2 := DialectMySQL.Column("foo", ColumnTypeBool, UnsetSize) + assert.Equal(t, DialectMySQL, c2.Dialect) +} + +func TestColumnType_Format(t *testing.T) { + type args struct { + dialect DialectType + size OptionalInt + } + tests := []struct { + name string + d ColumnType + args args + want string + wantErr bool + }{ + {"Sqlite bool", ColumnTypeBool, args{dialect: DialectSQLite}, "INTEGER", false}, + {"Sqlite small int", ColumnTypeSmallInt, args{dialect: DialectSQLite}, "INTEGER", false}, + {"Sqlite int", ColumnTypeInteger, args{dialect: DialectSQLite}, "INTEGER", false}, + {"Sqlite char", ColumnTypeChar, args{dialect: DialectSQLite}, "TEXT", false}, + {"Sqlite varchar", ColumnTypeVarChar, args{dialect: DialectSQLite}, "TEXT", false}, + {"Sqlite text", ColumnTypeText, args{dialect: DialectSQLite}, "TEXT", false}, + {"Sqlite datetime", ColumnTypeDateTime, args{dialect: DialectSQLite}, "DATETIME", false}, + + {"MySQL bool", ColumnTypeBool, args{dialect: DialectMySQL}, "TINYINT(1)", false}, + {"MySQL small int", ColumnTypeSmallInt, args{dialect: DialectMySQL}, "SMALLINT", false}, + {"MySQL small int with param", ColumnTypeSmallInt, args{dialect: DialectMySQL, size: OptionalInt{true, 3}}, "SMALLINT(3)", false}, + {"MySQL int", ColumnTypeInteger, args{dialect: DialectMySQL}, "INT", false}, + {"MySQL int with param", ColumnTypeInteger, args{dialect: DialectMySQL, size: OptionalInt{true, 11}}, "INT(11)", false}, + {"MySQL char", ColumnTypeChar, args{dialect: DialectMySQL}, "CHAR", false}, + {"MySQL char with param", ColumnTypeChar, args{dialect: DialectMySQL, size: OptionalInt{true, 4}}, "CHAR(4)", false}, + {"MySQL varchar", ColumnTypeVarChar, args{dialect: DialectMySQL}, "VARCHAR", false}, + {"MySQL varchar with param", ColumnTypeVarChar, args{dialect: DialectMySQL, size: OptionalInt{true, 25}}, "VARCHAR(25)", false}, + {"MySQL text", ColumnTypeText, args{dialect: DialectMySQL}, "TEXT", false}, + {"MySQL datetime", ColumnTypeDateTime, args{dialect: DialectMySQL}, "DATETIME", false}, + + {"invalid column type", 10000, args{dialect: DialectMySQL}, "", true}, + {"invalid dialect", ColumnTypeBool, args{dialect: 10000}, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.d.Format(tt.args.dialect, tt.args.size) + if (err != nil) != tt.wantErr { + t.Errorf("Format() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Format() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestColumn_Build(t *testing.T) { + type fields struct { + Dialect DialectType + Name string + Nullable bool + Default OptionalString + Type ColumnType + Size OptionalInt + PrimaryKey bool + } + tests := []struct { + name string + fields fields + want string + wantErr bool + }{ + {"Sqlite bool", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER NOT NULL", false}, + {"Sqlite bool nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER", false}, + {"Sqlite small int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo INTEGER NOT NULL PRIMARY KEY", false}, + {"Sqlite small int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo INTEGER", false}, + {"Sqlite int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER NOT NULL", false}, + {"Sqlite int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER", false}, + {"Sqlite char", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT NOT NULL", false}, + {"Sqlite char nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT", false}, + {"Sqlite varchar", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT NOT NULL", false}, + {"Sqlite varchar nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT", false}, + {"Sqlite text", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false}, + {"Sqlite text nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false}, + {"Sqlite datetime", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false}, + {"Sqlite datetime nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false}, + + {"MySQL bool", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1) NOT NULL", false}, + {"MySQL bool nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1)", false}, + {"MySQL small int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false}, + {"MySQL small int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo SMALLINT", false}, + {"MySQL int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT NOT NULL", false}, + {"MySQL int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT", false}, + {"MySQL char", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR NOT NULL", false}, + {"MySQL char nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR", false}, + {"MySQL varchar", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR NOT NULL", false}, + {"MySQL varchar nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR", false}, + {"MySQL text", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false}, + {"MySQL text nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false}, + {"MySQL datetime", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false}, + {"MySQL datetime nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Column{ + Dialect: tt.fields.Dialect, + Name: tt.fields.Name, + Nullable: tt.fields.Nullable, + Default: tt.fields.Default, + Type: tt.fields.Type, + Size: tt.fields.Size, + PrimaryKey: tt.fields.PrimaryKey, + } + if got, err := c.String(); got != tt.want { + if (err != nil) != tt.wantErr { + t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("String() got = %v, want %v", got, tt.want) + } + } + }) + } +} + +func TestCreateTableSqlBuilder_ToSQL(t *testing.T) { + sql, err := DialectMySQL. + Table("foo"). + SetIfNotExists(true). + Column(DialectMySQL.Column("bar", ColumnTypeInteger, UnsetSize).SetPrimaryKey(true)). + Column(DialectMySQL.Column("baz", ColumnTypeText, UnsetSize)). + Column(DialectMySQL.Column("qux", ColumnTypeDateTime, UnsetSize).SetDefault("NOW()")). + UniqueConstraint("bar"). + UniqueConstraint("bar", "baz"). + ToSQL() + assert.NoError(t, err) + assert.Equal(t, "CREATE TABLE IF NOT EXISTS foo ( bar INT NOT NULL PRIMARY KEY, baz TEXT NOT NULL, qux DATETIME NOT NULL DEFAULT NOW(), UNIQUE(bar), UNIQUE(bar,baz) )", sql) +} diff --git a/db/dialect.go b/db/dialect.go new file mode 100644 index 0000000..4251465 --- /dev/null +++ b/db/dialect.go @@ -0,0 +1,76 @@ +package db + +import "fmt" + +type DialectType int + +const ( + DialectSQLite DialectType = iota + DialectMySQL DialectType = iota +) + +func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column { + switch d { + case DialectSQLite: + return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size} + case DialectMySQL: + return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} + +func (d DialectType) Table(name string) *CreateTableSqlBuilder { + switch d { + case DialectSQLite: + return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name} + case DialectMySQL: + return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} + +func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder { + switch d { + case DialectSQLite: + return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name} + case DialectMySQL: + return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} + +func (d DialectType) CreateUniqueIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { + switch d { + case DialectSQLite: + return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: true, Columns: columns} + case DialectMySQL: + return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: true, Columns: columns} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} + +func (d DialectType) CreateIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { + switch d { + case DialectSQLite: + return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: false, Columns: columns} + case DialectMySQL: + return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: false, Columns: columns} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} + +func (d DialectType) DropIndex(name, table string) *DropIndexSqlBuilder { + switch d { + case DialectSQLite: + return &DropIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table} + case DialectMySQL: + return &DropIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table} + default: + panic(fmt.Sprintf("unexpected dialect: %d", d)) + } +} diff --git a/db/index.go b/db/index.go new file mode 100644 index 0000000..8180224 --- /dev/null +++ b/db/index.go @@ -0,0 +1,53 @@ +package db + +import ( + "fmt" + "strings" +) + +type CreateIndexSqlBuilder struct { + Dialect DialectType + Name string + Table string + Unique bool + Columns []string +} + +type DropIndexSqlBuilder struct { + Dialect DialectType + Name string + Table string +} + +func (b *CreateIndexSqlBuilder) ToSQL() (string, error) { + var str strings.Builder + + str.WriteString("CREATE ") + if b.Unique { + str.WriteString("UNIQUE ") + } + str.WriteString("INDEX ") + str.WriteString(b.Name) + str.WriteString(" on ") + str.WriteString(b.Table) + + if len(b.Columns) == 0 { + return "", fmt.Errorf("columns provided for this index: %s", b.Name) + } + + str.WriteString(" (") + columnCount := len(b.Columns) + for i, thing := range b.Columns { + str.WriteString(thing) + if i < columnCount-1 { + str.WriteString(", ") + } + } + str.WriteString(")") + + return str.String(), nil +} + +func (b *DropIndexSqlBuilder) ToSQL() (string, error) { + return fmt.Sprintf("DROP INDEX %s on %s", b.Name, b.Table), nil +} diff --git a/db/raw.go b/db/raw.go new file mode 100644 index 0000000..d0301c8 --- /dev/null +++ b/db/raw.go @@ -0,0 +1,9 @@ +package db + +type RawSqlBuilder struct { + Query string +} + +func (b *RawSqlBuilder) ToSQL() (string, error) { + return b.Query, nil +} diff --git a/db/tx.go b/db/tx.go new file mode 100644 index 0000000..5c321af --- /dev/null +++ b/db/tx.go @@ -0,0 +1,26 @@ +package db + +import ( + "context" + "database/sql" +) + +// TransactionScopedWork describes code executed within a database transaction. +type TransactionScopedWork func(ctx context.Context, db *sql.Tx) error + +// RunTransactionWithOptions executes a block of code within a database transaction. +func RunTransactionWithOptions(ctx context.Context, db *sql.DB, txOpts *sql.TxOptions, txWork TransactionScopedWork) error { + tx, err := db.BeginTx(ctx, txOpts) + if err != nil { + return err + } + + if err = txWork(ctx, tx); err != nil { + if txErr := tx.Rollback(); txErr != nil { + return txErr + } + return err + } + return tx.Commit() +} + diff --git a/go.mod b/go.mod index 339af45..f6aa8b7 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/ikeikeikeike/go-sitemap-generator v1.0.1 github.com/ikeikeikeike/go-sitemap-generator/v2 v2.0.2 github.com/jtolds/gls v4.2.1+incompatible // indirect + github.com/kr/pretty v0.1.0 github.com/kylemcc/twitter-text-go v0.0.0-20180726194232-7f582f6736ec github.com/lunixbochs/vtclean v1.0.0 // indirect github.com/manifoldco/promptui v0.3.2 @@ -31,21 +32,19 @@ require ( github.com/nicksnyder/go-i18n v1.10.0 // indirect github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d github.com/pelletier/go-toml v1.2.0 // indirect - github.com/pkg/errors v0.8.1 // indirect + github.com/pkg/errors v0.8.1 github.com/rainycape/unidecode v0.0.0-20150907023854-cb7f23ec59be // indirect - github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect github.com/smartystreets/assertions v0.0.0-20190116191733-b6c0e53d7304 // indirect github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c // indirect - github.com/stretchr/testify v1.3.0 // indirect + github.com/stretchr/testify v1.3.0 github.com/writeas/activity v0.1.2 github.com/writeas/go-strip-markdown v2.0.1+incompatible github.com/writeas/go-webfinger v0.0.0-20190106002315-85cf805c86d2 github.com/writeas/httpsig v1.0.0 - github.com/writeas/impart v1.1.0 + github.com/writeas/impart v1.1.1-0.20191230230525-d3c45ced010d github.com/writeas/import v0.2.0 github.com/writeas/monday v0.0.0-20181024183321-54a7dd579219 github.com/writeas/nerds v1.0.0 - github.com/writeas/openssl-go v1.0.0 // indirect github.com/writeas/saturday v1.7.1 github.com/writeas/slug v1.2.0 github.com/writeas/web-core v1.2.0 @@ -54,10 +53,11 @@ require ( golang.org/x/lint v0.0.0-20181217174547-8f45f776aaf1 // indirect golang.org/x/net v0.0.0-20190206173232-65e2d4e15006 // indirect golang.org/x/sys v0.0.0-20190209173611-3b5209105503 // indirect - golang.org/x/tools v0.0.0-20190208222737-3744606dbb67 // indirect + golang.org/x/tools v0.0.0-20190208222737-3744606dbb67 google.golang.org/appengine v1.4.0 // indirect gopkg.in/alecthomas/kingpin.v3-unstable v3.0.0-20180810215634-df19058c872c // indirect gopkg.in/ini.v1 v1.41.0 - gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0 // indirect gopkg.in/yaml.v2 v2.2.2 // indirect ) + +go 1.13 diff --git a/go.sum b/go.sum index b0a56ab..5b8b88a 100644 --- a/go.sum +++ b/go.sum @@ -135,6 +135,8 @@ github.com/writeas/httpsig v1.0.0 h1:peIAoIA3DmlP8IG8tMNZqI4YD1uEnWBmkcC9OFPjt3A github.com/writeas/httpsig v1.0.0/go.mod h1:7ClMGSrSVXJbmiLa17bZ1LrG1oibGZmUMlh3402flPY= github.com/writeas/impart v1.1.0 h1:nPnoO211VscNkp/gnzir5UwCDEvdHThL5uELU60NFSE= github.com/writeas/impart v1.1.0/go.mod h1:g0MpxdnTOHHrl+Ca/2oMXUHJ0PcRAEWtkCzYCJUXC9Y= +github.com/writeas/impart v1.1.1-0.20191230230525-d3c45ced010d h1:PK7DOj3JE6MGf647esPrKzXEHFjGWX2hl22uX79ixaE= +github.com/writeas/impart v1.1.1-0.20191230230525-d3c45ced010d/go.mod h1:g0MpxdnTOHHrl+Ca/2oMXUHJ0PcRAEWtkCzYCJUXC9Y= github.com/writeas/import v0.0.0-20190815214647-baae8acd8d06 h1:S6oKKP8GhSoyZUvVuhO9UiQ9f+U1aR/x5B4MP7YQHaU= github.com/writeas/import v0.0.0-20190815214647-baae8acd8d06/go.mod h1:f3K8z7YnJwKnPIT4h7980n9C6cQb4DIB2QcxVCTB7lE= github.com/writeas/import v0.0.0-20190815235139-628d10daaa9e h1:31PkvDTWkjzC1nGzWw9uAE92ZfcVyFX/K9L9ejQjnEs= @@ -149,6 +151,7 @@ github.com/writeas/nerds v1.0.0 h1:ZzRcCN+Sr3MWID7o/x1cr1ZbLvdpej9Y1/Ho+JKlqxo= github.com/writeas/nerds v1.0.0/go.mod h1:Gn2bHy1EwRcpXeB7ZhVmuUwiweK0e+JllNf66gvNLdU= github.com/writeas/openssl-go v1.0.0 h1:YXM1tDXeYOlTyJjoMlYLQH1xOloUimSR1WMF8kjFc5o= github.com/writeas/openssl-go v1.0.0/go.mod h1:WsKeK5jYl0B5y8ggOmtVjbmb+3rEGqSD25TppjJnETA= +github.com/writeas/saturday v1.6.0/go.mod h1:ETE1EK6ogxptJpAgUbcJD0prAtX48bSloie80+tvnzQ= github.com/writeas/saturday v1.7.1 h1:lYo1EH6CYyrFObQoA9RNWHVlpZA5iYL5Opxo7PYAnZE= github.com/writeas/saturday v1.7.1/go.mod h1:ETE1EK6ogxptJpAgUbcJD0prAtX48bSloie80+tvnzQ= github.com/writeas/slug v1.2.0 h1:EMQ+cwLiOcA6EtFwUgyw3Ge18x9uflUnOnR6bp/J+/g= @@ -161,6 +164,7 @@ github.com/writefreely/go-nodeinfo v1.2.0 h1:La+YbTCvmpTwFhBSlebWDDL81N88Qf/SCAv github.com/writefreely/go-nodeinfo v1.2.0/go.mod h1:UTvE78KpcjYOlRHupZIiSEFcXHioTXuacCbHU+CAcPg= golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59 h1:hk3yo72LXLapY9EXVttc3Z1rLOxT9IuAPPX3GpY2+jo= golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190208162236-193df9c0f06f h1:ETU2VEl7TnT5bl7IvuKEzTDpplg5wzGYsOCAPhdoEIg= golang.org/x/crypto v0.0.0-20190208162236-193df9c0f06f/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= diff --git a/handle.go b/handle.go index 7e410f5..0fcc483 100644 --- a/handle.go +++ b/handle.go @@ -73,7 +73,7 @@ type ( type Handler struct { errors *ErrorPages - sessionStore *sessions.CookieStore + sessionStore sessions.Store app Apper } @@ -96,7 +96,7 @@ func NewHandler(apper Apper) *Handler { InternalServerError: template.Must(template.New("").Parse("{{define \"base\"}}500

Internal server error.

{{end}}")), Blank: template.Must(template.New("").Parse("{{define \"base\"}}{{.Title}}

{{.Content}}

{{end}}")), }, - sessionStore: apper.App().sessionStore, + sessionStore: apper.App().SessionStore(), app: apper, } @@ -549,6 +549,37 @@ func (h *Handler) All(f handlerFunc) http.HandlerFunc { } } +func (h *Handler) OAuth(f handlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + h.handleOAuthError(w, r, func() error { + // TODO: return correct "success" status + status := 200 + start := time.Now() + + defer func() { + if e := recover(); e != nil { + log.Error("%s:\n%s", e, debug.Stack()) + impart.WriteError(w, impart.HTTPError{http.StatusInternalServerError, "Something didn't work quite right."}) + status = 500 + } + + log.Info(h.app.ReqLog(r, status, time.Since(start))) + }() + + err := f(h.app.App(), w, r) + if err != nil { + if err, ok := err.(impart.HTTPError); ok { + status = err.Status + } else { + status = 500 + } + } + + return err + }()) + } +} + func (h *Handler) AllReader(f handlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { h.handleError(w, r, func() error { @@ -779,6 +810,25 @@ func (h *Handler) handleError(w http.ResponseWriter, r *http.Request, err error) h.errors.InternalServerError.ExecuteTemplate(w, "base", pageForReq(h.app.App(), r)) } +func (h *Handler) handleOAuthError(w http.ResponseWriter, r *http.Request, err error) { + if err == nil { + return + } + + if err, ok := err.(impart.HTTPError); ok { + if err.Status >= 300 && err.Status < 400 { + sendRedirect(w, err.Status, err.Message) + return + } + + impart.WriteOAuthError(w, err) + return + } + + impart.WriteOAuthError(w, impart.HTTPError{http.StatusInternalServerError, "This is an unhelpful error message for a miscellaneous internal error."}) + return +} + func correctPageFromLoginAttempt(r *http.Request) string { to := r.FormValue("to") if to == "" { diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..d600d83 --- /dev/null +++ b/main_test.go @@ -0,0 +1,153 @@ +package writefreely + +import ( + "context" + "database/sql" + "encoding/gob" + "errors" + "fmt" + uuid "github.com/nu7hatch/gouuid" + "github.com/stretchr/testify/assert" + "math/rand" + "os" + "strings" + "testing" + "time" +) + +var testDB *sql.DB + +type ScopedTestBody func(*sql.DB) + +// TestMain provides testing infrastructure within this package. +func TestMain(m *testing.M) { + rand.Seed(time.Now().UTC().UnixNano()) + gob.Register(&User{}) + + if runMySQLTests() { + var err error + + testDB, err = initMySQL(os.Getenv("WF_USER"), os.Getenv("WF_PASSWORD"), os.Getenv("WF_DB"), os.Getenv("WF_HOST")) + if err != nil { + fmt.Println(err) + return + } + } + + code := m.Run() + if runMySQLTests() { + if closeErr := testDB.Close(); closeErr != nil { + fmt.Println(closeErr) + } + } + os.Exit(code) +} + +func runMySQLTests() bool { + return len(os.Getenv("TEST_MYSQL")) > 0 +} + +func initMySQL(dbUser, dbPassword, dbName, dbHost string) (*sql.DB, error) { + if dbUser == "" || dbPassword == "" { + return nil, errors.New("database user or password not set") + } + if dbHost == "" { + dbHost = "localhost" + } + if dbName == "" { + dbName = "writefreely" + } + + dsn := fmt.Sprintf("%s:%s@tcp(%s:3306)/%s?charset=utf8mb4&parseTime=true", dbUser, dbPassword, dbHost, dbName) + db, err := sql.Open("mysql", dsn) + if err != nil { + return nil, err + } + if err := ensureMySQL(db); err != nil { + return nil, err + } + return db, nil +} + +func ensureMySQL(db *sql.DB) error { + if err := db.Ping(); err != nil { + return err + } + db.SetMaxOpenConns(250) + return nil +} + +// withTestDB provides a scoped database connection. +func withTestDB(t *testing.T, testBody ScopedTestBody) { + db, cleanup, err := newTestDatabase(testDB, + os.Getenv("WF_USER"), + os.Getenv("WF_PASSWORD"), + os.Getenv("WF_DB"), + os.Getenv("WF_HOST"), + ) + assert.NoError(t, err) + defer func() { + assert.NoError(t, cleanup()) + }() + + testBody(db) +} + +// newTestDatabase creates a new temporary test database. When a test +// database connection is returned, it will have created a new database and +// initialized it with tables from a reference database. +func newTestDatabase(base *sql.DB, dbUser, dbPassword, dbName, dbHost string) (*sql.DB, func() error, error) { + var err error + var baseName = dbName + + if baseName == "" { + row := base.QueryRow("SELECT DATABASE()") + err := row.Scan(&baseName) + if err != nil { + return nil, nil, err + } + } + tUUID, _ := uuid.NewV4() + suffix := strings.Replace(tUUID.String(), "-", "_", -1) + newDBName := baseName + suffix + _, err = base.Exec("CREATE DATABASE " + newDBName) + if err != nil { + return nil, nil, err + } + newDB, err := initMySQL(dbUser, dbPassword, newDBName, dbHost) + if err != nil { + return nil, nil, err + } + + rows, err := base.Query("SHOW TABLES IN " + baseName) + if err != nil { + return nil, nil, err + } + for rows.Next() { + var tableName string + if err := rows.Scan(&tableName); err != nil { + return nil, nil, err + } + query := fmt.Sprintf("CREATE TABLE %s LIKE %s.%s", tableName, baseName, tableName) + if _, err := newDB.Exec(query); err != nil { + return nil, nil, err + } + } + + cleanup := func() error { + if closeErr := newDB.Close(); closeErr != nil { + fmt.Println(closeErr) + } + + _, err = base.Exec("DROP DATABASE " + newDBName) + return err + } + return newDB, cleanup, nil +} + +func countRows(t *testing.T, ctx context.Context, db *sql.DB, count int, query string, args ...interface{}) { + var returned int + err := db.QueryRowContext(ctx, query, args...).Scan(&returned) + assert.NoError(t, err, "error executing query %s and args %s", query, args) + assert.Equal(t, count, returned, "unexpected return count %d, expected %d from %s and args %s", returned, count, query, args) +} \ No newline at end of file diff --git a/migrations/migrations.go b/migrations/migrations.go index 0799f8e..917d912 100644 --- a/migrations/migrations.go +++ b/migrations/migrations.go @@ -59,6 +59,8 @@ var migrations = []Migration{ New("support user invites", supportUserInvites), // -> V1 (v0.8.0) New("support dynamic instance pages", supportInstancePages), // V1 -> V2 (v0.9.0) New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0) + New("support oauth", oauth), // V3 -> V4 + New("support slack oauth", oauthSlack), // V4 -> v5 } // CurrentVer returns the current migration version the application is on diff --git a/migrations/v4.go b/migrations/v4.go new file mode 100644 index 0000000..c075dd8 --- /dev/null +++ b/migrations/v4.go @@ -0,0 +1,46 @@ +package migrations + +import ( + "context" + "database/sql" + + wf_db "github.com/writeas/writefreely/db" +) + +func oauth(db *datastore) error { + dialect := wf_db.DialectMySQL + if db.driverName == driverSQLite { + dialect = wf_db.DialectSQLite + } + return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + createTableUsersOauth, err := dialect. + Table("oauth_users"). + SetIfNotExists(true). + Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). + Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). + UniqueConstraint("user_id"). + UniqueConstraint("remote_user_id"). + ToSQL() + if err != nil { + return err + } + createTableOauthClientState, err := dialect. + Table("oauth_client_states"). + SetIfNotExists(true). + Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})). + Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)). + Column(dialect.Column("created_at", wf_db.ColumnTypeDateTime, wf_db.UnsetSize).SetDefault("NOW()")). + UniqueConstraint("state"). + ToSQL() + if err != nil { + return err + } + + for _, table := range []string{createTableUsersOauth, createTableOauthClientState} { + if _, err := tx.ExecContext(ctx, table); err != nil { + return err + } + } + return nil + }) +} diff --git a/migrations/v5.go b/migrations/v5.go new file mode 100644 index 0000000..94e3944 --- /dev/null +++ b/migrations/v5.go @@ -0,0 +1,67 @@ +package migrations + +import ( + "context" + "database/sql" + + wf_db "github.com/writeas/writefreely/db" +) + +func oauthSlack(db *datastore) error { + dialect := wf_db.DialectMySQL + if db.driverName == driverSQLite { + dialect = wf_db.DialectSQLite + } + return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + builders := []wf_db.SQLBuilder{ + dialect. + AlterTable("oauth_client_states"). + AddColumn(dialect. + Column( + "provider", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 24,})). + AddColumn(dialect. + Column( + "client_id", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 128,})), + dialect. + AlterTable("oauth_users"). + ChangeColumn("remote_user_id", + dialect. + Column( + "remote_user_id", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 128,})). + AddColumn(dialect. + Column( + "provider", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 24,})). + AddColumn(dialect. + Column( + "client_id", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 128,})). + AddColumn(dialect. + Column( + "access_token", + wf_db.ColumnTypeVarChar, + wf_db.OptionalInt{Set: true, Value: 512,})), + dialect.DropIndex("remote_user_id", "oauth_users"), + dialect.DropIndex("user_id", "oauth_users"), + dialect.CreateUniqueIndex("oauth_users", "oauth_users", "user_id", "provider", "client_id"), + } + for _, builder := range builders { + query, err := builder.ToSQL() + if err != nil { + return err + } + if _, err := tx.ExecContext(ctx, query); err != nil { + return err + } + } + return nil + }) +} diff --git a/oauth.go b/oauth.go new file mode 100644 index 0000000..4758e0f --- /dev/null +++ b/oauth.go @@ -0,0 +1,244 @@ +package writefreely + +import ( + "context" + "encoding/json" + "fmt" + "github.com/gorilla/mux" + "github.com/gorilla/sessions" + "github.com/writeas/impart" + "github.com/writeas/nerds/store" + "github.com/writeas/web-core/auth" + "github.com/writeas/web-core/log" + "github.com/writeas/writefreely/config" + "io" + "io/ioutil" + "net/http" + "time" +) + +// TokenResponse contains data returned when a token is created either +// through a code exchange or using a refresh token. +type TokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + Error string `json:"error"` +} + +// InspectResponse contains data returned when an access token is inspected. +type InspectResponse struct { + ClientID string `json:"client_id"` + UserID string `json:"user_id"` + ExpiresAt time.Time `json:"expires_at"` + Username string `json:"username"` + DisplayName string `json:"-"` + Email string `json:"email"` + Error string `json:"error"` +} + +// tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token +// endpoint. One megabyte is plenty. +const tokenRequestMaxLen = 1000000 + +// infoRequestMaxLen is the most bytes that we'll read from the +// /oauth/inspect endpoint. +const infoRequestMaxLen = 1000000 + +// OAuthDatastoreProvider provides a minimal interface of data store, config, +// and session store for use with the oauth handlers. +type OAuthDatastoreProvider interface { + DB() OAuthDatastore + Config() *config.Config + SessionStore() sessions.Store +} + +// OAuthDatastore provides a minimal interface of data store methods used in +// oauth functionality. +type OAuthDatastore interface { + GetIDForRemoteUser(context.Context, string, string, string) (int64, error) + RecordRemoteUserID(context.Context, int64, string, string, string, string) error + ValidateOAuthState(context.Context, string) (string, string, error) + GenerateOAuthState(context.Context, string, string) (string, error) + + CreateUser(*config.Config, *User, string) error + GetUserByID(int64) (*User, error) +} + +type HttpClient interface { + Do(req *http.Request) (*http.Response, error) +} + +type oauthClient interface { + GetProvider() string + GetClientID() string + buildLoginURL(state string) (string, error) + exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) + inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) +} + +type oauthHandler struct { + Config *config.Config + DB OAuthDatastore + Store sessions.Store + EmailKey []byte + oauthClient oauthClient +} + +func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID()) + if err != nil { + return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} + } + location, err := h.oauthClient.buildLoginURL(state) + if err != nil { + return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} + } + return impart.HTTPError{http.StatusTemporaryRedirect, location} +} + +func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) { + if app.Config().SlackOauth.ClientID != "" { + oauthClient := slackOauthClient{ + ClientID: app.Config().SlackOauth.ClientID, + ClientSecret: app.Config().SlackOauth.ClientSecret, + TeamID: app.Config().SlackOauth.TeamID, + CallbackLocation: app.Config().App.Host + "/oauth/callback", + HttpClient: config.DefaultHTTPClient(), + } + configureOauthRoutes(parentHandler, r, app, oauthClient) + } +} + +func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) { + if app.Config().WriteAsOauth.ClientID != "" { + oauthClient := writeAsOauthClient{ + ClientID: app.Config().WriteAsOauth.ClientID, + ClientSecret: app.Config().WriteAsOauth.ClientSecret, + ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation), + InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation), + AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation), + HttpClient: config.DefaultHTTPClient(), + CallbackLocation: app.Config().App.Host + "/oauth/callback", + } + configureOauthRoutes(parentHandler, r, app, oauthClient) + } +} + +func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient) { + handler := &oauthHandler{ + Config: app.Config(), + DB: app.DB(), + Store: app.SessionStore(), + oauthClient: oauthClient, + EmailKey: app.keys.EmailKey, + } + r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET") + r.HandleFunc("/oauth/callback", parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET") +} + +func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + code := r.FormValue("code") + state := r.FormValue("state") + + provider, clientID, err := h.DB.ValidateOAuthState(ctx, state) + if err != nil { + log.Error("Unable to ValidateOAuthState: %s", err) + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + + tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code) + if err != nil { + log.Error("Unable to exchangeOauthCode: %s", err) + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + + // Now that we have the access token, let's use it real quick to make sur + // it really really works. + tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) + if err != nil { + log.Error("Unable to inspectOauthAccessToken: %s", err) + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + + localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID) + if err != nil { + log.Error("Unable to GetIDForRemoteUser: %s", err) + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + + if localUserID == -1 { + // We don't have, nor do we want, the password from the origin, so we + //create a random string. If the user needs to set a password, they + //can do so through the settings page or through the password reset + //flow. + randPass := store.Generate62RandomString(14) + hashedPass, err := auth.HashPass([]byte(randPass)) + if err != nil { + return impart.HTTPError{http.StatusInternalServerError, "unable to create password hash"} + } + newUser := &User{ + Username: tokenInfo.Username, + HashedPass: hashedPass, + HasPass: true, + Email: prepareUserEmail(tokenInfo.Email, h.EmailKey), + Created: time.Now().Truncate(time.Second).UTC(), + } + displayName := tokenInfo.DisplayName + if len(displayName) == 0 { + displayName = tokenInfo.Username + } + + err = h.DB.CreateUser(h.Config, newUser, displayName) + if err != nil { + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + + err = h.DB.RecordRemoteUserID(ctx, newUser.ID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken) + if err != nil { + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + + if err := loginOrFail(h.Store, w, r, newUser); err != nil { + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + return nil + } + + user, err := h.DB.GetUserByID(localUserID) + if err != nil { + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + if err = loginOrFail(h.Store, w, r, user); err != nil { + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + return nil +} + +func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error { + lr := io.LimitReader(body, int64(n+1)) + data, err := ioutil.ReadAll(lr) + if err != nil { + return err + } + if len(data) == n+1 { + return fmt.Errorf("content larger than max read allowance: %d", n) + } + return json.Unmarshal(data, thing) +} + +func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error { + // An error may be returned, but a valid session should always be returned. + session, _ := store.Get(r, cookieName) + session.Values[cookieUserVal] = user.Cookie() + if err := session.Save(r, w); err != nil { + fmt.Println("error saving session", err) + return err + } + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + return nil +} diff --git a/oauth/state.go b/oauth/state.go new file mode 100644 index 0000000..e8dd154 --- /dev/null +++ b/oauth/state.go @@ -0,0 +1,10 @@ +package oauth + +import "context" + +// ClientStateStore provides state management used by the OAuth client. +type ClientStateStore interface { + Generate(ctx context.Context) (string, error) + Validate(ctx context.Context, state string) error +} + diff --git a/oauth_slack.go b/oauth_slack.go new file mode 100644 index 0000000..066aa18 --- /dev/null +++ b/oauth_slack.go @@ -0,0 +1,164 @@ +package writefreely + +import ( + "context" + "errors" + "github.com/writeas/slug" + "net/http" + "net/url" + "strings" +) + +type slackOauthClient struct { + ClientID string + ClientSecret string + TeamID string + CallbackLocation string + HttpClient HttpClient +} + +type slackExchangeResponse struct { + OK bool `json:"ok"` + AccessToken string `json:"access_token"` + Scope string `json:"scope"` + TeamName string `json:"team_name"` + TeamID string `json:"team_id"` + Error string `json:"error"` +} + +type slackIdentity struct { + Name string `json:"name"` + ID string `json:"id"` + Email string `json:"email"` +} + +type slackTeam struct { + Name string `json:"name"` + ID string `json:"id"` +} + +type slackUserIdentityResponse struct { + OK bool `json:"ok"` + User slackIdentity `json:"user"` + Team slackTeam `json:"team"` + Error string `json:"error"` +} + +const ( + slackAuthLocation = "https://slack.com/oauth/authorize" + slackExchangeLocation = "https://slack.com/api/oauth.access" + slackIdentityLocation = "https://slack.com/api/users.identity" +) + +var _ oauthClient = slackOauthClient{} + +func (c slackOauthClient) GetProvider() string { + return "slack" +} + +func (c slackOauthClient) GetClientID() string { + return c.ClientID +} + +func (c slackOauthClient) buildLoginURL(state string) (string, error) { + u, err := url.Parse(slackAuthLocation) + if err != nil { + return "", err + } + q := u.Query() + q.Set("client_id", c.ClientID) + q.Set("scope", "identity.basic identity.email identity.team") + q.Set("redirect_uri", c.CallbackLocation) + q.Set("state", state) + + // If this param is not set, the user can select which team they + // authenticate through and then we'd have to match the configured team + // against the profile get. That is extra work in the post-auth phase + // that we don't want to do. + q.Set("team", c.TeamID) + + // The Slack OAuth docs don't explicitly list this one, but it is part of + // the spec, so we include it anyway. + q.Set("response_type", "code") + u.RawQuery = q.Encode() + return u.String(), nil +} + +func (c slackOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { + form := url.Values{} + // The oauth.access documentation doesn't explicitly mention this + // parameter, but it is part of the spec, so we include it anyway. + // https://api.slack.com/methods/oauth.access + form.Add("grant_type", "authorization_code") + form.Add("redirect_uri", c.CallbackLocation) + form.Add("code", code) + req, err := http.NewRequest("POST", slackExchangeLocation, strings.NewReader(form.Encode())) + if err != nil { + return nil, err + } + req.WithContext(ctx) + req.Header.Set("User-Agent", "writefreely") + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(c.ClientID, c.ClientSecret) + + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to exchange code for access token") + } + + var tokenResponse slackExchangeResponse + if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { + return nil, err + } + if !tokenResponse.OK { + return nil, errors.New(tokenResponse.Error) + } + return tokenResponse.TokenResponse(), nil +} + +func (c slackOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { + req, err := http.NewRequest("GET", slackIdentityLocation, nil) + if err != nil { + return nil, err + } + req.WithContext(ctx) + req.Header.Set("User-Agent", "writefreely") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to inspect access token") + } + + var inspectResponse slackUserIdentityResponse + if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { + return nil, err + } + if !inspectResponse.OK { + return nil, errors.New(inspectResponse.Error) + } + return inspectResponse.InspectResponse(), nil +} + +func (resp slackUserIdentityResponse) InspectResponse() *InspectResponse { + return &InspectResponse{ + UserID: resp.User.ID, + Username: slug.Make(resp.User.Name), + DisplayName: resp.User.Name, + Email: resp.User.Email, + } +} + +func (resp slackExchangeResponse) TokenResponse() *TokenResponse { + return &TokenResponse{ + AccessToken: resp.AccessToken, + } +} diff --git a/oauth_test.go b/oauth_test.go new file mode 100644 index 0000000..2e293e7 --- /dev/null +++ b/oauth_test.go @@ -0,0 +1,253 @@ +package writefreely + +import ( + "context" + "fmt" + "github.com/gorilla/sessions" + "github.com/stretchr/testify/assert" + "github.com/writeas/impart" + "github.com/writeas/nerds/store" + "github.com/writeas/writefreely/config" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +type MockOAuthDatastoreProvider struct { + DoDB func() OAuthDatastore + DoConfig func() *config.Config + DoSessionStore func() sessions.Store +} + +type MockOAuthDatastore struct { + DoGenerateOAuthState func(context.Context, string, string) (string, error) + DoValidateOAuthState func(context.Context, string) (string, string, error) + DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error) + DoCreateUser func(*config.Config, *User, string) error + DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error + DoGetUserByID func(int64) (*User, error) +} + +var _ OAuthDatastore = &MockOAuthDatastore{} + +type StringReadCloser struct { + *strings.Reader +} + +func (src *StringReadCloser) Close() error { + return nil +} + +type MockHTTPClient struct { + DoDo func(req *http.Request) (*http.Response, error) +} + +func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) { + if m.DoDo != nil { + return m.DoDo(req) + } + return &http.Response{}, nil +} + +func (m *MockOAuthDatastoreProvider) SessionStore() sessions.Store { + if m.DoSessionStore != nil { + return m.DoSessionStore() + } + return sessions.NewCookieStore([]byte("secret-key")) +} + +func (m *MockOAuthDatastoreProvider) DB() OAuthDatastore { + if m.DoDB != nil { + return m.DoDB() + } + return &MockOAuthDatastore{} +} + +func (m *MockOAuthDatastoreProvider) Config() *config.Config { + if m.DoConfig != nil { + return m.DoConfig() + } + cfg := config.New() + cfg.UseSQLite(true) + cfg.WriteAsOauth = config.WriteAsOauthCfg{ + ClientID: "development", + ClientSecret: "development", + AuthLocation: "https://write.as/oauth/login", + TokenLocation: "https://write.as/oauth/token", + InspectLocation: "https://write.as/oauth/inspect", + } + cfg.SlackOauth = config.SlackOauthCfg{ + ClientID: "development", + ClientSecret: "development", + TeamID: "development", + } + return cfg +} + +func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) { + if m.DoValidateOAuthState != nil { + return m.DoValidateOAuthState(ctx, state) + } + return "", "", nil +} + +func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) { + if m.DoGetIDForRemoteUser != nil { + return m.DoGetIDForRemoteUser(ctx, remoteUserID, provider, clientID) + } + return -1, nil +} + +func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username string) error { + if m.DoCreateUser != nil { + return m.DoCreateUser(cfg, u, username) + } + u.ID = 1 + return nil +} + +func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error { + if m.DoRecordRemoteUserID != nil { + return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID, provider, clientID, accessToken) + } + return nil +} + +func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) { + if m.DoGetUserByID != nil { + return m.DoGetUserByID(userID) + } + user := &User{ + + } + return user, nil +} + +func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string) (string, error) { + if m.DoGenerateOAuthState != nil { + return m.DoGenerateOAuthState(ctx, provider, clientID) + } + return store.Generate62RandomString(14), nil +} + +func TestViewOauthInit(t *testing.T) { + + t.Run("success", func(t *testing.T) { + app := &MockOAuthDatastoreProvider{} + h := oauthHandler{ + Config: app.Config(), + DB: app.DB(), + Store: app.SessionStore(), + EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd}, + oauthClient: writeAsOauthClient{ + ClientID: app.Config().WriteAsOauth.ClientID, + ClientSecret: app.Config().WriteAsOauth.ClientSecret, + ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, + InspectLocation: app.Config().WriteAsOauth.InspectLocation, + AuthLocation: app.Config().WriteAsOauth.AuthLocation, + CallbackLocation: "http://localhost/oauth/callback", + HttpClient: nil, + }, + } + req, err := http.NewRequest("GET", "/oauth/client", nil) + assert.NoError(t, err) + rr := httptest.NewRecorder() + err = h.viewOauthInit(nil, rr, req) + assert.NotNil(t, err) + httpErr, ok := err.(impart.HTTPError) + assert.True(t, ok) + assert.Equal(t, http.StatusTemporaryRedirect, httpErr.Status) + assert.NotEmpty(t, httpErr.Message) + locURI, err := url.Parse(httpErr.Message) + assert.NoError(t, err) + assert.Equal(t, "/oauth/login", locURI.Path) + assert.Equal(t, "development", locURI.Query().Get("client_id")) + assert.Equal(t, "http://localhost/oauth/callback", locURI.Query().Get("redirect_uri")) + assert.Equal(t, "code", locURI.Query().Get("response_type")) + assert.NotEmpty(t, locURI.Query().Get("state")) + }) + + t.Run("state failure", func(t *testing.T) { + app := &MockOAuthDatastoreProvider{ + DoDB: func() OAuthDatastore { + return &MockOAuthDatastore{ + DoGenerateOAuthState: func(ctx context.Context, provider, clientID string) (string, error) { + return "", fmt.Errorf("pretend unable to write state error") + }, + } + }, + } + h := oauthHandler{ + Config: app.Config(), + DB: app.DB(), + Store: app.SessionStore(), + EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd}, + oauthClient: writeAsOauthClient{ + ClientID: app.Config().WriteAsOauth.ClientID, + ClientSecret: app.Config().WriteAsOauth.ClientSecret, + ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, + InspectLocation: app.Config().WriteAsOauth.InspectLocation, + AuthLocation: app.Config().WriteAsOauth.AuthLocation, + CallbackLocation: "http://localhost/oauth/callback", + HttpClient: nil, + }, + } + req, err := http.NewRequest("GET", "/oauth/client", nil) + assert.NoError(t, err) + rr := httptest.NewRecorder() + err = h.viewOauthInit(nil, rr, req) + httpErr, ok := err.(impart.HTTPError) + assert.True(t, ok) + assert.NotEmpty(t, httpErr.Message) + assert.Equal(t, http.StatusInternalServerError, httpErr.Status) + assert.Equal(t, "could not prepare oauth redirect url", httpErr.Message) + }) +} + +func TestViewOauthCallback(t *testing.T) { + t.Run("success", func(t *testing.T) { + app := &MockOAuthDatastoreProvider{} + h := oauthHandler{ + Config: app.Config(), + DB: app.DB(), + Store: app.SessionStore(), + EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd}, + oauthClient: writeAsOauthClient{ + ClientID: app.Config().WriteAsOauth.ClientID, + ClientSecret: app.Config().WriteAsOauth.ClientSecret, + ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, + InspectLocation: app.Config().WriteAsOauth.InspectLocation, + AuthLocation: app.Config().WriteAsOauth.AuthLocation, + CallbackLocation: "http://localhost/oauth/callback", + HttpClient: &MockHTTPClient{ + DoDo: func(req *http.Request) (*http.Response, error) { + switch req.URL.String() { + case "https://write.as/oauth/token": + return &http.Response{ + StatusCode: 200, + Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)}, + }, nil + case "https://write.as/oauth/inspect": + return &http.Response{ + StatusCode: 200, + Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)}, + }, nil + } + + return &http.Response{ + StatusCode: http.StatusNotFound, + }, nil + }, + }, + }, + } + req, err := http.NewRequest("GET", "/oauth/callback", nil) + assert.NoError(t, err) + rr := httptest.NewRecorder() + err = h.viewOauthCallback(nil, rr, req) + assert.NoError(t, err) + assert.Equal(t, http.StatusTemporaryRedirect, rr.Code) + }) +} diff --git a/oauth_writeas.go b/oauth_writeas.go new file mode 100644 index 0000000..eb12f64 --- /dev/null +++ b/oauth_writeas.go @@ -0,0 +1,110 @@ +package writefreely + +import ( + "context" + "errors" + "net/http" + "net/url" + "strings" +) + +type writeAsOauthClient struct { + ClientID string + ClientSecret string + AuthLocation string + ExchangeLocation string + InspectLocation string + CallbackLocation string + HttpClient HttpClient +} + +var _ oauthClient = writeAsOauthClient{} + +const ( + writeAsAuthLocation = "https://write.as/oauth/login" + writeAsExchangeLocation = "https://write.as/oauth/token" + writeAsIdentityLocation = "https://write.as/oauth/inspect" +) + +func (c writeAsOauthClient) GetProvider() string { + return "write.as" +} + +func (c writeAsOauthClient) GetClientID() string { + return c.ClientID +} + +func (c writeAsOauthClient) buildLoginURL(state string) (string, error) { + u, err := url.Parse(c.AuthLocation) + if err != nil { + return "", err + } + q := u.Query() + q.Set("client_id", c.ClientID) + q.Set("redirect_uri", c.CallbackLocation) + q.Set("response_type", "code") + q.Set("state", state) + u.RawQuery = q.Encode() + return u.String(), nil +} + +func (c writeAsOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { + form := url.Values{} + form.Add("grant_type", "authorization_code") + form.Add("redirect_uri", c.CallbackLocation) + form.Add("code", code) + req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode())) + if err != nil { + return nil, err + } + req.WithContext(ctx) + req.Header.Set("User-Agent", "writefreely") + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(c.ClientID, c.ClientSecret) + + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to exchange code for access token") + } + + var tokenResponse TokenResponse + if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { + return nil, err + } + if tokenResponse.Error != "" { + return nil, errors.New(tokenResponse.Error) + } + return &tokenResponse, nil +} + +func (c writeAsOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { + req, err := http.NewRequest("GET", c.InspectLocation, nil) + if err != nil { + return nil, err + } + req.WithContext(ctx) + req.Header.Set("User-Agent", "writefreely") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err := c.HttpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to inspect access token") + } + + var inspectResponse InspectResponse + if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { + return nil, err + } + if inspectResponse.Error != "" { + return nil, errors.New(inspectResponse.Error) + } + return &inspectResponse, nil +} diff --git a/pad.go b/pad.go index 37d1c9b..3b0f1c2 100644 --- a/pad.go +++ b/pad.go @@ -92,6 +92,7 @@ func handleViewPad(app *App, w http.ResponseWriter, r *http.Request) error { if err != nil { return err } + appData.EditCollection.hostName = app.cfg.App.Host } else { // Editing a floating article appData.Post = getRawPost(app, action) @@ -161,6 +162,7 @@ func handleViewMeta(app *App, w http.ResponseWriter, r *http.Request) error { if err != nil { return err } + appData.EditCollection.hostName = app.cfg.App.Host } else { // Editing a floating article appData.Post = getRawPost(app, action) diff --git a/postrender.go b/postrender.go index 83fb5ad..312de58 100644 --- a/postrender.go +++ b/postrender.go @@ -1,5 +1,5 @@ /* - * Copyright © 2018 A Bunch Tell LLC. + * Copyright © 2018-2020 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -11,9 +11,11 @@ package writefreely import ( + "encoding/json" "fmt" "html" "html/template" + "net/http" "regexp" "strings" "unicode" @@ -21,7 +23,9 @@ import ( "github.com/microcosm-cc/bluemonday" stripmd "github.com/writeas/go-strip-markdown" + "github.com/writeas/impart" blackfriday "github.com/writeas/saturday" + "github.com/writeas/web-core/log" "github.com/writeas/web-core/stringmanip" "github.com/writeas/writefreely/config" "github.com/writeas/writefreely/parse" @@ -234,3 +238,29 @@ func shortPostDescription(content string) string { } return strings.TrimSpace(fmt.Sprintf(fmtStr, strings.Replace(stringmanip.Substring(content, 0, maxLen-truncation), "\n", " ", -1))) } + +func handleRenderMarkdown(app *App, w http.ResponseWriter, r *http.Request) error { + if !IsJSON(r) { + return impart.HTTPError{Status: http.StatusUnsupportedMediaType, Message: "Markdown API only supports JSON requests"} + } + + in := struct { + CollectionURL string `json:"collection_url"` + RawBody string `json:"raw_body"` + }{} + + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&in) + if err != nil { + log.Error("Couldn't parse markdown JSON request: %v", err) + return ErrBadJSON + } + + out := struct { + Body string `json:"body"` + }{ + Body: applyMarkdown([]byte(in.RawBody), in.CollectionURL, app.cfg), + } + + return impart.WriteSuccess(w, out, http.StatusOK) +} diff --git a/posts.go b/posts.go index 6410735..21ed1a1 100644 --- a/posts.go +++ b/posts.go @@ -381,10 +381,12 @@ func handleViewPost(app *App, w http.ResponseWriter, r *http.Request) error { } } - suspended, err := app.db.IsUserSuspended(ownerID.Int64) - if err != nil { - log.Error("view post: %v", err) - return ErrInternalGeneral + var suspended bool + if found { + suspended, err = app.db.IsUserSuspended(ownerID.Int64) + if err != nil { + log.Error("view post: %v", err) + } } // Check if post has been unpublished @@ -511,7 +513,6 @@ func newPost(app *App, w http.ResponseWriter, r *http.Request) error { suspended, err := app.db.IsUserSuspended(userID) if err != nil { log.Error("new post: %v", err) - return ErrInternalGeneral } if suspended { return ErrUserSuspended @@ -685,7 +686,6 @@ func existingPost(app *App, w http.ResponseWriter, r *http.Request) error { suspended, err := app.db.IsUserSuspended(userID) if err != nil { log.Error("existing post: %v", err) - return ErrInternalGeneral } if suspended { return ErrUserSuspended @@ -888,7 +888,6 @@ func addPost(app *App, w http.ResponseWriter, r *http.Request) error { suspended, err := app.db.IsUserSuspended(ownerID) if err != nil { log.Error("add post: %v", err) - return ErrInternalGeneral } if suspended { return ErrUserSuspended @@ -991,7 +990,6 @@ func pinPost(app *App, w http.ResponseWriter, r *http.Request) error { suspended, err := app.db.IsUserSuspended(userID) if err != nil { log.Error("pin post: %v", err) - return ErrInternalGeneral } if suspended { return ErrUserSuspended @@ -1039,7 +1037,6 @@ func pinPost(app *App, w http.ResponseWriter, r *http.Request) error { func fetchPost(app *App, w http.ResponseWriter, r *http.Request) error { var collID int64 - var ownerID int64 var coll *Collection var err error vars := mux.Vars(r) @@ -1049,25 +1046,32 @@ func fetchPost(app *App, w http.ResponseWriter, r *http.Request) error { if err != nil { return err } - coll.hostName = app.cfg.App.Host - _, err = apiCheckCollectionPermissions(app, r, coll) - if err != nil { - return err - } collID = coll.ID - ownerID = coll.OwnerID } p, err := app.db.GetPost(vars["post"], collID) if err != nil { return err } - suspended, err := app.db.IsUserSuspended(ownerID) - if err != nil { - log.Error("fetch post: %v", err) - return ErrInternalGeneral + if coll == nil && p.CollectionID.Valid { + // Collection post is getting fetched by post ID, not coll alias + post slug, so get coll info now. + coll, err = app.db.GetCollectionByID(p.CollectionID.Int64) + if err != nil { + return err + } + } + if coll != nil { + coll.hostName = app.cfg.App.Host + _, err = apiCheckCollectionPermissions(app, r, coll) + if err != nil { + return err + } } + suspended, err := app.db.IsUserSuspended(p.OwnerID.Int64) + if err != nil { + log.Error("fetch post: %v", err) + } if suspended { return ErrPostNotFound } @@ -1076,13 +1080,6 @@ func fetchPost(app *App, w http.ResponseWriter, r *http.Request) error { accept := r.Header.Get("Accept") if strings.Contains(accept, "application/activity+json") { - // Fetch information about the collection this belongs to - if coll == nil && p.CollectionID.Valid { - coll, err = app.db.GetCollectionByID(p.CollectionID.Int64) - if err != nil { - return err - } - } if coll == nil { // This is a draft post; 404 for now // TODO: return ActivityObject @@ -1335,15 +1332,18 @@ func viewCollectionPost(app *App, w http.ResponseWriter, r *http.Request) error suspended, err := app.db.IsUserSuspended(c.OwnerID) if err != nil { log.Error("view collection post: %v", err) - return ErrInternalGeneral } // Check collection permissions if c.IsPrivate() && (u == nil || u.ID != c.OwnerID) { return ErrPostNotFound } - if c.IsProtected() && ((u == nil || u.ID != c.OwnerID) && !isAuthorizedForCollection(app, c.Alias, r)) { - return impart.HTTPError{http.StatusFound, c.CanonicalURL() + "/?g=" + slug} + if c.IsProtected() && (u == nil || u.ID != c.OwnerID) { + if suspended { + return ErrPostNotFound + } else if !isAuthorizedForCollection(app, c.Alias, r) { + return impart.HTTPError{http.StatusFound, c.CanonicalURL() + "/?g=" + slug} + } } cr.isCollOwner = u != nil && c.OwnerID == u.ID diff --git a/routes.go b/routes.go index eb5422a..7784d71 100644 --- a/routes.go +++ b/routes.go @@ -70,6 +70,9 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router { write.HandleFunc(nodeinfo.NodeInfoPath, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfoDiscover))) write.HandleFunc(niCfg.InfoURL, handler.LogHandlerFunc(http.HandlerFunc(ni.NodeInfo))) + configureSlackOauth(handler, write, apper.App()) + configureWriteAsOauth(handler, write, apper.App()) + // Set up dyamic page handlers // Handle auth auth := write.PathPrefix("/api/auth/").Subrouter() @@ -112,6 +115,8 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router { // Sign up validation write.HandleFunc("/api/alias", handler.All(handleUsernameCheck)).Methods("POST") + write.HandleFunc("/api/markdown", handler.All(handleRenderMarkdown)).Methods("POST") + // Handle collections write.HandleFunc("/api/collections", handler.All(newCollection)).Methods("POST") apiColls := write.PathPrefix("/api/collections/").Subrouter() @@ -183,6 +188,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router { } write.HandleFunc(draftEditPrefix+"/{post}", handler.Web(handleViewPost, UserLevelOptional)) write.HandleFunc("/", handler.Web(handleViewHome, UserLevelOptional)) + return r } diff --git a/templates/edit-meta.tmpl b/templates/edit-meta.tmpl index 6707e68..49c7781 100644 --- a/templates/edit-meta.tmpl +++ b/templates/edit-meta.tmpl @@ -270,7 +270,7 @@