Implemented oauth attach functionality, oauth detach functionality, and required data migration. T713

This commit is contained in:
Nick Gerakines 2020-01-15 13:16:59 -05:00
parent 75e2b60328
commit c0317b4e93
9 changed files with 173 additions and 30 deletions

View File

@ -1038,18 +1038,30 @@ func viewSettings(app *App, u *User, w http.ResponseWriter, r *http.Request) err
flashes, _ := getSessionFlashes(app, w, r, nil)
oauthAccounts, err := app.db.GetOauthAccounts(r.Context(), u.ID)
if err != nil {
log.Error("Unable to get oauth accounts for settings: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "Unable to retrieve user data. The humans have been alerted."}
}
obj := struct {
*UserPage
Email string
HasPass bool
IsLogOut bool
Suspended bool
Email string
HasPass bool
IsLogOut bool
Suspended bool
OauthAccounts []oauthAccountInfo
OauthSlack bool
OauthWriteAs bool
}{
UserPage: NewUserPage(app, r, u, "Account Settings", flashes),
Email: fullUser.EmailClear(app.keys),
HasPass: passIsSet,
IsLogOut: r.FormValue("logout") == "1",
Suspended: fullUser.IsSilenced(),
UserPage: NewUserPage(app, r, u, "Account Settings", flashes),
Email: fullUser.EmailClear(app.keys),
HasPass: passIsSet,
IsLogOut: r.FormValue("logout") == "1",
Suspended: fullUser.IsSilenced(),
OauthAccounts: oauthAccounts,
OauthSlack: app.Config().SlackOauth.ClientID != "",
OauthWriteAs: app.Config().WriteAsOauth.ClientID != "",
}
showUserPage(w, "settings", obj)
@ -1094,6 +1106,19 @@ func getTempInfo(app *App, key string, r *http.Request, w http.ResponseWriter) s
return s
}
func removeOauth(app *App, u *User, w http.ResponseWriter, r *http.Request) error {
provider := r.FormValue("provider")
clientID := r.FormValue("client_id")
remoteUserID := r.FormValue("remote_user_id")
err := app.db.RemoveOauth(r.Context(), u.ID, provider, clientID, remoteUserID)
if err != nil {
return impart.HTTPError{Status: http.StatusInternalServerError, Message: err.Error()}
}
return impart.HTTPError{Status: http.StatusFound, Message: "/me/settings"}
}
func prepareUserEmail(input string, emailKey []byte) zero.String {
email := zero.NewString("", input != "")
if len(input) > 0 {

View File

@ -128,8 +128,10 @@ type writestore interface {
GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
RecordRemoteUserID(context.Context, int64, string, string, string, string) error
ValidateOAuthState(context.Context, string) (string, string, error)
GenerateOAuthState(context.Context, string, string) (string, error)
ValidateOAuthState(context.Context, string) (string, string, int64, error)
GenerateOAuthState(context.Context, string, string, int64) (string, error)
GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error)
RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error
DatabaseInitialized() bool
}
@ -2462,20 +2464,23 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
return &t, nil
}
func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) {
func (db *datastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUser int64) (string, error) {
state := store.Generate62RandomString(24)
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, NOW())", state, provider, clientID)
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at, attach_user_id) VALUES (?, ?, ?, FALSE, NOW(), ?)", state, provider, clientID, attachUser)
if err != nil {
return "", fmt.Errorf("unable to record oauth client state: %w", err)
}
return state, nil
}
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) {
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) {
var provider string
var clientID string
var attachUserID int64
err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
err := tx.QueryRow("SELECT provider, client_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).Scan(&provider, &clientID)
err := tx.
QueryRowContext(ctx, "SELECT provider, client_id, attach_user_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).
Scan(&provider, &clientID, &attachUserID)
if err != nil {
return err
}
@ -2494,9 +2499,9 @@ func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (stri
return nil
})
if err != nil {
return "", "", nil
return "", "", 0, nil
}
return provider, clientID, nil
return provider, clientID, attachUserID, nil
}
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
@ -2525,6 +2530,33 @@ func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provi
return userID, nil
}
type oauthAccountInfo struct {
Provider string
ClientID string
RemoteUserID string
}
func (db *datastore) GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error) {
rows, err := db.QueryContext(ctx, "SELECT provider, client_id, remote_user_id FROM oauth_users WHERE user_id = ? ", userID)
if err != nil {
log.Error("Failed selecting from oauth_users: %v", err)
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user oauth accounts."}
}
defer rows.Close()
var records []oauthAccountInfo
for rows.Next() {
info := oauthAccountInfo{}
err = rows.Scan(&info.Provider, &info.ClientID, &info.RemoteUserID)
if err != nil {
log.Error("Failed scanning GetAllUsers() row: %v", err)
break
}
records = append(records, info)
}
return records, nil
}
// DatabaseInitialized returns whether or not the current datastore has been
// initialized with the correct schema.
// Currently, it checks to see if the `users` table exists.
@ -2547,6 +2579,11 @@ func (db *datastore) DatabaseInitialized() bool {
return true
}
func (db *datastore) RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error {
_, err := db.ExecContext(ctx, `DELETE FROM oauth_users WHERE user_id = ? AND provider = ? AND client_id = ? AND remote_user_id = ?`, userID, provider, clientID, remoteUserID)
return err
}
func stringLogln(log *string, s string, v ...interface{}) {
*log += fmt.Sprintf(s+"\n", v...)
}

