diff --git a/app.go b/app.go index 6fbb8de..59c0554 100644 --- a/app.go +++ b/app.go @@ -230,6 +230,10 @@ func Serve() { connectToDatabase(app) defer shutdown(app) + if !app.db.DatabaseInitialized() { + adminInitDatabase(app) + } + u := &User{ Username: d.User.Username, HashedPass: d.User.HashedPass, @@ -267,39 +271,7 @@ func Serve() { loadConfig(app) connectToDatabase(app) defer shutdown(app) - - schemaFileName := "schema.sql" - if app.cfg.Database.Type == driverSQLite { - schemaFileName = "sqlite.sql" - } - - schema, err := Asset(schemaFileName) - if err != nil { - log.Error("Unable to load schema file: %v", err) - os.Exit(1) - } - - tblReg := regexp.MustCompile("CREATE TABLE (IF NOT EXISTS )?`([a-z_]+)`") - - queries := strings.Split(string(schema), ";\n") - for _, q := range queries { - if strings.TrimSpace(q) == "" { - continue - } - parts := tblReg.FindStringSubmatch(q) - if len(parts) >= 3 { - log.Info("Creating table %s...", parts[2]) - } else { - log.Info("Creating table ??? (Weird query) No match in: %v", parts) - } - _, err = app.db.Exec(q) - if err != nil { - log.Error("%s", err) - } else { - log.Info("Created.") - } - } - os.Exit(0) + adminInitDatabase(app) } else if *createAdmin != "" { adminCreateUser(app, *createAdmin, true) } else if *createUser != "" { @@ -573,3 +545,38 @@ func adminCreateUser(app *app, credStr string, isAdmin bool) { log.Info("Done!") os.Exit(0) } + +func adminInitDatabase(app *app) { + schemaFileName := "schema.sql" + if app.cfg.Database.Type == driverSQLite { + schemaFileName = "sqlite.sql" + } + + schema, err := Asset(schemaFileName) + if err != nil { + log.Error("Unable to load schema file: %v", err) + os.Exit(1) + } + + tblReg := regexp.MustCompile("CREATE TABLE (IF NOT EXISTS )?`([a-z_]+)`") + + queries := strings.Split(string(schema), ";\n") + for _, q := range queries { + if strings.TrimSpace(q) == "" { + continue + } + parts := tblReg.FindStringSubmatch(q) + if len(parts) >= 3 { + log.Info("Creating table %s...", parts[2]) + } else { + log.Info("Creating table ??? (Weird query) No match in: %v", parts) + } + _, err = app.db.Exec(q) + if err != nil { + log.Error("%s", err) + } else { + log.Info("Created.") + } + } + os.Exit(0) +} diff --git a/database.go b/database.go index 47e61eb..651b357 100644 --- a/database.go +++ b/database.go @@ -116,6 +116,8 @@ type writestore interface { GetAllUsersCount() int64 GetUserLastPostTime(id int64) (*time.Time, error) GetCollectionLastPostTime(id int64) (*time.Time, error) + + DatabaseInitialized() bool } type datastore struct { @@ -2293,6 +2295,28 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) { return &t, nil } +// DatabaseInitialized returns whether or not the current datastore has been +// initialized with the correct schema. +// Currently, it checks to see if the `users` table exists. +func (db *datastore) DatabaseInitialized() bool { + var dummy string + var err error + if db.driverName == driverSQLite { + err = db.QueryRow("SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'users'").Scan(&dummy) + } else { + err = db.QueryRow("SHOW TABLES LIKE 'users'").Scan(&dummy) + } + switch { + case err == sql.ErrNoRows: + return false + case err != nil: + log.Error("Couldn't SHOW TABLES: %v", err) + return false + } + + return true +} + func stringLogln(log *string, s string, v ...interface{}) { *log += fmt.Sprintf(s+"\n", v...) }