diff --git a/server/route/api/auth/jwt.go b/server/route/api/auth/jwt.go deleted file mode 100644 index 36191491..00000000 --- a/server/route/api/auth/jwt.go +++ /dev/null @@ -1,160 +0,0 @@ -package auth - -import ( - "fmt" - "log/slog" - "net/http" - "strings" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/labstack/echo/v4" - "github.com/pkg/errors" - - "github.com/usememos/memos/internal/util" - storepb "github.com/usememos/memos/proto/gen/store" - "github.com/usememos/memos/store" -) - -const ( - // UserIDContextKey is the key name used to store user id in the context. - UserIDContextKey = "user-id" -) - -func extractTokenFromHeader(c echo.Context) (string, error) { - authHeader := c.Request().Header.Get("Authorization") - if authHeader == "" { - return "", nil - } - - authHeaderParts := strings.Fields(authHeader) - if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "", errors.New("Authorization header format must be Bearer {token}") - } - - return authHeaderParts[1], nil -} - -func findAccessToken(c echo.Context) string { - // Check the HTTP request header first. - accessToken, _ := extractTokenFromHeader(c) - if accessToken == "" { - // Check the cookie. - cookie, _ := c.Cookie(AccessTokenCookieName) - if cookie != nil { - accessToken = cookie.Value - } - } - return accessToken -} - -// JWTMiddleware validates the access token. -func JWTMiddleware(storeInstance *store.Store, next echo.HandlerFunc, secret string) echo.HandlerFunc { - return func(c echo.Context) error { - ctx := c.Request().Context() - path := c.Request().URL.Path - - accessToken := findAccessToken(c) - if accessToken == "" { - // Allow the user to access the public endpoints. - if util.HasPrefixes(path, "/o") { - return next(c) - } - return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") - } - - userID, err := getUserIDFromAccessToken(accessToken, secret) - if err != nil { - err = removeAccessTokenAndCookies(c, storeInstance, userID, accessToken) - if err != nil { - slog.Warn("fail to remove AccessToken and Cookies", err) - } - return echo.NewHTTPError(http.StatusUnauthorized, "Invalid or expired access token") - } - - accessTokens, err := storeInstance.GetUserAccessTokens(ctx, userID) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get user access tokens.").WithInternal(err) - } - if !validateAccessToken(accessToken, accessTokens) { - err = removeAccessTokenAndCookies(c, storeInstance, userID, accessToken) - if err != nil { - slog.Warn("fail to remove AccessToken and Cookies", err) - } - return echo.NewHTTPError(http.StatusUnauthorized, "Invalid access token.") - } - - // Even if there is no error, we still need to make sure the user still exists. - user, err := storeInstance.GetUser(ctx, &store.FindUser{ - ID: &userID, - }) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to find user ID: %d", userID)).SetInternal(err) - } - if user == nil { - return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to find user ID: %d", userID)) - } - - // Stores userID into context. - c.Set(UserIDContextKey, userID) - return next(c) - } -} - -func getUserIDFromAccessToken(accessToken, secret string) (int32, error) { - claims := &ClaimsMessage{} - _, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { - if t.Method.Alg() != jwt.SigningMethodHS256.Name { - return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) - } - if kid, ok := t.Header["kid"].(string); ok { - if kid == "v1" { - return []byte(secret), nil - } - } - return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) - }) - if err != nil { - return 0, errors.Wrap(err, "Invalid or expired access token") - } - // We either have a valid access token or we will attempt to generate new access token. - userID, err := util.ConvertStringToInt32(claims.Subject) - if err != nil { - return 0, errors.Wrap(err, "Malformed ID in the token") - } - return userID, nil -} - -func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool { - for _, userAccessToken := range userAccessTokens { - if accessTokenString == userAccessToken.AccessToken { - return true - } - } - return false -} - -// removeAccessTokenAndCookies removes the jwt token from the cookies. -func removeAccessTokenAndCookies(c echo.Context, s *store.Store, userID int32, token string) error { - err := s.RemoveUserAccessToken(c.Request().Context(), userID, token) - if err != nil { - return err - } - - cookieExp := time.Now().Add(-1 * time.Hour) - setTokenCookie(c, AccessTokenCookieName, "", cookieExp) - return nil -} - -// setTokenCookie sets the token to the cookie. -func setTokenCookie(c echo.Context, name, token string, expiration time.Time) { - cookie := new(http.Cookie) - cookie.Name = name - cookie.Value = token - cookie.Expires = expiration - cookie.Path = "/" - // Http-only helps mitigate the risk of client side script accessing the protected cookie. - cookie.HttpOnly = true - cookie.SameSite = http.SameSiteStrictMode - c.SetCookie(cookie) -} diff --git a/server/route/api/v1/acl.go b/server/route/api/v1/acl.go index ead98266..10bbdaa3 100644 --- a/server/route/api/v1/acl.go +++ b/server/route/api/v1/acl.go @@ -14,7 +14,6 @@ import ( "github.com/usememos/memos/internal/util" storepb "github.com/usememos/memos/proto/gen/store" - "github.com/usememos/memos/server/route/api/auth" "github.com/usememos/memos/store" ) @@ -84,7 +83,7 @@ func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessToken str if accessToken == "" { return "", status.Errorf(codes.Unauthenticated, "access token not found") } - claims := &auth.ClaimsMessage{} + claims := &ClaimsMessage{} _, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { if t.Method.Alg() != jwt.SigningMethodHS256.Name { return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) @@ -145,7 +144,7 @@ func getTokenFromMetadata(md metadata.MD) (string, error) { header := http.Header{} header.Add("Cookie", t) request := http.Request{Header: header} - if v, _ := request.Cookie(auth.AccessTokenCookieName); v != nil { + if v, _ := request.Cookie(AccessTokenCookieName); v != nil { accessToken = v.Value } } diff --git a/server/route/api/auth/auth.go b/server/route/api/v1/auth.go similarity index 99% rename from server/route/api/auth/auth.go rename to server/route/api/v1/auth.go index 5a46d010..78868f42 100644 --- a/server/route/api/auth/auth.go +++ b/server/route/api/v1/auth.go @@ -1,4 +1,4 @@ -package auth +package v1 import ( "fmt" diff --git a/server/route/api/v1/auth_service.go b/server/route/api/v1/auth_service.go index b8785e63..ebef2f42 100644 --- a/server/route/api/v1/auth_service.go +++ b/server/route/api/v1/auth_service.go @@ -20,7 +20,6 @@ import ( "github.com/usememos/memos/plugin/idp/oauth2" v1pb "github.com/usememos/memos/proto/gen/api/v1" storepb "github.com/usememos/memos/proto/gen/store" - "github.com/usememos/memos/server/route/api/auth" "github.com/usememos/memos/store" ) @@ -57,7 +56,7 @@ func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) return nil, status.Errorf(codes.InvalidArgument, "unmatched email and password") } - expireTime := time.Now().Add(auth.AccessTokenDuration) + expireTime := time.Now().Add(AccessTokenDuration) if request.NeverExpire { // Set the expire time to 100 years. expireTime = time.Now().Add(100 * 365 * 24 * time.Hour) @@ -138,14 +137,14 @@ func (s *APIV1Service) SignInWithSSO(ctx context.Context, request *v1pb.SignInWi return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("user has been archived with username %s", userInfo.Identifier)) } - if err := s.doSignIn(ctx, user, time.Now().Add(auth.AccessTokenDuration)); err != nil { + if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil { return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err)) } return convertUserFromStore(user), nil } func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error { - accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret)) + accessToken, err := GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret)) if err != nil { return status.Errorf(codes.Internal, fmt.Sprintf("failed to generate tokens, err: %s", err)) } @@ -208,7 +207,7 @@ func (s *APIV1Service) SignUp(ctx context.Context, request *v1pb.SignUpRequest) return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create user, err: %s", err)) } - if err := s.doSignIn(ctx, user, time.Now().Add(auth.AccessTokenDuration)); err != nil { + if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil { return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err)) } return convertUserFromStore(user), nil @@ -236,7 +235,7 @@ func (s *APIV1Service) clearAccessTokenCookie(ctx context.Context) error { func (*APIV1Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) { attrs := []string{ - fmt.Sprintf("%s=%s", auth.AccessTokenCookieName, accessToken), + fmt.Sprintf("%s=%s", AccessTokenCookieName, accessToken), "Path=/", "HttpOnly", } diff --git a/server/route/api/v1/user_service.go b/server/route/api/v1/user_service.go index 8c8fdfff..91bd5ee2 100644 --- a/server/route/api/v1/user_service.go +++ b/server/route/api/v1/user_service.go @@ -25,7 +25,6 @@ import ( "github.com/usememos/memos/internal/util" v1pb "github.com/usememos/memos/proto/gen/api/v1" storepb "github.com/usememos/memos/proto/gen/store" - "github.com/usememos/memos/server/route/api/auth" "github.com/usememos/memos/store" ) @@ -363,7 +362,7 @@ func (s *APIV1Service) ListUserAccessTokens(ctx context.Context, _ *v1pb.ListUse accessTokens := []*v1pb.UserAccessToken{} for _, userAccessToken := range userAccessTokens { - claims := &auth.ClaimsMessage{} + claims := &ClaimsMessage{} _, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) { if t.Method.Alg() != jwt.SigningMethodHS256.Name { return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) @@ -412,12 +411,12 @@ func (s *APIV1Service) CreateUserAccessToken(ctx context.Context, request *v1pb. expiresAt = request.ExpiresAt.AsTime() } - accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, expiresAt, []byte(s.Secret)) + accessToken, err := GenerateAccessToken(user.Username, user.ID, expiresAt, []byte(s.Secret)) if err != nil { return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err) } - claims := &auth.ClaimsMessage{} + claims := &ClaimsMessage{} _, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { if t.Method.Alg() != jwt.SigningMethodHS256.Name { return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)