Auth/pm 7672/Update token service to return new token from state (#9706)

* Changed return structure

* Object changes

* Added missing assert.

* Updated tests to use SetTokensResult

* Fixed constructor

* PM-7672 - Fix tests + add new setTokens test around refresh token

* Removed change to refreshIdentityToken.

* Updated return definition.

---------

Co-authored-by: Jared Snider <jsnider@bitwarden.com>
This commit is contained in:
Todd Martin 2024-06-19 11:51:12 -04:00 committed by GitHub
parent 7e3ba087ec
commit 88cc37e37f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 173 additions and 62 deletions

View File

@ -3,6 +3,7 @@ import { Observable } from "rxjs";
import { VaultTimeoutAction } from "../../enums/vault-timeout-action.enum"; import { VaultTimeoutAction } from "../../enums/vault-timeout-action.enum";
import { UserId } from "../../types/guid"; import { UserId } from "../../types/guid";
import { VaultTimeout } from "../../types/vault-timeout.type"; import { VaultTimeout } from "../../types/vault-timeout.type";
import { SetTokensResult } from "../models/domain/set-tokens-result";
import { DecodedAccessToken } from "../services/token.service"; import { DecodedAccessToken } from "../services/token.service";
export abstract class TokenService { export abstract class TokenService {
@ -23,7 +24,7 @@ export abstract class TokenService {
* @param refreshToken The optional refresh token to set. Note: this is undefined when using the CLI Login Via API Key flow * @param refreshToken The optional refresh token to set. Note: this is undefined when using the CLI Login Via API Key flow
* @param clientIdClientSecret The API Key Client ID and Client Secret to set. * @param clientIdClientSecret The API Key Client ID and Client Secret to set.
* *
* @returns A promise that resolves when the tokens have been set. * @returns A promise that resolves with the SetTokensResult containing the tokens that were set.
*/ */
setTokens: ( setTokens: (
accessToken: string, accessToken: string,
@ -31,7 +32,7 @@ export abstract class TokenService {
vaultTimeout: VaultTimeout, vaultTimeout: VaultTimeout,
refreshToken?: string, refreshToken?: string,
clientIdClientSecret?: [string, string], clientIdClientSecret?: [string, string],
) => Promise<void>; ) => Promise<SetTokensResult>;
/** /**
* Clears the access token, refresh token, API Key Client ID, and API Key Client Secret out of memory, disk, and secure storage if supported. * Clears the access token, refresh token, API Key Client ID, and API Key Client Secret out of memory, disk, and secure storage if supported.
@ -47,13 +48,13 @@ export abstract class TokenService {
* @param accessToken The access token to set. * @param accessToken The access token to set.
* @param vaultTimeoutAction The action to take when the vault times out. * @param vaultTimeoutAction The action to take when the vault times out.
* @param vaultTimeout The timeout for the vault. * @param vaultTimeout The timeout for the vault.
* @returns A promise that resolves when the access token has been set. * @returns A promise that resolves with the access token that has been set.
*/ */
setAccessToken: ( setAccessToken: (
accessToken: string, accessToken: string,
vaultTimeoutAction: VaultTimeoutAction, vaultTimeoutAction: VaultTimeoutAction,
vaultTimeout: VaultTimeout, vaultTimeout: VaultTimeout,
) => Promise<void>; ) => Promise<string>;
// TODO: revisit having this public clear method approach once the state service is fully deprecated. // TODO: revisit having this public clear method approach once the state service is fully deprecated.
/** /**
@ -86,14 +87,14 @@ export abstract class TokenService {
* @param clientId The API Key Client ID to set. * @param clientId The API Key Client ID to set.
* @param vaultTimeoutAction The action to take when the vault times out. * @param vaultTimeoutAction The action to take when the vault times out.
* @param vaultTimeout The timeout for the vault. * @param vaultTimeout The timeout for the vault.
* @returns A promise that resolves when the API Key Client ID has been set. * @returns A promise that resolves with the API Key Client ID that has been set.
*/ */
setClientId: ( setClientId: (
clientId: string, clientId: string,
vaultTimeoutAction: VaultTimeoutAction, vaultTimeoutAction: VaultTimeoutAction,
vaultTimeout: VaultTimeout, vaultTimeout: VaultTimeout,
userId?: UserId, userId?: UserId,
) => Promise<void>; ) => Promise<string>;
/** /**
* Gets the API Key Client ID for the active user. * Gets the API Key Client ID for the active user.
@ -106,14 +107,14 @@ export abstract class TokenService {
* @param clientSecret The API Key Client Secret to set. * @param clientSecret The API Key Client Secret to set.
* @param vaultTimeoutAction The action to take when the vault times out. * @param vaultTimeoutAction The action to take when the vault times out.
* @param vaultTimeout The timeout for the vault. * @param vaultTimeout The timeout for the vault.
* @returns A promise that resolves when the API Key Client Secret has been set. * @returns A promise that resolves with the client secret that has been set.
*/ */
setClientSecret: ( setClientSecret: (
clientSecret: string, clientSecret: string,
vaultTimeoutAction: VaultTimeoutAction, vaultTimeoutAction: VaultTimeoutAction,
vaultTimeout: VaultTimeout, vaultTimeout: VaultTimeout,
userId?: UserId, userId?: UserId,
) => Promise<void>; ) => Promise<string>;
/** /**
* Gets the API Key Client Secret for the active user. * Gets the API Key Client Secret for the active user.

View File

@ -0,0 +1,10 @@
export class SetTokensResult {
constructor(accessToken: string, refreshToken?: string, clientIdSecretPair?: [string, string]) {
this.accessToken = accessToken;
this.refreshToken = refreshToken;
this.clientIdSecretPair = clientIdSecretPair;
}
accessToken: string;
refreshToken?: string;
clientIdSecretPair?: [string, string];
}

View File

@ -15,6 +15,7 @@ import { SymmetricCryptoKey } from "../../platform/models/domain/symmetric-crypt
import { CsprngArray } from "../../types/csprng"; import { CsprngArray } from "../../types/csprng";
import { UserId } from "../../types/guid"; import { UserId } from "../../types/guid";
import { VaultTimeout, VaultTimeoutStringType } from "../../types/vault-timeout.type"; import { VaultTimeout, VaultTimeoutStringType } from "../../types/vault-timeout.type";
import { SetTokensResult } from "../models/domain/set-tokens-result";
import { ACCOUNT_ACTIVE_ACCOUNT_ID } from "./account.service"; import { ACCOUNT_ACTIVE_ACCOUNT_ID } from "./account.service";
import { import {
@ -232,7 +233,7 @@ describe("TokenService", () => {
describe("Memory storage tests", () => { describe("Memory storage tests", () => {
it("set the access token in memory", async () => { it("set the access token in memory", async () => {
// Act // Act
await tokenService.setAccessToken( const result = await tokenService.setAccessToken(
accessTokenJwt, accessTokenJwt,
memoryVaultTimeoutAction, memoryVaultTimeoutAction,
memoryVaultTimeout, memoryVaultTimeout,
@ -241,13 +242,14 @@ describe("TokenService", () => {
expect( expect(
singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY).nextMock, singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY).nextMock,
).toHaveBeenCalledWith(accessTokenJwt); ).toHaveBeenCalledWith(accessTokenJwt);
expect(result).toEqual(accessTokenJwt);
}); });
}); });
describe("Disk storage tests (secure storage not supported on platform)", () => { describe("Disk storage tests (secure storage not supported on platform)", () => {
it("should set the access token in disk", async () => { it("should set the access token in disk", async () => {
// Act // Act
await tokenService.setAccessToken( const result = await tokenService.setAccessToken(
accessTokenJwt, accessTokenJwt,
diskVaultTimeoutAction, diskVaultTimeoutAction,
diskVaultTimeout, diskVaultTimeout,
@ -256,6 +258,7 @@ describe("TokenService", () => {
expect( expect(
singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK).nextMock, singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK).nextMock,
).toHaveBeenCalledWith(accessTokenJwt); ).toHaveBeenCalledWith(accessTokenJwt);
expect(result).toEqual(accessTokenJwt);
}); });
}); });
@ -295,7 +298,7 @@ describe("TokenService", () => {
secureStorageService.get.mockResolvedValueOnce(null).mockResolvedValue(accessTokenKeyB64); secureStorageService.get.mockResolvedValueOnce(null).mockResolvedValue(accessTokenKeyB64);
// Act // Act
await tokenService.setAccessToken( const result = await tokenService.setAccessToken(
accessTokenJwt, accessTokenJwt,
diskVaultTimeoutAction, diskVaultTimeoutAction,
diskVaultTimeout, diskVaultTimeout,
@ -318,6 +321,9 @@ describe("TokenService", () => {
expect( expect(
singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY).nextMock, singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY).nextMock,
).toHaveBeenCalledWith(null); ).toHaveBeenCalledWith(null);
// assert that the decrypted access token was returned
expect(result).toEqual(accessTokenJwt);
}); });
it("should fallback to disk storage for the access token if the access token cannot be set in secure storage", async () => { it("should fallback to disk storage for the access token if the access token cannot be set in secure storage", async () => {
@ -331,7 +337,7 @@ describe("TokenService", () => {
secureStorageService.get.mockResolvedValueOnce(null).mockResolvedValue(null); secureStorageService.get.mockResolvedValueOnce(null).mockResolvedValue(null);
// Act // Act
await tokenService.setAccessToken( const result = await tokenService.setAccessToken(
accessTokenJwt, accessTokenJwt,
diskVaultTimeoutAction, diskVaultTimeoutAction,
diskVaultTimeout, diskVaultTimeout,
@ -355,6 +361,9 @@ describe("TokenService", () => {
expect( expect(
singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK).nextMock, singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK).nextMock,
).toHaveBeenCalledWith(accessTokenJwt); ).toHaveBeenCalledWith(accessTokenJwt);
// assert that the decrypted access token was returned
expect(result).toEqual(accessTokenJwt);
}); });
it("should fallback to disk storage for the access token if secure storage errors on trying to get an existing access token key", async () => { it("should fallback to disk storage for the access token if secure storage errors on trying to get an existing access token key", async () => {
@ -368,7 +377,7 @@ describe("TokenService", () => {
secureStorageService.get.mockRejectedValue(new Error(secureStorageError)); secureStorageService.get.mockRejectedValue(new Error(secureStorageError));
// Act // Act
await tokenService.setAccessToken( const result = await tokenService.setAccessToken(
accessTokenJwt, accessTokenJwt,
diskVaultTimeoutAction, diskVaultTimeoutAction,
diskVaultTimeout, diskVaultTimeout,
@ -385,6 +394,9 @@ describe("TokenService", () => {
expect( expect(
singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK).nextMock, singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK).nextMock,
).toHaveBeenCalledWith(accessTokenJwt); ).toHaveBeenCalledWith(accessTokenJwt);
// assert that the decrypted access token was returned
expect(result).toEqual(accessTokenJwt);
}); });
}); });
}); });
@ -2376,18 +2388,21 @@ describe("TokenService", () => {
const clientId = "clientId"; const clientId = "clientId";
const clientSecret = "clientSecret"; const clientSecret = "clientSecret";
(tokenService as any)._setAccessToken = jest.fn();
// any hack allows for mocking private method. // any hack allows for mocking private method.
(tokenService as any).setRefreshToken = jest.fn(); (tokenService as any)._setAccessToken = jest.fn().mockReturnValue(accessTokenJwt);
tokenService.setClientId = jest.fn(); (tokenService as any).setRefreshToken = jest.fn().mockReturnValue(refreshToken);
tokenService.setClientSecret = jest.fn(); tokenService.setClientId = jest.fn().mockReturnValue(clientId);
tokenService.setClientSecret = jest.fn().mockReturnValue(clientSecret);
// Act // Act
// Note: passing a valid access token so that a valid user id can be determined from the access token // Note: passing a valid access token so that a valid user id can be determined from the access token
await tokenService.setTokens(accessTokenJwt, vaultTimeoutAction, vaultTimeout, refreshToken, [ const result = await tokenService.setTokens(
clientId, accessTokenJwt,
clientSecret, vaultTimeoutAction,
]); vaultTimeout,
refreshToken,
[clientId, clientSecret],
);
// Assert // Assert
expect((tokenService as any)._setAccessToken).toHaveBeenCalledWith( expect((tokenService as any)._setAccessToken).toHaveBeenCalledWith(
@ -2417,6 +2432,44 @@ describe("TokenService", () => {
vaultTimeout, vaultTimeout,
userIdFromAccessToken, userIdFromAccessToken,
); );
expect(result).toStrictEqual(
new SetTokensResult(accessTokenJwt, refreshToken, [clientId, clientSecret]),
);
});
it("does not try to set the refresh token when it is not passed in", async () => {
// Arrange
const vaultTimeoutAction = VaultTimeoutAction.Lock;
const vaultTimeout = 30;
(tokenService as any)._setAccessToken = jest.fn().mockReturnValue(accessTokenJwt);
(tokenService as any).setRefreshToken = jest.fn();
tokenService.setClientId = jest.fn();
tokenService.setClientSecret = jest.fn();
// Act
const result = await tokenService.setTokens(
accessTokenJwt,
vaultTimeoutAction,
vaultTimeout,
null,
);
// Assert
expect((tokenService as any)._setAccessToken).toHaveBeenCalledWith(
accessTokenJwt,
vaultTimeoutAction,
vaultTimeout,
userIdFromAccessToken,
);
// any hack allows for testing private methods
expect((tokenService as any).setRefreshToken).not.toHaveBeenCalled();
expect(tokenService.setClientId).not.toHaveBeenCalled();
expect(tokenService.setClientSecret).not.toHaveBeenCalled();
expect(result).toStrictEqual(new SetTokensResult(accessTokenJwt));
}); });
it("does not try to set client id and client secret when they are not passed in", async () => { it("does not try to set client id and client secret when they are not passed in", async () => {
@ -2425,13 +2478,18 @@ describe("TokenService", () => {
const vaultTimeoutAction = VaultTimeoutAction.Lock; const vaultTimeoutAction = VaultTimeoutAction.Lock;
const vaultTimeout = 30; const vaultTimeout = 30;
(tokenService as any)._setAccessToken = jest.fn(); (tokenService as any)._setAccessToken = jest.fn().mockReturnValue(accessTokenJwt);
(tokenService as any).setRefreshToken = jest.fn(); (tokenService as any).setRefreshToken = jest.fn().mockReturnValue(refreshToken);
tokenService.setClientId = jest.fn(); tokenService.setClientId = jest.fn();
tokenService.setClientSecret = jest.fn(); tokenService.setClientSecret = jest.fn();
// Act // Act
await tokenService.setTokens(accessTokenJwt, vaultTimeoutAction, vaultTimeout, refreshToken); const result = await tokenService.setTokens(
accessTokenJwt,
vaultTimeoutAction,
vaultTimeout,
refreshToken,
);
// Assert // Assert
expect((tokenService as any)._setAccessToken).toHaveBeenCalledWith( expect((tokenService as any)._setAccessToken).toHaveBeenCalledWith(
@ -2451,6 +2509,8 @@ describe("TokenService", () => {
expect(tokenService.setClientId).not.toHaveBeenCalled(); expect(tokenService.setClientId).not.toHaveBeenCalled();
expect(tokenService.setClientSecret).not.toHaveBeenCalled(); expect(tokenService.setClientSecret).not.toHaveBeenCalled();
expect(result).toStrictEqual(new SetTokensResult(accessTokenJwt, refreshToken));
}); });
it("throws an error when the access token is invalid", async () => { it("throws an error when the access token is invalid", async () => {
@ -2535,10 +2595,16 @@ describe("TokenService", () => {
(tokenService as any).setRefreshToken = jest.fn(); (tokenService as any).setRefreshToken = jest.fn();
// Act // Act
await tokenService.setTokens(accessTokenJwt, vaultTimeoutAction, vaultTimeout, refreshToken); const result = await tokenService.setTokens(
accessTokenJwt,
vaultTimeoutAction,
vaultTimeout,
refreshToken,
);
// Assert // Assert
expect((tokenService as any).setRefreshToken).not.toHaveBeenCalled(); expect((tokenService as any).setRefreshToken).not.toHaveBeenCalled();
expect(result).toStrictEqual(new SetTokensResult(accessTokenJwt));
}); });
}); });

View File

@ -21,6 +21,7 @@ import {
import { UserId } from "../../types/guid"; import { UserId } from "../../types/guid";
import { VaultTimeout, VaultTimeoutStringType } from "../../types/vault-timeout.type"; import { VaultTimeout, VaultTimeoutStringType } from "../../types/vault-timeout.type";
import { TokenService as TokenServiceAbstraction } from "../abstractions/token.service"; import { TokenService as TokenServiceAbstraction } from "../abstractions/token.service";
import { SetTokensResult } from "../models/domain/set-tokens-result";
import { ACCOUNT_ACTIVE_ACCOUNT_ID } from "./account.service"; import { ACCOUNT_ACTIVE_ACCOUNT_ID } from "./account.service";
import { import {
@ -160,7 +161,7 @@ export class TokenService implements TokenServiceAbstraction {
vaultTimeout: VaultTimeout, vaultTimeout: VaultTimeout,
refreshToken?: string, refreshToken?: string,
clientIdClientSecret?: [string, string], clientIdClientSecret?: [string, string],
): Promise<void> { ): Promise<SetTokensResult> {
if (!accessToken) { if (!accessToken) {
throw new Error("Access token is required."); throw new Error("Access token is required.");
} }
@ -181,16 +182,40 @@ export class TokenService implements TokenServiceAbstraction {
throw new Error("User id not found. Cannot set tokens."); throw new Error("User id not found. Cannot set tokens.");
} }
await this._setAccessToken(accessToken, vaultTimeoutAction, vaultTimeout, userId); const newAccessToken = await this._setAccessToken(
accessToken,
vaultTimeoutAction,
vaultTimeout,
userId,
);
const newTokens = new SetTokensResult(newAccessToken);
if (refreshToken) { if (refreshToken) {
await this.setRefreshToken(refreshToken, vaultTimeoutAction, vaultTimeout, userId); newTokens.refreshToken = await this.setRefreshToken(
refreshToken,
vaultTimeoutAction,
vaultTimeout,
userId,
);
} }
if (clientIdClientSecret != null) { if (clientIdClientSecret != null) {
await this.setClientId(clientIdClientSecret[0], vaultTimeoutAction, vaultTimeout, userId); const clientId = await this.setClientId(
await this.setClientSecret(clientIdClientSecret[1], vaultTimeoutAction, vaultTimeout, userId); clientIdClientSecret[0],
vaultTimeoutAction,
vaultTimeout,
userId,
);
const clientSecret = await this.setClientSecret(
clientIdClientSecret[1],
vaultTimeoutAction,
vaultTimeout,
userId,
);
newTokens.clientIdSecretPair = [clientId, clientSecret];
} }
return newTokens;
} }
private async getAccessTokenKey(userId: UserId): Promise<AccessTokenKey | null> { private async getAccessTokenKey(userId: UserId): Promise<AccessTokenKey | null> {
@ -289,7 +314,7 @@ export class TokenService implements TokenServiceAbstraction {
vaultTimeoutAction: VaultTimeoutAction, vaultTimeoutAction: VaultTimeoutAction,
vaultTimeout: VaultTimeout, vaultTimeout: VaultTimeout,
userId: UserId, userId: UserId,
): Promise<void> { ): Promise<string> {
const storageLocation = await this.determineStorageLocation( const storageLocation = await this.determineStorageLocation(
vaultTimeoutAction, vaultTimeoutAction,
vaultTimeout, vaultTimeout,
@ -302,6 +327,8 @@ export class TokenService implements TokenServiceAbstraction {
// store the access token directly. Instead, we encrypt with accessTokenKey and store that // store the access token directly. Instead, we encrypt with accessTokenKey and store that
// in secure storage. // in secure storage.
let decryptedAccessToken: string = null;
try { try {
const encryptedAccessToken: EncString = await this.encryptAccessToken( const encryptedAccessToken: EncString = await this.encryptAccessToken(
accessToken, accessToken,
@ -313,6 +340,10 @@ export class TokenService implements TokenServiceAbstraction {
.get(userId, ACCESS_TOKEN_DISK) .get(userId, ACCESS_TOKEN_DISK)
.update((_) => encryptedAccessToken.encryptedString); .update((_) => encryptedAccessToken.encryptedString);
// If we've successfully stored the encrypted access token to disk, we can return the decrypted access token
// so that the caller can use it immediately.
decryptedAccessToken = accessToken;
// TODO: PM-6408 // TODO: PM-6408
// 2024-02-20: Remove access token from memory so that we migrate to encrypt the access token over time. // 2024-02-20: Remove access token from memory so that we migrate to encrypt the access token over time.
// Remove this call to remove the access token from memory after 3 months. // Remove this call to remove the access token from memory after 3 months.
@ -324,25 +355,23 @@ export class TokenService implements TokenServiceAbstraction {
); );
// Fall back to disk storage for unecrypted access token // Fall back to disk storage for unecrypted access token
await this.singleUserStateProvider decryptedAccessToken = await this.singleUserStateProvider
.get(userId, ACCESS_TOKEN_DISK) .get(userId, ACCESS_TOKEN_DISK)
.update((_) => accessToken); .update((_) => accessToken);
} }
return; return decryptedAccessToken;
} }
case TokenStorageLocation.Disk: case TokenStorageLocation.Disk:
// Access token stored on disk unencrypted as platform does not support secure storage // Access token stored on disk unencrypted as platform does not support secure storage
await this.singleUserStateProvider return await this.singleUserStateProvider
.get(userId, ACCESS_TOKEN_DISK) .get(userId, ACCESS_TOKEN_DISK)
.update((_) => accessToken); .update((_) => accessToken);
return;
case TokenStorageLocation.Memory: case TokenStorageLocation.Memory:
// Access token stored in memory due to vault timeout settings // Access token stored in memory due to vault timeout settings
await this.singleUserStateProvider return await this.singleUserStateProvider
.get(userId, ACCESS_TOKEN_MEMORY) .get(userId, ACCESS_TOKEN_MEMORY)
.update((_) => accessToken); .update((_) => accessToken);
return;
} }
} }
@ -350,7 +379,7 @@ export class TokenService implements TokenServiceAbstraction {
accessToken: string, accessToken: string,
vaultTimeoutAction: VaultTimeoutAction, vaultTimeoutAction: VaultTimeoutAction,
vaultTimeout: VaultTimeout, vaultTimeout: VaultTimeout,
): Promise<void> { ): Promise<string> {
if (!accessToken) { if (!accessToken) {
throw new Error("Access token is required."); throw new Error("Access token is required.");
} }
@ -370,7 +399,7 @@ export class TokenService implements TokenServiceAbstraction {
throw new Error("Vault Timeout Action is required."); throw new Error("Vault Timeout Action is required.");
} }
await this._setAccessToken(accessToken, vaultTimeoutAction, vaultTimeout, userId); return await this._setAccessToken(accessToken, vaultTimeoutAction, vaultTimeout, userId);
} }
async clearAccessToken(userId?: UserId): Promise<void> { async clearAccessToken(userId?: UserId): Promise<void> {
@ -486,7 +515,7 @@ export class TokenService implements TokenServiceAbstraction {
vaultTimeoutAction: VaultTimeoutAction, vaultTimeoutAction: VaultTimeoutAction,
vaultTimeout: VaultTimeout, vaultTimeout: VaultTimeout,
userId: UserId, userId: UserId,
): Promise<void> { ): Promise<string> {
// If we don't have a user id, we can't save the value // If we don't have a user id, we can't save the value
if (!userId) { if (!userId) {
throw new Error("User id not found. Cannot save refresh token."); throw new Error("User id not found. Cannot save refresh token.");
@ -509,6 +538,8 @@ export class TokenService implements TokenServiceAbstraction {
switch (storageLocation) { switch (storageLocation) {
case TokenStorageLocation.SecureStorage: { case TokenStorageLocation.SecureStorage: {
let decryptedRefreshToken: string = null;
try { try {
await this.saveStringToSecureStorage( await this.saveStringToSecureStorage(
userId, userId,
@ -530,6 +561,10 @@ export class TokenService implements TokenServiceAbstraction {
throw new Error("Refresh token failed to save to secure storage."); throw new Error("Refresh token failed to save to secure storage.");
} }
// If we've successfully stored the encrypted refresh token, we can return the decrypted refresh token
// so that the caller can use it immediately.
decryptedRefreshToken = refreshToken;
// TODO: PM-6408 // TODO: PM-6408
// 2024-02-20: Remove refresh token from memory and disk so that we migrate to secure storage over time. // 2024-02-20: Remove refresh token from memory and disk so that we migrate to secure storage over time.
// Remove these 2 calls to remove the refresh token from memory and disk after 3 months. // Remove these 2 calls to remove the refresh token from memory and disk after 3 months.
@ -544,24 +579,22 @@ export class TokenService implements TokenServiceAbstraction {
); );
// Fall back to disk storage for refresh token // Fall back to disk storage for refresh token
await this.singleUserStateProvider decryptedRefreshToken = await this.singleUserStateProvider
.get(userId, REFRESH_TOKEN_DISK) .get(userId, REFRESH_TOKEN_DISK)
.update((_) => refreshToken); .update((_) => refreshToken);
} }
return; return decryptedRefreshToken;
} }
case TokenStorageLocation.Disk: case TokenStorageLocation.Disk:
await this.singleUserStateProvider return await this.singleUserStateProvider
.get(userId, REFRESH_TOKEN_DISK) .get(userId, REFRESH_TOKEN_DISK)
.update((_) => refreshToken); .update((_) => refreshToken);
return;
case TokenStorageLocation.Memory: case TokenStorageLocation.Memory:
await this.singleUserStateProvider return await this.singleUserStateProvider
.get(userId, REFRESH_TOKEN_MEMORY) .get(userId, REFRESH_TOKEN_MEMORY)
.update((_) => refreshToken); .update((_) => refreshToken);
return;
} }
} }
@ -644,7 +677,7 @@ export class TokenService implements TokenServiceAbstraction {
vaultTimeoutAction: VaultTimeoutAction, vaultTimeoutAction: VaultTimeoutAction,
vaultTimeout: VaultTimeout, vaultTimeout: VaultTimeout,
userId?: UserId, userId?: UserId,
): Promise<void> { ): Promise<string> {
userId ??= await firstValueFrom(this.activeUserIdGlobalState.state$); userId ??= await firstValueFrom(this.activeUserIdGlobalState.state$);
// If we don't have a user id, we can't save the value // If we don't have a user id, we can't save the value
@ -668,11 +701,11 @@ export class TokenService implements TokenServiceAbstraction {
); );
if (storageLocation === TokenStorageLocation.Disk) { if (storageLocation === TokenStorageLocation.Disk) {
await this.singleUserStateProvider return await this.singleUserStateProvider
.get(userId, API_KEY_CLIENT_ID_DISK) .get(userId, API_KEY_CLIENT_ID_DISK)
.update((_) => clientId); .update((_) => clientId);
} else if (storageLocation === TokenStorageLocation.Memory) { } else if (storageLocation === TokenStorageLocation.Memory) {
await this.singleUserStateProvider return await this.singleUserStateProvider
.get(userId, API_KEY_CLIENT_ID_MEMORY) .get(userId, API_KEY_CLIENT_ID_MEMORY)
.update((_) => clientId); .update((_) => clientId);
} }
@ -721,7 +754,7 @@ export class TokenService implements TokenServiceAbstraction {
vaultTimeoutAction: VaultTimeoutAction, vaultTimeoutAction: VaultTimeoutAction,
vaultTimeout: VaultTimeout, vaultTimeout: VaultTimeout,
userId?: UserId, userId?: UserId,
): Promise<void> { ): Promise<string> {
userId ??= await firstValueFrom(this.activeUserIdGlobalState.state$); userId ??= await firstValueFrom(this.activeUserIdGlobalState.state$);
if (!userId) { if (!userId) {
@ -744,11 +777,11 @@ export class TokenService implements TokenServiceAbstraction {
); );
if (storageLocation === TokenStorageLocation.Disk) { if (storageLocation === TokenStorageLocation.Disk) {
await this.singleUserStateProvider return await this.singleUserStateProvider
.get(userId, API_KEY_CLIENT_SECRET_DISK) .get(userId, API_KEY_CLIENT_SECRET_DISK)
.update((_) => clientSecret); .update((_) => clientSecret);
} else if (storageLocation === TokenStorageLocation.Memory) { } else if (storageLocation === TokenStorageLocation.Memory) {
await this.singleUserStateProvider return await this.singleUserStateProvider
.get(userId, API_KEY_CLIENT_SECRET_MEMORY) .get(userId, API_KEY_CLIENT_SECRET_MEMORY)
.update((_) => clientSecret); .update((_) => clientSecret);
} }

View File

@ -249,7 +249,7 @@ export class ApiService implements ApiServiceAbstraction {
async refreshIdentityToken(): Promise<any> { async refreshIdentityToken(): Promise<any> {
try { try {
await this.doAuthRefresh(); await this.refreshToken();
} catch (e) { } catch (e) {
this.logService.error("Error refreshing access token: ", e); this.logService.error("Error refreshing access token: ", e);
throw e; throw e;
@ -1566,8 +1566,7 @@ export class ApiService implements ApiServiceAbstraction {
async getActiveBearerToken(): Promise<string> { async getActiveBearerToken(): Promise<string> {
let accessToken = await this.tokenService.getAccessToken(); let accessToken = await this.tokenService.getAccessToken();
if (await this.tokenService.tokenNeedsRefresh()) { if (await this.tokenService.tokenNeedsRefresh()) {
await this.doAuthRefresh(); accessToken = await this.refreshToken();
accessToken = await this.tokenService.getAccessToken();
} }
return accessToken; return accessToken;
} }
@ -1707,16 +1706,16 @@ export class ApiService implements ApiServiceAbstraction {
); );
} }
protected async doAuthRefresh(): Promise<void> { protected async refreshToken(): Promise<string> {
const refreshToken = await this.tokenService.getRefreshToken(); const refreshToken = await this.tokenService.getRefreshToken();
if (refreshToken != null && refreshToken !== "") { if (refreshToken != null && refreshToken !== "") {
return this.doRefreshToken(); return this.refreshAccessToken();
} }
const clientId = await this.tokenService.getClientId(); const clientId = await this.tokenService.getClientId();
const clientSecret = await this.tokenService.getClientSecret(); const clientSecret = await this.tokenService.getClientSecret();
if (!Utils.isNullOrWhitespace(clientId) && !Utils.isNullOrWhitespace(clientSecret)) { if (!Utils.isNullOrWhitespace(clientId) && !Utils.isNullOrWhitespace(clientSecret)) {
return this.doApiTokenRefresh(); return this.refreshApiToken();
} }
this.refreshAccessTokenErrorCallback(); this.refreshAccessTokenErrorCallback();
@ -1724,7 +1723,7 @@ export class ApiService implements ApiServiceAbstraction {
throw new Error("Cannot refresh access token, no refresh token or api keys are stored."); throw new Error("Cannot refresh access token, no refresh token or api keys are stored.");
} }
protected async doRefreshToken(): Promise<void> { protected async refreshAccessToken(): Promise<string> {
const refreshToken = await this.tokenService.getRefreshToken(); const refreshToken = await this.tokenService.getRefreshToken();
if (refreshToken == null || refreshToken === "") { if (refreshToken == null || refreshToken === "") {
throw new Error(); throw new Error();
@ -1770,19 +1769,20 @@ export class ApiService implements ApiServiceAbstraction {
this.vaultTimeoutSettingsService.getVaultTimeoutByUserId$(userId), this.vaultTimeoutSettingsService.getVaultTimeoutByUserId$(userId),
); );
await this.tokenService.setTokens( const refreshedTokens = await this.tokenService.setTokens(
tokenResponse.accessToken, tokenResponse.accessToken,
vaultTimeoutAction as VaultTimeoutAction, vaultTimeoutAction as VaultTimeoutAction,
vaultTimeout, vaultTimeout,
tokenResponse.refreshToken, tokenResponse.refreshToken,
); );
return refreshedTokens.accessToken;
} else { } else {
const error = await this.handleError(response, true, true); const error = await this.handleError(response, true, true);
return Promise.reject(error); return Promise.reject(error);
} }
} }
protected async doApiTokenRefresh(): Promise<void> { protected async refreshApiToken(): Promise<string> {
const clientId = await this.tokenService.getClientId(); const clientId = await this.tokenService.getClientId();
const clientSecret = await this.tokenService.getClientSecret(); const clientSecret = await this.tokenService.getClientSecret();
@ -1810,11 +1810,12 @@ export class ApiService implements ApiServiceAbstraction {
this.vaultTimeoutSettingsService.getVaultTimeoutByUserId$(userId), this.vaultTimeoutSettingsService.getVaultTimeoutByUserId$(userId),
); );
await this.tokenService.setAccessToken( const refreshedToken = await this.tokenService.setAccessToken(
response.accessToken, response.accessToken,
vaultTimeoutAction as VaultTimeoutAction, vaultTimeoutAction as VaultTimeoutAction,
vaultTimeout, vaultTimeout,
); );
return refreshedToken;
} }
async send( async send(