Implemented oauth attach functionality, oauth detach functionality, and required data migration. T713

This commit is contained in:
Nick Gerakines 2020-01-15 13:16:59 -05:00
parent 75e2b60328
commit c0317b4e93
9 changed files with 173 additions and 30 deletions

View File

@ -1038,18 +1038,30 @@ func viewSettings(app *App, u *User, w http.ResponseWriter, r *http.Request) err
flashes, _ := getSessionFlashes(app, w, r, nil) flashes, _ := getSessionFlashes(app, w, r, nil)
oauthAccounts, err := app.db.GetOauthAccounts(r.Context(), u.ID)
if err != nil {
log.Error("Unable to get oauth accounts for settings: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "Unable to retrieve user data. The humans have been alerted."}
}
obj := struct { obj := struct {
*UserPage *UserPage
Email string Email string
HasPass bool HasPass bool
IsLogOut bool IsLogOut bool
Suspended bool Suspended bool
OauthAccounts []oauthAccountInfo
OauthSlack bool
OauthWriteAs bool
}{ }{
UserPage: NewUserPage(app, r, u, "Account Settings", flashes), UserPage: NewUserPage(app, r, u, "Account Settings", flashes),
Email: fullUser.EmailClear(app.keys), Email: fullUser.EmailClear(app.keys),
HasPass: passIsSet, HasPass: passIsSet,
IsLogOut: r.FormValue("logout") == "1", IsLogOut: r.FormValue("logout") == "1",
Suspended: fullUser.IsSilenced(), Suspended: fullUser.IsSilenced(),
OauthAccounts: oauthAccounts,
OauthSlack: app.Config().SlackOauth.ClientID != "",
OauthWriteAs: app.Config().WriteAsOauth.ClientID != "",
} }
showUserPage(w, "settings", obj) showUserPage(w, "settings", obj)
@ -1094,6 +1106,19 @@ func getTempInfo(app *App, key string, r *http.Request, w http.ResponseWriter) s
return s return s
} }
func removeOauth(app *App, u *User, w http.ResponseWriter, r *http.Request) error {
provider := r.FormValue("provider")
clientID := r.FormValue("client_id")
remoteUserID := r.FormValue("remote_user_id")
err := app.db.RemoveOauth(r.Context(), u.ID, provider, clientID, remoteUserID)
if err != nil {
return impart.HTTPError{Status: http.StatusInternalServerError, Message: err.Error()}
}
return impart.HTTPError{Status: http.StatusFound, Message: "/me/settings"}
}
func prepareUserEmail(input string, emailKey []byte) zero.String { func prepareUserEmail(input string, emailKey []byte) zero.String {
email := zero.NewString("", input != "") email := zero.NewString("", input != "")
if len(input) > 0 { if len(input) > 0 {

View File

@ -128,8 +128,10 @@ type writestore interface {
GetIDForRemoteUser(context.Context, string, string, string) (int64, error) GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
RecordRemoteUserID(context.Context, int64, string, string, string, string) error RecordRemoteUserID(context.Context, int64, string, string, string, string) error
ValidateOAuthState(context.Context, string) (string, string, error) ValidateOAuthState(context.Context, string) (string, string, int64, error)
GenerateOAuthState(context.Context, string, string) (string, error) GenerateOAuthState(context.Context, string, string, int64) (string, error)
GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error)
RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error
DatabaseInitialized() bool DatabaseInitialized() bool
} }
@ -2462,20 +2464,23 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
return &t, nil return &t, nil
} }
func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) { func (db *datastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUser int64) (string, error) {
state := store.Generate62RandomString(24) state := store.Generate62RandomString(24)
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, NOW())", state, provider, clientID) _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at, attach_user_id) VALUES (?, ?, ?, FALSE, NOW(), ?)", state, provider, clientID, attachUser)
if err != nil { if err != nil {
return "", fmt.Errorf("unable to record oauth client state: %w", err) return "", fmt.Errorf("unable to record oauth client state: %w", err)
} }
return state, nil return state, nil
} }
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) { func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) {
var provider string var provider string
var clientID string var clientID string
var attachUserID int64
err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
err := tx.QueryRow("SELECT provider, client_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).Scan(&provider, &clientID) err := tx.
QueryRowContext(ctx, "SELECT provider, client_id, attach_user_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).
Scan(&provider, &clientID, &attachUserID)
if err != nil { if err != nil {
return err return err
} }
@ -2494,9 +2499,9 @@ func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (stri
return nil return nil
}) })
if err != nil { if err != nil {
return "", "", nil return "", "", 0, nil
} }
return provider, clientID, nil return provider, clientID, attachUserID, nil
} }
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error { func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
@ -2525,6 +2530,33 @@ func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provi
return userID, nil return userID, nil
} }
type oauthAccountInfo struct {
Provider string
ClientID string
RemoteUserID string
}
func (db *datastore) GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error) {
rows, err := db.QueryContext(ctx, "SELECT provider, client_id, remote_user_id FROM oauth_users WHERE user_id = ? ", userID)
if err != nil {
log.Error("Failed selecting from oauth_users: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user oauth accounts."}
}
defer rows.Close()
var records []oauthAccountInfo
for rows.Next() {
info := oauthAccountInfo{}
err = rows.Scan(&info.Provider, &info.ClientID, &info.RemoteUserID)
if err != nil {
log.Error("Failed scanning GetAllUsers() row: %v", err)
break
}
records = append(records, info)
}
return records, nil
}
// DatabaseInitialized returns whether or not the current datastore has been // DatabaseInitialized returns whether or not the current datastore has been
// initialized with the correct schema. // initialized with the correct schema.
// Currently, it checks to see if the `users` table exists. // Currently, it checks to see if the `users` table exists.
@ -2547,6 +2579,11 @@ func (db *datastore) DatabaseInitialized() bool {
return true return true
} }
func (db *datastore) RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error {
_, err := db.ExecContext(ctx, `DELETE FROM oauth_users WHERE user_id = ? AND provider = ? AND client_id = ? AND remote_user_id = ?`, userID, provider, clientID, remoteUserID)
return err
}
func stringLogln(log *string, s string, v ...interface{}) { func stringLogln(log *string, s string, v ...interface{}) {
*log += fmt.Sprintf(s+"\n", v...) *log += fmt.Sprintf(s+"\n", v...)
} }

