From 88cc37e37f84bda55f2df7a8fba64a989938b6ad Mon Sep 17 00:00:00 2001 From: Todd Martin <106564991+trmartin4@users.noreply.github.com> Date: Wed, 19 Jun 2024 11:51:12 -0400 Subject: [PATCH] Auth/pm 7672/Update token service to return new token from state (#9706) * Changed return structure * Object changes * Added missing assert. * Updated tests to use SetTokensResult * Fixed constructor * PM-7672 - Fix tests + add new setTokens test around refresh token * Removed change to refreshIdentityToken. * Updated return definition. --------- Co-authored-by: Jared Snider --- .../src/auth/abstractions/token.service.ts | 17 +-- .../auth/models/domain/set-tokens-result.ts | 10 ++ .../src/auth/services/token.service.spec.ts | 100 +++++++++++++++--- .../common/src/auth/services/token.service.ts | 87 ++++++++++----- libs/common/src/services/api.service.ts | 21 ++-- 5 files changed, 173 insertions(+), 62 deletions(-) create mode 100644 libs/common/src/auth/models/domain/set-tokens-result.ts diff --git a/libs/common/src/auth/abstractions/token.service.ts b/libs/common/src/auth/abstractions/token.service.ts index a88dfbb278..c86b5f1ee3 100644 --- a/libs/common/src/auth/abstractions/token.service.ts +++ b/libs/common/src/auth/abstractions/token.service.ts @@ -3,6 +3,7 @@ import { Observable } from "rxjs"; import { VaultTimeoutAction } from "../../enums/vault-timeout-action.enum"; import { UserId } from "../../types/guid"; import { VaultTimeout } from "../../types/vault-timeout.type"; +import { SetTokensResult } from "../models/domain/set-tokens-result"; import { DecodedAccessToken } from "../services/token.service"; export abstract class TokenService { @@ -23,7 +24,7 @@ export abstract class TokenService { * @param refreshToken The optional refresh token to set. Note: this is undefined when using the CLI Login Via API Key flow * @param clientIdClientSecret The API Key Client ID and Client Secret to set. * - * @returns A promise that resolves when the tokens have been set. + * @returns A promise that resolves with the SetTokensResult containing the tokens that were set. */ setTokens: ( accessToken: string, @@ -31,7 +32,7 @@ export abstract class TokenService { vaultTimeout: VaultTimeout, refreshToken?: string, clientIdClientSecret?: [string, string], - ) => Promise; + ) => Promise; /** * Clears the access token, refresh token, API Key Client ID, and API Key Client Secret out of memory, disk, and secure storage if supported. @@ -47,13 +48,13 @@ export abstract class TokenService { * @param accessToken The access token to set. * @param vaultTimeoutAction The action to take when the vault times out. * @param vaultTimeout The timeout for the vault. - * @returns A promise that resolves when the access token has been set. + * @returns A promise that resolves with the access token that has been set. */ setAccessToken: ( accessToken: string, vaultTimeoutAction: VaultTimeoutAction, vaultTimeout: VaultTimeout, - ) => Promise; + ) => Promise; // TODO: revisit having this public clear method approach once the state service is fully deprecated. /** @@ -86,14 +87,14 @@ export abstract class TokenService { * @param clientId The API Key Client ID to set. * @param vaultTimeoutAction The action to take when the vault times out. * @param vaultTimeout The timeout for the vault. - * @returns A promise that resolves when the API Key Client ID has been set. + * @returns A promise that resolves with the API Key Client ID that has been set. */ setClientId: ( clientId: string, vaultTimeoutAction: VaultTimeoutAction, vaultTimeout: VaultTimeout, userId?: UserId, - ) => Promise; + ) => Promise; /** * Gets the API Key Client ID for the active user. @@ -106,14 +107,14 @@ export abstract class TokenService { * @param clientSecret The API Key Client Secret to set. * @param vaultTimeoutAction The action to take when the vault times out. * @param vaultTimeout The timeout for the vault. - * @returns A promise that resolves when the API Key Client Secret has been set. + * @returns A promise that resolves with the client secret that has been set. */ setClientSecret: ( clientSecret: string, vaultTimeoutAction: VaultTimeoutAction, vaultTimeout: VaultTimeout, userId?: UserId, - ) => Promise; + ) => Promise; /** * Gets the API Key Client Secret for the active user. diff --git a/libs/common/src/auth/models/domain/set-tokens-result.ts b/libs/common/src/auth/models/domain/set-tokens-result.ts new file mode 100644 index 0000000000..3d72edd0d3 --- /dev/null +++ b/libs/common/src/auth/models/domain/set-tokens-result.ts @@ -0,0 +1,10 @@ +export class SetTokensResult { + constructor(accessToken: string, refreshToken?: string, clientIdSecretPair?: [string, string]) { + this.accessToken = accessToken; + this.refreshToken = refreshToken; + this.clientIdSecretPair = clientIdSecretPair; + } + accessToken: string; + refreshToken?: string; + clientIdSecretPair?: [string, string]; +} diff --git a/libs/common/src/auth/services/token.service.spec.ts b/libs/common/src/auth/services/token.service.spec.ts index d7a4c52716..4be945de5f 100644 --- a/libs/common/src/auth/services/token.service.spec.ts +++ b/libs/common/src/auth/services/token.service.spec.ts @@ -15,6 +15,7 @@ import { SymmetricCryptoKey } from "../../platform/models/domain/symmetric-crypt import { CsprngArray } from "../../types/csprng"; import { UserId } from "../../types/guid"; import { VaultTimeout, VaultTimeoutStringType } from "../../types/vault-timeout.type"; +import { SetTokensResult } from "../models/domain/set-tokens-result"; import { ACCOUNT_ACTIVE_ACCOUNT_ID } from "./account.service"; import { @@ -232,7 +233,7 @@ describe("TokenService", () => { describe("Memory storage tests", () => { it("set the access token in memory", async () => { // Act - await tokenService.setAccessToken( + const result = await tokenService.setAccessToken( accessTokenJwt, memoryVaultTimeoutAction, memoryVaultTimeout, @@ -241,13 +242,14 @@ describe("TokenService", () => { expect( singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY).nextMock, ).toHaveBeenCalledWith(accessTokenJwt); + expect(result).toEqual(accessTokenJwt); }); }); describe("Disk storage tests (secure storage not supported on platform)", () => { it("should set the access token in disk", async () => { // Act - await tokenService.setAccessToken( + const result = await tokenService.setAccessToken( accessTokenJwt, diskVaultTimeoutAction, diskVaultTimeout, @@ -256,6 +258,7 @@ describe("TokenService", () => { expect( singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK).nextMock, ).toHaveBeenCalledWith(accessTokenJwt); + expect(result).toEqual(accessTokenJwt); }); }); @@ -295,7 +298,7 @@ describe("TokenService", () => { secureStorageService.get.mockResolvedValueOnce(null).mockResolvedValue(accessTokenKeyB64); // Act - await tokenService.setAccessToken( + const result = await tokenService.setAccessToken( accessTokenJwt, diskVaultTimeoutAction, diskVaultTimeout, @@ -318,6 +321,9 @@ describe("TokenService", () => { expect( singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY).nextMock, ).toHaveBeenCalledWith(null); + + // assert that the decrypted access token was returned + expect(result).toEqual(accessTokenJwt); }); it("should fallback to disk storage for the access token if the access token cannot be set in secure storage", async () => { @@ -331,7 +337,7 @@ describe("TokenService", () => { secureStorageService.get.mockResolvedValueOnce(null).mockResolvedValue(null); // Act - await tokenService.setAccessToken( + const result = await tokenService.setAccessToken( accessTokenJwt, diskVaultTimeoutAction, diskVaultTimeout, @@ -355,6 +361,9 @@ describe("TokenService", () => { expect( singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK).nextMock, ).toHaveBeenCalledWith(accessTokenJwt); + + // assert that the decrypted access token was returned + expect(result).toEqual(accessTokenJwt); }); it("should fallback to disk storage for the access token if secure storage errors on trying to get an existing access token key", async () => { @@ -368,7 +377,7 @@ describe("TokenService", () => { secureStorageService.get.mockRejectedValue(new Error(secureStorageError)); // Act - await tokenService.setAccessToken( + const result = await tokenService.setAccessToken( accessTokenJwt, diskVaultTimeoutAction, diskVaultTimeout, @@ -385,6 +394,9 @@ describe("TokenService", () => { expect( singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK).nextMock, ).toHaveBeenCalledWith(accessTokenJwt); + + // assert that the decrypted access token was returned + expect(result).toEqual(accessTokenJwt); }); }); }); @@ -2376,18 +2388,21 @@ describe("TokenService", () => { const clientId = "clientId"; const clientSecret = "clientSecret"; - (tokenService as any)._setAccessToken = jest.fn(); // any hack allows for mocking private method. - (tokenService as any).setRefreshToken = jest.fn(); - tokenService.setClientId = jest.fn(); - tokenService.setClientSecret = jest.fn(); + (tokenService as any)._setAccessToken = jest.fn().mockReturnValue(accessTokenJwt); + (tokenService as any).setRefreshToken = jest.fn().mockReturnValue(refreshToken); + tokenService.setClientId = jest.fn().mockReturnValue(clientId); + tokenService.setClientSecret = jest.fn().mockReturnValue(clientSecret); // Act // Note: passing a valid access token so that a valid user id can be determined from the access token - await tokenService.setTokens(accessTokenJwt, vaultTimeoutAction, vaultTimeout, refreshToken, [ - clientId, - clientSecret, - ]); + const result = await tokenService.setTokens( + accessTokenJwt, + vaultTimeoutAction, + vaultTimeout, + refreshToken, + [clientId, clientSecret], + ); // Assert expect((tokenService as any)._setAccessToken).toHaveBeenCalledWith( @@ -2417,6 +2432,44 @@ describe("TokenService", () => { vaultTimeout, userIdFromAccessToken, ); + + expect(result).toStrictEqual( + new SetTokensResult(accessTokenJwt, refreshToken, [clientId, clientSecret]), + ); + }); + + it("does not try to set the refresh token when it is not passed in", async () => { + // Arrange + const vaultTimeoutAction = VaultTimeoutAction.Lock; + const vaultTimeout = 30; + + (tokenService as any)._setAccessToken = jest.fn().mockReturnValue(accessTokenJwt); + (tokenService as any).setRefreshToken = jest.fn(); + tokenService.setClientId = jest.fn(); + tokenService.setClientSecret = jest.fn(); + + // Act + const result = await tokenService.setTokens( + accessTokenJwt, + vaultTimeoutAction, + vaultTimeout, + null, + ); + + // Assert + expect((tokenService as any)._setAccessToken).toHaveBeenCalledWith( + accessTokenJwt, + vaultTimeoutAction, + vaultTimeout, + userIdFromAccessToken, + ); + + // any hack allows for testing private methods + expect((tokenService as any).setRefreshToken).not.toHaveBeenCalled(); + expect(tokenService.setClientId).not.toHaveBeenCalled(); + expect(tokenService.setClientSecret).not.toHaveBeenCalled(); + + expect(result).toStrictEqual(new SetTokensResult(accessTokenJwt)); }); it("does not try to set client id and client secret when they are not passed in", async () => { @@ -2425,13 +2478,18 @@ describe("TokenService", () => { const vaultTimeoutAction = VaultTimeoutAction.Lock; const vaultTimeout = 30; - (tokenService as any)._setAccessToken = jest.fn(); - (tokenService as any).setRefreshToken = jest.fn(); + (tokenService as any)._setAccessToken = jest.fn().mockReturnValue(accessTokenJwt); + (tokenService as any).setRefreshToken = jest.fn().mockReturnValue(refreshToken); tokenService.setClientId = jest.fn(); tokenService.setClientSecret = jest.fn(); // Act - await tokenService.setTokens(accessTokenJwt, vaultTimeoutAction, vaultTimeout, refreshToken); + const result = await tokenService.setTokens( + accessTokenJwt, + vaultTimeoutAction, + vaultTimeout, + refreshToken, + ); // Assert expect((tokenService as any)._setAccessToken).toHaveBeenCalledWith( @@ -2451,6 +2509,8 @@ describe("TokenService", () => { expect(tokenService.setClientId).not.toHaveBeenCalled(); expect(tokenService.setClientSecret).not.toHaveBeenCalled(); + + expect(result).toStrictEqual(new SetTokensResult(accessTokenJwt, refreshToken)); }); it("throws an error when the access token is invalid", async () => { @@ -2535,10 +2595,16 @@ describe("TokenService", () => { (tokenService as any).setRefreshToken = jest.fn(); // Act - await tokenService.setTokens(accessTokenJwt, vaultTimeoutAction, vaultTimeout, refreshToken); + const result = await tokenService.setTokens( + accessTokenJwt, + vaultTimeoutAction, + vaultTimeout, + refreshToken, + ); // Assert expect((tokenService as any).setRefreshToken).not.toHaveBeenCalled(); + expect(result).toStrictEqual(new SetTokensResult(accessTokenJwt)); }); }); diff --git a/libs/common/src/auth/services/token.service.ts b/libs/common/src/auth/services/token.service.ts index 38d0a77b52..ef7f23cb05 100644 --- a/libs/common/src/auth/services/token.service.ts +++ b/libs/common/src/auth/services/token.service.ts @@ -21,6 +21,7 @@ import { import { UserId } from "../../types/guid"; import { VaultTimeout, VaultTimeoutStringType } from "../../types/vault-timeout.type"; import { TokenService as TokenServiceAbstraction } from "../abstractions/token.service"; +import { SetTokensResult } from "../models/domain/set-tokens-result"; import { ACCOUNT_ACTIVE_ACCOUNT_ID } from "./account.service"; import { @@ -160,7 +161,7 @@ export class TokenService implements TokenServiceAbstraction { vaultTimeout: VaultTimeout, refreshToken?: string, clientIdClientSecret?: [string, string], - ): Promise { + ): Promise { if (!accessToken) { throw new Error("Access token is required."); } @@ -181,16 +182,40 @@ export class TokenService implements TokenServiceAbstraction { throw new Error("User id not found. Cannot set tokens."); } - await this._setAccessToken(accessToken, vaultTimeoutAction, vaultTimeout, userId); + const newAccessToken = await this._setAccessToken( + accessToken, + vaultTimeoutAction, + vaultTimeout, + userId, + ); + + const newTokens = new SetTokensResult(newAccessToken); if (refreshToken) { - await this.setRefreshToken(refreshToken, vaultTimeoutAction, vaultTimeout, userId); + newTokens.refreshToken = await this.setRefreshToken( + refreshToken, + vaultTimeoutAction, + vaultTimeout, + userId, + ); } if (clientIdClientSecret != null) { - await this.setClientId(clientIdClientSecret[0], vaultTimeoutAction, vaultTimeout, userId); - await this.setClientSecret(clientIdClientSecret[1], vaultTimeoutAction, vaultTimeout, userId); + const clientId = await this.setClientId( + clientIdClientSecret[0], + vaultTimeoutAction, + vaultTimeout, + userId, + ); + const clientSecret = await this.setClientSecret( + clientIdClientSecret[1], + vaultTimeoutAction, + vaultTimeout, + userId, + ); + newTokens.clientIdSecretPair = [clientId, clientSecret]; } + return newTokens; } private async getAccessTokenKey(userId: UserId): Promise { @@ -289,7 +314,7 @@ export class TokenService implements TokenServiceAbstraction { vaultTimeoutAction: VaultTimeoutAction, vaultTimeout: VaultTimeout, userId: UserId, - ): Promise { + ): Promise { const storageLocation = await this.determineStorageLocation( vaultTimeoutAction, vaultTimeout, @@ -302,6 +327,8 @@ export class TokenService implements TokenServiceAbstraction { // store the access token directly. Instead, we encrypt with accessTokenKey and store that // in secure storage. + let decryptedAccessToken: string = null; + try { const encryptedAccessToken: EncString = await this.encryptAccessToken( accessToken, @@ -313,6 +340,10 @@ export class TokenService implements TokenServiceAbstraction { .get(userId, ACCESS_TOKEN_DISK) .update((_) => encryptedAccessToken.encryptedString); + // If we've successfully stored the encrypted access token to disk, we can return the decrypted access token + // so that the caller can use it immediately. + decryptedAccessToken = accessToken; + // TODO: PM-6408 // 2024-02-20: Remove access token from memory so that we migrate to encrypt the access token over time. // Remove this call to remove the access token from memory after 3 months. @@ -324,25 +355,23 @@ export class TokenService implements TokenServiceAbstraction { ); // Fall back to disk storage for unecrypted access token - await this.singleUserStateProvider + decryptedAccessToken = await this.singleUserStateProvider .get(userId, ACCESS_TOKEN_DISK) .update((_) => accessToken); } - return; + return decryptedAccessToken; } case TokenStorageLocation.Disk: // Access token stored on disk unencrypted as platform does not support secure storage - await this.singleUserStateProvider + return await this.singleUserStateProvider .get(userId, ACCESS_TOKEN_DISK) .update((_) => accessToken); - return; case TokenStorageLocation.Memory: // Access token stored in memory due to vault timeout settings - await this.singleUserStateProvider + return await this.singleUserStateProvider .get(userId, ACCESS_TOKEN_MEMORY) .update((_) => accessToken); - return; } } @@ -350,7 +379,7 @@ export class TokenService implements TokenServiceAbstraction { accessToken: string, vaultTimeoutAction: VaultTimeoutAction, vaultTimeout: VaultTimeout, - ): Promise { + ): Promise { if (!accessToken) { throw new Error("Access token is required."); } @@ -370,7 +399,7 @@ export class TokenService implements TokenServiceAbstraction { throw new Error("Vault Timeout Action is required."); } - await this._setAccessToken(accessToken, vaultTimeoutAction, vaultTimeout, userId); + return await this._setAccessToken(accessToken, vaultTimeoutAction, vaultTimeout, userId); } async clearAccessToken(userId?: UserId): Promise { @@ -486,7 +515,7 @@ export class TokenService implements TokenServiceAbstraction { vaultTimeoutAction: VaultTimeoutAction, vaultTimeout: VaultTimeout, userId: UserId, - ): Promise { + ): Promise { // If we don't have a user id, we can't save the value if (!userId) { throw new Error("User id not found. Cannot save refresh token."); @@ -509,6 +538,8 @@ export class TokenService implements TokenServiceAbstraction { switch (storageLocation) { case TokenStorageLocation.SecureStorage: { + let decryptedRefreshToken: string = null; + try { await this.saveStringToSecureStorage( userId, @@ -530,6 +561,10 @@ export class TokenService implements TokenServiceAbstraction { throw new Error("Refresh token failed to save to secure storage."); } + // If we've successfully stored the encrypted refresh token, we can return the decrypted refresh token + // so that the caller can use it immediately. + decryptedRefreshToken = refreshToken; + // TODO: PM-6408 // 2024-02-20: Remove refresh token from memory and disk so that we migrate to secure storage over time. // Remove these 2 calls to remove the refresh token from memory and disk after 3 months. @@ -544,24 +579,22 @@ export class TokenService implements TokenServiceAbstraction { ); // Fall back to disk storage for refresh token - await this.singleUserStateProvider + decryptedRefreshToken = await this.singleUserStateProvider .get(userId, REFRESH_TOKEN_DISK) .update((_) => refreshToken); } - return; + return decryptedRefreshToken; } case TokenStorageLocation.Disk: - await this.singleUserStateProvider + return await this.singleUserStateProvider .get(userId, REFRESH_TOKEN_DISK) .update((_) => refreshToken); - return; case TokenStorageLocation.Memory: - await this.singleUserStateProvider + return await this.singleUserStateProvider .get(userId, REFRESH_TOKEN_MEMORY) .update((_) => refreshToken); - return; } } @@ -644,7 +677,7 @@ export class TokenService implements TokenServiceAbstraction { vaultTimeoutAction: VaultTimeoutAction, vaultTimeout: VaultTimeout, userId?: UserId, - ): Promise { + ): Promise { userId ??= await firstValueFrom(this.activeUserIdGlobalState.state$); // If we don't have a user id, we can't save the value @@ -668,11 +701,11 @@ export class TokenService implements TokenServiceAbstraction { ); if (storageLocation === TokenStorageLocation.Disk) { - await this.singleUserStateProvider + return await this.singleUserStateProvider .get(userId, API_KEY_CLIENT_ID_DISK) .update((_) => clientId); } else if (storageLocation === TokenStorageLocation.Memory) { - await this.singleUserStateProvider + return await this.singleUserStateProvider .get(userId, API_KEY_CLIENT_ID_MEMORY) .update((_) => clientId); } @@ -721,7 +754,7 @@ export class TokenService implements TokenServiceAbstraction { vaultTimeoutAction: VaultTimeoutAction, vaultTimeout: VaultTimeout, userId?: UserId, - ): Promise { + ): Promise { userId ??= await firstValueFrom(this.activeUserIdGlobalState.state$); if (!userId) { @@ -744,11 +777,11 @@ export class TokenService implements TokenServiceAbstraction { ); if (storageLocation === TokenStorageLocation.Disk) { - await this.singleUserStateProvider + return await this.singleUserStateProvider .get(userId, API_KEY_CLIENT_SECRET_DISK) .update((_) => clientSecret); } else if (storageLocation === TokenStorageLocation.Memory) { - await this.singleUserStateProvider + return await this.singleUserStateProvider .get(userId, API_KEY_CLIENT_SECRET_MEMORY) .update((_) => clientSecret); } diff --git a/libs/common/src/services/api.service.ts b/libs/common/src/services/api.service.ts index 48ba643391..ffb228406b 100644 --- a/libs/common/src/services/api.service.ts +++ b/libs/common/src/services/api.service.ts @@ -249,7 +249,7 @@ export class ApiService implements ApiServiceAbstraction { async refreshIdentityToken(): Promise { try { - await this.doAuthRefresh(); + await this.refreshToken(); } catch (e) { this.logService.error("Error refreshing access token: ", e); throw e; @@ -1566,8 +1566,7 @@ export class ApiService implements ApiServiceAbstraction { async getActiveBearerToken(): Promise { let accessToken = await this.tokenService.getAccessToken(); if (await this.tokenService.tokenNeedsRefresh()) { - await this.doAuthRefresh(); - accessToken = await this.tokenService.getAccessToken(); + accessToken = await this.refreshToken(); } return accessToken; } @@ -1707,16 +1706,16 @@ export class ApiService implements ApiServiceAbstraction { ); } - protected async doAuthRefresh(): Promise { + protected async refreshToken(): Promise { const refreshToken = await this.tokenService.getRefreshToken(); if (refreshToken != null && refreshToken !== "") { - return this.doRefreshToken(); + return this.refreshAccessToken(); } const clientId = await this.tokenService.getClientId(); const clientSecret = await this.tokenService.getClientSecret(); if (!Utils.isNullOrWhitespace(clientId) && !Utils.isNullOrWhitespace(clientSecret)) { - return this.doApiTokenRefresh(); + return this.refreshApiToken(); } this.refreshAccessTokenErrorCallback(); @@ -1724,7 +1723,7 @@ export class ApiService implements ApiServiceAbstraction { throw new Error("Cannot refresh access token, no refresh token or api keys are stored."); } - protected async doRefreshToken(): Promise { + protected async refreshAccessToken(): Promise { const refreshToken = await this.tokenService.getRefreshToken(); if (refreshToken == null || refreshToken === "") { throw new Error(); @@ -1770,19 +1769,20 @@ export class ApiService implements ApiServiceAbstraction { this.vaultTimeoutSettingsService.getVaultTimeoutByUserId$(userId), ); - await this.tokenService.setTokens( + const refreshedTokens = await this.tokenService.setTokens( tokenResponse.accessToken, vaultTimeoutAction as VaultTimeoutAction, vaultTimeout, tokenResponse.refreshToken, ); + return refreshedTokens.accessToken; } else { const error = await this.handleError(response, true, true); return Promise.reject(error); } } - protected async doApiTokenRefresh(): Promise { + protected async refreshApiToken(): Promise { const clientId = await this.tokenService.getClientId(); const clientSecret = await this.tokenService.getClientSecret(); @@ -1810,11 +1810,12 @@ export class ApiService implements ApiServiceAbstraction { this.vaultTimeoutSettingsService.getVaultTimeoutByUserId$(userId), ); - await this.tokenService.setAccessToken( + const refreshedToken = await this.tokenService.setAccessToken( response.accessToken, vaultTimeoutAction as VaultTimeoutAction, vaultTimeout, ); + return refreshedToken; } async send(