diff --git a/api/v1/auth.go b/api/v1/auth.go index a1168d3c..8a52edb1 100644 --- a/api/v1/auth.go +++ b/api/v1/auth.go @@ -98,7 +98,7 @@ func (s *APIV1Service) SignIn(c echo.Context) error { return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect login credentials, please try again") } - accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, time.Now().Add(auth.AccessTokenDuration), s.Secret) + accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, time.Now().Add(auth.AccessTokenDuration), s.Secret) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err) } @@ -222,7 +222,7 @@ func (s *APIV1Service) SignInSSO(c echo.Context) error { return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", userInfo.Identifier)) } - accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, time.Now().Add(auth.AccessTokenDuration), s.Secret) + accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, time.Now().Add(auth.AccessTokenDuration), s.Secret) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err) } @@ -318,7 +318,7 @@ func (s *APIV1Service) SignUp(c echo.Context) error { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err) } - accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, time.Now().Add(auth.AccessTokenDuration), s.Secret) + accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, time.Now().Add(auth.AccessTokenDuration), s.Secret) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err) } diff --git a/api/v1/jwt.go b/api/v1/jwt.go index 343f2179..2af4b91f 100644 --- a/api/v1/jwt.go +++ b/api/v1/jwt.go @@ -10,6 +10,7 @@ import ( "github.com/pkg/errors" "github.com/usememos/memos/api/auth" "github.com/usememos/memos/common/util" + storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" ) @@ -66,13 +67,15 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e return next(c) } + println("path", path) + // Skip validation for server status endpoints. if util.HasPrefixes(path, "/api/v1/ping", "/api/v1/idp", "/api/v1/status", "/api/v1/user") && path != "/api/v1/user/me" && method == http.MethodGet { return next(c) } - token := findAccessToken(c) - if token == "" { + accessToken := findAccessToken(c) + if accessToken == "" { // Allow the user to access the public endpoints. if util.HasPrefixes(path, "/o") { return next(c) @@ -85,7 +88,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e } claims := &auth.ClaimsMessage{} - _, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) { + _, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { if t.Method.Alg() != jwt.SigningMethodHS256.Name { return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) } @@ -98,6 +101,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e }) if err != nil { + RemoveTokensAndCookies(c) return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token")) } if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) { @@ -110,6 +114,15 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e return echo.NewHTTPError(http.StatusUnauthorized, "Malformed ID in the token.") } + accessTokens, err := server.Store.GetUserAccessTokens(ctx, userID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get user access tokens.").WithInternal(err) + } + if !validateAccessToken(accessToken, accessTokens) { + RemoveTokensAndCookies(c) + return echo.NewHTTPError(http.StatusUnauthorized, "Invalid access token.") + } + // Even if there is no error, we still need to make sure the user still exists. user, err := server.Store.GetUser(ctx, &store.FindUser{ ID: &userID, @@ -127,13 +140,16 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e } } -func (s *APIV1Service) defaultAuthSkipper(c echo.Context) bool { +func (*APIV1Service) defaultAuthSkipper(c echo.Context) bool { path := c.Path() + return util.HasPrefixes(path, "/api/v1/auth") +} - // Skip auth. - if util.HasPrefixes(path, "/api/v1/auth") { - return true +func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool { + for _, userAccessToken := range userAccessTokens { + if accessTokenString == userAccessToken.AccessToken { + return true + } } - return false } diff --git a/api/v2/acl.go b/api/v2/acl.go index 1affc904..c67f0369 100644 --- a/api/v2/acl.go +++ b/api/v2/acl.go @@ -8,6 +8,7 @@ import ( "github.com/golang-jwt/jwt/v4" "github.com/pkg/errors" "github.com/usememos/memos/api/auth" + storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/store" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -44,12 +45,12 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re if !ok { return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context") } - accessTokenStr, err := getTokenFromMetadata(md) + accessToken, err := getTokenFromMetadata(md) if err != nil { return nil, status.Errorf(codes.Unauthenticated, err.Error()) } - username, err := in.authenticate(ctx, accessTokenStr) + username, err := in.authenticate(ctx, accessToken) if err != nil { if isUnauthorizeAllowedMethod(serverInfo.FullMethod) { return handler(ctx, request) @@ -74,12 +75,12 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re return handler(childCtx, request) } -func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr string) (string, error) { - if accessTokenStr == "" { +func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessToken string) (string, error) { + if accessToken == "" { return "", status.Errorf(codes.Unauthenticated, "access token not found") } claims := &auth.ClaimsMessage{} - _, err := jwt.ParseWithClaims(accessTokenStr, claims, func(t *jwt.Token) (any, error) { + _, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { if t.Method.Alg() != jwt.SigningMethodHS256.Name { return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) } @@ -115,6 +116,14 @@ func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr return "", errors.Errorf("user %q is archived", username) } + accessTokens, err := in.Store.GetUserAccessTokens(ctx, user.ID) + if err != nil { + return "", errors.Wrapf(err, "failed to get user access tokens") + } + if !validateAccessToken(accessToken, accessTokens) { + return "", status.Errorf(codes.Unauthenticated, "invalid access token") + } + return username, nil } @@ -148,3 +157,12 @@ func audienceContains(audience jwt.ClaimStrings, token string) bool { } return false } + +func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool { + for _, userAccessToken := range userAccessTokens { + if accessTokenString == userAccessToken.AccessToken { + return true + } + } + return false +} diff --git a/api/v2/user_service.go b/api/v2/user_service.go index 6a9e3e83..0f980c2c 100644 --- a/api/v2/user_service.go +++ b/api/v2/user_service.go @@ -167,7 +167,7 @@ func (s *UserService) CreateUserAccessToken(ctx context.Context, request *apiv2p return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) } - accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, request.UserAccessToken.ExpiresAt.AsTime(), s.Secret) + accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, request.UserAccessToken.ExpiresAt.AsTime(), s.Secret) if err != nil { return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err) } diff --git a/store/cache.go b/store/cache.go index 2a060f9b..671a6a61 100644 --- a/store/cache.go +++ b/store/cache.go @@ -7,3 +7,7 @@ import ( func getUserSettingCacheKey(userID int32, key string) string { return fmt.Sprintf("%d-%s", userID, key) } + +func getUserSettingV1CacheKey(userID int32, key string) string { + return fmt.Sprintf("%d-%s-v1", userID, key) +} diff --git a/store/user_setting.go b/store/user_setting.go index 1d8301af..014f450f 100644 --- a/store/user_setting.go +++ b/store/user_setting.go @@ -136,7 +136,7 @@ func (s *Store) UpsertUserSettingV1(ctx context.Context, upsert *storepb.UserSet } userSettingMessage := upsert - s.userSettingCache.Store(getUserSettingCacheKey(userSettingMessage.UserId, userSettingMessage.Key.String()), userSettingMessage) + s.userSettingCache.Store(getUserSettingV1CacheKey(userSettingMessage.UserId, userSettingMessage.Key.String()), userSettingMessage) return userSettingMessage, nil } @@ -195,14 +195,14 @@ func (s *Store) ListUserSettingsV1(ctx context.Context, find *FindUserSettingV1) } for _, userSetting := range userSettingList { - s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) + s.userSettingCache.Store(getUserSettingV1CacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) } return userSettingList, nil } func (s *Store) GetUserSettingV1(ctx context.Context, find *FindUserSettingV1) (*storepb.UserSetting, error) { if find.UserID != nil { - if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key.String())); ok { + if cache, ok := s.userSettingCache.Load(getUserSettingV1CacheKey(*find.UserID, find.Key.String())); ok { return cache.(*storepb.UserSetting), nil } } @@ -217,7 +217,7 @@ func (s *Store) GetUserSettingV1(ctx context.Context, find *FindUserSettingV1) ( } userSetting := list[0] - s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) + s.userSettingCache.Store(getUserSettingV1CacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) return userSetting, nil } diff --git a/web/src/components/CreateAccessTokenDialog.tsx b/web/src/components/CreateAccessTokenDialog.tsx new file mode 100644 index 00000000..7be95b44 --- /dev/null +++ b/web/src/components/CreateAccessTokenDialog.tsx @@ -0,0 +1,145 @@ +import { Button, Input, Radio, RadioGroup } from "@mui/joy"; +import axios from "axios"; +import React, { useState } from "react"; +import { toast } from "react-hot-toast"; +import useCurrentUser from "@/hooks/useCurrentUser"; +import useLoading from "@/hooks/useLoading"; +import { useTranslate } from "@/utils/i18n"; +import { generateDialog } from "./Dialog"; +import Icon from "./Icon"; + +interface Props extends DialogProps { + onConfirm: () => void; +} + +const expirationOptions = [ + { + label: "8 hours", + value: 3600 * 8, + }, + { + label: "1 month", + value: 3600 * 24 * 30, + }, + { + label: "Never", + value: 0, + }, +]; + +interface State { + description: string; + expiration: number; +} + +const CreateAccessTokenDialog: React.FC = (props: Props) => { + const { destroy, onConfirm } = props; + const t = useTranslate(); + const currentUser = useCurrentUser(); + const [state, setState] = useState({ + description: "", + expiration: 3600 * 8, + }); + const requestState = useLoading(false); + + const setPartialState = (partialState: Partial) => { + setState({ + ...state, + ...partialState, + }); + }; + + const handleDescriptionInputChange = (e: React.ChangeEvent) => { + setPartialState({ + description: e.target.value, + }); + }; + + const handleRoleInputChange = (e: React.ChangeEvent) => { + setPartialState({ + expiration: Number(e.target.value), + }); + }; + + const handleSaveBtnClick = async () => { + if (!state.description) { + toast.error("Description is required"); + return; + } + + try { + await axios.post(`/api/v2/users/${currentUser.id}/access_tokens`, { + description: state.description, + expiresAt: new Date(Date.now() + state.expiration * 1000), + }); + + onConfirm(); + destroy(); + } catch (error: any) { + console.error(error); + toast.error(error.response.data.message); + } + }; + + return ( + <> +
+

Create access token

+ +
+
+
+ + Description * + +
+ +
+
+
+ + Expiration * + +
+ + {expirationOptions.map((option) => ( + + ))} + +
+
+
+ + +
+
+ + ); +}; + +function showCreateAccessTokenDialog(onConfirm: () => void) { + generateDialog( + { + className: "create-access-token-dialog", + dialogName: "create-access-token-dialog", + }, + CreateAccessTokenDialog, + { + onConfirm, + } + ); +} + +export default showCreateAccessTokenDialog; diff --git a/web/src/components/Settings/AccessTokenSection.tsx b/web/src/components/Settings/AccessTokenSection.tsx new file mode 100644 index 00000000..08f0e1ef --- /dev/null +++ b/web/src/components/Settings/AccessTokenSection.tsx @@ -0,0 +1,148 @@ +import { Button, IconButton } from "@mui/joy"; +import axios from "axios"; +import copy from "copy-to-clipboard"; +import { useEffect, useState } from "react"; +import { toast } from "react-hot-toast"; +import useCurrentUser from "@/hooks/useCurrentUser"; +import { ListUserAccessTokensResponse, UserAccessToken } from "@/types/proto/api/v2/user_service_pb"; +import { useTranslate } from "@/utils/i18n"; +import showCreateAccessTokenDialog from "../CreateAccessTokenDialog"; +import { showCommonDialog } from "../Dialog/CommonDialog"; +import Icon from "../Icon"; + +const listAccessTokens = async (username: string) => { + const { data } = await axios.get(`/api/v2/users/${username}/access_tokens`); + return data.accessTokens; +}; + +const AccessTokenSection = () => { + const t = useTranslate(); + const currentUser = useCurrentUser(); + const [userAccessTokens, setUserAccessTokens] = useState([]); + + useEffect(() => { + listAccessTokens(currentUser.username).then((accessTokens) => { + setUserAccessTokens(accessTokens); + }); + }, []); + + const handleCreateAccessTokenDialogConfirm = async () => { + const accessTokens = await listAccessTokens(currentUser.username); + setUserAccessTokens(accessTokens); + }; + + const copyAccessToken = (accessToken: string) => { + copy(accessToken); + toast.success("Access token copied to clipboard"); + }; + + const handleDeleteAccessToken = async (accessToken: string) => { + showCommonDialog({ + title: "Delete Access Token", + content: `Are you sure to delete access token \`${getFormatedAccessToken(accessToken)}\`? You cannot undo this action.`, + style: "danger", + dialogName: "delete-access-token-dialog", + onConfirm: async () => { + await axios.delete(`/api/v2/users/${currentUser.id}/access_tokens/${accessToken}`); + setUserAccessTokens(userAccessTokens.filter((token) => token.accessToken !== accessToken)); + }, + }); + }; + + const getFormatedAccessToken = (accessToken: string) => { + return `${accessToken.slice(0, 4)}****${accessToken.slice(-4)}`; + }; + + return ( + <> +
+
+
+
+

Access Tokens

+

A list of all access tokens for your account.

+
+
+ +
+
+
+
+
+ + + + + + + + + + + + {userAccessTokens.map((userAccessToken) => ( + + + + + + + + ))} + +
+ Token + + Description + + Created At + + Expires At + + {t("common.delete")} +
+ {getFormatedAccessToken(userAccessToken.accessToken)} + copyAccessToken(userAccessToken.accessToken)} + > + + + + {userAccessToken.description} + + {String(userAccessToken.issuedAt)} + + {String(userAccessToken.expiresAt ?? "Never")} + + { + handleDeleteAccessToken(userAccessToken.accessToken); + }} + > + + +
+
+
+
+
+
+ + ); +}; + +export default AccessTokenSection; diff --git a/web/src/components/Settings/MyAccountSection.tsx b/web/src/components/Settings/MyAccountSection.tsx index 5a2b03f2..0561a249 100644 --- a/web/src/components/Settings/MyAccountSection.tsx +++ b/web/src/components/Settings/MyAccountSection.tsx @@ -4,6 +4,7 @@ import { useTranslate } from "@/utils/i18n"; import showChangePasswordDialog from "../ChangePasswordDialog"; import showUpdateAccountDialog from "../UpdateAccountDialog"; import UserAvatar from "../UserAvatar"; +import AccessTokenSection from "./AccessTokenSection"; const MyAccountSection = () => { const t = useTranslate(); @@ -12,7 +13,7 @@ const MyAccountSection = () => { return ( <>
-

{t("setting.account-section.title")}

+

{t("setting.account-section.title")}

{user.nickname} @@ -27,6 +28,8 @@ const MyAccountSection = () => { {t("setting.account-section.change-password")}
+ +
); diff --git a/web/src/pages/EmbedMemo.tsx b/web/src/pages/EmbedMemo.tsx index 769d0002..a80968cb 100644 --- a/web/src/pages/EmbedMemo.tsx +++ b/web/src/pages/EmbedMemo.tsx @@ -42,7 +42,7 @@ const EmbedMemo = () => { return (
{!loadingState.isLoading && ( -
+
{getDateTimeString(state.memo.displayTs)} @@ -53,7 +53,7 @@ const EmbedMemo = () => { undefined} />
-
+ )}
); diff --git a/web/src/pages/Explore.tsx b/web/src/pages/Explore.tsx index 55ef505f..e3f34905 100644 --- a/web/src/pages/Explore.tsx +++ b/web/src/pages/Explore.tsx @@ -90,7 +90,7 @@ const Explore = () => {
{!loadingState.isLoading && ( -
+
{sortedMemos.map((memo) => { return ; @@ -107,7 +107,7 @@ const Explore = () => { {t("memo.fetch-more")}

)} -
+ )}
); diff --git a/web/src/pages/Home.tsx b/web/src/pages/Home.tsx index 8bf3b73a..f45c5807 100644 --- a/web/src/pages/Home.tsx +++ b/web/src/pages/Home.tsx @@ -13,15 +13,13 @@ const Home = () => { const user = useCurrentUser(); useEffect(() => { - if (user) { - return; - } - - const systemStatus = globalStore.state.systemStatus; - if (systemStatus.disablePublicMemos) { - window.location.href = "/auth"; - } else { - window.location.href = "/explore"; + if (!user) { + const systemStatus = globalStore.state.systemStatus; + if (systemStatus.disablePublicMemos) { + window.location.href = "/auth"; + } else { + window.location.href = "/explore"; + } } }, []); diff --git a/web/src/pages/MemoDetail.tsx b/web/src/pages/MemoDetail.tsx index f775d28c..dbc675b1 100644 --- a/web/src/pages/MemoDetail.tsx +++ b/web/src/pages/MemoDetail.tsx @@ -47,9 +47,9 @@ const MemoDetail = () => { {!loadingState.isLoading && (memo ? ( <> -
+
-
+ ) : ( <> diff --git a/web/src/pages/Setting.tsx b/web/src/pages/Setting.tsx index 0866d633..cd72a7a5 100644 --- a/web/src/pages/Setting.tsx +++ b/web/src/pages/Setting.tsx @@ -49,7 +49,7 @@ const Setting = () => { }; return ( -
+
@@ -100,7 +100,7 @@ const Setting = () => { ) : null}
-
+