View File

@ -18,13 +18,13 @@ func TestOAuthDatastore(t *testing.T) {
driverName: "", driverName: "",
} }
state, err := ds.GenerateOAuthState(ctx, "test", "development") state, err := ds.GenerateOAuthState(ctx, "test", "development", 0)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, state, 24) assert.Len(t, state, 24)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state)
_, _, err = ds.ValidateOAuthState(ctx, state) _, _, _, err = ds.ValidateOAuthState(ctx, state)
assert.NoError(t, err) assert.NoError(t, err)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state) countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state)

View File

@ -61,6 +61,7 @@ var migrations = []Migration{
New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0) New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0)
New("support oauth", oauth), // V3 -> V4 New("support oauth", oauth), // V3 -> V4
New("support slack oauth", oauthSlack), // V4 -> v5 New("support slack oauth", oauthSlack), // V4 -> v5
New("support oauth attach", oauthAttach), // V5 -> V6
} }
// CurrentVer returns the current migration version the application is on // CurrentVer returns the current migration version the application is on

36
migrations/v6.go Normal file
View File

@ -0,0 +1,36 @@
package migrations
import (
"context"
"database/sql"
wf_db "github.com/writeas/writefreely/db"
)
func oauthAttach(db *datastore) error {
dialect := wf_db.DialectMySQL
if db.driverName == driverSQLite {
dialect = wf_db.DialectSQLite
}
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_states").
AddColumn(dialect.
Column(
"attach_user_id",
wf_db.ColumnTypeInteger,
wf_db.OptionalInt{Set: true, Value: 24,}).SetNullable(false).SetDefault("0")),
}
for _, builder := range builders {
query, err := builder.ToSQL()
if err != nil {
return err
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return err
}
}
return nil
})
}

