mirror of
				https://github.com/usememos/memos.git
				synced 2025-06-05 22:09:59 +02:00 
			
		
		
		
	chore: migrate auth package
This commit is contained in:
		| @@ -1,160 +0,0 @@ | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log/slog" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
| 	"github.com/labstack/echo/v4" | ||||
| 	"github.com/pkg/errors" | ||||
|  | ||||
| 	"github.com/usememos/memos/internal/util" | ||||
| 	storepb "github.com/usememos/memos/proto/gen/store" | ||||
| 	"github.com/usememos/memos/store" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	// UserIDContextKey is the key name used to store user id in the context. | ||||
| 	UserIDContextKey = "user-id" | ||||
| ) | ||||
|  | ||||
| func extractTokenFromHeader(c echo.Context) (string, error) { | ||||
| 	authHeader := c.Request().Header.Get("Authorization") | ||||
| 	if authHeader == "" { | ||||
| 		return "", nil | ||||
| 	} | ||||
|  | ||||
| 	authHeaderParts := strings.Fields(authHeader) | ||||
| 	if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { | ||||
| 		return "", errors.New("Authorization header format must be Bearer {token}") | ||||
| 	} | ||||
|  | ||||
| 	return authHeaderParts[1], nil | ||||
| } | ||||
|  | ||||
| func findAccessToken(c echo.Context) string { | ||||
| 	// Check the HTTP request header first. | ||||
| 	accessToken, _ := extractTokenFromHeader(c) | ||||
| 	if accessToken == "" { | ||||
| 		// Check the cookie. | ||||
| 		cookie, _ := c.Cookie(AccessTokenCookieName) | ||||
| 		if cookie != nil { | ||||
| 			accessToken = cookie.Value | ||||
| 		} | ||||
| 	} | ||||
| 	return accessToken | ||||
| } | ||||
|  | ||||
| // JWTMiddleware validates the access token. | ||||
| func JWTMiddleware(storeInstance *store.Store, next echo.HandlerFunc, secret string) echo.HandlerFunc { | ||||
| 	return func(c echo.Context) error { | ||||
| 		ctx := c.Request().Context() | ||||
| 		path := c.Request().URL.Path | ||||
|  | ||||
| 		accessToken := findAccessToken(c) | ||||
| 		if accessToken == "" { | ||||
| 			// Allow the user to access the public endpoints. | ||||
| 			if util.HasPrefixes(path, "/o") { | ||||
| 				return next(c) | ||||
| 			} | ||||
| 			return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") | ||||
| 		} | ||||
|  | ||||
| 		userID, err := getUserIDFromAccessToken(accessToken, secret) | ||||
| 		if err != nil { | ||||
| 			err = removeAccessTokenAndCookies(c, storeInstance, userID, accessToken) | ||||
| 			if err != nil { | ||||
| 				slog.Warn("fail to remove AccessToken and Cookies", err) | ||||
| 			} | ||||
| 			return echo.NewHTTPError(http.StatusUnauthorized, "Invalid or expired access token") | ||||
| 		} | ||||
|  | ||||
| 		accessTokens, err := storeInstance.GetUserAccessTokens(ctx, userID) | ||||
| 		if err != nil { | ||||
| 			return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get user access tokens.").WithInternal(err) | ||||
| 		} | ||||
| 		if !validateAccessToken(accessToken, accessTokens) { | ||||
| 			err = removeAccessTokenAndCookies(c, storeInstance, userID, accessToken) | ||||
| 			if err != nil { | ||||
| 				slog.Warn("fail to remove AccessToken and Cookies", err) | ||||
| 			} | ||||
| 			return echo.NewHTTPError(http.StatusUnauthorized, "Invalid access token.") | ||||
| 		} | ||||
|  | ||||
| 		// Even if there is no error, we still need to make sure the user still exists. | ||||
| 		user, err := storeInstance.GetUser(ctx, &store.FindUser{ | ||||
| 			ID: &userID, | ||||
| 		}) | ||||
| 		if err != nil { | ||||
| 			return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to find user ID: %d", userID)).SetInternal(err) | ||||
| 		} | ||||
| 		if user == nil { | ||||
| 			return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to find user ID: %d", userID)) | ||||
| 		} | ||||
|  | ||||
| 		// Stores userID into context. | ||||
| 		c.Set(UserIDContextKey, userID) | ||||
| 		return next(c) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func getUserIDFromAccessToken(accessToken, secret string) (int32, error) { | ||||
| 	claims := &ClaimsMessage{} | ||||
| 	_, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { | ||||
| 		if t.Method.Alg() != jwt.SigningMethodHS256.Name { | ||||
| 			return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) | ||||
| 		} | ||||
| 		if kid, ok := t.Header["kid"].(string); ok { | ||||
| 			if kid == "v1" { | ||||
| 				return []byte(secret), nil | ||||
| 			} | ||||
| 		} | ||||
| 		return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return 0, errors.Wrap(err, "Invalid or expired access token") | ||||
| 	} | ||||
| 	// We either have a valid access token or we will attempt to generate new access token. | ||||
| 	userID, err := util.ConvertStringToInt32(claims.Subject) | ||||
| 	if err != nil { | ||||
| 		return 0, errors.Wrap(err, "Malformed ID in the token") | ||||
| 	} | ||||
| 	return userID, nil | ||||
| } | ||||
|  | ||||
| func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool { | ||||
| 	for _, userAccessToken := range userAccessTokens { | ||||
| 		if accessTokenString == userAccessToken.AccessToken { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // removeAccessTokenAndCookies removes the jwt token from the cookies. | ||||
| func removeAccessTokenAndCookies(c echo.Context, s *store.Store, userID int32, token string) error { | ||||
| 	err := s.RemoveUserAccessToken(c.Request().Context(), userID, token) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	cookieExp := time.Now().Add(-1 * time.Hour) | ||||
| 	setTokenCookie(c, AccessTokenCookieName, "", cookieExp) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // setTokenCookie sets the token to the cookie. | ||||
| func setTokenCookie(c echo.Context, name, token string, expiration time.Time) { | ||||
| 	cookie := new(http.Cookie) | ||||
| 	cookie.Name = name | ||||
| 	cookie.Value = token | ||||
| 	cookie.Expires = expiration | ||||
| 	cookie.Path = "/" | ||||
| 	// Http-only helps mitigate the risk of client side script accessing the protected cookie. | ||||
| 	cookie.HttpOnly = true | ||||
| 	cookie.SameSite = http.SameSiteStrictMode | ||||
| 	c.SetCookie(cookie) | ||||
| } | ||||
| @@ -14,7 +14,6 @@ import ( | ||||
|  | ||||
| 	"github.com/usememos/memos/internal/util" | ||||
| 	storepb "github.com/usememos/memos/proto/gen/store" | ||||
| 	"github.com/usememos/memos/server/route/api/auth" | ||||
| 	"github.com/usememos/memos/store" | ||||
| ) | ||||
|  | ||||
| @@ -84,7 +83,7 @@ func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessToken str | ||||
| 	if accessToken == "" { | ||||
| 		return "", status.Errorf(codes.Unauthenticated, "access token not found") | ||||
| 	} | ||||
| 	claims := &auth.ClaimsMessage{} | ||||
| 	claims := &ClaimsMessage{} | ||||
| 	_, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { | ||||
| 		if t.Method.Alg() != jwt.SigningMethodHS256.Name { | ||||
| 			return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) | ||||
| @@ -145,7 +144,7 @@ func getTokenFromMetadata(md metadata.MD) (string, error) { | ||||
| 		header := http.Header{} | ||||
| 		header.Add("Cookie", t) | ||||
| 		request := http.Request{Header: header} | ||||
| 		if v, _ := request.Cookie(auth.AccessTokenCookieName); v != nil { | ||||
| 		if v, _ := request.Cookie(AccessTokenCookieName); v != nil { | ||||
| 			accessToken = v.Value | ||||
| 		} | ||||
| 	} | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| package auth | ||||
| package v1 | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| @@ -20,7 +20,6 @@ import ( | ||||
| 	"github.com/usememos/memos/plugin/idp/oauth2" | ||||
| 	v1pb "github.com/usememos/memos/proto/gen/api/v1" | ||||
| 	storepb "github.com/usememos/memos/proto/gen/store" | ||||
| 	"github.com/usememos/memos/server/route/api/auth" | ||||
| 	"github.com/usememos/memos/store" | ||||
| ) | ||||
|  | ||||
| @@ -57,7 +56,7 @@ func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) | ||||
| 		return nil, status.Errorf(codes.InvalidArgument, "unmatched email and password") | ||||
| 	} | ||||
|  | ||||
| 	expireTime := time.Now().Add(auth.AccessTokenDuration) | ||||
| 	expireTime := time.Now().Add(AccessTokenDuration) | ||||
| 	if request.NeverExpire { | ||||
| 		// Set the expire time to 100 years. | ||||
| 		expireTime = time.Now().Add(100 * 365 * 24 * time.Hour) | ||||
| @@ -138,14 +137,14 @@ func (s *APIV1Service) SignInWithSSO(ctx context.Context, request *v1pb.SignInWi | ||||
| 		return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("user has been archived with username %s", userInfo.Identifier)) | ||||
| 	} | ||||
|  | ||||
| 	if err := s.doSignIn(ctx, user, time.Now().Add(auth.AccessTokenDuration)); err != nil { | ||||
| 	if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil { | ||||
| 		return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err)) | ||||
| 	} | ||||
| 	return convertUserFromStore(user), nil | ||||
| } | ||||
|  | ||||
| func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error { | ||||
| 	accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret)) | ||||
| 	accessToken, err := GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret)) | ||||
| 	if err != nil { | ||||
| 		return status.Errorf(codes.Internal, fmt.Sprintf("failed to generate tokens, err: %s", err)) | ||||
| 	} | ||||
| @@ -208,7 +207,7 @@ func (s *APIV1Service) SignUp(ctx context.Context, request *v1pb.SignUpRequest) | ||||
| 		return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create user, err: %s", err)) | ||||
| 	} | ||||
|  | ||||
| 	if err := s.doSignIn(ctx, user, time.Now().Add(auth.AccessTokenDuration)); err != nil { | ||||
| 	if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil { | ||||
| 		return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err)) | ||||
| 	} | ||||
| 	return convertUserFromStore(user), nil | ||||
| @@ -236,7 +235,7 @@ func (s *APIV1Service) clearAccessTokenCookie(ctx context.Context) error { | ||||
|  | ||||
| func (*APIV1Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) { | ||||
| 	attrs := []string{ | ||||
| 		fmt.Sprintf("%s=%s", auth.AccessTokenCookieName, accessToken), | ||||
| 		fmt.Sprintf("%s=%s", AccessTokenCookieName, accessToken), | ||||
| 		"Path=/", | ||||
| 		"HttpOnly", | ||||
| 	} | ||||
|   | ||||
| @@ -25,7 +25,6 @@ import ( | ||||
| 	"github.com/usememos/memos/internal/util" | ||||
| 	v1pb "github.com/usememos/memos/proto/gen/api/v1" | ||||
| 	storepb "github.com/usememos/memos/proto/gen/store" | ||||
| 	"github.com/usememos/memos/server/route/api/auth" | ||||
| 	"github.com/usememos/memos/store" | ||||
| ) | ||||
|  | ||||
| @@ -363,7 +362,7 @@ func (s *APIV1Service) ListUserAccessTokens(ctx context.Context, _ *v1pb.ListUse | ||||
|  | ||||
| 	accessTokens := []*v1pb.UserAccessToken{} | ||||
| 	for _, userAccessToken := range userAccessTokens { | ||||
| 		claims := &auth.ClaimsMessage{} | ||||
| 		claims := &ClaimsMessage{} | ||||
| 		_, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) { | ||||
| 			if t.Method.Alg() != jwt.SigningMethodHS256.Name { | ||||
| 				return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) | ||||
| @@ -412,12 +411,12 @@ func (s *APIV1Service) CreateUserAccessToken(ctx context.Context, request *v1pb. | ||||
| 		expiresAt = request.ExpiresAt.AsTime() | ||||
| 	} | ||||
|  | ||||
| 	accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, expiresAt, []byte(s.Secret)) | ||||
| 	accessToken, err := GenerateAccessToken(user.Username, user.ID, expiresAt, []byte(s.Secret)) | ||||
| 	if err != nil { | ||||
| 		return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	claims := &auth.ClaimsMessage{} | ||||
| 	claims := &ClaimsMessage{} | ||||
| 	_, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { | ||||
| 		if t.Method.Alg() != jwt.SigningMethodHS256.Name { | ||||
| 			return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user