mirror of
				https://github.com/usememos/memos.git
				synced 2025-06-05 22:09:59 +02:00 
			
		
		
		
	chore: migrate auth package
This commit is contained in:
		| @@ -1,160 +0,0 @@ | |||||||
| package auth |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"log/slog" |  | ||||||
| 	"net/http" |  | ||||||
| 	"strings" |  | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/golang-jwt/jwt/v5" |  | ||||||
| 	"github.com/labstack/echo/v4" |  | ||||||
| 	"github.com/pkg/errors" |  | ||||||
|  |  | ||||||
| 	"github.com/usememos/memos/internal/util" |  | ||||||
| 	storepb "github.com/usememos/memos/proto/gen/store" |  | ||||||
| 	"github.com/usememos/memos/store" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	// UserIDContextKey is the key name used to store user id in the context. |  | ||||||
| 	UserIDContextKey = "user-id" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func extractTokenFromHeader(c echo.Context) (string, error) { |  | ||||||
| 	authHeader := c.Request().Header.Get("Authorization") |  | ||||||
| 	if authHeader == "" { |  | ||||||
| 		return "", nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	authHeaderParts := strings.Fields(authHeader) |  | ||||||
| 	if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { |  | ||||||
| 		return "", errors.New("Authorization header format must be Bearer {token}") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return authHeaderParts[1], nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func findAccessToken(c echo.Context) string { |  | ||||||
| 	// Check the HTTP request header first. |  | ||||||
| 	accessToken, _ := extractTokenFromHeader(c) |  | ||||||
| 	if accessToken == "" { |  | ||||||
| 		// Check the cookie. |  | ||||||
| 		cookie, _ := c.Cookie(AccessTokenCookieName) |  | ||||||
| 		if cookie != nil { |  | ||||||
| 			accessToken = cookie.Value |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return accessToken |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // JWTMiddleware validates the access token. |  | ||||||
| func JWTMiddleware(storeInstance *store.Store, next echo.HandlerFunc, secret string) echo.HandlerFunc { |  | ||||||
| 	return func(c echo.Context) error { |  | ||||||
| 		ctx := c.Request().Context() |  | ||||||
| 		path := c.Request().URL.Path |  | ||||||
|  |  | ||||||
| 		accessToken := findAccessToken(c) |  | ||||||
| 		if accessToken == "" { |  | ||||||
| 			// Allow the user to access the public endpoints. |  | ||||||
| 			if util.HasPrefixes(path, "/o") { |  | ||||||
| 				return next(c) |  | ||||||
| 			} |  | ||||||
| 			return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		userID, err := getUserIDFromAccessToken(accessToken, secret) |  | ||||||
| 		if err != nil { |  | ||||||
| 			err = removeAccessTokenAndCookies(c, storeInstance, userID, accessToken) |  | ||||||
| 			if err != nil { |  | ||||||
| 				slog.Warn("fail to remove AccessToken and Cookies", err) |  | ||||||
| 			} |  | ||||||
| 			return echo.NewHTTPError(http.StatusUnauthorized, "Invalid or expired access token") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		accessTokens, err := storeInstance.GetUserAccessTokens(ctx, userID) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get user access tokens.").WithInternal(err) |  | ||||||
| 		} |  | ||||||
| 		if !validateAccessToken(accessToken, accessTokens) { |  | ||||||
| 			err = removeAccessTokenAndCookies(c, storeInstance, userID, accessToken) |  | ||||||
| 			if err != nil { |  | ||||||
| 				slog.Warn("fail to remove AccessToken and Cookies", err) |  | ||||||
| 			} |  | ||||||
| 			return echo.NewHTTPError(http.StatusUnauthorized, "Invalid access token.") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Even if there is no error, we still need to make sure the user still exists. |  | ||||||
| 		user, err := storeInstance.GetUser(ctx, &store.FindUser{ |  | ||||||
| 			ID: &userID, |  | ||||||
| 		}) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to find user ID: %d", userID)).SetInternal(err) |  | ||||||
| 		} |  | ||||||
| 		if user == nil { |  | ||||||
| 			return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to find user ID: %d", userID)) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Stores userID into context. |  | ||||||
| 		c.Set(UserIDContextKey, userID) |  | ||||||
| 		return next(c) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func getUserIDFromAccessToken(accessToken, secret string) (int32, error) { |  | ||||||
| 	claims := &ClaimsMessage{} |  | ||||||
| 	_, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { |  | ||||||
| 		if t.Method.Alg() != jwt.SigningMethodHS256.Name { |  | ||||||
| 			return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) |  | ||||||
| 		} |  | ||||||
| 		if kid, ok := t.Header["kid"].(string); ok { |  | ||||||
| 			if kid == "v1" { |  | ||||||
| 				return []byte(secret), nil |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 		return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) |  | ||||||
| 	}) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, errors.Wrap(err, "Invalid or expired access token") |  | ||||||
| 	} |  | ||||||
| 	// We either have a valid access token or we will attempt to generate new access token. |  | ||||||
| 	userID, err := util.ConvertStringToInt32(claims.Subject) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, errors.Wrap(err, "Malformed ID in the token") |  | ||||||
| 	} |  | ||||||
| 	return userID, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool { |  | ||||||
| 	for _, userAccessToken := range userAccessTokens { |  | ||||||
| 		if accessTokenString == userAccessToken.AccessToken { |  | ||||||
| 			return true |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return false |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // removeAccessTokenAndCookies removes the jwt token from the cookies. |  | ||||||
| func removeAccessTokenAndCookies(c echo.Context, s *store.Store, userID int32, token string) error { |  | ||||||
| 	err := s.RemoveUserAccessToken(c.Request().Context(), userID, token) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	cookieExp := time.Now().Add(-1 * time.Hour) |  | ||||||
| 	setTokenCookie(c, AccessTokenCookieName, "", cookieExp) |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // setTokenCookie sets the token to the cookie. |  | ||||||
| func setTokenCookie(c echo.Context, name, token string, expiration time.Time) { |  | ||||||
| 	cookie := new(http.Cookie) |  | ||||||
| 	cookie.Name = name |  | ||||||
| 	cookie.Value = token |  | ||||||
| 	cookie.Expires = expiration |  | ||||||
| 	cookie.Path = "/" |  | ||||||
| 	// Http-only helps mitigate the risk of client side script accessing the protected cookie. |  | ||||||
| 	cookie.HttpOnly = true |  | ||||||
| 	cookie.SameSite = http.SameSiteStrictMode |  | ||||||
| 	c.SetCookie(cookie) |  | ||||||
| } |  | ||||||
| @@ -14,7 +14,6 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/usememos/memos/internal/util" | 	"github.com/usememos/memos/internal/util" | ||||||
| 	storepb "github.com/usememos/memos/proto/gen/store" | 	storepb "github.com/usememos/memos/proto/gen/store" | ||||||
| 	"github.com/usememos/memos/server/route/api/auth" |  | ||||||
| 	"github.com/usememos/memos/store" | 	"github.com/usememos/memos/store" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -84,7 +83,7 @@ func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessToken str | |||||||
| 	if accessToken == "" { | 	if accessToken == "" { | ||||||
| 		return "", status.Errorf(codes.Unauthenticated, "access token not found") | 		return "", status.Errorf(codes.Unauthenticated, "access token not found") | ||||||
| 	} | 	} | ||||||
| 	claims := &auth.ClaimsMessage{} | 	claims := &ClaimsMessage{} | ||||||
| 	_, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { | 	_, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { | ||||||
| 		if t.Method.Alg() != jwt.SigningMethodHS256.Name { | 		if t.Method.Alg() != jwt.SigningMethodHS256.Name { | ||||||
| 			return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) | 			return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) | ||||||
| @@ -145,7 +144,7 @@ func getTokenFromMetadata(md metadata.MD) (string, error) { | |||||||
| 		header := http.Header{} | 		header := http.Header{} | ||||||
| 		header.Add("Cookie", t) | 		header.Add("Cookie", t) | ||||||
| 		request := http.Request{Header: header} | 		request := http.Request{Header: header} | ||||||
| 		if v, _ := request.Cookie(auth.AccessTokenCookieName); v != nil { | 		if v, _ := request.Cookie(AccessTokenCookieName); v != nil { | ||||||
| 			accessToken = v.Value | 			accessToken = v.Value | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| package auth | package v1 | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| @@ -20,7 +20,6 @@ import ( | |||||||
| 	"github.com/usememos/memos/plugin/idp/oauth2" | 	"github.com/usememos/memos/plugin/idp/oauth2" | ||||||
| 	v1pb "github.com/usememos/memos/proto/gen/api/v1" | 	v1pb "github.com/usememos/memos/proto/gen/api/v1" | ||||||
| 	storepb "github.com/usememos/memos/proto/gen/store" | 	storepb "github.com/usememos/memos/proto/gen/store" | ||||||
| 	"github.com/usememos/memos/server/route/api/auth" |  | ||||||
| 	"github.com/usememos/memos/store" | 	"github.com/usememos/memos/store" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -57,7 +56,7 @@ func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) | |||||||
| 		return nil, status.Errorf(codes.InvalidArgument, "unmatched email and password") | 		return nil, status.Errorf(codes.InvalidArgument, "unmatched email and password") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	expireTime := time.Now().Add(auth.AccessTokenDuration) | 	expireTime := time.Now().Add(AccessTokenDuration) | ||||||
| 	if request.NeverExpire { | 	if request.NeverExpire { | ||||||
| 		// Set the expire time to 100 years. | 		// Set the expire time to 100 years. | ||||||
| 		expireTime = time.Now().Add(100 * 365 * 24 * time.Hour) | 		expireTime = time.Now().Add(100 * 365 * 24 * time.Hour) | ||||||
| @@ -138,14 +137,14 @@ func (s *APIV1Service) SignInWithSSO(ctx context.Context, request *v1pb.SignInWi | |||||||
| 		return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("user has been archived with username %s", userInfo.Identifier)) | 		return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("user has been archived with username %s", userInfo.Identifier)) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if err := s.doSignIn(ctx, user, time.Now().Add(auth.AccessTokenDuration)); err != nil { | 	if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil { | ||||||
| 		return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err)) | 		return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err)) | ||||||
| 	} | 	} | ||||||
| 	return convertUserFromStore(user), nil | 	return convertUserFromStore(user), nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error { | func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error { | ||||||
| 	accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret)) | 	accessToken, err := GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return status.Errorf(codes.Internal, fmt.Sprintf("failed to generate tokens, err: %s", err)) | 		return status.Errorf(codes.Internal, fmt.Sprintf("failed to generate tokens, err: %s", err)) | ||||||
| 	} | 	} | ||||||
| @@ -208,7 +207,7 @@ func (s *APIV1Service) SignUp(ctx context.Context, request *v1pb.SignUpRequest) | |||||||
| 		return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create user, err: %s", err)) | 		return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create user, err: %s", err)) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if err := s.doSignIn(ctx, user, time.Now().Add(auth.AccessTokenDuration)); err != nil { | 	if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil { | ||||||
| 		return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err)) | 		return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err)) | ||||||
| 	} | 	} | ||||||
| 	return convertUserFromStore(user), nil | 	return convertUserFromStore(user), nil | ||||||
| @@ -236,7 +235,7 @@ func (s *APIV1Service) clearAccessTokenCookie(ctx context.Context) error { | |||||||
|  |  | ||||||
| func (*APIV1Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) { | func (*APIV1Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) { | ||||||
| 	attrs := []string{ | 	attrs := []string{ | ||||||
| 		fmt.Sprintf("%s=%s", auth.AccessTokenCookieName, accessToken), | 		fmt.Sprintf("%s=%s", AccessTokenCookieName, accessToken), | ||||||
| 		"Path=/", | 		"Path=/", | ||||||
| 		"HttpOnly", | 		"HttpOnly", | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -25,7 +25,6 @@ import ( | |||||||
| 	"github.com/usememos/memos/internal/util" | 	"github.com/usememos/memos/internal/util" | ||||||
| 	v1pb "github.com/usememos/memos/proto/gen/api/v1" | 	v1pb "github.com/usememos/memos/proto/gen/api/v1" | ||||||
| 	storepb "github.com/usememos/memos/proto/gen/store" | 	storepb "github.com/usememos/memos/proto/gen/store" | ||||||
| 	"github.com/usememos/memos/server/route/api/auth" |  | ||||||
| 	"github.com/usememos/memos/store" | 	"github.com/usememos/memos/store" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -363,7 +362,7 @@ func (s *APIV1Service) ListUserAccessTokens(ctx context.Context, _ *v1pb.ListUse | |||||||
|  |  | ||||||
| 	accessTokens := []*v1pb.UserAccessToken{} | 	accessTokens := []*v1pb.UserAccessToken{} | ||||||
| 	for _, userAccessToken := range userAccessTokens { | 	for _, userAccessToken := range userAccessTokens { | ||||||
| 		claims := &auth.ClaimsMessage{} | 		claims := &ClaimsMessage{} | ||||||
| 		_, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) { | 		_, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) { | ||||||
| 			if t.Method.Alg() != jwt.SigningMethodHS256.Name { | 			if t.Method.Alg() != jwt.SigningMethodHS256.Name { | ||||||
| 				return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) | 				return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) | ||||||
| @@ -412,12 +411,12 @@ func (s *APIV1Service) CreateUserAccessToken(ctx context.Context, request *v1pb. | |||||||
| 		expiresAt = request.ExpiresAt.AsTime() | 		expiresAt = request.ExpiresAt.AsTime() | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, expiresAt, []byte(s.Secret)) | 	accessToken, err := GenerateAccessToken(user.Username, user.ID, expiresAt, []byte(s.Secret)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err) | 		return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	claims := &auth.ClaimsMessage{} | 	claims := &ClaimsMessage{} | ||||||
| 	_, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { | 	_, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { | ||||||
| 		if t.Method.Alg() != jwt.SigningMethodHS256.Name { | 		if t.Method.Alg() != jwt.SigningMethodHS256.Name { | ||||||
| 			return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) | 			return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user