chore: add cookie builder

This commit is contained in:
Steven 2024-02-05 23:28:29 +08:00
parent 46ea16ef7e
commit 434ef44f8c
2 changed files with 54 additions and 25 deletions

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"regexp" "regexp"
"strings"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -12,14 +13,12 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"github.com/usememos/memos/api/auth" "github.com/usememos/memos/api/auth"
"github.com/usememos/memos/internal/util" "github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/idp" "github.com/usememos/memos/plugin/idp"
"github.com/usememos/memos/plugin/idp/oauth2" "github.com/usememos/memos/plugin/idp/oauth2"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2" apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/service/metric" "github.com/usememos/memos/server/service/metric"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
@ -31,7 +30,7 @@ func (s *APIV2Service) GetAuthStatus(ctx context.Context, _ *apiv2pb.GetAuthStat
} }
if user == nil { if user == nil {
// Set the cookie header to expire access token. // Set the cookie header to expire access token.
if err := clearAccessTokenCookie(ctx); err != nil { if err := s.clearAccessTokenCookie(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to set grpc header") return nil, status.Errorf(codes.Internal, "failed to set grpc header")
} }
return nil, status.Errorf(codes.Unauthenticated, "user not found") return nil, status.Errorf(codes.Unauthenticated, "user not found")
@ -61,8 +60,8 @@ func (s *APIV2Service) SignIn(ctx context.Context, request *apiv2pb.SignInReques
expireTime := time.Now().Add(auth.AccessTokenDuration) expireTime := time.Now().Add(auth.AccessTokenDuration)
if request.NeverExpire { if request.NeverExpire {
// Zero time means never expire. // Set the expire time to 100 years.
expireTime = time.Time{} expireTime = time.Now().Add(100 * 365 * 24 * time.Hour)
} }
if err := s.doSignIn(ctx, user, expireTime); err != nil { if err := s.doSignIn(ctx, user, expireTime); err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err)) return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
@ -159,13 +158,12 @@ func (s *APIV2Service) doSignIn(ctx context.Context, user *store.User, expireTim
return status.Errorf(codes.Internal, fmt.Sprintf("failed to upsert access token to store, err: %s", err)) return status.Errorf(codes.Internal, fmt.Sprintf("failed to upsert access token to store, err: %s", err))
} }
cookieExpires := time.Now().Add(auth.CookieExpDuration) cookie, err := s.buildAccessTokenCookie(ctx, accessToken, expireTime)
if expireTime.IsZero() { if err != nil {
// Set cookie expires to 100 years. return status.Errorf(codes.Internal, fmt.Sprintf("failed to build access token cookie, err: %s", err))
cookieExpires = time.Now().AddDate(100, 0, 0)
} }
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{ if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
"Set-Cookie": fmt.Sprintf("%s=%s; Path=/; Expires=%s; HttpOnly; SameSite=Strict", auth.AccessTokenCookieName, accessToken, cookieExpires.Format(time.RFC1123)), "Set-Cookie": cookie,
})); err != nil { })); err != nil {
return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err) return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
} }
@ -222,34 +220,46 @@ func (s *APIV2Service) SignUp(ctx context.Context, request *apiv2pb.SignUpReques
}, nil }, nil
} }
func (*APIV2Service) SignOut(ctx context.Context, _ *apiv2pb.SignOutRequest) (*apiv2pb.SignOutResponse, error) { func (s *APIV2Service) SignOut(ctx context.Context, _ *apiv2pb.SignOutRequest) (*apiv2pb.SignOutResponse, error) {
if err := clearAccessTokenCookie(ctx); err != nil { if err := s.clearAccessTokenCookie(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err) return nil, status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err)
} }
return &apiv2pb.SignOutResponse{}, nil return &apiv2pb.SignOutResponse{}, nil
} }
func clearAccessTokenCookie(ctx context.Context) error { func (s *APIV2Service) clearAccessTokenCookie(ctx context.Context) error {
cookie, err := s.buildAccessTokenCookie(ctx, "", time.Time{})
if err != nil {
return errors.Wrap(err, "failed to build access token cookie")
}
if err := grpc.SetHeader(ctx, metadata.New(map[string]string{ if err := grpc.SetHeader(ctx, metadata.New(map[string]string{
"Set-Cookie": fmt.Sprintf("%s=; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT; HttpOnly; SameSite=Strict", auth.AccessTokenCookieName), "Set-Cookie": cookie,
})); err != nil { })); err != nil {
return errors.Wrap(err, "failed to set grpc header") return errors.Wrap(err, "failed to set grpc header")
} }
return nil return nil
} }
func (s *APIV2Service) GetWorkspaceGeneralSetting(ctx context.Context) (*storepb.WorkspaceGeneralSetting, error) { func (s *APIV2Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) {
workspaceSetting, err := s.Store.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{ attrs := []string{
Name: storepb.WorkspaceSettingKey_WORKSPACE_SETTING_GENERAL.String(), fmt.Sprintf("%s=%s", auth.AccessTokenCookieName, accessToken),
}) "Path=/",
"HttpOnly",
}
if expireTime.IsZero() {
attrs = append(attrs, "Expires=Thu, 01 Jan 1970 00:00:00 GMT")
} else {
attrs = append(attrs, "Expires="+expireTime.Format(time.RFC1123))
}
workspaceGeneralSetting, err := s.GetWorkspaceGeneralSetting(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to get workspace setting") return "", errors.Wrap(err, "failed to get workspace setting")
} }
workspaceGeneralSetting := &storepb.WorkspaceGeneralSetting{} if workspaceGeneralSetting.InstanceUrl != "" && strings.HasPrefix(workspaceGeneralSetting.InstanceUrl, "https://") {
if workspaceSetting != nil { attrs = append(attrs, "SameSite=None")
if err := proto.Unmarshal([]byte(workspaceSetting.Value), workspaceGeneralSetting); err != nil { attrs = append(attrs, "Secure")
return nil, errors.Wrap(err, "failed to unmarshal workspace setting") } else {
} attrs = append(attrs, "SameSite=Strict")
} }
return workspaceGeneralSetting, nil return strings.Join(attrs, "; "), nil
} }

View File

@ -6,8 +6,11 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"github.com/pkg/errors"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2" apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
@ -88,3 +91,19 @@ func (s *APIV2Service) UpdateWorkspaceProfile(ctx context.Context, request *apiv
WorkspaceProfile: workspaceProfileMessage.WorkspaceProfile, WorkspaceProfile: workspaceProfileMessage.WorkspaceProfile,
}, nil }, nil
} }
func (s *APIV2Service) GetWorkspaceGeneralSetting(ctx context.Context) (*storepb.WorkspaceGeneralSetting, error) {
workspaceSetting, err := s.Store.GetWorkspaceSetting(ctx, &store.FindWorkspaceSetting{
Name: storepb.WorkspaceSettingKey_WORKSPACE_SETTING_GENERAL.String(),
})
if err != nil {
return nil, errors.Wrap(err, "failed to get workspace setting")
}
workspaceGeneralSetting := &storepb.WorkspaceGeneralSetting{}
if workspaceSetting != nil {
if err := proto.Unmarshal([]byte(workspaceSetting.Value), workspaceGeneralSetting); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal workspace setting")
}
}
return workspaceGeneralSetting, nil
}