From ca4a576c31679599371de51173fe1eb9fddc6257 Mon Sep 17 00:00:00 2001 From: Matt Baer Date: Mon, 20 Apr 2020 18:18:23 -0400 Subject: [PATCH] Support OAuth registration with invite code This adds any OAuth login buttons to the invite signup page, stores the invite code for the flow duration, and associates the new user with it once successfully registered. It enables invite-only instances with OAuth-based registration. --- database.go | 20 ++++++++-------- database_test.go | 4 ++-- invites.go | 16 ++++++++++++- less/app.less | 1 + less/login.less | 45 ++++++++++++++++++++++++++++++++++++ migrations/migrations.go | 3 ++- migrations/v8.go | 45 ++++++++++++++++++++++++++++++++++++ oauth.go | 49 +++++++++++++++++++++++++++++++++++----- oauth_signup.go | 13 +++++++++++ oauth_test.go | 14 ++++++------ pages/login.tmpl | 33 --------------------------- pages/signup-oauth.tmpl | 1 + pages/signup.tmpl | 19 ++++++++++++++++ 13 files changed, 204 insertions(+), 59 deletions(-) create mode 100644 less/login.less create mode 100644 migrations/v8.go diff --git a/database.go b/database.go index 128e436..a1e8642 100644 --- a/database.go +++ b/database.go @@ -132,8 +132,8 @@ type writestore interface { GetIDForRemoteUser(context.Context, string, string, string) (int64, error) RecordRemoteUserID(context.Context, int64, string, string, string, string) error - ValidateOAuthState(context.Context, string) (string, string, int64, error) - GenerateOAuthState(context.Context, string, string, int64) (string, error) + ValidateOAuthState(context.Context, string) (string, string, int64, string, error) + GenerateOAuthState(context.Context, string, string, int64, string) (string, error) GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error) RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error @@ -2516,24 +2516,26 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) { return &t, nil } -func (db *datastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUser int64) (string, error) { +func (db *datastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUser int64, inviteCode string) (string, error) { state := store.Generate62RandomString(24) attachUserVal := sql.NullInt64{Valid: attachUser > 0, Int64: attachUser} - _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at, attach_user_id) VALUES (?, ?, ?, FALSE, "+db.now()+", ?)", state, provider, clientID, attachUserVal) + inviteCodeVal := sql.NullString{Valid: inviteCode != "", String: inviteCode} + _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at, attach_user_id, invite_code) VALUES (?, ?, ?, FALSE, "+db.now()+", ?, ?)", state, provider, clientID, attachUserVal, inviteCodeVal) if err != nil { return "", fmt.Errorf("unable to record oauth client state: %w", err) } return state, nil } -func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) { +func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, string, error) { var provider string var clientID string var attachUserID sql.NullInt64 + var inviteCode sql.NullString err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { err := tx. - QueryRowContext(ctx, "SELECT provider, client_id, attach_user_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state). - Scan(&provider, &clientID, &attachUserID) + QueryRowContext(ctx, "SELECT provider, client_id, attach_user_id, invite_code FROM oauth_client_states WHERE state = ? AND used = FALSE", state). + Scan(&provider, &clientID, &attachUserID, &inviteCode) if err != nil { return err } @@ -2552,9 +2554,9 @@ func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (stri return nil }) if err != nil { - return "", "", 0, nil + return "", "", 0, "", nil } - return provider, clientID, attachUserID.Int64, nil + return provider, clientID, attachUserID.Int64, inviteCode.String, nil } func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error { diff --git a/database_test.go b/database_test.go index 569d020..c114077 100644 --- a/database_test.go +++ b/database_test.go @@ -18,13 +18,13 @@ func TestOAuthDatastore(t *testing.T) { driverName: "", } - state, err := ds.GenerateOAuthState(ctx, "test", "development", 0) + state, err := ds.GenerateOAuthState(ctx, "test", "development", 0, "") assert.NoError(t, err) assert.Len(t, state, 24) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state) - _, _, _, err = ds.ValidateOAuthState(ctx, state) + _, _, _, _, err = ds.ValidateOAuthState(ctx, state) assert.NoError(t, err) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state) diff --git a/invites.go b/invites.go index c1c7d95..10416b2 100644 --- a/invites.go +++ b/invites.go @@ -1,5 +1,5 @@ /* - * Copyright © 2019 A Bunch Tell LLC. + * Copyright © 2019-2020 A Bunch Tell LLC. * * This file is part of WriteFreely. * @@ -42,6 +42,18 @@ func (i Invite) Expired() bool { return i.Expires != nil && i.Expires.Before(time.Now()) } +func (i Invite) Active(db *datastore) bool { + if i.Expired() { + return false + } + if i.MaxUses.Valid && i.MaxUses.Int64 > 0 { + if c := db.GetUsersInvitedCount(i.ID); c >= i.MaxUses.Int64 { + return false + } + } + return true +} + func (i Invite) ExpiresFriendly() string { return i.Expires.Format("January 2, 2006, 3:04 PM") } @@ -161,9 +173,11 @@ func handleViewInvite(app *App, w http.ResponseWriter, r *http.Request) error { Error string Flashes []template.HTML Invite string + OAuth *OAuthButtons }{ StaticPage: pageForReq(app, r), Invite: inviteCode, + OAuth: NewOAuthButtons(app.cfg), } if expired { diff --git a/less/app.less b/less/app.less index ec3472d..e1cf5ea 100644 --- a/less/app.less +++ b/less/app.less @@ -5,6 +5,7 @@ @import "post-temp"; @import "effects"; @import "admin"; +@import "login"; @import "pages/error"; @import "lib/elements"; @import "lib/material"; diff --git a/less/login.less b/less/login.less new file mode 100644 index 0000000..57efee5 --- /dev/null +++ b/less/login.less @@ -0,0 +1,45 @@ +/* + * Copyright © 2020 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +.row.signinbtns { + justify-content: space-evenly; + font-size: 1em; + margin-top: 3em; + margin-bottom: 2em; + + .loginbtn { + height: 40px; + } + + #writeas-login, #gitlab-login { + box-sizing: border-box; + font-size: 17px; + } +} + +.or { + text-align: center; + margin-bottom: 3.5em; + + p { + display: inline-block; + background-color: white; + padding: 0 1em; + } + + hr { + margin-top: -1.6em; + margin-bottom: 0; + } + + hr.short { + max-width: 30rem; + } +} \ No newline at end of file diff --git a/migrations/migrations.go b/migrations/migrations.go index 31ae43c..a0b3f25 100644 --- a/migrations/migrations.go +++ b/migrations/migrations.go @@ -62,7 +62,8 @@ var migrations = []Migration{ New("support oauth", oauth), // V3 -> V4 New("support slack oauth", oauthSlack), // V4 -> v5 New("support ActivityPub mentions", supportActivityPubMentions), // V5 -> V6 - New("support oauth attach", oauthAttach), // V6 -> V7 (v0.12.0) + New("support oauth attach", oauthAttach), // V6 -> V7 + New("support oauth via invite", oauthInvites), // V7 -> V8 (v0.12.0) } // CurrentVer returns the current migration version the application is on diff --git a/migrations/v8.go b/migrations/v8.go new file mode 100644 index 0000000..2318c4e --- /dev/null +++ b/migrations/v8.go @@ -0,0 +1,45 @@ +/* + * Copyright © 2020 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package migrations + +import ( + "context" + "database/sql" + + wf_db "github.com/writeas/writefreely/db" +) + +func oauthInvites(db *datastore) error { + dialect := wf_db.DialectMySQL + if db.driverName == driverSQLite { + dialect = wf_db.DialectSQLite + } + return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + builders := []wf_db.SQLBuilder{ + dialect. + AlterTable("oauth_client_states"). + AddColumn(dialect.Column("invite_code", wf_db.ColumnTypeChar, wf_db.OptionalInt{ + Set: true, + Value: 6, + }).SetNullable(true)), + } + for _, builder := range builders { + query, err := builder.ToSQL() + if err != nil { + return err + } + if _, err := tx.ExecContext(ctx, query); err != nil { + return err + } + } + return nil + }) +} diff --git a/oauth.go b/oauth.go index d3e31e4..b5c88aa 100644 --- a/oauth.go +++ b/oauth.go @@ -1,3 +1,13 @@ +/* + * Copyright © 2019-2020 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + package writefreely import ( @@ -15,10 +25,27 @@ import ( "github.com/gorilla/sessions" "github.com/writeas/impart" "github.com/writeas/web-core/log" - "github.com/writeas/writefreely/config" ) +// OAuthButtons holds display information for different OAuth providers we support. +type OAuthButtons struct { + SlackEnabled bool + WriteAsEnabled bool + GitLabEnabled bool + GitLabDisplayName string +} + +// NewOAuthButtons creates a new OAuthButtons struct based on our app configuration. +func NewOAuthButtons(cfg *config.Config) *OAuthButtons { + return &OAuthButtons{ + SlackEnabled: cfg.SlackOauth.ClientID != "", + WriteAsEnabled: cfg.WriteAsOauth.ClientID != "", + GitLabEnabled: cfg.GitlabOauth.ClientID != "", + GitLabDisplayName: config.OrDefaultString(cfg.GitlabOauth.DisplayName, gitlabDisplayName), + } +} + // TokenResponse contains data returned when a token is created either // through a code exchange or using a refresh token. type TokenResponse struct { @@ -61,8 +88,8 @@ type OAuthDatastoreProvider interface { type OAuthDatastore interface { GetIDForRemoteUser(context.Context, string, string, string) (int64, error) RecordRemoteUserID(context.Context, int64, string, string, string, string) error - ValidateOAuthState(context.Context, string) (string, string, int64, error) - GenerateOAuthState(context.Context, string, string, int64) (string, error) + ValidateOAuthState(context.Context, string) (string, string, int64, string, error) + GenerateOAuthState(context.Context, string, string, int64, string) (string, error) CreateUser(*config.Config, *User, string) error GetUserByID(int64) (*User, error) @@ -108,7 +135,7 @@ func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Req attachUser = user.ID } - state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser) + state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser, r.FormValue("invite_code")) if err != nil { log.Error("viewOauthInit error: %s", err) return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} @@ -228,7 +255,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http code := r.FormValue("code") state := r.FormValue("state") - provider, clientID, attachUserID, err := h.DB.ValidateOAuthState(ctx, state) + provider, clientID, attachUserID, inviteCode, err := h.DB.ValidateOAuthState(ctx, state) if err != nil { log.Error("Unable to ValidateOAuthState: %s", err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} @@ -285,7 +312,16 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http // New user registration below. // First, verify that user is allowed to register - if !app.cfg.App.OpenRegistration { + if inviteCode != "" { + // Verify invite code is valid + i, err := app.db.GetUserInvite(inviteCode) + if err != nil { + return impart.HTTPError{http.StatusInternalServerError, err.Error()} + } + if !i.Active(app.db) { + return impart.HTTPError{http.StatusNotFound, "Invite link has expired."} + } + } else if !app.cfg.App.OpenRegistration { addSessionFlash(app, w, r, ErrUserNotFound.Error(), nil) return impart.HTTPError{http.StatusFound, "/login"} } @@ -303,6 +339,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http TokenRemoteUser: tokenInfo.UserID, Provider: provider, ClientID: clientID, + InviteCode: inviteCode, } tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed) diff --git a/oauth_signup.go b/oauth_signup.go index 220afbd..cbe4f60 100644 --- a/oauth_signup.go +++ b/oauth_signup.go @@ -38,6 +38,7 @@ type viewOauthSignupVars struct { Provider string ClientID string TokenHash string + InviteCode string LoginUsername string Alias string // TODO: rename this to match the data it represents: the collection title @@ -57,6 +58,7 @@ const ( oauthParamAlias = "alias" oauthParamEmail = "email" oauthParamPassword = "password" + oauthParamInviteCode = "invite_code" ) type oauthSignupPageParams struct { @@ -68,6 +70,7 @@ type oauthSignupPageParams struct { ClientID string Provider string TokenHash string + InviteCode string } func (p oauthSignupPageParams) HashTokenParams(key string) string { @@ -92,6 +95,7 @@ func (h oauthHandler) viewOauthSignup(app *App, w http.ResponseWriter, r *http.R TokenRemoteUser: r.FormValue(oauthParamTokenRemoteUserID), ClientID: r.FormValue(oauthParamClientID), Provider: r.FormValue(oauthParamProvider), + InviteCode: r.FormValue(oauthParamInviteCode), } if tp.HashTokenParams(h.Config.Server.HashSeed) != r.FormValue(oauthParamHash) { return impart.HTTPError{Status: http.StatusBadRequest, Message: "Request has been tampered with."} @@ -128,6 +132,14 @@ func (h oauthHandler) viewOauthSignup(app *App, w http.ResponseWriter, r *http.R return h.showOauthSignupPage(app, w, r, tp, err) } + // Log invite if needed + if tp.InviteCode != "" { + err = app.db.CreateInvitedUser(tp.InviteCode, newUser.ID) + if err != nil { + return err + } + } + err = h.DB.RecordRemoteUserID(r.Context(), newUser.ID, r.FormValue(oauthParamTokenRemoteUserID), r.FormValue(oauthParamProvider), r.FormValue(oauthParamClientID), r.FormValue(oauthParamAccessToken)) if err != nil { return h.showOauthSignupPage(app, w, r, tp, err) @@ -195,6 +207,7 @@ func (h oauthHandler) showOauthSignupPage(app *App, w http.ResponseWriter, r *ht Provider: tp.Provider, ClientID: tp.ClientID, TokenHash: tp.TokenHash, + InviteCode: tp.InviteCode, LoginUsername: username, Alias: collTitle, diff --git a/oauth_test.go b/oauth_test.go index c23eadd..1bd86af 100644 --- a/oauth_test.go +++ b/oauth_test.go @@ -22,8 +22,8 @@ type MockOAuthDatastoreProvider struct { } type MockOAuthDatastore struct { - DoGenerateOAuthState func(context.Context, string, string, int64) (string, error) - DoValidateOAuthState func(context.Context, string) (string, string, int64, error) + DoGenerateOAuthState func(context.Context, string, string, int64, string) (string, error) + DoValidateOAuthState func(context.Context, string) (string, string, int64, string, error) DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error) DoCreateUser func(*config.Config, *User, string) error DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error @@ -86,11 +86,11 @@ func (m *MockOAuthDatastoreProvider) Config() *config.Config { return cfg } -func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) { +func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, string, error) { if m.DoValidateOAuthState != nil { return m.DoValidateOAuthState(ctx, state) } - return "", "", 0, nil + return "", "", 0, "", nil } func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) { @@ -125,9 +125,9 @@ func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) { return user, nil } -func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64) (string, error) { +func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64, inviteCode string) (string, error) { if m.DoGenerateOAuthState != nil { - return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID) + return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID, inviteCode) } return store.Generate62RandomString(14), nil } @@ -173,7 +173,7 @@ func TestViewOauthInit(t *testing.T) { app := &MockOAuthDatastoreProvider{ DoDB: func() OAuthDatastore { return &MockOAuthDatastore{ - DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64) (string, error) { + DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64, inviteCode string) (string, error) { return "", fmt.Errorf("pretend unable to write state error") }, } diff --git a/pages/login.tmpl b/pages/login.tmpl index 94087f3..5a338a4 100644 --- a/pages/login.tmpl +++ b/pages/login.tmpl @@ -3,39 +3,6 @@ {{end}} {{define "content"}} diff --git a/pages/signup-oauth.tmpl b/pages/signup-oauth.tmpl index ecf5db0..e02b89d 100644 --- a/pages/signup-oauth.tmpl +++ b/pages/signup-oauth.tmpl @@ -74,6 +74,7 @@ form dd { + {{if .InviteCode}}{{end}}