mirror of
https://github.com/usememos/memos.git
synced 2025-06-05 22:09:59 +02:00
feat: add setup cmd (#1418)
This command can be used for automatization of initial application's setup
This commit is contained in:
48
cmd/memos.go
48
cmd/memos.go
@ -7,12 +7,16 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
|
|
||||||
"github.com/usememos/memos/server"
|
"github.com/usememos/memos/server"
|
||||||
_profile "github.com/usememos/memos/server/profile"
|
_profile "github.com/usememos/memos/server/profile"
|
||||||
|
"github.com/usememos/memos/setup"
|
||||||
|
"github.com/usememos/memos/store"
|
||||||
|
"github.com/usememos/memos/store/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -69,6 +73,40 @@ var (
|
|||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setupCmd = &cobra.Command{
|
||||||
|
Use: "setup",
|
||||||
|
Short: "Make initial setup for memos",
|
||||||
|
Run: func(cmd *cobra.Command, _ []string) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
hostUsername, err := cmd.Flags().GetString(setupCmdFlagHostUsername)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("failed to get owner username, error: %+v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hostPassword, err := cmd.Flags().GetString(setupCmdFlagHostPassword)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("failed to get owner password, error: %+v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
db := db.NewDB(profile)
|
||||||
|
if err := db.Open(ctx); err != nil {
|
||||||
|
fmt.Printf("failed to open db, error: %+v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
st := store.New(db.DBInstance, profile)
|
||||||
|
|
||||||
|
if err := setup.Execute(ctx, st, hostUsername, hostPassword); err != nil {
|
||||||
|
fmt.Printf("failed to setup, error: %+v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func Execute() error {
|
func Execute() error {
|
||||||
@ -98,6 +136,11 @@ func init() {
|
|||||||
viper.SetDefault("mode", "demo")
|
viper.SetDefault("mode", "demo")
|
||||||
viper.SetDefault("port", 8081)
|
viper.SetDefault("port", 8081)
|
||||||
viper.SetEnvPrefix("memos")
|
viper.SetEnvPrefix("memos")
|
||||||
|
|
||||||
|
setupCmd.Flags().String(setupCmdFlagHostUsername, "", "Owner username")
|
||||||
|
setupCmd.Flags().String(setupCmdFlagHostPassword, "", "Owner password")
|
||||||
|
|
||||||
|
rootCmd.AddCommand(setupCmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
func initConfig() {
|
func initConfig() {
|
||||||
@ -117,3 +160,8 @@ func initConfig() {
|
|||||||
println("version:", profile.Version)
|
println("version:", profile.Version)
|
||||||
println("---")
|
println("---")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
setupCmdFlagHostUsername = "host-username"
|
||||||
|
setupCmdFlagHostPassword = "host-password"
|
||||||
|
)
|
||||||
|
1
go.mod
1
go.mod
@ -62,6 +62,7 @@ require (
|
|||||||
github.com/spf13/cast v1.5.0 // indirect
|
github.com/spf13/cast v1.5.0 // indirect
|
||||||
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
||||||
github.com/spf13/pflag v1.0.5 // indirect
|
github.com/spf13/pflag v1.0.5 // indirect
|
||||||
|
github.com/stretchr/objx v0.5.0 // indirect
|
||||||
github.com/subosito/gotenv v1.4.2 // indirect
|
github.com/subosito/gotenv v1.4.2 // indirect
|
||||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||||
github.com/valyala/fasttemplate v1.2.1 // indirect
|
github.com/valyala/fasttemplate v1.2.1 // indirect
|
||||||
|
1
go.sum
1
go.sum
@ -244,6 +244,7 @@ github.com/spf13/viper v1.15.0 h1:js3yy885G8xwJa6iOISGFwd+qlUo5AvyXb7CiihdtiU=
|
|||||||
github.com/spf13/viper v1.15.0/go.mod h1:fFcTBJxvhhzSJiZy8n+PeW6t8l+KeT/uTARa0jHOQLA=
|
github.com/spf13/viper v1.15.0/go.mod h1:fFcTBJxvhhzSJiZy8n+PeW6t8l+KeT/uTARa0jHOQLA=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
|
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
90
setup/setup.go
Normal file
90
setup/setup.go
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
package setup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/api"
|
||||||
|
"github.com/usememos/memos/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Execute(
|
||||||
|
ctx context.Context,
|
||||||
|
store store,
|
||||||
|
hostUsername, hostPassword string,
|
||||||
|
) error {
|
||||||
|
s := setupService{store: store}
|
||||||
|
return s.Setup(ctx, hostUsername, hostPassword)
|
||||||
|
}
|
||||||
|
|
||||||
|
type store interface {
|
||||||
|
FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error)
|
||||||
|
CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type setupService struct {
|
||||||
|
store store
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s setupService) Setup(
|
||||||
|
ctx context.Context,
|
||||||
|
hostUsername, hostPassword string,
|
||||||
|
) error {
|
||||||
|
if err := s.makeSureHostUserNotExists(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.createUser(ctx, hostUsername, hostPassword); err != nil {
|
||||||
|
return fmt.Errorf("create user: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s setupService) makeSureHostUserNotExists(ctx context.Context) error {
|
||||||
|
hostUserType := api.Host
|
||||||
|
existedHostUsers, err := s.store.FindUserList(ctx, &api.UserFind{
|
||||||
|
Role: &hostUserType,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("find user list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(existedHostUsers) != 0 {
|
||||||
|
return errors.New("host user already exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s setupService) createUser(
|
||||||
|
ctx context.Context,
|
||||||
|
hostUsername, hostPassword string,
|
||||||
|
) error {
|
||||||
|
userCreate := &api.UserCreate{
|
||||||
|
Username: hostUsername,
|
||||||
|
// The new signup user should be normal user by default.
|
||||||
|
Role: api.Host,
|
||||||
|
Nickname: hostUsername,
|
||||||
|
Password: hostPassword,
|
||||||
|
OpenID: common.GenUUID(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := userCreate.Validate(); err != nil {
|
||||||
|
return fmt.Errorf("validate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte(hostPassword), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("hash password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userCreate.PasswordHash = string(passwordHash)
|
||||||
|
if _, err := s.store.CreateUser(ctx, userCreate); err != nil {
|
||||||
|
return fmt.Errorf("create user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
181
setup/setup_test.go
Normal file
181
setup/setup_test.go
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
package setup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSetupService_makeSureHostUserNotExists(t *testing.T) {
|
||||||
|
cc := map[string]struct {
|
||||||
|
setupStore func(*storeMock)
|
||||||
|
expectedErr string
|
||||||
|
}{
|
||||||
|
"failed to get list": {
|
||||||
|
setupStore: func(m *storeMock) {
|
||||||
|
hostUserType := api.Host
|
||||||
|
m.
|
||||||
|
On("FindUserList", mock.Anything, &api.UserFind{
|
||||||
|
Role: &hostUserType,
|
||||||
|
}).
|
||||||
|
Return(nil, errors.New("fake error"))
|
||||||
|
},
|
||||||
|
expectedErr: "find user list: fake error",
|
||||||
|
},
|
||||||
|
"success, not empty": {
|
||||||
|
setupStore: func(m *storeMock) {
|
||||||
|
hostUserType := api.Host
|
||||||
|
m.
|
||||||
|
On("FindUserList", mock.Anything, &api.UserFind{
|
||||||
|
Role: &hostUserType,
|
||||||
|
}).
|
||||||
|
Return([]*api.User{
|
||||||
|
{},
|
||||||
|
}, nil)
|
||||||
|
},
|
||||||
|
expectedErr: "host user already exists",
|
||||||
|
},
|
||||||
|
"success, empty": {
|
||||||
|
setupStore: func(m *storeMock) {
|
||||||
|
hostUserType := api.Host
|
||||||
|
m.
|
||||||
|
On("FindUserList", mock.Anything, &api.UserFind{
|
||||||
|
Role: &hostUserType,
|
||||||
|
}).
|
||||||
|
Return(nil, nil)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, c := range cc {
|
||||||
|
c := c
|
||||||
|
t.Run(n, func(t *testing.T) {
|
||||||
|
sm := newStoreMock(t)
|
||||||
|
if c.setupStore != nil {
|
||||||
|
c.setupStore(sm)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := setupService{store: sm}
|
||||||
|
err := srv.makeSureHostUserNotExists(context.Background())
|
||||||
|
if c.expectedErr == "" {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
assert.EqualError(t, err, c.expectedErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetupService_createUser(t *testing.T) {
|
||||||
|
expectedCreated := &api.UserCreate{
|
||||||
|
Username: "demohero",
|
||||||
|
Role: api.Host,
|
||||||
|
Nickname: "demohero",
|
||||||
|
Password: "123456",
|
||||||
|
}
|
||||||
|
|
||||||
|
userCreateMatcher := mock.MatchedBy(func(arg *api.UserCreate) bool {
|
||||||
|
return arg.Username == expectedCreated.Username &&
|
||||||
|
arg.Role == expectedCreated.Role &&
|
||||||
|
arg.Nickname == expectedCreated.Nickname &&
|
||||||
|
arg.Password == expectedCreated.Password &&
|
||||||
|
arg.PasswordHash != ""
|
||||||
|
})
|
||||||
|
|
||||||
|
cc := map[string]struct {
|
||||||
|
setupStore func(*storeMock)
|
||||||
|
hostUsername, hostPassword string
|
||||||
|
expectedErr string
|
||||||
|
}{
|
||||||
|
`username == "", password == ""`: {
|
||||||
|
expectedErr: "validate: username is too short, minimum length is 3",
|
||||||
|
},
|
||||||
|
`username == "", password != ""`: {
|
||||||
|
hostPassword: expectedCreated.Password,
|
||||||
|
expectedErr: "validate: username is too short, minimum length is 3",
|
||||||
|
},
|
||||||
|
`username != "", password == ""`: {
|
||||||
|
hostUsername: expectedCreated.Username,
|
||||||
|
expectedErr: "validate: password is too short, minimum length is 6",
|
||||||
|
},
|
||||||
|
"failed to create": {
|
||||||
|
setupStore: func(m *storeMock) {
|
||||||
|
m.
|
||||||
|
On("CreateUser", mock.Anything, userCreateMatcher).
|
||||||
|
Return(nil, errors.New("fake error"))
|
||||||
|
},
|
||||||
|
hostUsername: expectedCreated.Username,
|
||||||
|
hostPassword: expectedCreated.Password,
|
||||||
|
expectedErr: "create user: fake error",
|
||||||
|
},
|
||||||
|
"success": {
|
||||||
|
setupStore: func(m *storeMock) {
|
||||||
|
m.
|
||||||
|
On("CreateUser", mock.Anything, userCreateMatcher).
|
||||||
|
Return(nil, nil)
|
||||||
|
},
|
||||||
|
hostUsername: expectedCreated.Username,
|
||||||
|
hostPassword: expectedCreated.Password,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, c := range cc {
|
||||||
|
c := c
|
||||||
|
t.Run(n, func(t *testing.T) {
|
||||||
|
sm := newStoreMock(t)
|
||||||
|
if c.setupStore != nil {
|
||||||
|
c.setupStore(sm)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := setupService{store: sm}
|
||||||
|
err := srv.createUser(context.Background(), c.hostUsername, c.hostPassword)
|
||||||
|
if c.expectedErr == "" {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
assert.EqualError(t, err, c.expectedErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type storeMock struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *storeMock) FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error) {
|
||||||
|
ret := m.Called(ctx, find)
|
||||||
|
|
||||||
|
var uu []*api.User
|
||||||
|
ret1 := ret.Get(0)
|
||||||
|
if ret1 != nil {
|
||||||
|
uu = ret1.([]*api.User)
|
||||||
|
}
|
||||||
|
|
||||||
|
return uu, ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *storeMock) CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error) {
|
||||||
|
ret := m.Called(ctx, create)
|
||||||
|
|
||||||
|
var u *api.User
|
||||||
|
ret1 := ret.Get(0)
|
||||||
|
if ret1 != nil {
|
||||||
|
u = ret1.(*api.User)
|
||||||
|
}
|
||||||
|
|
||||||
|
return u, ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStoreMock(t *testing.T) *storeMock {
|
||||||
|
m := &storeMock{}
|
||||||
|
m.Mock.Test(t)
|
||||||
|
|
||||||
|
t.Cleanup(func() { m.AssertExpectations(t) })
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
Reference in New Issue
Block a user