From b5f716135b9b28cdaff706bae670571ed010b9ac Mon Sep 17 00:00:00 2001 From: Nick Gerakines Date: Tue, 31 Dec 2019 11:28:05 -0500 Subject: [PATCH] Changed oauth table names per PR feedback. T705 --- database.go | 12 ++++++------ database_test.go | 6 +++--- migrations/v4.go | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/database.go b/database.go index 56035dd..ca62d3c 100644 --- a/database.go +++ b/database.go @@ -2461,7 +2461,7 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) { func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) { state := store.Generate62RandomString(24) - _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, used, created_at) VALUES (?, FALSE, NOW())", state) + _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, used, created_at) VALUES (?, FALSE, NOW())", state) if err != nil { return "", fmt.Errorf("unable to record oauth client state: %w", err) } @@ -2469,7 +2469,7 @@ func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) { } func (db *datastore) ValidateOAuthState(ctx context.Context, state string) error { - res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ?", state) + res, err := db.ExecContext(ctx, "UPDATE oauth_client_states SET used = TRUE WHERE state = ?", state) if err != nil { return err } @@ -2486,12 +2486,12 @@ func (db *datastore) ValidateOAuthState(ctx context.Context, state string) error func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error { var err error if db.driverName == driverSQLite { - _, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO users_oauth (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID) + _, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO oauth_users (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID) } else { - _, err = db.ExecContext(ctx, "INSERT INTO users_oauth (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id")+" user_id = ?", localUserID, remoteUserID, localUserID) + _, err = db.ExecContext(ctx, "INSERT INTO oauth_users (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id")+" user_id = ?", localUserID, remoteUserID, localUserID) } if err != nil { - log.Error("Unable to INSERT users_oauth for '%d': %v", localUserID, err) + log.Error("Unable to INSERT oauth_users for '%d': %v", localUserID, err) } return err } @@ -2500,7 +2500,7 @@ func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remote func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) { var userID int64 = -1 err := db. - QueryRowContext(ctx, "SELECT user_id FROM users_oauth WHERE remote_user_id = ?", remoteUserID). + QueryRowContext(ctx, "SELECT user_id FROM oauth_users WHERE remote_user_id = ?", remoteUserID). Scan(&userID) // Not finding a record is OK. if err != nil && err != sql.ErrNoRows { diff --git a/database_test.go b/database_test.go index 4a45dd0..879840e 100644 --- a/database_test.go +++ b/database_test.go @@ -22,19 +22,19 @@ func TestOAuthDatastore(t *testing.T) { assert.NoError(t, err) assert.Len(t, state, 24) - countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = false", state) + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state) err = ds.ValidateOAuthState(ctx, state) assert.NoError(t, err) - countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_state` WHERE `state` = ? AND `used` = true", state) + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state) var localUserID int64 = 99 var remoteUserID int64 = 100 err = ds.RecordRemoteUserID(ctx, localUserID, remoteUserID) assert.NoError(t, err) - countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `users_oauth` WHERE `user_id` = ? AND `remote_user_id` = ?", localUserID, remoteUserID) + countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ?", localUserID, remoteUserID) foundUserID, err := ds.GetIDForRemoteUser(ctx, remoteUserID) assert.NoError(t, err) diff --git a/migrations/v4.go b/migrations/v4.go index c123f54..c075dd8 100644 --- a/migrations/v4.go +++ b/migrations/v4.go @@ -14,7 +14,7 @@ func oauth(db *datastore) error { } return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { createTableUsersOauth, err := dialect. - Table("users_oauth"). + Table("oauth_users"). SetIfNotExists(true). Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). @@ -25,7 +25,7 @@ func oauth(db *datastore) error { return err } createTableOauthClientState, err := dialect. - Table("oauth_client_state"). + Table("oauth_client_states"). SetIfNotExists(true). Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})). Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)).