Auth/pm 7672/Update token service to return new token from state (#9706)

* Changed return structure

* Object changes

* Added missing assert.

* Updated tests to use SetTokensResult

* Fixed constructor

* PM-7672 - Fix tests + add new setTokens test around refresh token

* Removed change to refreshIdentityToken.

* Updated return definition.

---------

Co-authored-by: Jared Snider <jsnider@bitwarden.com>
This commit is contained in:
Todd Martin 2024-06-19 11:51:12 -04:00 committed by GitHub
parent 7e3ba087ec
commit 88cc37e37f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 173 additions and 62 deletions

View File

@ -3,6 +3,7 @@ import { Observable } from "rxjs";
import { VaultTimeoutAction } from "../../enums/vault-timeout-action.enum";
import { UserId } from "../../types/guid";
import { VaultTimeout } from "../../types/vault-timeout.type";
import { SetTokensResult } from "../models/domain/set-tokens-result";
import { DecodedAccessToken } from "../services/token.service";
export abstract class TokenService {
@ -23,7 +24,7 @@ export abstract class TokenService {
* @param refreshToken The optional refresh token to set. Note: this is undefined when using the CLI Login Via API Key flow
* @param clientIdClientSecret The API Key Client ID and Client Secret to set.
*
* @returns A promise that resolves when the tokens have been set.
* @returns A promise that resolves with the SetTokensResult containing the tokens that were set.
*/
setTokens: (
accessToken: string,
@ -31,7 +32,7 @@ export abstract class TokenService {
vaultTimeout: VaultTimeout,
refreshToken?: string,
clientIdClientSecret?: [string, string],
) => Promise<void>;
) => Promise<SetTokensResult>;
/**
* Clears the access token, refresh token, API Key Client ID, and API Key Client Secret out of memory, disk, and secure storage if supported.
@ -47,13 +48,13 @@ export abstract class TokenService {
* @param accessToken The access token to set.
* @param vaultTimeoutAction The action to take when the vault times out.
* @param vaultTimeout The timeout for the vault.
* @returns A promise that resolves when the access token has been set.
* @returns A promise that resolves with the access token that has been set.
*/
setAccessToken: (
accessToken: string,
vaultTimeoutAction: VaultTimeoutAction,
vaultTimeout: VaultTimeout,
) => Promise<void>;
) => Promise<string>;
// TODO: revisit having this public clear method approach once the state service is fully deprecated.
/**
@ -86,14 +87,14 @@ export abstract class TokenService {
* @param clientId The API Key Client ID to set.
* @param vaultTimeoutAction The action to take when the vault times out.
* @param vaultTimeout The timeout for the vault.
* @returns A promise that resolves when the API Key Client ID has been set.
* @returns A promise that resolves with the API Key Client ID that has been set.
*/
setClientId: (
clientId: string,
vaultTimeoutAction: VaultTimeoutAction,
vaultTimeout: VaultTimeout,
userId?: UserId,
) => Promise<void>;
) => Promise<string>;
/**
* Gets the API Key Client ID for the active user.
@ -106,14 +107,14 @@ export abstract class TokenService {
* @param clientSecret The API Key Client Secret to set.
* @param vaultTimeoutAction The action to take when the vault times out.
* @param vaultTimeout The timeout for the vault.
* @returns A promise that resolves when the API Key Client Secret has been set.
* @returns A promise that resolves with the client secret that has been set.
*/
setClientSecret: (
clientSecret: string,
vaultTimeoutAction: VaultTimeoutAction,
vaultTimeout: VaultTimeout,
userId?: UserId,
) => Promise<void>;
) => Promise<string>;
/**
* Gets the API Key Client Secret for the active user.

View File

@ -0,0 +1,10 @@
export class SetTokensResult {
constructor(accessToken: string, refreshToken?: string, clientIdSecretPair?: [string, string]) {
this.accessToken = accessToken;
this.refreshToken = refreshToken;
this.clientIdSecretPair = clientIdSecretPair;
}
accessToken: string;
refreshToken?: string;
clientIdSecretPair?: [string, string];
}

View File

@ -15,6 +15,7 @@ import { SymmetricCryptoKey } from "../../platform/models/domain/symmetric-crypt
import { CsprngArray } from "../../types/csprng";
import { UserId } from "../../types/guid";
import { VaultTimeout, VaultTimeoutStringType } from "../../types/vault-timeout.type";
import { SetTokensResult } from "../models/domain/set-tokens-result";
import { ACCOUNT_ACTIVE_ACCOUNT_ID } from "./account.service";
import {
@ -232,7 +233,7 @@ describe("TokenService", () => {
describe("Memory storage tests", () => {
it("set the access token in memory", async () => {
// Act
await tokenService.setAccessToken(
const result = await tokenService.setAccessToken(
accessTokenJwt,
memoryVaultTimeoutAction,
memoryVaultTimeout,
@ -241,13 +242,14 @@ describe("TokenService", () => {
expect(
singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY).nextMock,
).toHaveBeenCalledWith(accessTokenJwt);
expect(result).toEqual(accessTokenJwt);
});
});
describe("Disk storage tests (secure storage not supported on platform)", () => {
it("should set the access token in disk", async () => {
// Act
await tokenService.setAccessToken(
const result = await tokenService.setAccessToken(
accessTokenJwt,
diskVaultTimeoutAction,
diskVaultTimeout,
@ -256,6 +258,7 @@ describe("TokenService", () => {
expect(
singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK).nextMock,
).toHaveBeenCalledWith(accessTokenJwt);
expect(result).toEqual(accessTokenJwt);
});
});
@ -295,7 +298,7 @@ describe("TokenService", () => {
secureStorageService.get.mockResolvedValueOnce(null).mockResolvedValue(accessTokenKeyB64);
// Act
await tokenService.setAccessToken(
const result = await tokenService.setAccessToken(
accessTokenJwt,
diskVaultTimeoutAction,
diskVaultTimeout,
@ -318,6 +321,9 @@ describe("TokenService", () => {
expect(
singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_MEMORY).nextMock,
).toHaveBeenCalledWith(null);
// assert that the decrypted access token was returned
expect(result).toEqual(accessTokenJwt);
});
it("should fallback to disk storage for the access token if the access token cannot be set in secure storage", async () => {
@ -331,7 +337,7 @@ describe("TokenService", () => {
secureStorageService.get.mockResolvedValueOnce(null).mockResolvedValue(null);
// Act
await tokenService.setAccessToken(
const result = await tokenService.setAccessToken(
accessTokenJwt,
diskVaultTimeoutAction,
diskVaultTimeout,
@ -355,6 +361,9 @@ describe("TokenService", () => {
expect(
singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK).nextMock,
).toHaveBeenCalledWith(accessTokenJwt);
// assert that the decrypted access token was returned
expect(result).toEqual(accessTokenJwt);
});
it("should fallback to disk storage for the access token if secure storage errors on trying to get an existing access token key", async () => {
@ -368,7 +377,7 @@ describe("TokenService", () => {
secureStorageService.get.mockRejectedValue(new Error(secureStorageError));
// Act
await tokenService.setAccessToken(
const result = await tokenService.setAccessToken(
accessTokenJwt,
diskVaultTimeoutAction,
diskVaultTimeout,
@ -385,6 +394,9 @@ describe("TokenService", () => {
expect(
singleUserStateProvider.getFake(userIdFromAccessToken, ACCESS_TOKEN_DISK).nextMock,
).toHaveBeenCalledWith(accessTokenJwt);
// assert that the decrypted access token was returned
expect(result).toEqual(accessTokenJwt);
});
});
});
@ -2376,18 +2388,21 @@ describe("TokenService", () => {
const clientId = "clientId";
const clientSecret = "clientSecret";
(tokenService as any)._setAccessToken = jest.fn();
// any hack allows for mocking private method.
(tokenService as any).setRefreshToken = jest.fn();
tokenService.setClientId = jest.fn();
tokenService.setClientSecret = jest.fn();
(tokenService as any)._setAccessToken = jest.fn().mockReturnValue(accessTokenJwt);
(tokenService as any).setRefreshToken = jest.fn().mockReturnValue(refreshToken);
tokenService.setClientId = jest.fn().mockReturnValue(clientId);
tokenService.setClientSecret = jest.fn().mockReturnValue(clientSecret);
// Act
// Note: passing a valid access token so that a valid user id can be determined from the access token
await tokenService.setTokens(accessTokenJwt, vaultTimeoutAction, vaultTimeout, refreshToken, [
clientId,
clientSecret,
]);
const result = await tokenService.setTokens(
accessTokenJwt,
vaultTimeoutAction,
vaultTimeout,
refreshToken,
[clientId, clientSecret],
);
// Assert
expect((tokenService as any)._setAccessToken).toHaveBeenCalledWith(
@ -2417,6 +2432,44 @@ describe("TokenService", () => {
vaultTimeout,
userIdFromAccessToken,
);
expect(result).toStrictEqual(
new SetTokensResult(accessTokenJwt, refreshToken, [clientId, clientSecret]),
);
});
it("does not try to set the refresh token when it is not passed in", async () => {
// Arrange
const vaultTimeoutAction = VaultTimeoutAction.Lock;
const vaultTimeout = 30;
(tokenService as any)._setAccessToken = jest.fn().mockReturnValue(accessTokenJwt);
(tokenService as any).setRefreshToken = jest.fn();
tokenService.setClientId = jest.fn();
tokenService.setClientSecret = jest.fn();
// Act
const result = await tokenService.setTokens(
accessTokenJwt,
vaultTimeoutAction,
vaultTimeout,
null,
);
// Assert
expect((tokenService as any)._setAccessToken).toHaveBeenCalledWith(
accessTokenJwt,
vaultTimeoutAction,
vaultTimeout,
userIdFromAccessToken,
);
// any hack allows for testing private methods
expect((tokenService as any).setRefreshToken).not.toHaveBeenCalled();
expect(tokenService.setClientId).not.toHaveBeenCalled();
expect(tokenService.setClientSecret).not.toHaveBeenCalled();
expect(result).toStrictEqual(new SetTokensResult(accessTokenJwt));
});
it("does not try to set client id and client secret when they are not passed in", async () => {
@ -2425,13 +2478,18 @@ describe("TokenService", () => {
const vaultTimeoutAction = VaultTimeoutAction.Lock;
const vaultTimeout = 30;
(tokenService as any)._setAccessToken = jest.fn();
(tokenService as any).setRefreshToken = jest.fn();
(tokenService as any)._setAccessToken = jest.fn().mockReturnValue(accessTokenJwt);
(tokenService as any).setRefreshToken = jest.fn().mockReturnValue(refreshToken);
tokenService.setClientId = jest.fn();
tokenService.setClientSecret = jest.fn();
// Act
await tokenService.setTokens(accessTokenJwt, vaultTimeoutAction, vaultTimeout, refreshToken);
const result = await tokenService.setTokens(
accessTokenJwt,
vaultTimeoutAction,
vaultTimeout,
refreshToken,
);
// Assert
expect((tokenService as any)._setAccessToken).toHaveBeenCalledWith(
@ -2451,6 +2509,8 @@ describe("TokenService", () => {
expect(tokenService.setClientId).not.toHaveBeenCalled();
expect(tokenService.setClientSecret).not.toHaveBeenCalled();
expect(result).toStrictEqual(new SetTokensResult(accessTokenJwt, refreshToken));
});
it("throws an error when the access token is invalid", async () => {
@ -2535,10 +2595,16 @@ describe("TokenService", () => {
(tokenService as any).setRefreshToken = jest.fn();
// Act
await tokenService.setTokens(accessTokenJwt, vaultTimeoutAction, vaultTimeout, refreshToken);
const result = await tokenService.setTokens(
accessTokenJwt,
vaultTimeoutAction,
vaultTimeout,
refreshToken,
);
// Assert
expect((tokenService as any).setRefreshToken).not.toHaveBeenCalled();
expect(result).toStrictEqual(new SetTokensResult(accessTokenJwt));
});
});

View File

@ -21,6 +21,7 @@ import {
import { UserId } from "../../types/guid";
import { VaultTimeout, VaultTimeoutStringType } from "../../types/vault-timeout.type";
import { TokenService as TokenServiceAbstraction } from "../abstractions/token.service";
import { SetTokensResult } from "../models/domain/set-tokens-result";
import { ACCOUNT_ACTIVE_ACCOUNT_ID } from "./account.service";
import {
@ -160,7 +161,7 @@ export class TokenService implements TokenServiceAbstraction {
vaultTimeout: VaultTimeout,
refreshToken?: string,
clientIdClientSecret?: [string, string],
): Promise<void> {
): Promise<SetTokensResult> {
if (!accessToken) {
throw new Error("Access token is required.");
}
@ -181,16 +182,40 @@ export class TokenService implements TokenServiceAbstraction {
throw new Error("User id not found. Cannot set tokens.");
}
await this._setAccessToken(accessToken, vaultTimeoutAction, vaultTimeout, userId);
const newAccessToken = await this._setAccessToken(
accessToken,
vaultTimeoutAction,
vaultTimeout,
userId,
);
const newTokens = new SetTokensResult(newAccessToken);
if (refreshToken) {
await this.setRefreshToken(refreshToken, vaultTimeoutAction, vaultTimeout, userId);
newTokens.refreshToken = await this.setRefreshToken(
refreshToken,
vaultTimeoutAction,
vaultTimeout,
userId,
);
}
if (clientIdClientSecret != null) {
await this.setClientId(clientIdClientSecret[0], vaultTimeoutAction, vaultTimeout, userId);
await this.setClientSecret(clientIdClientSecret[1], vaultTimeoutAction, vaultTimeout, userId);
const clientId = await this.setClientId(
clientIdClientSecret[0],
vaultTimeoutAction,
vaultTimeout,
userId,
);
const clientSecret = await this.setClientSecret(
clientIdClientSecret[1],
vaultTimeoutAction,
vaultTimeout,
userId,
);
newTokens.clientIdSecretPair = [clientId, clientSecret];
}
return newTokens;
}
private async getAccessTokenKey(userId: UserId): Promise<AccessTokenKey | null> {
@ -289,7 +314,7 @@ export class TokenService implements TokenServiceAbstraction {
vaultTimeoutAction: VaultTimeoutAction,
vaultTimeout: VaultTimeout,
userId: UserId,
): Promise<void> {
): Promise<string> {
const storageLocation = await this.determineStorageLocation(
vaultTimeoutAction,
vaultTimeout,
@ -302,6 +327,8 @@ export class TokenService implements TokenServiceAbstraction {
// store the access token directly. Instead, we encrypt with accessTokenKey and store that
// in secure storage.
let decryptedAccessToken: string = null;
try {
const encryptedAccessToken: EncString = await this.encryptAccessToken(
accessToken,
@ -313,6 +340,10 @@ export class TokenService implements TokenServiceAbstraction {
.get(userId, ACCESS_TOKEN_DISK)
.update((_) => encryptedAccessToken.encryptedString);
// If we've successfully stored the encrypted access token to disk, we can return the decrypted access token
// so that the caller can use it immediately.
decryptedAccessToken = accessToken;
// TODO: PM-6408
// 2024-02-20: Remove access token from memory so that we migrate to encrypt the access token over time.
// Remove this call to remove the access token from memory after 3 months.
@ -324,25 +355,23 @@ export class TokenService implements TokenServiceAbstraction {
);
// Fall back to disk storage for unecrypted access token
await this.singleUserStateProvider
decryptedAccessToken = await this.singleUserStateProvider
.get(userId, ACCESS_TOKEN_DISK)
.update((_) => accessToken);
}
return;
return decryptedAccessToken;
}
case TokenStorageLocation.Disk:
// Access token stored on disk unencrypted as platform does not support secure storage
await this.singleUserStateProvider
return await this.singleUserStateProvider
.get(userId, ACCESS_TOKEN_DISK)
.update((_) => accessToken);
return;
case TokenStorageLocation.Memory:
// Access token stored in memory due to vault timeout settings
await this.singleUserStateProvider
return await this.singleUserStateProvider
.get(userId, ACCESS_TOKEN_MEMORY)
.update((_) => accessToken);
return;
}
}
@ -350,7 +379,7 @@ export class TokenService implements TokenServiceAbstraction {
accessToken: string,
vaultTimeoutAction: VaultTimeoutAction,
vaultTimeout: VaultTimeout,
): Promise<void> {
): Promise<string> {
if (!accessToken) {
throw new Error("Access token is required.");
}
@ -370,7 +399,7 @@ export class TokenService implements TokenServiceAbstraction {
throw new Error("Vault Timeout Action is required.");
}
await this._setAccessToken(accessToken, vaultTimeoutAction, vaultTimeout, userId);
return await this._setAccessToken(accessToken, vaultTimeoutAction, vaultTimeout, userId);
}
async clearAccessToken(userId?: UserId): Promise<void> {
@ -486,7 +515,7 @@ export class TokenService implements TokenServiceAbstraction {
vaultTimeoutAction: VaultTimeoutAction,
vaultTimeout: VaultTimeout,
userId: UserId,
): Promise<void> {
): Promise<string> {
// If we don't have a user id, we can't save the value
if (!userId) {
throw new Error("User id not found. Cannot save refresh token.");
@ -509,6 +538,8 @@ export class TokenService implements TokenServiceAbstraction {
switch (storageLocation) {
case TokenStorageLocation.SecureStorage: {
let decryptedRefreshToken: string = null;
try {
await this.saveStringToSecureStorage(
userId,
@ -530,6 +561,10 @@ export class TokenService implements TokenServiceAbstraction {
throw new Error("Refresh token failed to save to secure storage.");
}
// If we've successfully stored the encrypted refresh token, we can return the decrypted refresh token
// so that the caller can use it immediately.
decryptedRefreshToken = refreshToken;
// TODO: PM-6408
// 2024-02-20: Remove refresh token from memory and disk so that we migrate to secure storage over time.
// Remove these 2 calls to remove the refresh token from memory and disk after 3 months.
@ -544,24 +579,22 @@ export class TokenService implements TokenServiceAbstraction {
);
// Fall back to disk storage for refresh token
await this.singleUserStateProvider
decryptedRefreshToken = await this.singleUserStateProvider
.get(userId, REFRESH_TOKEN_DISK)
.update((_) => refreshToken);
}
return;
return decryptedRefreshToken;
}
case TokenStorageLocation.Disk:
await this.singleUserStateProvider
return await this.singleUserStateProvider
.get(userId, REFRESH_TOKEN_DISK)
.update((_) => refreshToken);
return;
case TokenStorageLocation.Memory:
await this.singleUserStateProvider
return await this.singleUserStateProvider
.get(userId, REFRESH_TOKEN_MEMORY)
.update((_) => refreshToken);
return;
}
}
@ -644,7 +677,7 @@ export class TokenService implements TokenServiceAbstraction {
vaultTimeoutAction: VaultTimeoutAction,
vaultTimeout: VaultTimeout,
userId?: UserId,
): Promise<void> {
): Promise<string> {
userId ??= await firstValueFrom(this.activeUserIdGlobalState.state$);
// If we don't have a user id, we can't save the value
@ -668,11 +701,11 @@ export class TokenService implements TokenServiceAbstraction {
);
if (storageLocation === TokenStorageLocation.Disk) {
await this.singleUserStateProvider
return await this.singleUserStateProvider
.get(userId, API_KEY_CLIENT_ID_DISK)
.update((_) => clientId);
} else if (storageLocation === TokenStorageLocation.Memory) {
await this.singleUserStateProvider
return await this.singleUserStateProvider
.get(userId, API_KEY_CLIENT_ID_MEMORY)
.update((_) => clientId);
}
@ -721,7 +754,7 @@ export class TokenService implements TokenServiceAbstraction {
vaultTimeoutAction: VaultTimeoutAction,
vaultTimeout: VaultTimeout,
userId?: UserId,
): Promise<void> {
): Promise<string> {
userId ??= await firstValueFrom(this.activeUserIdGlobalState.state$);
if (!userId) {
@ -744,11 +777,11 @@ export class TokenService implements TokenServiceAbstraction {
);
if (storageLocation === TokenStorageLocation.Disk) {
await this.singleUserStateProvider
return await this.singleUserStateProvider
.get(userId, API_KEY_CLIENT_SECRET_DISK)
.update((_) => clientSecret);
} else if (storageLocation === TokenStorageLocation.Memory) {
await this.singleUserStateProvider
return await this.singleUserStateProvider
.get(userId, API_KEY_CLIENT_SECRET_MEMORY)
.update((_) => clientSecret);
}

View File

@ -249,7 +249,7 @@ export class ApiService implements ApiServiceAbstraction {
async refreshIdentityToken(): Promise<any> {
try {
await this.doAuthRefresh();
await this.refreshToken();
} catch (e) {
this.logService.error("Error refreshing access token: ", e);
throw e;
@ -1566,8 +1566,7 @@ export class ApiService implements ApiServiceAbstraction {
async getActiveBearerToken(): Promise<string> {
let accessToken = await this.tokenService.getAccessToken();
if (await this.tokenService.tokenNeedsRefresh()) {
await this.doAuthRefresh();
accessToken = await this.tokenService.getAccessToken();
accessToken = await this.refreshToken();
}
return accessToken;
}
@ -1707,16 +1706,16 @@ export class ApiService implements ApiServiceAbstraction {
);
}
protected async doAuthRefresh(): Promise<void> {
protected async refreshToken(): Promise<string> {
const refreshToken = await this.tokenService.getRefreshToken();
if (refreshToken != null && refreshToken !== "") {
return this.doRefreshToken();
return this.refreshAccessToken();
}
const clientId = await this.tokenService.getClientId();
const clientSecret = await this.tokenService.getClientSecret();
if (!Utils.isNullOrWhitespace(clientId) && !Utils.isNullOrWhitespace(clientSecret)) {
return this.doApiTokenRefresh();
return this.refreshApiToken();
}
this.refreshAccessTokenErrorCallback();
@ -1724,7 +1723,7 @@ export class ApiService implements ApiServiceAbstraction {
throw new Error("Cannot refresh access token, no refresh token or api keys are stored.");
}
protected async doRefreshToken(): Promise<void> {
protected async refreshAccessToken(): Promise<string> {
const refreshToken = await this.tokenService.getRefreshToken();
if (refreshToken == null || refreshToken === "") {
throw new Error();
@ -1770,19 +1769,20 @@ export class ApiService implements ApiServiceAbstraction {
this.vaultTimeoutSettingsService.getVaultTimeoutByUserId$(userId),
);
await this.tokenService.setTokens(
const refreshedTokens = await this.tokenService.setTokens(
tokenResponse.accessToken,
vaultTimeoutAction as VaultTimeoutAction,
vaultTimeout,
tokenResponse.refreshToken,
);
return refreshedTokens.accessToken;
} else {
const error = await this.handleError(response, true, true);
return Promise.reject(error);
}
}
protected async doApiTokenRefresh(): Promise<void> {
protected async refreshApiToken(): Promise<string> {
const clientId = await this.tokenService.getClientId();
const clientSecret = await this.tokenService.getClientSecret();
@ -1810,11 +1810,12 @@ export class ApiService implements ApiServiceAbstraction {
this.vaultTimeoutSettingsService.getVaultTimeoutByUserId$(userId),
);
await this.tokenService.setAccessToken(
const refreshedToken = await this.tokenService.setAccessToken(
response.accessToken,
vaultTimeoutAction as VaultTimeoutAction,
vaultTimeout,
);
return refreshedToken;
}
async send(