View File

@ -59,8 +59,8 @@ type OAuthDatastoreProvider interface {
type OAuthDatastore interface { type OAuthDatastore interface {
GetIDForRemoteUser(context.Context, string, string, string) (int64, error) GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
RecordRemoteUserID(context.Context, int64, string, string, string, string) error RecordRemoteUserID(context.Context, int64, string, string, string, string) error
ValidateOAuthState(context.Context, string) (string, string, error) ValidateOAuthState(context.Context, string) (string, string, int64, error)
GenerateOAuthState(context.Context, string, string) (string, error) GenerateOAuthState(context.Context, string, string, int64) (string, error)
CreateUser(*config.Config, *User, string) error CreateUser(*config.Config, *User, string) error
GetUserByID(int64) (*User, error) GetUserByID(int64) (*User, error)
@ -96,19 +96,32 @@ type oauthHandler struct {
func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error { func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error {
ctx := r.Context() ctx := r.Context()
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID())
var attachUser int64
if attach := r.URL.Query().Get("attach"); attach == "t" {
user, _ := getUserAndSession(app, r)
if user == nil {
return impart.HTTPError{http.StatusInternalServerError, "cannot attach auth to user: user not found in session"}
}
attachUser = user.ID
}
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser)
if err != nil { if err != nil {
log.Error("viewOauthInit error: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
} }
if h.callbackProxy != nil { if h.callbackProxy != nil {
if err := h.callbackProxy.register(ctx, state); err != nil { if err := h.callbackProxy.register(ctx, state); err != nil {
log.Error("viewOauthInit error: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "could not register state server"} return impart.HTTPError{http.StatusInternalServerError, "could not register state server"}
} }
} }
location, err := h.oauthClient.buildLoginURL(state) location, err := h.oauthClient.buildLoginURL(state)
if err != nil { if err != nil {
log.Error("viewOauthInit error: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
} }
return impart.HTTPError{http.StatusTemporaryRedirect, location} return impart.HTTPError{http.StatusTemporaryRedirect, location}
@ -185,7 +198,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
code := r.FormValue("code") code := r.FormValue("code")
state := r.FormValue("state") state := r.FormValue("state")
provider, clientID, err := h.DB.ValidateOAuthState(ctx, state) provider, clientID, attachUserID, err := h.DB.ValidateOAuthState(ctx, state)
if err != nil { if err != nil {
log.Error("Unable to ValidateOAuthState: %s", err) log.Error("Unable to ValidateOAuthState: %s", err)
return impart.HTTPError{http.StatusInternalServerError, err.Error()} return impart.HTTPError{http.StatusInternalServerError, err.Error()}
@ -223,6 +236,14 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
} }
return nil return nil
} }
if attachUserID > 0 {
log.Info("attaching to user %d", attachUserID)
err = h.DB.RecordRemoteUserID(r.Context(), attachUserID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken)
if err != nil {
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
}
return impart.HTTPError{http.StatusFound, "/me/settings"}
}
displayName := tokenInfo.DisplayName displayName := tokenInfo.DisplayName
if len(displayName) == 0 { if len(displayName) == 0 {

View File

@ -22,8 +22,8 @@ type MockOAuthDatastoreProvider struct {
} }
type MockOAuthDatastore struct { type MockOAuthDatastore struct {
DoGenerateOAuthState func(context.Context, string, string) (string, error) DoGenerateOAuthState func(context.Context, string, string, int64) (string, error)
DoValidateOAuthState func(context.Context, string) (string, string, error) DoValidateOAuthState func(context.Context, string) (string, string, int64, error)
DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error) DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error)
DoCreateUser func(*config.Config, *User, string) error DoCreateUser func(*config.Config, *User, string) error
DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error
@ -86,11 +86,11 @@ func (m *MockOAuthDatastoreProvider) Config() *config.Config {
return cfg return cfg
} }
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) { func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) {
if m.DoValidateOAuthState != nil { if m.DoValidateOAuthState != nil {
return m.DoValidateOAuthState(ctx, state) return m.DoValidateOAuthState(ctx, state)
} }
return "", "", nil return "", "", 0, nil
} }
func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) { func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
@ -125,9 +125,9 @@ func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) {
return user, nil return user, nil
} }
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string) (string, error) { func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64) (string, error) {
if m.DoGenerateOAuthState != nil { if m.DoGenerateOAuthState != nil {
return m.DoGenerateOAuthState(ctx, provider, clientID) return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID)
} }
return store.Generate62RandomString(14), nil return store.Generate62RandomString(14), nil
} }
@ -173,7 +173,7 @@ func TestViewOauthInit(t *testing.T) {
app := &MockOAuthDatastoreProvider{ app := &MockOAuthDatastoreProvider{
DoDB: func() OAuthDatastore { DoDB: func() OAuthDatastore {
return &MockOAuthDatastore{ return &MockOAuthDatastore{
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string) (string, error) { DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64) (string, error) {
return "", fmt.Errorf("pretend unable to write state error") return "", fmt.Errorf("pretend unable to write state error")
}, },
} }

View File

@ -101,6 +101,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
me.HandleFunc("/settings", handler.User(viewSettings)).Methods("GET") me.HandleFunc("/settings", handler.User(viewSettings)).Methods("GET")
me.HandleFunc("/invites", handler.User(handleViewUserInvites)).Methods("GET") me.HandleFunc("/invites", handler.User(handleViewUserInvites)).Methods("GET")
me.HandleFunc("/logout", handler.Web(viewLogout, UserLevelNone)).Methods("GET") me.HandleFunc("/logout", handler.Web(viewLogout, UserLevelNone)).Methods("GET")
me.HandleFunc("/oauth/remove", handler.User(removeOauth)).Methods("POST")
write.HandleFunc("/api/me", handler.All(viewMeAPI)).Methods("GET") write.HandleFunc("/api/me", handler.All(viewMeAPI)).Methods("GET")
apiMe := write.PathPrefix("/api/me/").Subrouter() apiMe := write.PathPrefix("/api/me/").Subrouter()

View File

@ -66,6 +66,28 @@ h3 { font-weight: normal; }
<input type="submit" value="Save changes" tabindex="4" /> <input type="submit" value="Save changes" tabindex="4" />
</div> </div>
</form> </form>
{{ if .OauthAccounts }}
{{ range $oauth_account := .OauthAccounts }}
<form method="post" action="/me/oauth/remove" autocomplete="false">
<input type="hidden" name="provider" value="{{ $oauth_account.Provider }}" />
<input type="hidden" name="client_id" value="{{ $oauth_account.ClientID }}" />
<input type="hidden" name="remote_user_id" value="{{ $oauth_account.RemoteUserID }}" />
<div class="option">
<h3>{{ $oauth_account.Provider }} </h3>
<div class="section">
<input type="submit" value="Remove" style="margin-left: 1em;" />
</div>
</div>
</form>
{{ end }}
{{ end }}
{{ if .OauthSlack }}
<a class="loginbtn" href="/oauth/slack?attach=t"><img alt="Sign in with Slack" height="40" width="172" src="/img/sign_in_with_slack.png" srcset="/img/sign_in_with_slack.png 1x, /img/sign_in_with_slack@2x.png 2x" /></a>
{{ end }}
{{ if .OauthWriteAs }}
<a class="btn cta loginbtn" id="writeas-login" href="/oauth/write.as?attach=t">Link your <strong>Write.as</strong> account.</a>
{{ end }}
</div> </div>
<script> <script>