View File

@ -18,13 +18,13 @@ func TestOAuthDatastore(t *testing.T) {
driverName: "",
}
state, err := ds.GenerateOAuthState(ctx, "test", "development")
state, err := ds.GenerateOAuthState(ctx, "test", "development", 0)
assert.NoError(t, err)
assert.Len(t, state, 24)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state)
_, _, err = ds.ValidateOAuthState(ctx, state)
_, _, _, err = ds.ValidateOAuthState(ctx, state)
assert.NoError(t, err)
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state)

View File

@ -61,6 +61,7 @@ var migrations = []Migration{
New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0)
New("support oauth", oauth), // V3 -> V4
New("support slack oauth", oauthSlack), // V4 -> v5
New("support oauth attach", oauthAttach), // V5 -> V6
}
// CurrentVer returns the current migration version the application is on

36
migrations/v6.go Normal file
View File

@ -0,0 +1,36 @@
package migrations
import (
"context"
"database/sql"
wf_db "github.com/writeas/writefreely/db"
)
func oauthAttach(db *datastore) error {
dialect := wf_db.DialectMySQL
if db.driverName == driverSQLite {
dialect = wf_db.DialectSQLite
}
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_states").
AddColumn(dialect.
Column(
"attach_user_id",
wf_db.ColumnTypeInteger,
wf_db.OptionalInt{Set: true, Value: 24,}).SetNullable(false).SetDefault("0")),
}
for _, builder := range builders {
query, err := builder.ToSQL()
if err != nil {
return err
}
if _, err := tx.ExecContext(ctx, query); err != nil {
return err
}
}
return nil
})
}

View File

