From 062ae0e16a56f6da68b1d69f83daebfd5101f616 Mon Sep 17 00:00:00 2001 From: Matt Baer Date: Sun, 13 Jan 2019 09:08:47 -0500 Subject: [PATCH] Initialize db on single-user instance config This fixes the --config step so that when setting up a single-user instance for the first time (and creating the admin user as part of the process), the database is automatically initialized before creating that user. This removes the need for the --init-db command after --config when setting up single-user instances. It fixes #59: "no such table: users" error during the --config step on single-user instances that haven't previously run --init-db. --- app.go | 73 +++++++++++++++++++++++++++++------------------------ database.go | 24 ++++++++++++++++++ 2 files changed, 64 insertions(+), 33 deletions(-) diff --git a/app.go b/app.go index 6fbb8de..59c0554 100644 --- a/app.go +++ b/app.go @@ -230,6 +230,10 @@ func Serve() { connectToDatabase(app) defer shutdown(app) + if !app.db.DatabaseInitialized() { + adminInitDatabase(app) + } + u := &User{ Username: d.User.Username, HashedPass: d.User.HashedPass, @@ -267,39 +271,7 @@ func Serve() { loadConfig(app) connectToDatabase(app) defer shutdown(app) - - schemaFileName := "schema.sql" - if app.cfg.Database.Type == driverSQLite { - schemaFileName = "sqlite.sql" - } - - schema, err := Asset(schemaFileName) - if err != nil { - log.Error("Unable to load schema file: %v", err) - os.Exit(1) - } - - tblReg := regexp.MustCompile("CREATE TABLE (IF NOT EXISTS )?`([a-z_]+)`") - - queries := strings.Split(string(schema), ";\n") - for _, q := range queries { - if strings.TrimSpace(q) == "" { - continue - } - parts := tblReg.FindStringSubmatch(q) - if len(parts) >= 3 { - log.Info("Creating table %s...", parts[2]) - } else { - log.Info("Creating table ??? (Weird query) No match in: %v", parts) - } - _, err = app.db.Exec(q) - if err != nil { - log.Error("%s", err) - } else { - log.Info("Created.") - } - } - os.Exit(0) + adminInitDatabase(app) } else if *createAdmin != "" { adminCreateUser(app, *createAdmin, true) } else if *createUser != "" { @@ -573,3 +545,38 @@ func adminCreateUser(app *app, credStr string, isAdmin bool) { log.Info("Done!") os.Exit(0) } + +func adminInitDatabase(app *app) { + schemaFileName := "schema.sql" + if app.cfg.Database.Type == driverSQLite { + schemaFileName = "sqlite.sql" + } + + schema, err := Asset(schemaFileName) + if err != nil { + log.Error("Unable to load schema file: %v", err) + os.Exit(1) + } + + tblReg := regexp.MustCompile("CREATE TABLE (IF NOT EXISTS )?`([a-z_]+)`") + + queries := strings.Split(string(schema), ";\n") + for _, q := range queries { + if strings.TrimSpace(q) == "" { + continue + } + parts := tblReg.FindStringSubmatch(q) + if len(parts) >= 3 { + log.Info("Creating table %s...", parts[2]) + } else { + log.Info("Creating table ??? (Weird query) No match in: %v", parts) + } + _, err = app.db.Exec(q) + if err != nil { + log.Error("%s", err) + } else { + log.Info("Created.") + } + } + os.Exit(0) +} diff --git a/database.go b/database.go index 47e61eb..651b357 100644 --- a/database.go +++ b/database.go @@ -116,6 +116,8 @@ type writestore interface { GetAllUsersCount() int64 GetUserLastPostTime(id int64) (*time.Time, error) GetCollectionLastPostTime(id int64) (*time.Time, error) + + DatabaseInitialized() bool } type datastore struct { @@ -2293,6 +2295,28 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) { return &t, nil } +// DatabaseInitialized returns whether or not the current datastore has been +// initialized with the correct schema. +// Currently, it checks to see if the `users` table exists. +func (db *datastore) DatabaseInitialized() bool { + var dummy string + var err error + if db.driverName == driverSQLite { + err = db.QueryRow("SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'users'").Scan(&dummy) + } else { + err = db.QueryRow("SHOW TABLES LIKE 'users'").Scan(&dummy) + } + switch { + case err == sql.ErrNoRows: + return false + case err != nil: + log.Error("Couldn't SHOW TABLES: %v", err) + return false + } + + return true +} + func stringLogln(log *string, s string, v ...interface{}) { *log += fmt.Sprintf(s+"\n", v...) }