diff --git a/cmd/memos.go b/cmd/memos.go index 8cebc959..c14386ae 100644 --- a/cmd/memos.go +++ b/cmd/memos.go @@ -16,6 +16,7 @@ import ( "github.com/usememos/memos/server" _profile "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store" + "github.com/usememos/memos/store/mysql" "github.com/usememos/memos/store/sqlite" ) @@ -36,13 +37,27 @@ var ( addr string port int data string + driver string + dsn string rootCmd = &cobra.Command{ Use: "memos", Short: `An open-source, self-hosted memo hub with knowledge management and social networking.`, Run: func(_cmd *cobra.Command, _args []string) { ctx, cancel := context.WithCancel(context.Background()) - driver, err := sqlite.NewDriver(profile) + + var err error + var driver store.Driver + switch profile.Driver { + case "sqlite": + driver, err = sqlite.NewDriver(profile) + case "mysql": + driver, err = mysql.NewDriver(profile) + default: + cancel() + log.Error("unknown db driver", zap.String("driver", profile.Driver)) + return + } if err != nil { cancel() log.Error("failed to create db driver", zap.Error(err)) @@ -101,6 +116,8 @@ func init() { rootCmd.PersistentFlags().StringVarP(&addr, "addr", "a", "", "address of server") rootCmd.PersistentFlags().IntVarP(&port, "port", "p", 8081, "port of server") rootCmd.PersistentFlags().StringVarP(&data, "data", "d", "", "data directory") + rootCmd.PersistentFlags().StringVarP(&driver, "driver", "", "", "database driver") + rootCmd.PersistentFlags().StringVarP(&dsn, "dsn", "", "", "database source name(aka. DSN)") err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode")) if err != nil { @@ -118,8 +135,17 @@ func init() { if err != nil { panic(err) } + err = viper.BindPFlag("driver", rootCmd.PersistentFlags().Lookup("driver")) + if err != nil { + panic(err) + } + err = viper.BindPFlag("dsn", rootCmd.PersistentFlags().Lookup("dsn")) + if err != nil { + panic(err) + } viper.SetDefault("mode", "demo") + viper.SetDefault("driver", "sqlite") viper.SetDefault("addr", "") viper.SetDefault("port", 8081) viper.SetEnvPrefix("memos") @@ -140,6 +166,7 @@ func initConfig() { println("addr:", profile.Addr) println("port:", profile.Port) println("mode:", profile.Mode) + println("driver:", profile.Driver) println("version:", profile.Version) println("---") } diff --git a/docker-compose.dev.yaml b/docker-compose.dev.yaml index cfda115a..2d80a7a5 100644 --- a/docker-compose.dev.yaml +++ b/docker-compose.dev.yaml @@ -9,10 +9,17 @@ # docker compose logs -f # services: + db: + image: mysql + volumes: + - ./.air/mysql:/var/lib/mysql api: image: golang:1.21-alpine working_dir: /work command: air -c ./scripts/.air.toml + environment: + - "MEMOS_DSN=root@tcp(db)/memos" + - "MEMOS_DRIVER=mysql" volumes: - $HOME/go/pkg/:/go/pkg/ # Cache for go mod shared with the host - ./.air/bin/:/go/bin/ # Cache for binary used only in container, such as *air* diff --git a/go.mod b/go.mod index 71c19ec0..a19439fe 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.51 github.com/aws/aws-sdk-go-v2/service/s3 v1.30.3 github.com/disintegration/imaging v1.6.2 + github.com/go-sql-driver/mysql v1.7.1 github.com/google/cel-go v0.17.1 github.com/google/uuid v1.3.0 github.com/gorilla/feeds v1.1.1 diff --git a/go.sum b/go.sum index 3ff35cd2..8029ea9d 100644 --- a/go.sum +++ b/go.sum @@ -199,6 +199,8 @@ github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+ github.com/go-playground/validator/v10 v10.2.0 h1:KgJ0snyC2R9VXYN2rneOtQcw5aHQB1Vv0sFl1UcHBOY= github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= +github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= diff --git a/main.go b/main.go index 2533268e..8d53abe0 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + _ "github.com/go-sql-driver/mysql" _ "modernc.org/sqlite" "github.com/usememos/memos/cmd" diff --git a/server/profile/profile.go b/server/profile/profile.go index 5da57c3d..f4e9ab66 100644 --- a/server/profile/profile.go +++ b/server/profile/profile.go @@ -24,7 +24,9 @@ type Profile struct { // Data is the data directory Data string `json:"-"` // DSN points to where Memos stores its own data - DSN string `json:"-"` + DSN string `json:"dsn"` + // Driver is the database driver + Driver string `json:"driver"` // Version is the current version of server Version string `json:"version"` } @@ -33,7 +35,7 @@ func (p *Profile) IsDev() bool { return p.Mode != "prod" } -func checkDSN(dataDir string) (string, error) { +func checkDataDir(dataDir string) (string, error) { // Convert to absolute path if relative path is supplied. if !filepath.IsAbs(dataDir) { relativeDir := filepath.Join(filepath.Dir(os.Args[0]), dataDir) @@ -81,15 +83,17 @@ func GetProfile() (*Profile, error) { } } - dataDir, err := checkDSN(profile.Data) + dataDir, err := checkDataDir(profile.Data) if err != nil { fmt.Printf("Failed to check dsn: %s, err: %+v\n", dataDir, err) return nil, err } profile.Data = dataDir - dbFile := fmt.Sprintf("memos_%s.db", profile.Mode) - profile.DSN = filepath.Join(dataDir, dbFile) + if profile.Driver == "sqlite" && profile.DSN == "" { + dbFile := fmt.Sprintf("memos_%s.db", profile.Mode) + profile.DSN = filepath.Join(dataDir, dbFile) + } profile.Version = version.GetCurrentVersion(profile.Mode) return &profile, nil diff --git a/store/mysql/activity.go b/store/mysql/activity.go new file mode 100644 index 00000000..013de47f --- /dev/null +++ b/store/mysql/activity.go @@ -0,0 +1,64 @@ +package mysql + +import ( + "context" + + "github.com/pkg/errors" + + "github.com/usememos/memos/store" +) + +func (d *Driver) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) { + stmt := ` + INSERT INTO activity ( + creator_id, + type, + level, + payload + ) + VALUES (?, ?, ?, ?) + ` + result, err := d.db.ExecContext(ctx, stmt, + create.CreatorID, + create.Type, + create.Level, + create.Payload, + ) + if err != nil { + return nil, errors.Wrap(err, "failed to db.Exec") + } + + id, err := result.LastInsertId() + if err != nil { + return nil, errors.Wrap(err, "failed to db.LastInsertId") + } + + return d.FindActivity(ctx, id) +} + +func (d *Driver) FindActivity(ctx context.Context, id int64) (*store.Activity, error) { + var activity store.Activity + stmt := ` + SELECT + id, + creator_id, + type, + level, + payload, + UNIX_TIMESTAMP(created_ts) + FROM activity + WHERE id = ? + ` + if err := d.db.QueryRowContext(ctx, stmt, id).Scan( + &activity.ID, + &activity.CreatorID, + &activity.Type, + &activity.Level, + &activity.Payload, + &activity.CreatedTs, + ); err != nil { + return nil, errors.Wrap(err, "failed to db.QueryRow") + } + + return &activity, nil +} diff --git a/store/mysql/idp.go b/store/mysql/idp.go new file mode 100644 index 00000000..fc78f8ad --- /dev/null +++ b/store/mysql/idp.go @@ -0,0 +1,199 @@ +package mysql + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/pkg/errors" + + "github.com/usememos/memos/store" +) + +func (d *Driver) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) { + var configBytes []byte + if create.Type == store.IdentityProviderOAuth2Type { + bytes, err := json.Marshal(create.Config.OAuth2Config) + if err != nil { + return nil, err + } + configBytes = bytes + } else { + return nil, errors.Errorf("unsupported idp type %s", string(create.Type)) + } + + stmt := ` + INSERT INTO idp ( + name, + type, + identifier_filter, + config + ) + VALUES (?, ?, ?, ?) + ` + result, err := d.db.ExecContext( + ctx, + stmt, + create.Name, + create.Type, + create.IdentifierFilter, + string(configBytes), + ) + if err != nil { + return nil, err + } + + id, err := result.LastInsertId() + if err != nil { + return nil, err + } + + create.ID = int32(id) + return create, nil +} + +func (d *Driver) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) { + where, args := []string{"1 = 1"}, []any{} + if v := find.ID; v != nil { + where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v) + } + + rows, err := d.db.QueryContext(ctx, ` + SELECT + id, + name, + type, + identifier_filter, + config + FROM idp + WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`, + args..., + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var identityProviders []*store.IdentityProvider + for rows.Next() { + var identityProvider store.IdentityProvider + var identityProviderConfig string + if err := rows.Scan( + &identityProvider.ID, + &identityProvider.Name, + &identityProvider.Type, + &identityProvider.IdentifierFilter, + &identityProviderConfig, + ); err != nil { + return nil, err + } + + if identityProvider.Type == store.IdentityProviderOAuth2Type { + oauth2Config := &store.IdentityProviderOAuth2Config{} + if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { + return nil, err + } + identityProvider.Config = &store.IdentityProviderConfig{ + OAuth2Config: oauth2Config, + } + } else { + return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type)) + } + identityProviders = append(identityProviders, &identityProvider) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return identityProviders, nil +} + +func (d *Driver) GetIdentityProvider(ctx context.Context, find *store.FindIdentityProvider) (*store.IdentityProvider, error) { + list, err := d.ListIdentityProviders(ctx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + + identityProvider := list[0] + return identityProvider, nil +} + +func (d *Driver) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) { + set, args := []string{}, []any{} + if v := update.Name; v != nil { + set, args = append(set, "name = ?"), append(args, *v) + } + if v := update.IdentifierFilter; v != nil { + set, args = append(set, "identifier_filter = ?"), append(args, *v) + } + if v := update.Config; v != nil { + var configBytes []byte + if update.Type == store.IdentityProviderOAuth2Type { + bytes, err := json.Marshal(update.Config.OAuth2Config) + if err != nil { + return nil, err + } + configBytes = bytes + } else { + return nil, errors.Errorf("unsupported idp type %s", string(update.Type)) + } + set, args = append(set, "config = ?"), append(args, string(configBytes)) + } + args = append(args, update.ID) + + stmt := ` + UPDATE idp + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + RETURNING id, name, type, identifier_filter, config + ` + _, err := d.db.ExecContext(ctx, stmt, args...) + if err != nil { + return nil, err + } + + var identityProvider store.IdentityProvider + var identityProviderConfig string + stmt = `SELECT id, name, type, identifier_filter, config FROM idp WHERE id = ?` + if err := d.db.QueryRowContext(ctx, stmt, update.ID).Scan( + &identityProvider.ID, + &identityProvider.Name, + &identityProvider.Type, + &identityProvider.IdentifierFilter, + &identityProviderConfig, + ); err != nil { + return nil, err + } + + if identityProvider.Type == store.IdentityProviderOAuth2Type { + oauth2Config := &store.IdentityProviderOAuth2Config{} + if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { + return nil, err + } + identityProvider.Config = &store.IdentityProviderConfig{ + OAuth2Config: oauth2Config, + } + } else { + return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type)) + } + + return &identityProvider, nil +} + +func (d *Driver) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error { + where, args := []string{"id = ?"}, []any{delete.ID} + stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ") + result, err := d.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + if _, err = result.RowsAffected(); err != nil { + return err + } + return nil +} diff --git a/store/mysql/memo.go b/store/mysql/memo.go new file mode 100644 index 00000000..6cd9b138 --- /dev/null +++ b/store/mysql/memo.go @@ -0,0 +1,311 @@ +package mysql + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/pkg/errors" + + "github.com/usememos/memos/common/util" + "github.com/usememos/memos/store" +) + +func (d *Driver) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) { + stmt := ` + INSERT INTO memo ( + creator_id, + content, + visibility + ) + VALUES (?, ?, ?) + ` + result, err := d.db.ExecContext( + ctx, + stmt, + create.CreatorID, + create.Content, + create.Visibility, + ) + if err != nil { + return nil, err + } + + id, err := result.LastInsertId() + if err != nil { + return nil, err + } + + var memo store.Memo + stmt = ` + SELECT + id, + creator_id, + content, + visibility, + UNIX_TIMESTAMP(created_ts), + UNIX_TIMESTAMP(updated_ts), + row_status + FROM memo + WHERE id = ? + ` + if err := d.db.QueryRowContext(ctx, stmt, id).Scan( + &memo.ID, + &memo.CreatorID, + &memo.Content, + &memo.Visibility, + &memo.UpdatedTs, + &memo.CreatedTs, + &memo.RowStatus, + ); err != nil { + return nil, err + } + + return &memo, nil +} + +func (d *Driver) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) { + where, args := []string{"1 = 1"}, []any{} + + if v := find.ID; v != nil { + where, args = append(where, "memo.id = ?"), append(args, *v) + } + if v := find.CreatorID; v != nil { + where, args = append(where, "memo.creator_id = ?"), append(args, *v) + } + if v := find.RowStatus; v != nil { + where, args = append(where, "memo.row_status = ?"), append(args, *v) + } + if v := find.CreatedTsBefore; v != nil { + where, args = append(where, "UNIX_TIMESTAMP(memo.created_ts) < ?"), append(args, *v) + } + if v := find.CreatedTsAfter; v != nil { + where, args = append(where, "UNIX_TIMESTAMP(memo.created_ts) > ?"), append(args, *v) + } + if v := find.Pinned; v != nil { + where = append(where, "memo_organizer.pinned = 1") + } + if v := find.ContentSearch; len(v) != 0 { + for _, s := range v { + where, args = append(where, "memo.content LIKE ?"), append(args, "%"+s+"%") + } + } + if v := find.VisibilityList; len(v) != 0 { + list := []string{} + for _, visibility := range v { + list = append(list, "?") + args = append(args, visibility) + } + where = append(where, fmt.Sprintf("memo.visibility in (%s)", strings.Join(list, ","))) + } + orders := []string{"pinned DESC"} + if find.OrderByUpdatedTs { + orders = append(orders, "updated_ts DESC") + } else { + orders = append(orders, "created_ts DESC") + } + orders = append(orders, "id DESC") + + query := ` + SELECT + memo.id AS id, + memo.creator_id AS creator_id, + UNIX_TIMESTAMP(memo.created_ts) AS created_ts, + UNIX_TIMESTAMP(memo.updated_ts) AS updated_ts, + memo.row_status AS row_status, + memo.content AS content, + memo.visibility AS visibility, + MAX(CASE WHEN memo_organizer.pinned = 1 THEN 1 ELSE 0 END) AS pinned, + GROUP_CONCAT(resource.id) AS resource_id_list, + ( + SELECT + GROUP_CONCAT(related_memo_id,':',type) + FROM + memo_relation + WHERE + memo_relation.memo_id = memo.id + GROUP BY + memo_relation.memo_id + ) AS relation_list + FROM + memo + LEFT JOIN + memo_organizer ON memo.id = memo_organizer.memo_id + LEFT JOIN + resource ON memo.id = resource.memo_id + WHERE ` + strings.Join(where, " AND ") + ` + GROUP BY memo.id + ORDER BY ` + strings.Join(orders, ", ") + ` + ` + if find.Limit != nil { + query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit) + if find.Offset != nil { + query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset) + } + } + + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*store.Memo, 0) + for rows.Next() { + var memo store.Memo + var memoResourceIDList sql.NullString + var memoRelationList sql.NullString + if err := rows.Scan( + &memo.ID, + &memo.CreatorID, + &memo.CreatedTs, + &memo.UpdatedTs, + &memo.RowStatus, + &memo.Content, + &memo.Visibility, + &memo.Pinned, + &memoResourceIDList, + &memoRelationList, + ); err != nil { + return nil, err + } + + if memoResourceIDList.Valid { + idStringList := strings.Split(memoResourceIDList.String, ",") + memo.ResourceIDList = make([]int32, 0, len(idStringList)) + for _, idString := range idStringList { + id, err := util.ConvertStringToInt32(idString) + if err != nil { + return nil, err + } + memo.ResourceIDList = append(memo.ResourceIDList, id) + } + } + if memoRelationList.Valid { + memo.RelationList = make([]*store.MemoRelation, 0) + relatedMemoTypeList := strings.Split(memoRelationList.String, ",") + for _, relatedMemoType := range relatedMemoTypeList { + relatedMemoTypeList := strings.Split(relatedMemoType, ":") + if len(relatedMemoTypeList) != 2 { + return nil, errors.Errorf("invalid relation format") + } + relatedMemoID, err := util.ConvertStringToInt32(relatedMemoTypeList[0]) + if err != nil { + return nil, err + } + memo.RelationList = append(memo.RelationList, &store.MemoRelation{ + MemoID: memo.ID, + RelatedMemoID: relatedMemoID, + Type: store.MemoRelationType(relatedMemoTypeList[1]), + }) + } + } + list = append(list, &memo) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *Driver) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error { + set, args := []string{}, []any{} + if v := update.CreatedTs; v != nil { + set, args = append(set, "created_ts = FROM_UNIXTIME(?)"), append(args, *v) + } + if v := update.UpdatedTs; v != nil { + set, args = append(set, "updated_ts = FROM_UNIXTIME(?)"), append(args, *v) + } + if v := update.RowStatus; v != nil { + set, args = append(set, "row_status = ?"), append(args, *v) + } + if v := update.Content; v != nil { + set, args = append(set, "content = ?"), append(args, *v) + } + if v := update.Visibility; v != nil { + set, args = append(set, "visibility = ?"), append(args, *v) + } + args = append(args, update.ID) + + stmt := ` + UPDATE memo + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + ` + if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil { + return err + } + return nil +} + +func (d *Driver) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error { + where, args := []string{"id = ?"}, []any{delete.ID} + stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ") + result, err := d.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + if _, err := result.RowsAffected(); err != nil { + return err + } + + if err := d.Vacuum(ctx); err != nil { + // Prevent linter warning. + return err + } + return nil +} + +func (d *Driver) FindMemosVisibilityList(ctx context.Context, memoIDs []int32) ([]store.Visibility, error) { + args := make([]any, 0, len(memoIDs)) + list := make([]string, 0, len(memoIDs)) + for _, memoID := range memoIDs { + args = append(args, memoID) + list = append(list, "?") + } + + where := fmt.Sprintf("id in (%s)", strings.Join(list, ",")) + query := `SELECT DISTINCT(visibility) FROM memo WHERE ` + where + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + visibilityList := make([]store.Visibility, 0) + for rows.Next() { + var visibility store.Visibility + if err := rows.Scan(&visibility); err != nil { + return nil, err + } + visibilityList = append(visibilityList, visibility) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return visibilityList, nil +} + +func vacuumMemo(ctx context.Context, tx *sql.Tx) error { + stmt := ` + DELETE FROM + memo + WHERE + creator_id NOT IN ( + SELECT + id + FROM + user + )` + _, err := tx.ExecContext(ctx, stmt) + if err != nil { + return err + } + + return nil +} diff --git a/store/mysql/memo_organizer.go b/store/mysql/memo_organizer.go new file mode 100644 index 00000000..e3a765fe --- /dev/null +++ b/store/mysql/memo_organizer.go @@ -0,0 +1,106 @@ +package mysql + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/usememos/memos/store" +) + +func (d *Driver) UpsertMemoOrganizer(ctx context.Context, upsert *store.MemoOrganizer) (*store.MemoOrganizer, error) { + stmt := ` + INSERT INTO memo_organizer ( + memo_id, + user_id, + pinned + ) + VALUES (?, ?, ?) + ON DUPLICATE KEY UPDATE pinned = ? + ` + if _, err := d.db.ExecContext(ctx, stmt, upsert.MemoID, upsert.UserID, upsert.Pinned, upsert.Pinned); err != nil { + return nil, err + } + + return upsert, nil +} + +func (d *Driver) GetMemoOrganizer(ctx context.Context, find *store.FindMemoOrganizer) (*store.MemoOrganizer, error) { + where, args := []string{}, []any{} + if find.MemoID != 0 { + where = append(where, "memo_id = ?") + args = append(args, find.MemoID) + } + if find.UserID != 0 { + where = append(where, "user_id = ?") + args = append(args, find.UserID) + } + + query := fmt.Sprintf(` + SELECT + memo_id, + user_id, + pinned + FROM memo_organizer + WHERE %s + `, strings.Join(where, " AND ")) + row := d.db.QueryRowContext(ctx, query, args...) + if err := row.Err(); err != nil { + return nil, err + } + if row == nil { + return nil, nil + } + + memoOrganizer := &store.MemoOrganizer{} + if err := row.Scan( + &memoOrganizer.MemoID, + &memoOrganizer.UserID, + &memoOrganizer.Pinned, + ); err != nil { + return nil, err + } + + return memoOrganizer, nil +} + +func (d *Driver) DeleteMemoOrganizer(ctx context.Context, delete *store.DeleteMemoOrganizer) error { + where, args := []string{}, []any{} + if v := delete.MemoID; v != nil { + where, args = append(where, "memo_id = ?"), append(args, *v) + } + if v := delete.UserID; v != nil { + where, args = append(where, "user_id = ?"), append(args, *v) + } + stmt := `DELETE FROM memo_organizer WHERE ` + strings.Join(where, " AND ") + if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil { + return err + } + return nil +} + +func vacuumMemoOrganizer(ctx context.Context, tx *sql.Tx) error { + stmt := ` + DELETE FROM + memo_organizer + WHERE + memo_id NOT IN ( + SELECT + id + FROM + memo + ) + OR user_id NOT IN ( + SELECT + id + FROM + user + )` + _, err := tx.ExecContext(ctx, stmt) + if err != nil { + return err + } + + return nil +} diff --git a/store/mysql/memo_relation.go b/store/mysql/memo_relation.go new file mode 100644 index 00000000..906616bc --- /dev/null +++ b/store/mysql/memo_relation.go @@ -0,0 +1,118 @@ +package mysql + +import ( + "context" + "database/sql" + "strings" + + "github.com/usememos/memos/store" +) + +func (d *Driver) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) { + stmt := ` + INSERT INTO memo_relation ( + memo_id, + related_memo_id, + type + ) + VALUES (?, ?, ?) + ON DUPLICATE KEY UPDATE type = ? + ` + _, err := d.db.ExecContext( + ctx, + stmt, + create.MemoID, + create.RelatedMemoID, + create.Type, + create.Type, + ) + if err != nil { + return nil, err + } + + memoRelation := store.MemoRelation{ + MemoID: create.MemoID, + RelatedMemoID: create.RelatedMemoID, + Type: create.Type, + } + + return &memoRelation, nil +} + +func (d *Driver) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) { + where, args := []string{"TRUE"}, []any{} + if find.MemoID != nil { + where, args = append(where, "memo_id = ?"), append(args, find.MemoID) + } + if find.RelatedMemoID != nil { + where, args = append(where, "related_memo_id = ?"), append(args, find.RelatedMemoID) + } + if find.Type != nil { + where, args = append(where, "type = ?"), append(args, find.Type) + } + + rows, err := d.db.QueryContext(ctx, ` + SELECT + memo_id, + related_memo_id, + type + FROM memo_relation + WHERE `+strings.Join(where, " AND "), args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*store.MemoRelation{} + for rows.Next() { + memoRelation := &store.MemoRelation{} + if err := rows.Scan( + &memoRelation.MemoID, + &memoRelation.RelatedMemoID, + &memoRelation.Type, + ); err != nil { + return nil, err + } + list = append(list, memoRelation) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *Driver) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error { + where, args := []string{"TRUE"}, []any{} + if delete.MemoID != nil { + where, args = append(where, "memo_id = ?"), append(args, delete.MemoID) + } + if delete.RelatedMemoID != nil { + where, args = append(where, "related_memo_id = ?"), append(args, delete.RelatedMemoID) + } + if delete.Type != nil { + where, args = append(where, "type = ?"), append(args, delete.Type) + } + stmt := ` + DELETE FROM memo_relation + WHERE ` + strings.Join(where, " AND ") + result, err := d.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + if _, err = result.RowsAffected(); err != nil { + return err + } + return nil +} + +func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error { + if _, err := tx.ExecContext(ctx, ` + DELETE FROM memo_relation + WHERE memo_id NOT IN (SELECT id FROM memo) OR related_memo_id NOT IN (SELECT id FROM memo) + `); err != nil { + return err + } + return nil +} diff --git a/store/mysql/migrate.go b/store/mysql/migrate.go new file mode 100644 index 00000000..861ab2ba --- /dev/null +++ b/store/mysql/migrate.go @@ -0,0 +1,182 @@ +package mysql + +import ( + "context" + "embed" + "fmt" + "io/fs" + "regexp" + "sort" + "strings" + + "github.com/pkg/errors" + + "github.com/usememos/memos/server/version" +) + +const ( + latestSchemaFileName = "LATEST__SCHEMA.sql" +) + +//go:embed migration +var migrationFS embed.FS + +func (d *Driver) Migrate(ctx context.Context) error { + if d.profile.IsDev() { + return d.nonProdMigrate(ctx) + } + + return d.prodMigrate(ctx) +} + +func (d *Driver) nonProdMigrate(ctx context.Context) error { + buf, err := migrationFS.ReadFile("migration/dev/" + latestSchemaFileName) + if err != nil { + return errors.Errorf("failed to read latest schema file: %s", err) + } + + for _, stmt := range strings.Split(string(buf), ";") { + stmt = strings.TrimSpace(stmt) + if stmt == "" { + continue + } + _, err := d.db.ExecContext(ctx, stmt) + if err != nil { + return errors.Errorf("failed to exec SQL %s: %s", stmt, err) + } + } + + // In demo mode, we should seed the database. + if d.profile.Mode == "demo" { + if err := d.seed(ctx); err != nil { + return errors.Wrap(err, "failed to seed") + } + } + return nil +} +func (d *Driver) prodMigrate(ctx context.Context) error { + currentVersion := version.GetCurrentVersion(d.profile.Mode) + migrationHistoryList, err := d.FindMigrationHistoryList(ctx, &MigrationHistoryFind{}) + if err != nil { + return errors.Wrap(err, "failed to find migration history") + } + if len(migrationHistoryList) == 0 { + _, err := d.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{ + Version: currentVersion, + }) + if err != nil { + return errors.Wrap(err, "failed to upsert migration history") + } + return nil + } + + migrationHistoryVersionList := []string{} + for _, migrationHistory := range migrationHistoryList { + migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version) + } + sort.Sort(version.SortVersion(migrationHistoryVersionList)) + latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1] + + if !version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) { + return nil + } + + println("start migrate") + for _, minorVersion := range getMinorVersionList() { + normalizedVersion := minorVersion + ".0" + if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) { + println("applying migration for", normalizedVersion) + if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil { + return errors.Wrap(err, "failed to apply minor version migration") + } + } + } + println("end migrate") + return nil +} + +func (d *Driver) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error { + filenames, err := fs.Glob(migrationFS, fmt.Sprintf("%s/%s/*.sql", "migration/prod", minorVersion)) + if err != nil { + return errors.Wrap(err, "failed to read ddl files") + } + + sort.Strings(filenames) + // Loop over all migration files and execute them in order. + for _, filename := range filenames { + buf, err := migrationFS.ReadFile(filename) + if err != nil { + return errors.Wrapf(err, "failed to read minor version migration file, filename=%s", filename) + } + for _, stmt := range strings.Split(string(buf), ";") { + if strings.TrimSpace(stmt) == "" { + continue + } + if _, err := d.db.ExecContext(ctx, stmt); err != nil { + return errors.Wrapf(err, "migrate error: %s", stmt) + } + } + } + + // Upsert the newest version to migration_history. + version := minorVersion + ".0" + if _, err = d.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{Version: version}); err != nil { + return errors.Wrapf(err, "failed to upsert migration history with version: %s", version) + } + + return nil +} + +//go:embed seed +var seedFS embed.FS + +func (d *Driver) seed(ctx context.Context) error { + filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed")) + if err != nil { + return errors.Wrap(err, "failed to read seed files") + } + + sort.Strings(filenames) + + // Loop over all seed files and execute them in order. + for _, filename := range filenames { + buf, err := seedFS.ReadFile(filename) + if err != nil { + return errors.Wrapf(err, "failed to read seed file, filename=%s", filename) + } + + for _, stmt := range strings.Split(string(buf), ";") { + if strings.TrimSpace(stmt) == "" { + continue + } + if _, err := d.db.ExecContext(ctx, stmt); err != nil { + return errors.Wrapf(err, "seed error: %s", stmt) + } + } + } + return nil +} + +// minorDirRegexp is a regular expression for minor version directory. +var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`) + +func getMinorVersionList() []string { + minorVersionList := []string{} + + if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error { + if err != nil { + return err + } + if file.IsDir() && minorDirRegexp.MatchString(path) { + minorVersionList = append(minorVersionList, file.Name()) + } + + return nil + }); err != nil { + panic(err) + } + + sort.Sort(version.SortVersion(minorVersionList)) + + return minorVersionList +} diff --git a/store/mysql/migration/dev/LATEST__SCHEMA.sql b/store/mysql/migration/dev/LATEST__SCHEMA.sql new file mode 100644 index 00000000..13cbabd4 --- /dev/null +++ b/store/mysql/migration/dev/LATEST__SCHEMA.sql @@ -0,0 +1,131 @@ +-- activity +CREATE TABLE IF NOT EXISTS `activity` ( + `id` int NOT NULL AUTO_INCREMENT, + `creator_id` int NOT NULL, + `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `type` varchar(255) NOT NULL DEFAULT '', + `level` varchar(255) NOT NULL DEFAULT 'INFO', + `payload` text NOT NULL, + PRIMARY KEY (`id`), + CONSTRAINT `activity_chk_1` CHECK ((`level` in (_utf8mb4'INFO',_utf8mb4'WARN',_utf8mb4'ERROR'))) +); + +-- idp +CREATE TABLE IF NOT EXISTS `idp` ( + `id` int NOT NULL AUTO_INCREMENT, + `name` text NOT NULL, + `type` text NOT NULL, + `identifier_filter` varchar(256) NOT NULL DEFAULT '', + `config` text NOT NULL, + PRIMARY KEY (`id`) +); + +-- memo +CREATE TABLE IF NOT EXISTS `memo` ( + `id` int NOT NULL AUTO_INCREMENT, + `creator_id` int NOT NULL, + `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `row_status` varchar(255) NOT NULL DEFAULT 'NORMAL', + `content` text NOT NULL, + `visibility` varchar(255) NOT NULL DEFAULT 'PRIVATE', + PRIMARY KEY (`id`), + KEY `creator_id` (`creator_id`), + KEY `visibility` (`visibility`), + CONSTRAINT `memo_chk_1` CHECK ((`row_status` in (_utf8mb4'NORMAL',_utf8mb4'ARCHIVED'))), + CONSTRAINT `memo_chk_2` CHECK ((`visibility` in (_utf8mb4'PUBLIC',_utf8mb4'PROTECTED',_utf8mb4'PRIVATE'))) +); + +-- memo_organizer +CREATE TABLE IF NOT EXISTS `memo_organizer` ( + `memo_id` int NOT NULL, + `user_id` int NOT NULL, + `pinned` int NOT NULL DEFAULT '0', + UNIQUE KEY `memo_id` (`memo_id`,`user_id`), + CONSTRAINT `memo_organizer_chk_1` CHECK ((`pinned` in (0,1))) +); + +-- memo_relation +CREATE TABLE IF NOT EXISTS `memo_relation` ( + `memo_id` int NOT NULL, + `related_memo_id` int NOT NULL, + `type` varchar(256) NOT NULL, + UNIQUE KEY `memo_id` (`memo_id`,`related_memo_id`,`type`) +); + +-- migration_history +CREATE TABLE IF NOT EXISTS `migration_history` ( + `version` varchar(255) NOT NULL, + `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (`version`) +); + +-- resource +CREATE TABLE IF NOT EXISTS `resource` ( + `id` int NOT NULL AUTO_INCREMENT, + `creator_id` int NOT NULL, + `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `filename` text NOT NULL, + `blob` blob, + `external_link` text NOT NULL, + `type` varchar(255) NOT NULL DEFAULT '', + `size` int NOT NULL DEFAULT '0', + `internal_path` varchar(255) NOT NULL DEFAULT '', + `memo_id` int DEFAULT NULL, + PRIMARY KEY (`id`), + KEY `creator_id` (`creator_id`), + KEY `memo_id` (`memo_id`) +); + +-- storage +CREATE TABLE IF NOT EXISTS `storage` ( + `id` int NOT NULL AUTO_INCREMENT, + `name` varchar(256) NOT NULL, + `type` varchar(256) NOT NULL, + `config` text NOT NULL, + PRIMARY KEY (`id`) +); + +-- system_setting +CREATE TABLE IF NOT EXISTS `system_setting` ( + `name` varchar(255) NOT NULL, + `value` text NOT NULL, + `description` text NOT NULL, + PRIMARY KEY (`name`) +); + +-- tag +CREATE TABLE IF NOT EXISTS `tag` ( + `name` varchar(255) NOT NULL, + `creator_id` int NOT NULL, + UNIQUE KEY `name` (`name`,`creator_id`) +); + +-- user +CREATE TABLE IF NOT EXISTS `user` ( + `id` int NOT NULL AUTO_INCREMENT, + `created_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `updated_ts` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + `row_status` varchar(255) NOT NULL DEFAULT 'NORMAL', + `username` varchar(255) NOT NULL, + `role` varchar(255) NOT NULL DEFAULT 'USER', + `email` varchar(255) NOT NULL DEFAULT '', + `nickname` varchar(255) NOT NULL DEFAULT '', + `password_hash` varchar(255) NOT NULL, + `avatar_url` text NOT NULL, + PRIMARY KEY (`id`), + UNIQUE KEY `username` (`username`), + CONSTRAINT `user_chk_1` CHECK ((`row_status` in (_utf8mb4'NORMAL',_utf8mb4'ARCHIVED'))), + CONSTRAINT `user_chk_2` CHECK ((`role` in (_utf8mb4'HOST',_utf8mb4'ADMIN',_utf8mb4'USER'))) +); + +-- user_setting +CREATE TABLE IF NOT EXISTS `user_setting` ( + `user_id` int NOT NULL, + `key` varchar(255) NOT NULL, + `value` text NOT NULL, + UNIQUE KEY `user_id` (`user_id`,`key`) +); + + diff --git a/store/mysql/migration_history.go b/store/mysql/migration_history.go new file mode 100644 index 00000000..c9211845 --- /dev/null +++ b/store/mysql/migration_history.go @@ -0,0 +1,84 @@ +package mysql + +import ( + "context" + "strings" +) + +type MigrationHistory struct { + Version string + CreatedTs int64 +} + +type MigrationHistoryUpsert struct { + Version string +} + +type MigrationHistoryFind struct { + Version *string +} + +func (d *Driver) FindMigrationHistoryList(ctx context.Context, find *MigrationHistoryFind) ([]*MigrationHistory, error) { + where, args := []string{"1 = 1"}, []any{} + + if v := find.Version; v != nil { + where, args = append(where, "version = ?"), append(args, *v) + } + + query := ` + SELECT version, UNIX_TIMESTAMP(created_ts) + FROM migration_history + WHERE ` + strings.Join(where, " AND ") + ` + ORDER BY created_ts DESC + ` + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*MigrationHistory, 0) + for rows.Next() { + var migrationHistory MigrationHistory + if err := rows.Scan( + &migrationHistory.Version, + &migrationHistory.CreatedTs, + ); err != nil { + return nil, err + } + + list = append(list, &migrationHistory) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *Driver) UpsertMigrationHistory(ctx context.Context, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) { + stmt := ` + INSERT INTO migration_history (version) VALUES (?) + ON DUPLICATE KEY UPDATE version = ? + ` + _, err := d.db.ExecContext(ctx, stmt, upsert.Version, upsert.Version) + if err != nil { + return nil, err + } + + var migrationHistory MigrationHistory + stmt = ` + SELECT version, UNIX_TIMESTAMP(created_ts) + FROM migration_history + WHERE version = ? + ` + if err := d.db.QueryRowContext(ctx, stmt, upsert.Version).Scan( + &migrationHistory.Version, + &migrationHistory.CreatedTs, + ); err != nil { + return nil, err + } + + return &migrationHistory, nil +} diff --git a/store/mysql/mysql.go b/store/mysql/mysql.go new file mode 100644 index 00000000..f989bec7 --- /dev/null +++ b/store/mysql/mysql.go @@ -0,0 +1,64 @@ +package mysql + +import ( + "context" + "database/sql" + + "github.com/pkg/errors" + + "github.com/usememos/memos/server/profile" + "github.com/usememos/memos/store" +) + +type Driver struct { + db *sql.DB + profile *profile.Profile +} + +func NewDriver(profile *profile.Profile) (store.Driver, error) { + db, err := sql.Open("mysql", profile.DSN) + if err != nil { + return nil, err + } + + driver := Driver{db: db, profile: profile} + return &driver, nil +} + +func (d *Driver) Vacuum(ctx context.Context) error { + tx, err := d.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + if err := vacuumMemo(ctx, tx); err != nil { + return err + } + if err := vacuumResource(ctx, tx); err != nil { + return err + } + if err := vacuumUserSetting(ctx, tx); err != nil { + return err + } + if err := vacuumMemoOrganizer(ctx, tx); err != nil { + return err + } + if err := vacuumMemoRelations(ctx, tx); err != nil { + return err + } + if err := vacuumTag(ctx, tx); err != nil { + // Prevent revive warning. + return err + } + + return tx.Commit() +} + +func (*Driver) BackupTo(context.Context, string) error { + return errors.New("Please use mysqldump to backup") +} + +func (d *Driver) Close() error { + return d.db.Close() +} diff --git a/store/mysql/resource.go b/store/mysql/resource.go new file mode 100644 index 00000000..cc21f41b --- /dev/null +++ b/store/mysql/resource.go @@ -0,0 +1,217 @@ +package mysql + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/pkg/errors" + + "github.com/usememos/memos/store" +) + +func (d *Driver) CreateResource(ctx context.Context, create *store.Resource) (*store.Resource, error) { + stmt := ` + INSERT INTO resource ( + filename, + resource.blob, + external_link, + type, + size, + creator_id, + internal_path + ) + VALUES (?, ?, ?, ?, ?, ?, ?) + ` + result, err := d.db.ExecContext( + ctx, + stmt, + create.Filename, + create.Blob, + create.ExternalLink, + create.Type, + create.Size, + create.CreatorID, + create.InternalPath, + ) + if err != nil { + return nil, err + } + + id, err := result.LastInsertId() + if err != nil { + return nil, err + } + + id32 := int32(id) + list, err := d.ListResources(ctx, &store.FindResource{ID: &id32}) + if err != nil { + return nil, err + } + if len(list) != 1 { + return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list)) + } + + return list[0], nil +} + +func (d *Driver) ListResources(ctx context.Context, find *store.FindResource) ([]*store.Resource, error) { + where, args := []string{"1 = 1"}, []any{} + + if v := find.ID; v != nil { + where, args = append(where, "id = ?"), append(args, *v) + } + if v := find.CreatorID; v != nil { + where, args = append(where, "creator_id = ?"), append(args, *v) + } + if v := find.Filename; v != nil { + where, args = append(where, "filename = ?"), append(args, *v) + } + if v := find.MemoID; v != nil { + where, args = append(where, "memo_id = ?"), append(args, *v) + } + if find.HasRelatedMemo { + where = append(where, "memo_id IS NOT NULL") + } + + fields := []string{"id", "filename", "external_link", "type", "size", "creator_id", "UNIX_TIMESTAMP(created_ts)", "UNIX_TIMESTAMP(updated_ts)", "internal_path", "memo_id"} + if find.GetBlob { + fields = append(fields, "resource.blob") + } + + query := fmt.Sprintf(` + SELECT + %s + FROM resource + WHERE %s + GROUP BY id + ORDER BY created_ts DESC + `, strings.Join(fields, ", "), strings.Join(where, " AND ")) + if find.Limit != nil { + query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit) + if find.Offset != nil { + query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset) + } + } + + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*store.Resource, 0) + for rows.Next() { + resource := store.Resource{} + var memoID sql.NullInt32 + dests := []any{ + &resource.ID, + &resource.Filename, + &resource.ExternalLink, + &resource.Type, + &resource.Size, + &resource.CreatorID, + &resource.CreatedTs, + &resource.UpdatedTs, + &resource.InternalPath, + &memoID, + } + if find.GetBlob { + dests = append(dests, &resource.Blob) + } + if err := rows.Scan(dests...); err != nil { + return nil, err + } + if memoID.Valid { + resource.MemoID = &memoID.Int32 + } + list = append(list, &resource) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *Driver) UpdateResource(ctx context.Context, update *store.UpdateResource) (*store.Resource, error) { + set, args := []string{}, []any{} + + if v := update.UpdatedTs; v != nil { + set, args = append(set, "updated_ts = ?"), append(args, *v) + } + if v := update.Filename; v != nil { + set, args = append(set, "filename = ?"), append(args, *v) + } + if v := update.InternalPath; v != nil { + set, args = append(set, "internal_path = ?"), append(args, *v) + } + if v := update.MemoID; v != nil { + set, args = append(set, "memo_id = ?"), append(args, *v) + } + if update.UnbindMemo { + set = append(set, "memo_id = NULL") + } + if v := update.Blob; v != nil { + set, args = append(set, "resource.blob = ?"), append(args, v) + } + + args = append(args, update.ID) + stmt := ` + UPDATE resource + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + ` + if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil { + return nil, err + } + + list, err := d.ListResources(ctx, &store.FindResource{ID: &update.ID}) + if err != nil { + return nil, err + } + if len(list) != 1 { + return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list)) + } + + return list[0], nil +} + +func (d *Driver) DeleteResource(ctx context.Context, delete *store.DeleteResource) error { + stmt := `DELETE FROM resource WHERE id = ?` + result, err := d.db.ExecContext(ctx, stmt, delete.ID) + if err != nil { + return err + } + if _, err := result.RowsAffected(); err != nil { + return err + } + + if err := d.Vacuum(ctx); err != nil { + // Prevent linter warning. + return err + } + + return nil +} + +func vacuumResource(ctx context.Context, tx *sql.Tx) error { + stmt := ` + DELETE FROM + resource + WHERE + creator_id NOT IN ( + SELECT + id + FROM + user + )` + _, err := tx.ExecContext(ctx, stmt) + if err != nil { + return err + } + + return nil +} diff --git a/store/mysql/seed/10000__reset.sql b/store/mysql/seed/10000__reset.sql new file mode 100644 index 00000000..de4e97c9 --- /dev/null +++ b/store/mysql/seed/10000__reset.sql @@ -0,0 +1,4 @@ +TRUNCATE TABLE memo_organizer; +TRUNCATE TABLE resource; +TRUNCATE TABLE memo; +TRUNCATE TABLE user; diff --git a/store/mysql/seed/10001__user.sql b/store/mysql/seed/10001__user.sql new file mode 100644 index 00000000..4bd60e84 --- /dev/null +++ b/store/mysql/seed/10001__user.sql @@ -0,0 +1,45 @@ +INSERT INTO + user ( + `id`, + `username`, + `role`, + `email`, + `nickname`, + `row_status`, + `avatar_url`, + `password_hash` + ) +VALUES + ( + 101, + 'memos-demo', + 'HOST', + 'demo@usememos.com', + 'Derobot', + 'NORMAL', + '', + -- raw password: secret + '$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK' + ), + ( + 102, + 'jack', + 'USER', + 'jack@usememos.com', + 'Jack', + 'NORMAL', + '', + -- raw password: secret + '$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK' + ), + ( + 103, + 'bob', + 'USER', + 'bob@usememos.com', + 'Bob', + 'ARCHIVED', + '', + -- raw password: secret + '$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK' + ); diff --git a/store/mysql/seed/10002__memo.sql b/store/mysql/seed/10002__memo.sql new file mode 100644 index 00000000..31e3d1da --- /dev/null +++ b/store/mysql/seed/10002__memo.sql @@ -0,0 +1,54 @@ +INSERT INTO + memo (`id`, `content`, `creator_id`) +VALUES + ( + 1, + "#Hello πŸ‘‹ Welcome to memos.", + 101 + ); + +INSERT INTO + memo ( + `id`, + `content`, + `creator_id`, + `visibility` + ) +VALUES + ( + 2, + '#TODO +- [x] Take more photos about **πŸŒ„ sunset** +- [x] Clean the room +- [ ] Read *πŸ“– The Little Prince* +(πŸ‘† click to toggle status)', + 101, + 'PROTECTED' + ), + ( + 3, + "**[Slash](https://github.com/boojack/slash)**: A bookmarking and url shortener, save and share your links very easily. +![](https://github.com/boojack/slash/raw/main/resources/demo.gif) + +**[SQL Chat](https://www.sqlchat.ai)**: Chat-based SQL Client +![](https://www.sqlchat.ai/chat-logo-and-text.webp)", + 101, + 'PUBLIC' + ), + ( + 4, + '#TODO +- [x] Take more photos about **πŸŒ„ sunset** +- [ ] Clean the classroom +- [ ] Watch *πŸ‘¦ The Boys* +(πŸ‘† click to toggle status) +', + 102, + 'PROTECTED' + ), + ( + 5, + 'δΈ‰δΊΊθ‘ŒοΌŒεΏ…ζœ‰ζˆ‘εΈˆη„‰οΌπŸ‘¨β€πŸ«', + 102, + 'PUBLIC' + ); diff --git a/store/mysql/seed/10003__memo_organizer.sql b/store/mysql/seed/10003__memo_organizer.sql new file mode 100644 index 00000000..e1a2c406 --- /dev/null +++ b/store/mysql/seed/10003__memo_organizer.sql @@ -0,0 +1,5 @@ +INSERT INTO + memo_organizer (`memo_id`, `user_id`, `pinned`) +VALUES + (1, 101, 1), + (3, 101, 1); diff --git a/store/mysql/seed/10004__tag.sql b/store/mysql/seed/10004__tag.sql new file mode 100644 index 00000000..40a7c774 --- /dev/null +++ b/store/mysql/seed/10004__tag.sql @@ -0,0 +1,6 @@ +INSERT INTO + tag (`name`, `creator_id`) +VALUES + ('Hello', 101), + ('TODO', 101), + ('TODO', 102); diff --git a/store/mysql/storage.go b/store/mysql/storage.go new file mode 100644 index 00000000..3d47943b --- /dev/null +++ b/store/mysql/storage.go @@ -0,0 +1,137 @@ +package mysql + +import ( + "context" + "strings" + + "github.com/usememos/memos/store" +) + +func (d *Driver) CreateStorage(ctx context.Context, create *store.Storage) (*store.Storage, error) { + stmt := ` + INSERT INTO storage ( + name, + type, + config + ) + VALUES (?, ?, ?) + ` + result, err := d.db.ExecContext(ctx, stmt, create.Name, create.Type, create.Config) + if err != nil { + return nil, err + } + + id, err := result.LastInsertId() + if err != nil { + return nil, err + } + + create.ID = int32(id) + return create, nil +} + +func (d *Driver) ListStorages(ctx context.Context, find *store.FindStorage) ([]*store.Storage, error) { + where, args := []string{"1 = 1"}, []any{} + if find.ID != nil { + where, args = append(where, "id = ?"), append(args, *find.ID) + } + + rows, err := d.db.QueryContext(ctx, ` + SELECT + id, + name, + type, + config + FROM storage + WHERE `+strings.Join(where, " AND ")+` + ORDER BY id DESC`, + args..., + ) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*store.Storage{} + for rows.Next() { + storage := &store.Storage{} + if err := rows.Scan( + &storage.ID, + &storage.Name, + &storage.Type, + &storage.Config, + ); err != nil { + return nil, err + } + list = append(list, storage) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *Driver) GetStorage(ctx context.Context, find *store.FindStorage) (*store.Storage, error) { + list, err := d.ListStorages(ctx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + + return list[0], nil +} + +func (d *Driver) UpdateStorage(ctx context.Context, update *store.UpdateStorage) (*store.Storage, error) { + set, args := []string{}, []any{} + if update.Name != nil { + set = append(set, "name = ?") + args = append(args, *update.Name) + } + if update.Config != nil { + set = append(set, "config = ?") + args = append(args, *update.Config) + } + args = append(args, update.ID) + + stmt := ` + UPDATE storage + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + ` + _, err := d.db.ExecContext(ctx, stmt, args...) + if err != nil { + return nil, err + } + + storage := &store.Storage{} + stmt = `SELECT id,name,type,config FROM storage WHERE id = ?` + if err := d.db.QueryRowContext(ctx, stmt, update.ID).Scan( + &storage.ID, + &storage.Name, + &storage.Type, + &storage.Config, + ); err != nil { + return nil, err + } + + return storage, nil +} + +func (d *Driver) DeleteStorage(ctx context.Context, delete *store.DeleteStorage) error { + stmt := ` + DELETE FROM storage + WHERE id = ? + ` + result, err := d.db.ExecContext(ctx, stmt, delete.ID) + if err != nil { + return err + } + if _, err := result.RowsAffected(); err != nil { + return err + } + return nil +} diff --git a/store/mysql/system_setting.go b/store/mysql/system_setting.go new file mode 100644 index 00000000..6f683ce3 --- /dev/null +++ b/store/mysql/system_setting.go @@ -0,0 +1,72 @@ +package mysql + +import ( + "context" + "strings" + + "github.com/usememos/memos/store" +) + +func (d *Driver) UpsertSystemSetting(ctx context.Context, upsert *store.SystemSetting) (*store.SystemSetting, error) { + stmt := ` + INSERT INTO system_setting ( + name, value, description + ) + VALUES (?, ?, ?) + ON DUPLICATE KEY UPDATE value = ?, description = ? + ` + _, err := d.db.ExecContext( + ctx, + stmt, + upsert.Name, + upsert.Value, + upsert.Description, + upsert.Value, + upsert.Description, + ) + if err != nil { + return nil, err + } + + return upsert, nil +} + +func (d *Driver) ListSystemSettings(ctx context.Context, find *store.FindSystemSetting) ([]*store.SystemSetting, error) { + where, args := []string{"1 = 1"}, []any{} + if find.Name != "" { + where, args = append(where, "name = ?"), append(args, find.Name) + } + + query := ` + SELECT + name, + value, + description + FROM system_setting + WHERE ` + strings.Join(where, " AND ") + + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*store.SystemSetting{} + for rows.Next() { + systemSettingMessage := &store.SystemSetting{} + if err := rows.Scan( + &systemSettingMessage.Name, + &systemSettingMessage.Value, + &systemSettingMessage.Description, + ); err != nil { + return nil, err + } + list = append(list, systemSettingMessage) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} diff --git a/store/mysql/tag.go b/store/mysql/tag.go new file mode 100644 index 00000000..7ca95c04 --- /dev/null +++ b/store/mysql/tag.go @@ -0,0 +1,90 @@ +package mysql + +import ( + "context" + "database/sql" + "strings" + + "github.com/usememos/memos/store" +) + +func (d *Driver) UpsertTag(ctx context.Context, upsert *store.Tag) (*store.Tag, error) { + stmt := ` + INSERT INTO tag (name, creator_id) + VALUES (?, ?) + ON DUPLICATE KEY UPDATE name = ? + ` + if _, err := d.db.ExecContext(ctx, stmt, upsert.Name, upsert.CreatorID, upsert.Name); err != nil { + return nil, err + } + + return upsert, nil +} + +func (d *Driver) ListTags(ctx context.Context, find *store.FindTag) ([]*store.Tag, error) { + where, args := []string{"creator_id = ?"}, []any{find.CreatorID} + query := ` + SELECT + name, + creator_id + FROM tag + WHERE ` + strings.Join(where, " AND ") + ` + ORDER BY name ASC + ` + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*store.Tag{} + for rows.Next() { + tag := &store.Tag{} + if err := rows.Scan( + &tag.Name, + &tag.CreatorID, + ); err != nil { + return nil, err + } + + list = append(list, tag) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *Driver) DeleteTag(ctx context.Context, delete *store.DeleteTag) error { + where, args := []string{"name = ?", "creator_id = ?"}, []any{delete.Name, delete.CreatorID} + stmt := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ") + result, err := d.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + if _, err = result.RowsAffected(); err != nil { + return err + } + return nil +} + +func vacuumTag(ctx context.Context, tx *sql.Tx) error { + stmt := ` + DELETE FROM + tag + WHERE + creator_id NOT IN ( + SELECT + id + FROM + user + )` + _, err := tx.ExecContext(ctx, stmt) + if err != nil { + return err + } + + return nil +} diff --git a/store/mysql/user.go b/store/mysql/user.go new file mode 100644 index 00000000..b5e2277d --- /dev/null +++ b/store/mysql/user.go @@ -0,0 +1,205 @@ +package mysql + +import ( + "context" + "strings" + + "github.com/pkg/errors" + + "github.com/usememos/memos/store" +) + +func (d *Driver) CreateUser(ctx context.Context, create *store.User) (*store.User, error) { + stmt := ` + INSERT INTO user ( + username, + role, + email, + nickname, + password_hash, + avatar_url + ) + VALUES (?, ?, ?, ?, ?, ?) + ` + result, err := d.db.ExecContext(ctx, stmt, + create.Username, + create.Role, + create.Email, + create.Nickname, + create.PasswordHash, + create.AvatarURL, + ) + if err != nil { + return nil, err + } + + id, err := result.LastInsertId() + if err != nil { + return nil, err + } + + id64 := int32(id) + list, err := d.ListUsers(ctx, &store.FindUser{ID: &id64}) + if err != nil { + return nil, err + } + if len(list) != 1 { + return nil, errors.Wrapf(nil, "unexpected user count: %d", len(list)) + } + + return list[0], nil +} + +func (d *Driver) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) { + set, args := []string{}, []any{} + if v := update.UpdatedTs; v != nil { + set, args = append(set, "updated_ts = ?"), append(args, *v) + } + if v := update.RowStatus; v != nil { + set, args = append(set, "row_status = ?"), append(args, *v) + } + if v := update.Username; v != nil { + set, args = append(set, "username = ?"), append(args, *v) + } + if v := update.Email; v != nil { + set, args = append(set, "email = ?"), append(args, *v) + } + if v := update.Nickname; v != nil { + set, args = append(set, "nickname = ?"), append(args, *v) + } + if v := update.AvatarURL; v != nil { + set, args = append(set, "avatar_url = ?"), append(args, *v) + } + if v := update.PasswordHash; v != nil { + set, args = append(set, "password_hash = ?"), append(args, *v) + } + args = append(args, update.ID) + + query := ` + UPDATE user + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + ` + if _, err := d.db.ExecContext(ctx, query, args...); err != nil { + return nil, err + } + + user := &store.User{} + query = ` + SELECT + id, + username, + role, + email, + nickname, + password_hash, + avatar_url, + UNIX_TIMESTAMP(created_ts), + UNIX_TIMESTAMP(updated_ts), + row_status + FROM user WHERE id = ? + ` + if err := d.db.QueryRowContext(ctx, query, update.ID).Scan( + &user.ID, + &user.Username, + &user.Role, + &user.Email, + &user.Nickname, + &user.PasswordHash, + &user.AvatarURL, + &user.CreatedTs, + &user.UpdatedTs, + &user.RowStatus, + ); err != nil { + return nil, err + } + + return user, nil +} + +func (d *Driver) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) { + where, args := []string{"1 = 1"}, []any{} + + if v := find.ID; v != nil { + where, args = append(where, "id = ?"), append(args, *v) + } + if v := find.Username; v != nil { + where, args = append(where, "username = ?"), append(args, *v) + } + if v := find.Role; v != nil { + where, args = append(where, "role = ?"), append(args, *v) + } + if v := find.Email; v != nil { + where, args = append(where, "email = ?"), append(args, *v) + } + if v := find.Nickname; v != nil { + where, args = append(where, "nickname = ?"), append(args, *v) + } + + query := ` + SELECT + id, + username, + role, + email, + nickname, + password_hash, + avatar_url, + UNIX_TIMESTAMP(created_ts), + UNIX_TIMESTAMP(updated_ts), + row_status + FROM user + WHERE ` + strings.Join(where, " AND ") + ` + ORDER BY created_ts DESC, row_status DESC + ` + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*store.User, 0) + for rows.Next() { + var user store.User + if err := rows.Scan( + &user.ID, + &user.Username, + &user.Role, + &user.Email, + &user.Nickname, + &user.PasswordHash, + &user.AvatarURL, + &user.CreatedTs, + &user.UpdatedTs, + &user.RowStatus, + ); err != nil { + return nil, err + } + list = append(list, &user) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *Driver) DeleteUser(ctx context.Context, delete *store.DeleteUser) error { + result, err := d.db.ExecContext(ctx, ` + DELETE FROM user WHERE id = ? + `, delete.ID) + if err != nil { + return err + } + if _, err := result.RowsAffected(); err != nil { + return err + } + + if err := d.Vacuum(ctx); err != nil { + // Prevent linter warning. + return err + } + + return nil +} diff --git a/store/mysql/user_setting.go b/store/mysql/user_setting.go new file mode 100644 index 00000000..7de503f4 --- /dev/null +++ b/store/mysql/user_setting.go @@ -0,0 +1,169 @@ +package mysql + +import ( + "context" + "database/sql" + "strings" + + "github.com/pkg/errors" + "google.golang.org/protobuf/encoding/protojson" + + storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/store" +) + +func (d *Driver) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (*store.UserSetting, error) { + stmt := ` + INSERT INTO user_setting (user_id,user_setting.key,value) + VALUES (?, ?, ?) + ON DUPLICATE KEY UPDATE value = ? + ` + if _, err := d.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key, upsert.Value, upsert.Value); err != nil { + return nil, err + } + + return upsert, nil +} + +func (d *Driver) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*store.UserSetting, error) { + where, args := []string{"1 = 1"}, []any{} + + if v := find.Key; v != "" { + where, args = append(where, "user_setting.key = ?"), append(args, v) + } + if v := find.UserID; v != nil { + where, args = append(where, "user_id = ?"), append(args, *find.UserID) + } + + query := ` + SELECT + user_id, + user_setting.key, + value + FROM user_setting + WHERE ` + strings.Join(where, " AND ") + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + userSettingList := make([]*store.UserSetting, 0) + for rows.Next() { + var userSetting store.UserSetting + if err := rows.Scan( + &userSetting.UserID, + &userSetting.Key, + &userSetting.Value, + ); err != nil { + return nil, err + } + userSettingList = append(userSettingList, &userSetting) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return userSettingList, nil +} + +func (d *Driver) UpsertUserSettingV1(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) { + stmt := ` + INSERT INTO user_setting (user_id, user_setting.key, value) + VALUES (?, ?, ?) + ON DUPLICATE KEY UPDATE value = ? + ` + var valueString string + if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS { + valueBytes, err := protojson.Marshal(upsert.GetAccessTokens()) + if err != nil { + return nil, err + } + valueString = string(valueBytes) + } else { + return nil, errors.New("invalid user setting key") + } + + if _, err := d.db.ExecContext(ctx, stmt, upsert.UserId, upsert.Key.String(), valueString, valueString); err != nil { + return nil, err + } + + return upsert, nil +} + +func (d *Driver) ListUserSettingsV1(ctx context.Context, find *store.FindUserSettingV1) ([]*storepb.UserSetting, error) { + where, args := []string{"1 = 1"}, []any{} + + if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED { + where, args = append(where, "user_setting.key = ?"), append(args, v.String()) + } + if v := find.UserID; v != nil { + where, args = append(where, "user_id = ?"), append(args, *find.UserID) + } + + query := ` + SELECT + user_id, + user_setting.key, + value + FROM user_setting + WHERE ` + strings.Join(where, " AND ") + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + userSettingList := make([]*storepb.UserSetting, 0) + for rows.Next() { + userSetting := &storepb.UserSetting{} + var keyString, valueString string + if err := rows.Scan( + &userSetting.UserId, + &keyString, + &valueString, + ); err != nil { + return nil, err + } + userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString]) + if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS { + accessTokensUserSetting := &storepb.AccessTokensUserSetting{} + if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil { + return nil, err + } + userSetting.Value = &storepb.UserSetting_AccessTokens{ + AccessTokens: accessTokensUserSetting, + } + } else { + // Skip unknown user setting v1 key. + continue + } + userSettingList = append(userSettingList, userSetting) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return userSettingList, nil +} + +func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error { + stmt := ` + DELETE FROM + user_setting + WHERE + user_id NOT IN ( + SELECT + id + FROM + user + )` + _, err := tx.ExecContext(ctx, stmt) + if err != nil { + return err + } + + return nil +}