diff --git a/go.mod b/go.mod index 298afad8..b699c101 100644 --- a/go.mod +++ b/go.mod @@ -90,6 +90,7 @@ require ( github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.2.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/soheilhy/cmux v0.1.5 github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect diff --git a/go.sum b/go.sum index f501c3be..820ac3f6 100644 --- a/go.sum +++ b/go.sum @@ -395,6 +395,8 @@ github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= +github.com/soheilhy/cmux v0.1.5 h1:jjzc5WVemNEDTLwv9tlmemhC73tI08BNOIGwBOo10Js= +github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= @@ -507,6 +509,7 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200421231249-e086a090c8fd/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= @@ -546,6 +549,7 @@ golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -556,6 +560,7 @@ golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= diff --git a/server/route/api/v2/v2.go b/server/route/api/v2/v2.go index fa307148..fbfc1942 100644 --- a/server/route/api/v2/v2.go +++ b/server/route/api/v2/v2.go @@ -3,13 +3,10 @@ package v2 import ( "context" "fmt" - "log/slog" - "net" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/improbable-eng/grpc-web/go/grpcweb" "github.com/labstack/echo/v4" - "github.com/pkg/errors" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/reflection" @@ -38,27 +35,17 @@ type APIV2Service struct { Profile *profile.Profile Store *store.Store - grpcServer *grpc.Server - grpcServerPort int + grpcServer *grpc.Server } -func NewAPIV2Service(secret string, profile *profile.Profile, store *store.Store, grpcServerPort int) *APIV2Service { +func NewAPIV2Service(secret string, profile *profile.Profile, store *store.Store, grpcServer *grpc.Server) *APIV2Service { grpc.EnableTracing = true - authProvider := NewGRPCAuthInterceptor(store, secret) - grpcServer := grpc.NewServer( - grpc.ChainUnaryInterceptor( - NewLoggerInterceptor().LoggerInterceptor, - authProvider.AuthenticationInterceptor, - ), - ) apiv2Service := &APIV2Service{ - Secret: secret, - Profile: profile, - Store: store, - grpcServer: grpcServer, - grpcServerPort: grpcServerPort, + Secret: secret, + Profile: profile, + Store: store, + grpcServer: grpcServer, } - apiv2pb.RegisterWorkspaceServiceServer(grpcServer, apiv2Service) apiv2pb.RegisterWorkspaceSettingServiceServer(grpcServer, apiv2Service) apiv2pb.RegisterAuthServiceServer(grpcServer, apiv2Service) @@ -73,21 +60,16 @@ func NewAPIV2Service(secret string, profile *profile.Profile, store *store.Store apiv2pb.RegisterStorageServiceServer(grpcServer, apiv2Service) apiv2pb.RegisterIdentityProviderServiceServer(grpcServer, apiv2Service) reflection.Register(grpcServer) - return apiv2Service } -func (s *APIV2Service) GetGRPCServer() *grpc.Server { - return s.grpcServer -} - // RegisterGateway registers the gRPC-Gateway with the given Echo instance. -func (s *APIV2Service) RegisterGateway(ctx context.Context, e *echo.Echo) error { +func (s *APIV2Service) RegisterGateway(ctx context.Context, echoServer *echo.Echo) error { // Create a client connection to the gRPC Server we just started. // This is where the gRPC-Gateway proxies the requests. conn, err := grpc.DialContext( ctx, - fmt.Sprintf(":%d", s.grpcServerPort), + fmt.Sprintf(":%d", s.Profile.Port), grpc.WithTransportCredentials(insecure.NewCredentials()), ) if err != nil { @@ -134,7 +116,7 @@ func (s *APIV2Service) RegisterGateway(ctx context.Context, e *echo.Echo) error if err := apiv2pb.RegisterIdentityProviderServiceHandler(context.Background(), gwMux, conn); err != nil { return err } - e.Any("/api/v2/*", echo.WrapHandler(gwMux)) + echoServer.Any("/api/v2/*", echo.WrapHandler(gwMux)) // GRPC web proxy. options := []grpcweb.Option{ @@ -144,18 +126,7 @@ func (s *APIV2Service) RegisterGateway(ctx context.Context, e *echo.Echo) error }), } wrappedGrpc := grpcweb.WrapServer(s.grpcServer, options...) - e.Any("/memos.api.v2.*", echo.WrapHandler(wrappedGrpc)) - - // Start gRPC server. - listen, err := net.Listen("tcp", fmt.Sprintf("%s:%d", s.Profile.Addr, s.grpcServerPort)) - if err != nil { - return errors.Wrap(err, "failed to start gRPC server") - } - go func() { - if err := s.grpcServer.Serve(listen); err != nil { - slog.Error("failed to start gRPC server", err) - } - }() + echoServer.Any("/memos.api.v2.*", echo.WrapHandler(wrappedGrpc)) return nil } diff --git a/server/server.go b/server/server.go index 38e79b69..0256492d 100644 --- a/server/server.go +++ b/server/server.go @@ -3,6 +3,8 @@ package server import ( "context" "fmt" + "log/slog" + "net" "net/http" "strings" "time" @@ -10,6 +12,8 @@ import ( "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/pkg/errors" + "github.com/soheilhy/cmux" + "google.golang.org/grpc" storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/server/profile" @@ -23,28 +27,29 @@ import ( ) type Server struct { - e *echo.Echo - ID string Secret string Profile *profile.Profile Store *store.Store + + echoServer *echo.Echo + grpcServer *grpc.Server } func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store) (*Server, error) { - e := echo.New() - e.Debug = true - e.HideBanner = true - e.HidePort = true - s := &Server{ - e: e, Store: store, Profile: profile, } + echoServer := echo.New() + echoServer.Debug = true + echoServer.HideBanner = true + echoServer.HidePort = true + s.echoServer = echoServer + // Register CORS middleware. - e.Use(CORSMiddleware(s.Profile.Origins)) + echoServer.Use(CORSMiddleware(s.Profile.Origins)) workspaceBasicSetting, err := s.getOrUpsertWorkspaceBasicSetting(ctx) if err != nil { @@ -59,17 +64,17 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store s.Secret = secret // Register healthz endpoint. - e.GET("/healthz", func(c echo.Context) error { + echoServer.GET("/healthz", func(c echo.Context) error { return c.String(http.StatusOK, "Service ready.") }) // Only serve frontend when it's enabled. if profile.Frontend { frontendService := frontend.NewFrontendService(profile, store) - frontendService.Serve(ctx, e) + frontendService.Serve(ctx, echoServer) } - rootGroup := e.Group("") + rootGroup := echoServer.Group("") // Register public routes. publicGroup := rootGroup.Group("/o") @@ -83,9 +88,15 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store // Create and register RSS routes. rss.NewRSSService(s.Profile, s.Store).RegisterRoutes(rootGroup) - apiV2Service := apiv2.NewAPIV2Service(s.Secret, profile, store, s.Profile.Port+1) + grpcServer := grpc.NewServer(grpc.ChainUnaryInterceptor( + apiv2.NewLoggerInterceptor().LoggerInterceptor, + apiv2.NewGRPCAuthInterceptor(store, secret).AuthenticationInterceptor, + )) + s.grpcServer = grpcServer + + apiV2Service := apiv2.NewAPIV2Service(s.Secret, profile, store, grpcServer) // Register gRPC gateway as api v2. - if err := apiV2Service.RegisterGateway(ctx, e); err != nil { + if err := apiV2Service.RegisterGateway(ctx, echoServer); err != nil { return nil, errors.Wrap(err, "failed to register gRPC gateway") } @@ -93,8 +104,28 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store } func (s *Server) Start(ctx context.Context) error { - go versionchecker.NewVersionChecker(s.Store, s.Profile).Start(ctx) - return s.e.Start(fmt.Sprintf("%s:%d", s.Profile.Addr, s.Profile.Port)) + address := fmt.Sprintf(":%d", s.Profile.Port) + listener, err := net.Listen("tcp", address) + if err != nil { + return errors.Wrap(err, "failed to listen") + } + + muxServer := cmux.New(listener) + go func() { + grpcListener := muxServer.Match(cmux.HTTP2HeaderField("content-type", "application/grpc")) + if err := s.grpcServer.Serve(grpcListener); err != nil { + slog.Error("failed to serve gRPC", err) + } + }() + go func() { + httpListener := muxServer.Match(cmux.HTTP1Fast(), cmux.Any()) + s.echoServer.Listener = httpListener + if err := s.echoServer.Start(address); err != nil { + slog.Error("failed to start echo server", err) + } + }() + + return muxServer.Serve() } func (s *Server) Shutdown(ctx context.Context) { @@ -102,7 +133,7 @@ func (s *Server) Shutdown(ctx context.Context) { defer cancel() // Shutdown echo server - if err := s.e.Shutdown(ctx); err != nil { + if err := s.echoServer.Shutdown(ctx); err != nil { fmt.Printf("failed to shutdown server, error: %v\n", err) } @@ -114,8 +145,8 @@ func (s *Server) Shutdown(ctx context.Context) { fmt.Printf("memos stopped properly\n") } -func (s *Server) GetEcho() *echo.Echo { - return s.e +func (s *Server) StartRunners(ctx context.Context) { + go versionchecker.NewVersionChecker(s.Store, s.Profile).Start(ctx) } func (s *Server) getOrUpsertWorkspaceBasicSetting(ctx context.Context) (*storepb.WorkspaceBasicSetting, error) {