From c6851fee505914f94f245d313df353c16cc05b55 Mon Sep 17 00:00:00 2001 From: Matt Baer Date: Sat, 8 Dec 2018 13:25:20 -0500 Subject: [PATCH] Fix duplicate key checks in SQLite Ref T529 --- database.go | 84 +++++++++++++++++++++++++++-------------------------- 1 file changed, 43 insertions(+), 41 deletions(-) diff --git a/database.go b/database.go index 6677d88..5166a9a 100644 --- a/database.go +++ b/database.go @@ -8,6 +8,8 @@ import ( "time" "github.com/go-sql-driver/mysql" + "github.com/mattn/go-sqlite3" + "github.com/guregu/null" "github.com/guregu/null/zero" uuid "github.com/nu7hatch/gouuid" @@ -130,6 +132,22 @@ func (db *datastore) upsert(indexedCols ...string) string { return "ON DUPLICATE KEY UPDATE" } +func (db *datastore) isDuplicateKeyErr(err error) bool { + if db.driverName == driverSQLite { + if err, ok := err.(sqlite3.Error); ok { + return err.Code == sqlite3.ErrConstraint + } + } else if db.driverName == driverMySQL { + if mysqlErr, ok := err.(*mysql.MySQLError); ok { + return mysqlErr.Number == mySQLErrDuplicateKey + } + } else { + log.Error("isDuplicateKeyErr: failed check for unrecognized driver '%s'", db.driverName) + } + + return false +} + func (db *datastore) CreateUser(u *User, collectionTitle string) error { // New users get a `users` and `collections` row. t, err := db.Begin() @@ -146,10 +164,8 @@ func (db *datastore) CreateUser(u *User, collectionTitle string) error { res, err := t.Exec("INSERT INTO users (username, password, email) VALUES (?, ?, ?)", u.Username, u.HashedPass, u.Email) if err != nil { t.Rollback() - if mysqlErr, ok := err.(*mysql.MySQLError); ok { - if mysqlErr.Number == mySQLErrDuplicateKey { - return impart.HTTPError{http.StatusConflict, "Username is already taken."} - } + if db.isDuplicateKeyErr(err) { + return impart.HTTPError{http.StatusConflict, "Username is already taken."} } log.Error("Rolling back users INSERT: %v\n", err) @@ -169,10 +185,8 @@ func (db *datastore) CreateUser(u *User, collectionTitle string) error { res, err = t.Exec("INSERT INTO collections (alias, title, description, privacy, owner_id, view_count) VALUES (?, ?, ?, ?, ?, ?)", u.Username, collectionTitle, "", CollUnlisted, u.ID, 0) if err != nil { t.Rollback() - if mysqlErr, ok := err.(*mysql.MySQLError); ok { - if mysqlErr.Number == mySQLErrDuplicateKey { - return impart.HTTPError{http.StatusConflict, "Username is already taken."} - } + if db.isDuplicateKeyErr(err) { + return impart.HTTPError{http.StatusConflict, "Username is already taken."} } log.Error("Rolling back collections INSERT: %v\n", err) return err @@ -241,10 +255,8 @@ func (db *datastore) CreateCollection(alias, title string, userID int64) (*Colle // All good, so create new collection res, err := db.Exec("INSERT INTO collections (alias, title, description, privacy, owner_id, view_count) VALUES (?, ?, ?, ?, ?, ?)", alias, title, "", CollUnlisted, userID, 0) if err != nil { - if mysqlErr, ok := err.(*mysql.MySQLError); ok { - if mysqlErr.Number == mySQLErrDuplicateKey { - return nil, impart.HTTPError{http.StatusConflict, "Collection already exists."} - } + if db.isDuplicateKeyErr(err) { + return nil, impart.HTTPError{http.StatusConflict, "Collection already exists."} } log.Error("Couldn't add to collections: %v\n", err) return nil, err @@ -614,17 +626,13 @@ func (db *datastore) CreatePost(userID, collID int64, post *SubmittedPost) (*Pos defer stmt.Close() _, err = stmt.Exec(friendlyID, slug, post.Title, post.Content, appearance, post.Language, post.IsRTL, 0, ownerID, ownerCollID, created, 0) if err != nil { - if mysqlErr, ok := err.(*mysql.MySQLError); ok { - if mysqlErr.Number == mySQLErrDuplicateKey { - // Duplicate entry error; try a new slug - // TODO: make this a little more robust - slug = sql.NullString{id.GenSafeUniqueSlug(slug.String), true} - _, err = stmt.Exec(friendlyID, slug, post.Title, post.Content, appearance, post.Language, post.IsRTL, 0, ownerID, ownerCollID, created, 0) - if err != nil { - return nil, handleFailedPostInsert(fmt.Errorf("Retried slug generation, still failed: %v", err)) - } - } else { - return nil, handleFailedPostInsert(err) + if db.isDuplicateKeyErr(err) { + // Duplicate entry error; try a new slug + // TODO: make this a little more robust + slug = sql.NullString{id.GenSafeUniqueSlug(slug.String), true} + _, err = stmt.Exec(friendlyID, slug, post.Title, post.Content, appearance, post.Language, post.IsRTL, 0, ownerID, ownerCollID, created, 0) + if err != nil { + return nil, handleFailedPostInsert(fmt.Errorf("Retried slug generation, still failed: %v", err)) } } else { return nil, handleFailedPostInsert(err) @@ -1198,17 +1206,15 @@ func (db *datastore) CanCollect(cpr *ClaimPostRequest, userID int64) bool { func (db *datastore) AttemptClaim(p *ClaimPostRequest, query string, params []interface{}, slugIdx int) (sql.Result, error) { qRes, err := db.Exec(query, params...) if err != nil { - if mysqlErr, ok := err.(*mysql.MySQLError); ok { - if mysqlErr.Number == mySQLErrDuplicateKey && slugIdx > -1 { - s := id.GenSafeUniqueSlug(p.Slug) - if s == p.Slug { - // Sanity check to prevent infinite recursion - return qRes, fmt.Errorf("GenSafeUniqueSlug generated nothing unique: %s", s) - } - p.Slug = s - params[slugIdx] = p.Slug - return db.AttemptClaim(p, query, params, slugIdx) + if db.isDuplicateKeyErr(err) && slugIdx > -1 { + s := id.GenSafeUniqueSlug(p.Slug) + if s == p.Slug { + // Sanity check to prevent infinite recursion + return qRes, fmt.Errorf("GenSafeUniqueSlug generated nothing unique: %s", s) } + p.Slug = s + params[slugIdx] = p.Slug + return db.AttemptClaim(p, query, params, slugIdx) } return qRes, fmt.Errorf("attemptClaim: %s", err) } @@ -1779,10 +1785,8 @@ func (db *datastore) ChangeSettings(app *app, u *User, s *userSettings) error { _, err = t.Exec("UPDATE users SET username = ? WHERE id = ?", newUsername, u.ID) if err != nil { t.Rollback() - if mysqlErr, ok := err.(*mysql.MySQLError); ok { - if mysqlErr.Number == mySQLErrDuplicateKey { - return impart.HTTPError{http.StatusConflict, "Username is already taken."} - } + if db.isDuplicateKeyErr(err) { + return impart.HTTPError{http.StatusConflict, "Username is already taken."} } log.Error("Unable to update users table: %v", err) return ErrInternalGeneral @@ -1791,10 +1795,8 @@ func (db *datastore) ChangeSettings(app *app, u *User, s *userSettings) error { _, err = t.Exec("UPDATE collections SET alias = ? WHERE alias = ? AND owner_id = ?", newUsername, u.Username, u.ID) if err != nil { t.Rollback() - if mysqlErr, ok := err.(*mysql.MySQLError); ok { - if mysqlErr.Number == mySQLErrDuplicateKey { - return impart.HTTPError{http.StatusConflict, "Username is already taken."} - } + if db.isDuplicateKeyErr(err) { + return impart.HTTPError{http.StatusConflict, "Username is already taken."} } log.Error("Unable to update collection: %v", err) return ErrInternalGeneral