diff --git a/store/driver.go b/store/driver.go index 6cba2187..3d17b362 100644 --- a/store/driver.go +++ b/store/driver.go @@ -7,4 +7,9 @@ type Driver interface { UpsertSystemSetting(ctx context.Context, upsert *SystemSetting) (*SystemSetting, error) ListSystemSettings(ctx context.Context, find *FindSystemSetting) ([]*SystemSetting, error) + + CreateUser(ctx context.Context, create *User) (*User, error) + UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) + ListUsers(ctx context.Context, find *FindUser) ([]*User, error) + DeleteUser(ctx context.Context, delete *DeleteUser) error } diff --git a/store/sqlite3/user.go b/store/sqlite3/user.go new file mode 100644 index 00000000..64e101f2 --- /dev/null +++ b/store/sqlite3/user.go @@ -0,0 +1,172 @@ +package sqlite3 + +import ( + "context" + "strings" + + "github.com/usememos/memos/store" +) + +func (d *Driver) CreateUser(ctx context.Context, create *store.User) (*store.User, error) { + stmt := ` + INSERT INTO user ( + username, + role, + email, + nickname, + password_hash + ) + VALUES (?, ?, ?, ?, ?) + RETURNING id, avatar_url, created_ts, updated_ts, row_status + ` + if err := d.db.QueryRowContext( + ctx, + stmt, + create.Username, + create.Role, + create.Email, + create.Nickname, + create.PasswordHash, + ).Scan( + &create.ID, + &create.AvatarURL, + &create.CreatedTs, + &create.UpdatedTs, + &create.RowStatus, + ); err != nil { + return nil, err + } + + return create, nil +} + +func (d *Driver) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) { + set, args := []string{}, []any{} + if v := update.UpdatedTs; v != nil { + set, args = append(set, "updated_ts = ?"), append(args, *v) + } + if v := update.RowStatus; v != nil { + set, args = append(set, "row_status = ?"), append(args, *v) + } + if v := update.Username; v != nil { + set, args = append(set, "username = ?"), append(args, *v) + } + if v := update.Email; v != nil { + set, args = append(set, "email = ?"), append(args, *v) + } + if v := update.Nickname; v != nil { + set, args = append(set, "nickname = ?"), append(args, *v) + } + if v := update.AvatarURL; v != nil { + set, args = append(set, "avatar_url = ?"), append(args, *v) + } + if v := update.PasswordHash; v != nil { + set, args = append(set, "password_hash = ?"), append(args, *v) + } + args = append(args, update.ID) + + query := ` + UPDATE user + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + RETURNING id, username, role, email, nickname, password_hash, avatar_url, created_ts, updated_ts, row_status + ` + user := &store.User{} + if err := d.db.QueryRowContext(ctx, query, args...).Scan( + &user.ID, + &user.Username, + &user.Role, + &user.Email, + &user.Nickname, + &user.PasswordHash, + &user.AvatarURL, + &user.CreatedTs, + &user.UpdatedTs, + &user.RowStatus, + ); err != nil { + return nil, err + } + + return user, nil +} + +func (d *Driver) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) { + where, args := []string{"1 = 1"}, []any{} + + if v := find.ID; v != nil { + where, args = append(where, "id = ?"), append(args, *v) + } + if v := find.Username; v != nil { + where, args = append(where, "username = ?"), append(args, *v) + } + if v := find.Role; v != nil { + where, args = append(where, "role = ?"), append(args, *v) + } + if v := find.Email; v != nil { + where, args = append(where, "email = ?"), append(args, *v) + } + if v := find.Nickname; v != nil { + where, args = append(where, "nickname = ?"), append(args, *v) + } + + query := ` + SELECT + id, + username, + role, + email, + nickname, + password_hash, + avatar_url, + created_ts, + updated_ts, + row_status + FROM user + WHERE ` + strings.Join(where, " AND ") + ` + ORDER BY created_ts DESC, row_status DESC + ` + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*store.User, 0) + for rows.Next() { + var user store.User + if err := rows.Scan( + &user.ID, + &user.Username, + &user.Role, + &user.Email, + &user.Nickname, + &user.PasswordHash, + &user.AvatarURL, + &user.CreatedTs, + &user.UpdatedTs, + &user.RowStatus, + ); err != nil { + return nil, err + } + list = append(list, &user) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *Driver) DeleteUser(ctx context.Context, delete *store.DeleteUser) error { + result, err := d.db.ExecContext(ctx, ` + DELETE FROM user WHERE id = ? + `, delete.ID) + if err != nil { + return err + } + if _, err := result.RowsAffected(); err != nil { + return err + } + return nil +} diff --git a/store/user.go b/store/user.go index 6245b0e9..17948fa7 100644 --- a/store/user.go +++ b/store/user.go @@ -2,7 +2,6 @@ package store import ( "context" - "strings" ) // Role is the type of a role. @@ -74,84 +73,18 @@ type DeleteUser struct { } func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { - stmt := ` - INSERT INTO user ( - username, - role, - email, - nickname, - password_hash - ) - VALUES (?, ?, ?, ?, ?) - RETURNING id, avatar_url, created_ts, updated_ts, row_status - ` - if err := s.db.QueryRowContext( - ctx, - stmt, - create.Username, - create.Role, - create.Email, - create.Nickname, - create.PasswordHash, - ).Scan( - &create.ID, - &create.AvatarURL, - &create.CreatedTs, - &create.UpdatedTs, - &create.RowStatus, - ); err != nil { + user, err := s.driver.CreateUser(ctx, create) + if err != nil { return nil, err } - user := create s.userCache.Store(user.ID, user) return user, nil } func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) { - set, args := []string{}, []any{} - if v := update.UpdatedTs; v != nil { - set, args = append(set, "updated_ts = ?"), append(args, *v) - } - if v := update.RowStatus; v != nil { - set, args = append(set, "row_status = ?"), append(args, *v) - } - if v := update.Username; v != nil { - set, args = append(set, "username = ?"), append(args, *v) - } - if v := update.Email; v != nil { - set, args = append(set, "email = ?"), append(args, *v) - } - if v := update.Nickname; v != nil { - set, args = append(set, "nickname = ?"), append(args, *v) - } - if v := update.AvatarURL; v != nil { - set, args = append(set, "avatar_url = ?"), append(args, *v) - } - if v := update.PasswordHash; v != nil { - set, args = append(set, "password_hash = ?"), append(args, *v) - } - args = append(args, update.ID) - - query := ` - UPDATE user - SET ` + strings.Join(set, ", ") + ` - WHERE id = ? - RETURNING id, username, role, email, nickname, password_hash, avatar_url, created_ts, updated_ts, row_status - ` - user := &User{} - if err := s.db.QueryRowContext(ctx, query, args...).Scan( - &user.ID, - &user.Username, - &user.Role, - &user.Email, - &user.Nickname, - &user.PasswordHash, - &user.AvatarURL, - &user.CreatedTs, - &user.UpdatedTs, - &user.RowStatus, - ); err != nil { + user, err := s.driver.UpdateUser(ctx, update) + if err != nil { return nil, err } @@ -160,69 +93,10 @@ func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, erro } func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) { - where, args := []string{"1 = 1"}, []any{} - - if v := find.ID; v != nil { - where, args = append(where, "id = ?"), append(args, *v) - } - if v := find.Username; v != nil { - where, args = append(where, "username = ?"), append(args, *v) - } - if v := find.Role; v != nil { - where, args = append(where, "role = ?"), append(args, *v) - } - if v := find.Email; v != nil { - where, args = append(where, "email = ?"), append(args, *v) - } - if v := find.Nickname; v != nil { - where, args = append(where, "nickname = ?"), append(args, *v) - } - - query := ` - SELECT - id, - username, - role, - email, - nickname, - password_hash, - avatar_url, - created_ts, - updated_ts, - row_status - FROM user - WHERE ` + strings.Join(where, " AND ") + ` - ORDER BY created_ts DESC, row_status DESC - ` - rows, err := s.db.QueryContext(ctx, query, args...) + list, err := s.driver.ListUsers(ctx, find) if err != nil { return nil, err } - defer rows.Close() - - list := make([]*User, 0) - for rows.Next() { - var user User - if err := rows.Scan( - &user.ID, - &user.Username, - &user.Role, - &user.Email, - &user.Nickname, - &user.PasswordHash, - &user.AvatarURL, - &user.CreatedTs, - &user.UpdatedTs, - &user.RowStatus, - ); err != nil { - return nil, err - } - list = append(list, &user) - } - - if err := rows.Err(); err != nil { - return nil, err - } for _, user := range list { s.userCache.Store(user.ID, user) @@ -251,15 +125,11 @@ func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) { } func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error { - result, err := s.db.ExecContext(ctx, ` - DELETE FROM user WHERE id = ? - `, delete.ID) + err := s.driver.DeleteUser(ctx, delete) if err != nil { return err } - if _, err := result.RowsAffected(); err != nil { - return err - } + if err := s.Vacuum(ctx); err != nil { // Prevent linter warning. return err