chore: add origin flag to config cors

This commit is contained in:
Steven
2024-04-07 22:15:15 +08:00
parent b5893aa60b
commit 8101a5e0b1
3 changed files with 32 additions and 12 deletions

View File

@@ -39,10 +39,11 @@ var (
driver string driver string
dsn string dsn string
serveFrontend bool serveFrontend bool
allowedOrigins []string
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "memos", Use: "memos",
Short: `An open-source, self-hosted memo hub with knowledge management and social networking.`, Short: `An open source, lightweight note-taking service. Easily capture and share your great thoughts.`,
Run: func(_cmd *cobra.Command, _args []string) { Run: func(_cmd *cobra.Command, _args []string) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
dbDriver, err := db.NewDBDriver(profile) dbDriver, err := db.NewDBDriver(profile)
@@ -114,6 +115,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&driver, "driver", "", "", "database driver") rootCmd.PersistentFlags().StringVarP(&driver, "driver", "", "", "database driver")
rootCmd.PersistentFlags().StringVarP(&dsn, "dsn", "", "", "database source name(aka. DSN)") rootCmd.PersistentFlags().StringVarP(&dsn, "dsn", "", "", "database source name(aka. DSN)")
rootCmd.PersistentFlags().BoolVarP(&serveFrontend, "frontend", "", true, "serve frontend files") rootCmd.PersistentFlags().BoolVarP(&serveFrontend, "frontend", "", true, "serve frontend files")
rootCmd.PersistentFlags().StringArrayVarP(&allowedOrigins, "origins", "", []string{}, "CORS allowed domain origins")
err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode")) err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode"))
if err != nil { if err != nil {
@@ -143,12 +145,17 @@ func init() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = viper.BindPFlag("origins", rootCmd.PersistentFlags().Lookup("origins"))
if err != nil {
panic(err)
}
viper.SetDefault("mode", "demo") viper.SetDefault("mode", "demo")
viper.SetDefault("driver", "sqlite") viper.SetDefault("driver", "sqlite")
viper.SetDefault("addr", "") viper.SetDefault("addr", "")
viper.SetDefault("port", 8081) viper.SetDefault("port", 8081)
viper.SetDefault("frontend", true) viper.SetDefault("frontend", true)
viper.SetDefault("origins", []string{})
viper.SetEnvPrefix("memos") viper.SetEnvPrefix("memos")
} }

View File

@@ -32,6 +32,8 @@ type Profile struct {
Version string `json:"version"` Version string `json:"version"`
// Frontend indicate the frontend is enabled or not // Frontend indicate the frontend is enabled or not
Frontend bool `json:"-"` Frontend bool `json:"-"`
// Origins is the list of allowed origins
Origins []string `json:"-"`
} }
func (p *Profile) IsDev() bool { func (p *Profile) IsDev() bool {

View File

@@ -49,7 +49,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
} }
// Register CORS middleware. // Register CORS middleware.
e.Use(CORSMiddleware()) e.Use(CORSMiddleware(s.Profile.Origins))
serverID, err := s.getSystemServerID(ctx) serverID, err := s.getSystemServerID(ctx)
if err != nil { if err != nil {
@@ -160,7 +160,7 @@ func grpcRequestSkipper(c echo.Context) bool {
return strings.HasPrefix(c.Request().URL.Path, "/memos.api.v2.") return strings.HasPrefix(c.Request().URL.Path, "/memos.api.v2.")
} }
func CORSMiddleware() echo.MiddlewareFunc { func CORSMiddleware(origins []string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
if grpcRequestSkipper(c) { if grpcRequestSkipper(c) {
@@ -170,7 +170,18 @@ func CORSMiddleware() echo.MiddlewareFunc {
r := c.Request() r := c.Request()
w := c.Response().Writer w := c.Response().Writer
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) requestOrigin := r.Header.Get("Origin")
if len(origins) == 0 {
w.Header().Set("Access-Control-Allow-Origin", requestOrigin)
} else {
for _, origin := range origins {
if origin == requestOrigin {
w.Header().Set("Access-Control-Allow-Origin", origin)
break
}
}
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Allow-Credentials", "true")