diff --git a/server/version/version_test.go b/server/version/version_test.go index d06d97aa..ce1fde8f 100644 --- a/server/version/version_test.go +++ b/server/version/version_test.go @@ -53,6 +53,11 @@ func TestIsVersionGreaterThan(t *testing.T) { target: "0.8.0", want: true, }, + { + version: "0.23", + target: "0.22", + want: true, + }, { version: "0.8.0", target: "0.10.0", @@ -63,6 +68,11 @@ func TestIsVersionGreaterThan(t *testing.T) { target: "0.9.1", want: false, }, + { + version: "0.22", + target: "0.22", + want: false, + }, } for _, test := range tests { result := IsVersionGreaterThan(test.version, test.target) diff --git a/store/migrator.go b/store/migrator.go index c08bf0db..b965238f 100644 --- a/store/migrator.go +++ b/store/migrator.go @@ -17,6 +17,12 @@ import ( "github.com/usememos/memos/server/version" ) +//go:embed migration +var migrationFS embed.FS + +//go:embed seed +var seedFS embed.FS + const ( // MigrateFileNameSplit is the split character between the patch version and the description in the migration file name. // For example, "1__create_table.sql". @@ -26,12 +32,6 @@ const ( LatestSchemaFileName = "LATEST_SCHEMA.sql" ) -//go:embed migration -var migrationFS embed.FS - -//go:embed seed -var seedFS embed.FS - // Migrate applies the latest schema to the database. func (s *Store) Migrate(ctx context.Context) error { if err := s.preMigrate(ctx); err != nil { @@ -147,6 +147,9 @@ func (s *Store) preMigrate(ctx context.Context) error { return errors.Wrap(err, "failed to upsert migration history") } } + if err := s.normalizedMigrationHistoryList(ctx); err != nil { + return errors.Wrap(err, "failed to normalize migration history list") + } return nil } @@ -237,3 +240,55 @@ func (*Store) execute(ctx context.Context, tx *sql.Tx, stmt string) error { } return nil } + +func (s *Store) normalizedMigrationHistoryList(ctx context.Context) error { + migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{}) + if err != nil { + return errors.Wrap(err, "failed to find migration history") + } + versions := []string{} + for _, migrationHistory := range migrationHistoryList { + versions = append(versions, migrationHistory.Version) + } + sort.Sort(version.SortVersion(versions)) + latestVersion := versions[len(versions)-1] + latestMinorVersion := version.GetMinorVersion(latestVersion) + // If the latest version is greater than 0.22, return. + // As of 0.22, the migration history is already normalized. + if version.IsVersionGreaterThan(latestMinorVersion, "0.22") { + return nil + } + + schemaVersionMap := map[string]string{} + filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s/*/*.sql", s.getMigrationBasePath())) + if err != nil { + return errors.Wrap(err, "failed to read migration files") + } + sort.Strings(filePaths) + for _, filePath := range filePaths { + fileSchemaVersion, err := s.getSchemaVersionOfMigrateScript(filePath) + if err != nil { + return errors.Wrap(err, "failed to get schema version of migrate script") + } + schemaVersionMap[version.GetMinorVersion((fileSchemaVersion))] = fileSchemaVersion + } + + latestSchemaVersion := schemaVersionMap[latestMinorVersion] + if latestSchemaVersion == "" { + return errors.Errorf("latest schema version not found") + } + if version.IsVersionGreaterOrEqualThan(latestVersion, latestSchemaVersion) { + return nil + } + + // Start a transaction to insert the latest schema version to migration_history. + tx, err := s.driver.GetDB().Begin() + if err != nil { + return errors.Wrap(err, "failed to start transaction") + } + defer tx.Rollback() + if err := s.execute(ctx, tx, fmt.Sprintf("INSERT INTO migration_history (version) VALUES ('%s')", latestSchemaVersion)); err != nil { + return errors.Wrap(err, "failed to insert migration history") + } + return tx.Commit() +}