@ -59,8 +59,8 @@ type OAuthDatastoreProvider interface {
type OAuthDatastore interface {
GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
RecordRemoteUserID(context.Context, int64, string, string, string, string) error
ValidateOAuthState(context.Context, string) (string, string, error)
GenerateOAuthState(context.Context, string, string) (string, error)
ValidateOAuthState(context.Context, string) (string, string, int64, error)
GenerateOAuthState(context.Context, string, string, int64) (string, error)
CreateUser(*config.Config, *User, string) error
GetUserByID(int64) (*User, error)
@ -96,19 +96,32 @@ type oauthHandler struct {
func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID())
var attachUser int64
if attach := r.URL.Query().Get("attach"); attach == "t" {
user, _ := getUserAndSession(app, r)
if user == nil {
return impart.HTTPError{http.StatusInternalServerError, "cannot attach auth to user: user not found in session"}
}
attachUser = user.ID
}
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser)
if err != nil {
log.Error("viewOauthInit error: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
}
if h.callbackProxy != nil {
if err := h.callbackProxy.register(ctx, state); err != nil {
log.Error("viewOauthInit error: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "could not register state server"}
}
}
location, err := h.oauthClient.buildLoginURL(state)
if err != nil {
log.Error("viewOauthInit error: %s", err)
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
}
return impart.HTTPError{http.StatusTemporaryRedirect, location}
@ -185,7 +198,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
code := r.FormValue("code")
state := r.FormValue("state")
provider, clientID, err := h.DB.ValidateOAuthState(ctx, state)
provider, clientID, attachUserID, err := h.DB.ValidateOAuthState(ctx, state)
if err != nil {
log.Error("Unable to ValidateOAuthState: %s", err)
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
@ -223,6 +236,14 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
}
return nil
}
if attachUserID > 0 {
log.Info("attaching to user %d", attachUserID)
err = h.DB.RecordRemoteUserID(r.Context(), attachUserID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken)
if err != nil {
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
}
return impart.HTTPError{http.StatusFound, "/me/settings"}
}
displayName := tokenInfo.DisplayName
if len(displayName) == 0 {

View File

@ -22,8 +22,8 @@ type MockOAuthDatastoreProvider struct {
}
type MockOAuthDatastore struct {
DoGenerateOAuthState func(context.Context, string, string) (string, error)
DoValidateOAuthState func(context.Context, string) (string, string, error)
DoGenerateOAuthState func(context.Context, string, string, int64) (string, error)
DoValidateOAuthState func(context.Context, string) (string, string, int64, error)
DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error)
DoCreateUser func(*config.Config, *User, string) error
DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error
@ -86,11 +86,11 @@ func (m *MockOAuthDatastoreProvider) Config() *config.Config {
return cfg
}
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) {
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) {
if m.DoValidateOAuthState != nil {
return m.DoValidateOAuthState(ctx, state)
}
return "", "", nil
return "", "", 0, nil
}
func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
@ -125,9 +125,9 @@ func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) {
return user, nil
}
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string) (string, error) {
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64) (string, error) {
if m.DoGenerateOAuthState != nil {
return m.DoGenerateOAuthState(ctx, provider, clientID)
return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID)
}
return store.Generate62RandomString(14), nil
}
@ -173,7 +173,7 @@ func TestViewOauthInit(t *testing.T) {
app := &MockOAuthDatastoreProvider{
DoDB: func() OAuthDatastore {
return &MockOAuthDatastore{
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string) (string, error) {
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64) (string, error) {
return "", fmt.Errorf("pretend unable to write state error")
},
}

View File

@ -101,6 +101,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
me.HandleFunc("/settings", handler.User(viewSettings)).Methods("GET")
me.HandleFunc("/invites", handler.User(handleViewUserInvites)).Methods("GET")
me.HandleFunc("/logout", handler.Web(viewLogout, UserLevelNone)).Methods("GET")
me.HandleFunc("/oauth/remove", handler.User(removeOauth)).Methods("POST")
write.HandleFunc("/api/me", handler.All(viewMeAPI)).Methods("GET")
apiMe := write.PathPrefix("/api/me/").Subrouter()

View File

@ -66,6 +66,28 @@ h3 { font-weight: normal; }
<input type="submit" value="Save changes" tabindex="4" />
</div>
</form>
{{ if .OauthAccounts }}
{{ range $oauth_account := .OauthAccounts }}
<form method="post" action="/me/oauth/remove" autocomplete="false">
<input type="hidden" name="provider" value="{{ $oauth_account.Provider }}" />
<input type="hidden" name="client_id" value="{{ $oauth_account.ClientID }}" />
<input type="hidden" name="remote_user_id" value="{{ $oauth_account.RemoteUserID }}" />
<div class="option">
<h3>{{ $oauth_account.Provider }} </h3>
<div class="section">
<input type="submit" value="Remove" style="margin-left: 1em;" />
</div>
</div>
</form>
{{ end }}
{{ end }}
{{ if .OauthSlack }}
<a class="loginbtn" href="/oauth/slack?attach=t"><img alt="Sign in with Slack" height="40" width="172" src="/img/sign_in_with_slack.png" srcset="/img/sign_in_with_slack.png 1x, /img/sign_in_with_slack@2x.png 2x" /></a>
{{ end }}
{{ if .OauthWriteAs }}
<a class="btn cta loginbtn" id="writeas-login" href="/oauth/write.as?attach=t">Link your <strong>Write.as</strong> account.</a>
{{ end }}
</div>
<script>