mirror of
https://github.com/writeas/writefreely
synced 2025-02-09 13:08:41 +01:00
Implemented oauth attach functionality, oauth detach functionality, and required data migration. T713
This commit is contained in:
parent
75e2b60328
commit
c0317b4e93
43
account.go
43
account.go
@ -1038,18 +1038,30 @@ func viewSettings(app *App, u *User, w http.ResponseWriter, r *http.Request) err
|
||||
|
||||
flashes, _ := getSessionFlashes(app, w, r, nil)
|
||||
|
||||
oauthAccounts, err := app.db.GetOauthAccounts(r.Context(), u.ID)
|
||||
if err != nil {
|
||||
log.Error("Unable to get oauth accounts for settings: %s", err)
|
||||
return impart.HTTPError{http.StatusInternalServerError, "Unable to retrieve user data. The humans have been alerted."}
|
||||
}
|
||||
|
||||
obj := struct {
|
||||
*UserPage
|
||||
Email string
|
||||
HasPass bool
|
||||
IsLogOut bool
|
||||
Suspended bool
|
||||
Email string
|
||||
HasPass bool
|
||||
IsLogOut bool
|
||||
Suspended bool
|
||||
OauthAccounts []oauthAccountInfo
|
||||
OauthSlack bool
|
||||
OauthWriteAs bool
|
||||
}{
|
||||
UserPage: NewUserPage(app, r, u, "Account Settings", flashes),
|
||||
Email: fullUser.EmailClear(app.keys),
|
||||
HasPass: passIsSet,
|
||||
IsLogOut: r.FormValue("logout") == "1",
|
||||
Suspended: fullUser.IsSilenced(),
|
||||
UserPage: NewUserPage(app, r, u, "Account Settings", flashes),
|
||||
Email: fullUser.EmailClear(app.keys),
|
||||
HasPass: passIsSet,
|
||||
IsLogOut: r.FormValue("logout") == "1",
|
||||
Suspended: fullUser.IsSilenced(),
|
||||
OauthAccounts: oauthAccounts,
|
||||
OauthSlack: app.Config().SlackOauth.ClientID != "",
|
||||
OauthWriteAs: app.Config().WriteAsOauth.ClientID != "",
|
||||
}
|
||||
|
||||
showUserPage(w, "settings", obj)
|
||||
@ -1094,6 +1106,19 @@ func getTempInfo(app *App, key string, r *http.Request, w http.ResponseWriter) s
|
||||
return s
|
||||
}
|
||||
|
||||
func removeOauth(app *App, u *User, w http.ResponseWriter, r *http.Request) error {
|
||||
provider := r.FormValue("provider")
|
||||
clientID := r.FormValue("client_id")
|
||||
remoteUserID := r.FormValue("remote_user_id")
|
||||
|
||||
err := app.db.RemoveOauth(r.Context(), u.ID, provider, clientID, remoteUserID)
|
||||
if err != nil {
|
||||
return impart.HTTPError{Status: http.StatusInternalServerError, Message: err.Error()}
|
||||
}
|
||||
|
||||
return impart.HTTPError{Status: http.StatusFound, Message: "/me/settings"}
|
||||
}
|
||||
|
||||
func prepareUserEmail(input string, emailKey []byte) zero.String {
|
||||
email := zero.NewString("", input != "")
|
||||
if len(input) > 0 {
|
||||
|
53
database.go
53
database.go
@ -128,8 +128,10 @@ type writestore interface {
|
||||
|
||||
GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
|
||||
RecordRemoteUserID(context.Context, int64, string, string, string, string) error
|
||||
ValidateOAuthState(context.Context, string) (string, string, error)
|
||||
GenerateOAuthState(context.Context, string, string) (string, error)
|
||||
ValidateOAuthState(context.Context, string) (string, string, int64, error)
|
||||
GenerateOAuthState(context.Context, string, string, int64) (string, error)
|
||||
GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error)
|
||||
RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error
|
||||
|
||||
DatabaseInitialized() bool
|
||||
}
|
||||
@ -2462,20 +2464,23 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
func (db *datastore) GenerateOAuthState(ctx context.Context, provider, clientID string) (string, error) {
|
||||
func (db *datastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUser int64) (string, error) {
|
||||
state := store.Generate62RandomString(24)
|
||||
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at) VALUES (?, ?, ?, FALSE, NOW())", state, provider, clientID)
|
||||
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at, attach_user_id) VALUES (?, ?, ?, FALSE, NOW(), ?)", state, provider, clientID, attachUser)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to record oauth client state: %w", err)
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) {
|
||||
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) {
|
||||
var provider string
|
||||
var clientID string
|
||||
var attachUserID int64
|
||||
err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
|
||||
err := tx.QueryRow("SELECT provider, client_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).Scan(&provider, &clientID)
|
||||
err := tx.
|
||||
QueryRowContext(ctx, "SELECT provider, client_id, attach_user_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).
|
||||
Scan(&provider, &clientID, &attachUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -2494,9 +2499,9 @@ func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (stri
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", nil
|
||||
return "", "", 0, nil
|
||||
}
|
||||
return provider, clientID, nil
|
||||
return provider, clientID, attachUserID, nil
|
||||
}
|
||||
|
||||
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {
|
||||
@ -2525,6 +2530,33 @@ func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provi
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
type oauthAccountInfo struct {
|
||||
Provider string
|
||||
ClientID string
|
||||
RemoteUserID string
|
||||
}
|
||||
|
||||
func (db *datastore) GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error) {
|
||||
rows, err := db.QueryContext(ctx, "SELECT provider, client_id, remote_user_id FROM oauth_users WHERE user_id = ? ", userID)
|
||||
if err != nil {
|
||||
log.Error("Failed selecting from oauth_users: %v", err)
|
||||
return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user oauth accounts."}
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []oauthAccountInfo
|
||||
for rows.Next() {
|
||||
info := oauthAccountInfo{}
|
||||
err = rows.Scan(&info.Provider, &info.ClientID, &info.RemoteUserID)
|
||||
if err != nil {
|
||||
log.Error("Failed scanning GetAllUsers() row: %v", err)
|
||||
break
|
||||
}
|
||||
records = append(records, info)
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// DatabaseInitialized returns whether or not the current datastore has been
|
||||
// initialized with the correct schema.
|
||||
// Currently, it checks to see if the `users` table exists.
|
||||
@ -2547,6 +2579,11 @@ func (db *datastore) DatabaseInitialized() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (db *datastore) RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error {
|
||||
_, err := db.ExecContext(ctx, `DELETE FROM oauth_users WHERE user_id = ? AND provider = ? AND client_id = ? AND remote_user_id = ?`, userID, provider, clientID, remoteUserID)
|
||||
return err
|
||||
}
|
||||
|
||||
func stringLogln(log *string, s string, v ...interface{}) {
|
||||
*log += fmt.Sprintf(s+"\n", v...)
|
||||
}
|
||||
|
@ -18,13 +18,13 @@ func TestOAuthDatastore(t *testing.T) {
|
||||
driverName: "",
|
||||
}
|
||||
|
||||
state, err := ds.GenerateOAuthState(ctx, "test", "development")
|
||||
state, err := ds.GenerateOAuthState(ctx, "test", "development", 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, state, 24)
|
||||
|
||||
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state)
|
||||
|
||||
_, _, err = ds.ValidateOAuthState(ctx, state)
|
||||
_, _, _, err = ds.ValidateOAuthState(ctx, state)
|
||||
assert.NoError(t, err)
|
||||
|
||||
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state)
|
||||
|
@ -61,6 +61,7 @@ var migrations = []Migration{
|
||||
New("support users suspension", supportUserStatus), // V2 -> V3 (v0.11.0)
|
||||
New("support oauth", oauth), // V3 -> V4
|
||||
New("support slack oauth", oauthSlack), // V4 -> v5
|
||||
New("support oauth attach", oauthAttach), // V5 -> V6
|
||||
}
|
||||
|
||||
// CurrentVer returns the current migration version the application is on
|
||||
|
36
migrations/v6.go
Normal file
36
migrations/v6.go
Normal file
@ -0,0 +1,36 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
wf_db "github.com/writeas/writefreely/db"
|
||||
)
|
||||
|
||||
func oauthAttach(db *datastore) error {
|
||||
dialect := wf_db.DialectMySQL
|
||||
if db.driverName == driverSQLite {
|
||||
dialect = wf_db.DialectSQLite
|
||||
}
|
||||
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
|
||||
builders := []wf_db.SQLBuilder{
|
||||
dialect.
|
||||
AlterTable("oauth_client_states").
|
||||
AddColumn(dialect.
|
||||
Column(
|
||||
"attach_user_id",
|
||||
wf_db.ColumnTypeInteger,
|
||||
wf_db.OptionalInt{Set: true, Value: 24,}).SetNullable(false).SetDefault("0")),
|
||||
}
|
||||
for _, builder := range builders {
|
||||
query, err := builder.ToSQL()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, query); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
29
oauth.go
29
oauth.go
@ -59,8 +59,8 @@ type OAuthDatastoreProvider interface {
|
||||
type OAuthDatastore interface {
|
||||
GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
|
||||
RecordRemoteUserID(context.Context, int64, string, string, string, string) error
|
||||
ValidateOAuthState(context.Context, string) (string, string, error)
|
||||
GenerateOAuthState(context.Context, string, string) (string, error)
|
||||
ValidateOAuthState(context.Context, string) (string, string, int64, error)
|
||||
GenerateOAuthState(context.Context, string, string, int64) (string, error)
|
||||
|
||||
CreateUser(*config.Config, *User, string) error
|
||||
GetUserByID(int64) (*User, error)
|
||||
@ -96,19 +96,32 @@ type oauthHandler struct {
|
||||
|
||||
func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error {
|
||||
ctx := r.Context()
|
||||
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID())
|
||||
|
||||
var attachUser int64
|
||||
if attach := r.URL.Query().Get("attach"); attach == "t" {
|
||||
user, _ := getUserAndSession(app, r)
|
||||
if user == nil {
|
||||
return impart.HTTPError{http.StatusInternalServerError, "cannot attach auth to user: user not found in session"}
|
||||
}
|
||||
attachUser = user.ID
|
||||
}
|
||||
|
||||
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser)
|
||||
if err != nil {
|
||||
log.Error("viewOauthInit error: %s", err)
|
||||
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
|
||||
}
|
||||
|
||||
if h.callbackProxy != nil {
|
||||
if err := h.callbackProxy.register(ctx, state); err != nil {
|
||||
log.Error("viewOauthInit error: %s", err)
|
||||
return impart.HTTPError{http.StatusInternalServerError, "could not register state server"}
|
||||
}
|
||||
}
|
||||
|
||||
location, err := h.oauthClient.buildLoginURL(state)
|
||||
if err != nil {
|
||||
log.Error("viewOauthInit error: %s", err)
|
||||
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
|
||||
}
|
||||
return impart.HTTPError{http.StatusTemporaryRedirect, location}
|
||||
@ -185,7 +198,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
|
||||
code := r.FormValue("code")
|
||||
state := r.FormValue("state")
|
||||
|
||||
provider, clientID, err := h.DB.ValidateOAuthState(ctx, state)
|
||||
provider, clientID, attachUserID, err := h.DB.ValidateOAuthState(ctx, state)
|
||||
if err != nil {
|
||||
log.Error("Unable to ValidateOAuthState: %s", err)
|
||||
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
|
||||
@ -223,6 +236,14 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if attachUserID > 0 {
|
||||
log.Info("attaching to user %d", attachUserID)
|
||||
err = h.DB.RecordRemoteUserID(r.Context(), attachUserID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken)
|
||||
if err != nil {
|
||||
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
|
||||
}
|
||||
return impart.HTTPError{http.StatusFound, "/me/settings"}
|
||||
}
|
||||
|
||||
displayName := tokenInfo.DisplayName
|
||||
if len(displayName) == 0 {
|
||||
|
@ -22,8 +22,8 @@ type MockOAuthDatastoreProvider struct {
|
||||
}
|
||||
|
||||
type MockOAuthDatastore struct {
|
||||
DoGenerateOAuthState func(context.Context, string, string) (string, error)
|
||||
DoValidateOAuthState func(context.Context, string) (string, string, error)
|
||||
DoGenerateOAuthState func(context.Context, string, string, int64) (string, error)
|
||||
DoValidateOAuthState func(context.Context, string) (string, string, int64, error)
|
||||
DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error)
|
||||
DoCreateUser func(*config.Config, *User, string) error
|
||||
DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error
|
||||
@ -86,11 +86,11 @@ func (m *MockOAuthDatastoreProvider) Config() *config.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, error) {
|
||||
func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) {
|
||||
if m.DoValidateOAuthState != nil {
|
||||
return m.DoValidateOAuthState(ctx, state)
|
||||
}
|
||||
return "", "", nil
|
||||
return "", "", 0, nil
|
||||
}
|
||||
|
||||
func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) {
|
||||
@ -125,9 +125,9 @@ func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string) (string, error) {
|
||||
func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64) (string, error) {
|
||||
if m.DoGenerateOAuthState != nil {
|
||||
return m.DoGenerateOAuthState(ctx, provider, clientID)
|
||||
return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID)
|
||||
}
|
||||
return store.Generate62RandomString(14), nil
|
||||
}
|
||||
@ -173,7 +173,7 @@ func TestViewOauthInit(t *testing.T) {
|
||||
app := &MockOAuthDatastoreProvider{
|
||||
DoDB: func() OAuthDatastore {
|
||||
return &MockOAuthDatastore{
|
||||
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string) (string, error) {
|
||||
DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64) (string, error) {
|
||||
return "", fmt.Errorf("pretend unable to write state error")
|
||||
},
|
||||
}
|
||||
|
@ -101,6 +101,7 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router {
|
||||
me.HandleFunc("/settings", handler.User(viewSettings)).Methods("GET")
|
||||
me.HandleFunc("/invites", handler.User(handleViewUserInvites)).Methods("GET")
|
||||
me.HandleFunc("/logout", handler.Web(viewLogout, UserLevelNone)).Methods("GET")
|
||||
me.HandleFunc("/oauth/remove", handler.User(removeOauth)).Methods("POST")
|
||||
|
||||
write.HandleFunc("/api/me", handler.All(viewMeAPI)).Methods("GET")
|
||||
apiMe := write.PathPrefix("/api/me/").Subrouter()
|
||||
|
@ -66,6 +66,28 @@ h3 { font-weight: normal; }
|
||||
<input type="submit" value="Save changes" tabindex="4" />
|
||||
</div>
|
||||
</form>
|
||||
|
||||
{{ if .OauthAccounts }}
|
||||
{{ range $oauth_account := .OauthAccounts }}
|
||||
<form method="post" action="/me/oauth/remove" autocomplete="false">
|
||||
<input type="hidden" name="provider" value="{{ $oauth_account.Provider }}" />
|
||||
<input type="hidden" name="client_id" value="{{ $oauth_account.ClientID }}" />
|
||||
<input type="hidden" name="remote_user_id" value="{{ $oauth_account.RemoteUserID }}" />
|
||||
<div class="option">
|
||||
<h3>{{ $oauth_account.Provider }} </h3>
|
||||
<div class="section">
|
||||
<input type="submit" value="Remove" style="margin-left: 1em;" />
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
{{ if .OauthSlack }}
|
||||
<a class="loginbtn" href="/oauth/slack?attach=t"><img alt="Sign in with Slack" height="40" width="172" src="/img/sign_in_with_slack.png" srcset="/img/sign_in_with_slack.png 1x, /img/sign_in_with_slack@2x.png 2x" /></a>
|
||||
{{ end }}
|
||||
{{ if .OauthWriteAs }}
|
||||
<a class="btn cta loginbtn" id="writeas-login" href="/oauth/write.as?attach=t">Link your <strong>Write.as</strong> account.</a>
|
||||
{{ end }}
|
||||
</div>
|
||||
|
||||
<script>
|
||||
|
Loading…
x
Reference in New Issue
Block a user