mirror of
https://github.com/usememos/memos.git
synced 2025-06-05 22:09:59 +02:00
chore: add origin flag to config cors
This commit is contained in:
@@ -39,10 +39,11 @@ var (
|
||||
driver string
|
||||
dsn string
|
||||
serveFrontend bool
|
||||
allowedOrigins []string
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "memos",
|
||||
Short: `An open-source, self-hosted memo hub with knowledge management and social networking.`,
|
||||
Short: `An open source, lightweight note-taking service. Easily capture and share your great thoughts.`,
|
||||
Run: func(_cmd *cobra.Command, _args []string) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
dbDriver, err := db.NewDBDriver(profile)
|
||||
@@ -114,6 +115,7 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&driver, "driver", "", "", "database driver")
|
||||
rootCmd.PersistentFlags().StringVarP(&dsn, "dsn", "", "", "database source name(aka. DSN)")
|
||||
rootCmd.PersistentFlags().BoolVarP(&serveFrontend, "frontend", "", true, "serve frontend files")
|
||||
rootCmd.PersistentFlags().StringArrayVarP(&allowedOrigins, "origins", "", []string{}, "CORS allowed domain origins")
|
||||
|
||||
err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode"))
|
||||
if err != nil {
|
||||
@@ -143,12 +145,17 @@ func init() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = viper.BindPFlag("origins", rootCmd.PersistentFlags().Lookup("origins"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
viper.SetDefault("mode", "demo")
|
||||
viper.SetDefault("driver", "sqlite")
|
||||
viper.SetDefault("addr", "")
|
||||
viper.SetDefault("port", 8081)
|
||||
viper.SetDefault("frontend", true)
|
||||
viper.SetDefault("origins", []string{})
|
||||
viper.SetEnvPrefix("memos")
|
||||
}
|
||||
|
||||
|
@@ -32,6 +32,8 @@ type Profile struct {
|
||||
Version string `json:"version"`
|
||||
// Frontend indicate the frontend is enabled or not
|
||||
Frontend bool `json:"-"`
|
||||
// Origins is the list of allowed origins
|
||||
Origins []string `json:"-"`
|
||||
}
|
||||
|
||||
func (p *Profile) IsDev() bool {
|
||||
|
@@ -49,7 +49,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
|
||||
}
|
||||
|
||||
// Register CORS middleware.
|
||||
e.Use(CORSMiddleware())
|
||||
e.Use(CORSMiddleware(s.Profile.Origins))
|
||||
|
||||
serverID, err := s.getSystemServerID(ctx)
|
||||
if err != nil {
|
||||
@@ -160,7 +160,7 @@ func grpcRequestSkipper(c echo.Context) bool {
|
||||
return strings.HasPrefix(c.Request().URL.Path, "/memos.api.v2.")
|
||||
}
|
||||
|
||||
func CORSMiddleware() echo.MiddlewareFunc {
|
||||
func CORSMiddleware(origins []string) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if grpcRequestSkipper(c) {
|
||||
@@ -170,7 +170,18 @@ func CORSMiddleware() echo.MiddlewareFunc {
|
||||
r := c.Request()
|
||||
w := c.Response().Writer
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
|
||||
requestOrigin := r.Header.Get("Origin")
|
||||
if len(origins) == 0 {
|
||||
w.Header().Set("Access-Control-Allow-Origin", requestOrigin)
|
||||
} else {
|
||||
for _, origin := range origins {
|
||||
if origin == requestOrigin {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
|
Reference in New Issue
Block a user