mirror of
https://github.com/usememos/memos.git
synced 2025-06-05 22:09:59 +02:00
chore: add origin flag to config cors
This commit is contained in:
@@ -39,10 +39,11 @@ var (
|
|||||||
driver string
|
driver string
|
||||||
dsn string
|
dsn string
|
||||||
serveFrontend bool
|
serveFrontend bool
|
||||||
|
allowedOrigins []string
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "memos",
|
Use: "memos",
|
||||||
Short: `An open-source, self-hosted memo hub with knowledge management and social networking.`,
|
Short: `An open source, lightweight note-taking service. Easily capture and share your great thoughts.`,
|
||||||
Run: func(_cmd *cobra.Command, _args []string) {
|
Run: func(_cmd *cobra.Command, _args []string) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
dbDriver, err := db.NewDBDriver(profile)
|
dbDriver, err := db.NewDBDriver(profile)
|
||||||
@@ -114,6 +115,7 @@ func init() {
|
|||||||
rootCmd.PersistentFlags().StringVarP(&driver, "driver", "", "", "database driver")
|
rootCmd.PersistentFlags().StringVarP(&driver, "driver", "", "", "database driver")
|
||||||
rootCmd.PersistentFlags().StringVarP(&dsn, "dsn", "", "", "database source name(aka. DSN)")
|
rootCmd.PersistentFlags().StringVarP(&dsn, "dsn", "", "", "database source name(aka. DSN)")
|
||||||
rootCmd.PersistentFlags().BoolVarP(&serveFrontend, "frontend", "", true, "serve frontend files")
|
rootCmd.PersistentFlags().BoolVarP(&serveFrontend, "frontend", "", true, "serve frontend files")
|
||||||
|
rootCmd.PersistentFlags().StringArrayVarP(&allowedOrigins, "origins", "", []string{}, "CORS allowed domain origins")
|
||||||
|
|
||||||
err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode"))
|
err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -143,12 +145,17 @@ func init() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
err = viper.BindPFlag("origins", rootCmd.PersistentFlags().Lookup("origins"))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
viper.SetDefault("mode", "demo")
|
viper.SetDefault("mode", "demo")
|
||||||
viper.SetDefault("driver", "sqlite")
|
viper.SetDefault("driver", "sqlite")
|
||||||
viper.SetDefault("addr", "")
|
viper.SetDefault("addr", "")
|
||||||
viper.SetDefault("port", 8081)
|
viper.SetDefault("port", 8081)
|
||||||
viper.SetDefault("frontend", true)
|
viper.SetDefault("frontend", true)
|
||||||
|
viper.SetDefault("origins", []string{})
|
||||||
viper.SetEnvPrefix("memos")
|
viper.SetEnvPrefix("memos")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -32,6 +32,8 @@ type Profile struct {
|
|||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
// Frontend indicate the frontend is enabled or not
|
// Frontend indicate the frontend is enabled or not
|
||||||
Frontend bool `json:"-"`
|
Frontend bool `json:"-"`
|
||||||
|
// Origins is the list of allowed origins
|
||||||
|
Origins []string `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Profile) IsDev() bool {
|
func (p *Profile) IsDev() bool {
|
||||||
|
@@ -49,7 +49,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Register CORS middleware.
|
// Register CORS middleware.
|
||||||
e.Use(CORSMiddleware())
|
e.Use(CORSMiddleware(s.Profile.Origins))
|
||||||
|
|
||||||
serverID, err := s.getSystemServerID(ctx)
|
serverID, err := s.getSystemServerID(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -160,7 +160,7 @@ func grpcRequestSkipper(c echo.Context) bool {
|
|||||||
return strings.HasPrefix(c.Request().URL.Path, "/memos.api.v2.")
|
return strings.HasPrefix(c.Request().URL.Path, "/memos.api.v2.")
|
||||||
}
|
}
|
||||||
|
|
||||||
func CORSMiddleware() echo.MiddlewareFunc {
|
func CORSMiddleware(origins []string) echo.MiddlewareFunc {
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
if grpcRequestSkipper(c) {
|
if grpcRequestSkipper(c) {
|
||||||
@@ -170,7 +170,18 @@ func CORSMiddleware() echo.MiddlewareFunc {
|
|||||||
r := c.Request()
|
r := c.Request()
|
||||||
w := c.Response().Writer
|
w := c.Response().Writer
|
||||||
|
|
||||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
requestOrigin := r.Header.Get("Origin")
|
||||||
|
if len(origins) == 0 {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", requestOrigin)
|
||||||
|
} else {
|
||||||
|
for _, origin := range origins {
|
||||||
|
if origin == requestOrigin {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
|
||||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
Reference in New Issue
Block a user