From 4ccf920da8d373bcfa6c745fda8d27043ce421bd Mon Sep 17 00:00:00 2001 From: Matt Gibson Date: Wed, 15 May 2024 17:40:16 -0400 Subject: [PATCH] [PM-8155] Keep crypto derive dependencies in lockstep (#9191) * Keep derive dependencies in lockstep This reduces emissions in general due to updates of multiple inputs and removes decryption errors due to partially updated dependencies * Fix provider encrypted org keys * Fix provider state test types * Type fixes --- .../domain/encrypted-organization-key.ts | 38 +++++++++------ .../src/platform/services/crypto.service.ts | 47 ++++++++++--------- .../services/key-state/org-keys.state.spec.ts | 30 +++++++----- .../services/key-state/org-keys.state.ts | 35 +++++++++----- .../key-state/provider-keys.state.spec.ts | 7 +-- .../services/key-state/provider-keys.state.ts | 20 ++++---- .../services/key-state/user-key.state.spec.ts | 18 ++----- .../services/key-state/user-key.state.ts | 20 +++----- 8 files changed, 112 insertions(+), 103 deletions(-) diff --git a/libs/common/src/admin-console/models/domain/encrypted-organization-key.ts b/libs/common/src/admin-console/models/domain/encrypted-organization-key.ts index 470fa2317e..1f8c4e8c42 100644 --- a/libs/common/src/admin-console/models/domain/encrypted-organization-key.ts +++ b/libs/common/src/admin-console/models/domain/encrypted-organization-key.ts @@ -1,11 +1,11 @@ -import { CryptoService } from "../../../platform/abstractions/crypto.service"; +import { EncryptService } from "../../../platform/abstractions/encrypt.service"; import { EncString } from "../../../platform/models/domain/enc-string"; import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key"; -import { OrgKey } from "../../../types/key"; +import { OrgKey, UserPrivateKey } from "../../../types/key"; import { EncryptedOrganizationKeyData } from "../data/encrypted-organization-key.data"; export abstract class BaseEncryptedOrganizationKey { - decrypt: (cryptoService: CryptoService) => Promise; + abstract get encryptedOrganizationKey(): EncString; static fromData(data: EncryptedOrganizationKeyData) { switch (data.type) { @@ -19,22 +19,26 @@ export abstract class BaseEncryptedOrganizationKey { return null; } } + + static isProviderEncrypted( + key: EncryptedOrganizationKey | ProviderEncryptedOrganizationKey, + ): key is ProviderEncryptedOrganizationKey { + return key.toData().type === "provider"; + } } export class EncryptedOrganizationKey implements BaseEncryptedOrganizationKey { constructor(private key: string) {} - async decrypt(cryptoService: CryptoService) { - const activeUserPrivateKey = await cryptoService.getPrivateKey(); - - if (activeUserPrivateKey == null) { - throw new Error("Active user does not have a private key, cannot decrypt organization key."); - } - - const decValue = await cryptoService.rsaDecrypt(this.key, activeUserPrivateKey); + async decrypt(encryptService: EncryptService, privateKey: UserPrivateKey) { + const decValue = await encryptService.rsaDecrypt(this.encryptedOrganizationKey, privateKey); return new SymmetricCryptoKey(decValue) as OrgKey; } + get encryptedOrganizationKey() { + return new EncString(this.key); + } + toData(): EncryptedOrganizationKeyData { return { type: "organization", @@ -49,12 +53,18 @@ export class ProviderEncryptedOrganizationKey implements BaseEncryptedOrganizati private providerId: string, ) {} - async decrypt(cryptoService: CryptoService) { - const providerKey = await cryptoService.getProviderKey(this.providerId); - const decValue = await cryptoService.decryptToBytes(new EncString(this.key), providerKey); + async decrypt(encryptService: EncryptService, providerKeys: Record) { + const decValue = await encryptService.decryptToBytes( + new EncString(this.key), + providerKeys[this.providerId], + ); return new SymmetricCryptoKey(decValue) as OrgKey; } + get encryptedOrganizationKey() { + return new EncString(this.key); + } + toData(): EncryptedOrganizationKeyData { return { type: "provider", diff --git a/libs/common/src/platform/services/crypto.service.ts b/libs/common/src/platform/services/crypto.service.ts index fed22e06a0..2813bfb960 100644 --- a/libs/common/src/platform/services/crypto.service.ts +++ b/libs/common/src/platform/services/crypto.service.ts @@ -1,5 +1,5 @@ import * as bigInt from "big-integer"; -import { Observable, filter, firstValueFrom, map } from "rxjs"; +import { Observable, filter, firstValueFrom, map, zip } from "rxjs"; import { PinServiceAbstraction } from "../../../../auth/src/common/abstractions"; import { EncryptedOrganizationKeyData } from "../../admin-console/models/data/encrypted-organization-key.data"; @@ -97,13 +97,12 @@ export class CryptoService implements CryptoServiceAbstraction { // User Asymmetric Key Pair this.activeUserEncryptedPrivateKeyState = stateProvider.getActive(USER_ENCRYPTED_PRIVATE_KEY); this.activeUserPrivateKeyState = stateProvider.getDerived( - this.activeUserEncryptedPrivateKeyState.combinedState$.pipe( - filter(([_userId, key]) => key != null), + zip(this.activeUserEncryptedPrivateKeyState.state$, this.activeUserKey$).pipe( + filter(([, userKey]) => !!userKey), ), USER_PRIVATE_KEY, { encryptService: this.encryptService, - getUserKey: (userId) => this.getUserKey(userId), }, ); this.activeUserPrivateKey$ = this.activeUserPrivateKeyState.state$; // may be null @@ -116,27 +115,34 @@ export class CryptoService implements CryptoServiceAbstraction { ); this.activeUserPublicKey$ = this.activeUserPublicKeyState.state$; // may be null - // Organization keys - this.activeUserEncryptedOrgKeysState = stateProvider.getActive( - USER_ENCRYPTED_ORGANIZATION_KEYS, - ); - this.activeUserOrgKeysState = stateProvider.getDerived( - this.activeUserEncryptedOrgKeysState.state$.pipe(filter((keys) => keys != null)), - USER_ORGANIZATION_KEYS, - { cryptoService: this }, - ); - this.activeUserOrgKeys$ = this.activeUserOrgKeysState.state$; // null handled by `derive` function - // Provider keys this.activeUserEncryptedProviderKeysState = stateProvider.getActive( USER_ENCRYPTED_PROVIDER_KEYS, ); this.activeUserProviderKeysState = stateProvider.getDerived( - this.activeUserEncryptedProviderKeysState.state$.pipe(filter((keys) => keys != null)), + zip( + this.activeUserEncryptedProviderKeysState.state$.pipe(filter((keys) => keys != null)), + this.activeUserPrivateKey$, + ).pipe(filter(([, privateKey]) => !!privateKey)), USER_PROVIDER_KEYS, - { encryptService: this.encryptService, cryptoService: this }, + { encryptService: this.encryptService }, ); this.activeUserProviderKeys$ = this.activeUserProviderKeysState.state$; // null handled by `derive` function + + // Organization keys + this.activeUserEncryptedOrgKeysState = stateProvider.getActive( + USER_ENCRYPTED_ORGANIZATION_KEYS, + ); + this.activeUserOrgKeysState = stateProvider.getDerived( + zip( + this.activeUserEncryptedOrgKeysState.state$.pipe(filter((keys) => keys != null)), + this.activeUserPrivateKey$, + this.activeUserProviderKeys$, + ).pipe(filter(([, privateKey]) => !!privateKey)), + USER_ORGANIZATION_KEYS, + { encryptService: this.encryptService }, + ); + this.activeUserOrgKeys$ = this.activeUserOrgKeysState.state$; // null handled by `derive` function } async setUserKey(key: UserKey, userId?: UserId): Promise { @@ -656,17 +662,14 @@ export class CryptoService implements CryptoServiceAbstraction { } try { - const [userId, encPrivateKey] = await firstValueFrom( - this.activeUserEncryptedPrivateKeyState.combinedState$, - ); + const encPrivateKey = await firstValueFrom(this.activeUserEncryptedPrivateKeyState.state$); if (encPrivateKey == null) { return false; } // Can decrypt private key - const privateKey = await USER_PRIVATE_KEY.derive([userId, encPrivateKey], { + const privateKey = await USER_PRIVATE_KEY.derive([encPrivateKey, key], { encryptService: this.encryptService, - getUserKey: () => Promise.resolve(key), }); if (privateKey == null) { diff --git a/libs/common/src/platform/services/key-state/org-keys.state.spec.ts b/libs/common/src/platform/services/key-state/org-keys.state.spec.ts index 6b547a491a..98e0139cc4 100644 --- a/libs/common/src/platform/services/key-state/org-keys.state.spec.ts +++ b/libs/common/src/platform/services/key-state/org-keys.state.spec.ts @@ -1,8 +1,8 @@ import { mock } from "jest-mock-extended"; import { makeEncString, makeStaticByteArray } from "../../../../spec"; -import { OrgKey } from "../../../types/key"; -import { CryptoService } from "../../abstractions/crypto.service"; +import { OrgKey, UserPrivateKey } from "../../../types/key"; +import { EncryptService } from "../../abstractions/encrypt.service"; import { SymmetricCryptoKey } from "../../models/domain/symmetric-crypto-key"; import { USER_ENCRYPTED_ORGANIZATION_KEYS, USER_ORGANIZATION_KEYS } from "./org-keys.state"; @@ -30,7 +30,8 @@ describe("encrypted org keys", () => { }); describe("derived decrypted org keys", () => { - const cryptoService = mock(); + const encryptService = mock(); + const userPrivateKey = makeStaticByteArray(64, 3) as UserPrivateKey; const sut = USER_ORGANIZATION_KEYS; afterEach(() => { @@ -65,15 +66,11 @@ describe("derived decrypted org keys", () => { "org-id-2": new SymmetricCryptoKey(makeStaticByteArray(64, 2)) as OrgKey, }; - const userPrivateKey = makeStaticByteArray(64, 3); - - cryptoService.getPrivateKey.mockResolvedValue(userPrivateKey); - // TODO: How to not have to mock these decryptions. They are internal concerns of EncryptedOrganizationKey - cryptoService.rsaDecrypt.mockResolvedValueOnce(decryptedOrgKeys["org-id-1"].key); - cryptoService.rsaDecrypt.mockResolvedValueOnce(decryptedOrgKeys["org-id-2"].key); + encryptService.rsaDecrypt.mockResolvedValueOnce(decryptedOrgKeys["org-id-1"].key); + encryptService.rsaDecrypt.mockResolvedValueOnce(decryptedOrgKeys["org-id-2"].key); - const result = await sut.derive(encryptedOrgKeys, { cryptoService }); + const result = await sut.derive([encryptedOrgKeys, userPrivateKey, {}], { encryptService }); expect(result).toEqual(decryptedOrgKeys); }); @@ -92,16 +89,23 @@ describe("derived decrypted org keys", () => { }, }; + const providerKeys = { + "provider-id-1": new SymmetricCryptoKey(makeStaticByteArray(64, 1)), + "provider-id-2": new SymmetricCryptoKey(makeStaticByteArray(64, 2)), + }; + const decryptedOrgKeys = { "org-id-1": new SymmetricCryptoKey(makeStaticByteArray(64, 1)) as OrgKey, "org-id-2": new SymmetricCryptoKey(makeStaticByteArray(64, 2)) as OrgKey, }; // TODO: How to not have to mock these decryptions. They are internal concerns of ProviderEncryptedOrganizationKey - cryptoService.decryptToBytes.mockResolvedValueOnce(decryptedOrgKeys["org-id-1"].key); - cryptoService.decryptToBytes.mockResolvedValueOnce(decryptedOrgKeys["org-id-2"].key); + encryptService.decryptToBytes.mockResolvedValueOnce(decryptedOrgKeys["org-id-1"].key); + encryptService.decryptToBytes.mockResolvedValueOnce(decryptedOrgKeys["org-id-2"].key); - const result = await sut.derive(encryptedOrgKeys, { cryptoService }); + const result = await sut.derive([encryptedOrgKeys, userPrivateKey, providerKeys], { + encryptService, + }); expect(result).toEqual(decryptedOrgKeys); }); diff --git a/libs/common/src/platform/services/key-state/org-keys.state.ts b/libs/common/src/platform/services/key-state/org-keys.state.ts index f67e64b653..8a42e242b1 100644 --- a/libs/common/src/platform/services/key-state/org-keys.state.ts +++ b/libs/common/src/platform/services/key-state/org-keys.state.ts @@ -1,10 +1,10 @@ import { EncryptedOrganizationKeyData } from "../../../admin-console/models/data/encrypted-organization-key.data"; import { BaseEncryptedOrganizationKey } from "../../../admin-console/models/domain/encrypted-organization-key"; -import { OrganizationId } from "../../../types/guid"; -import { OrgKey } from "../../../types/key"; -import { CryptoService } from "../../abstractions/crypto.service"; +import { OrganizationId, ProviderId } from "../../../types/guid"; +import { OrgKey, ProviderKey, UserPrivateKey } from "../../../types/key"; +import { EncryptService } from "../../abstractions/encrypt.service"; import { SymmetricCryptoKey } from "../../models/domain/symmetric-crypto-key"; -import { CRYPTO_DISK, DeriveDefinition, UserKeyDefinition } from "../../state"; +import { CRYPTO_DISK, CRYPTO_MEMORY, DeriveDefinition, UserKeyDefinition } from "../../state"; export const USER_ENCRYPTED_ORGANIZATION_KEYS = UserKeyDefinition.record< EncryptedOrganizationKeyData, @@ -14,11 +14,15 @@ export const USER_ENCRYPTED_ORGANIZATION_KEYS = UserKeyDefinition.record< clearOn: ["logout"], }); -export const USER_ORGANIZATION_KEYS = DeriveDefinition.from< - Record, +export const USER_ORGANIZATION_KEYS = new DeriveDefinition< + [ + Record, + UserPrivateKey, + Record, + ], Record, - { cryptoService: CryptoService } ->(USER_ENCRYPTED_ORGANIZATION_KEYS, { + { encryptService: EncryptService } +>(CRYPTO_MEMORY, "organizationKeys", { deserializer: (obj) => { const result: Record = {}; for (const orgId of Object.keys(obj ?? {}) as OrganizationId[]) { @@ -26,14 +30,21 @@ export const USER_ORGANIZATION_KEYS = DeriveDefinition.from< } return result; }, - derive: async (from, { cryptoService }) => { + derive: async ([encryptedOrgKeys, privateKey, providerKeys], { encryptService }) => { const result: Record = {}; - for (const orgId of Object.keys(from ?? {}) as OrganizationId[]) { + for (const orgId of Object.keys(encryptedOrgKeys ?? {}) as OrganizationId[]) { if (result[orgId] != null) { continue; } - const encrypted = BaseEncryptedOrganizationKey.fromData(from[orgId]); - const decrypted = await encrypted.decrypt(cryptoService); + const encrypted = BaseEncryptedOrganizationKey.fromData(encryptedOrgKeys[orgId]); + + let decrypted: OrgKey; + + if (BaseEncryptedOrganizationKey.isProviderEncrypted(encrypted)) { + decrypted = await encrypted.decrypt(encryptService, providerKeys); + } else { + decrypted = await encrypted.decrypt(encryptService, privateKey); + } result[orgId] = decrypted; } diff --git a/libs/common/src/platform/services/key-state/provider-keys.state.spec.ts b/libs/common/src/platform/services/key-state/provider-keys.state.spec.ts index 78e61e0391..ca84d4a6ea 100644 --- a/libs/common/src/platform/services/key-state/provider-keys.state.spec.ts +++ b/libs/common/src/platform/services/key-state/provider-keys.state.spec.ts @@ -6,7 +6,6 @@ import { ProviderKey, UserPrivateKey } from "../../../types/key"; import { EncryptService } from "../../abstractions/encrypt.service"; import { EncryptedString } from "../../models/domain/enc-string"; import { SymmetricCryptoKey } from "../../models/domain/symmetric-crypto-key"; -import { CryptoService } from "../crypto.service"; import { USER_ENCRYPTED_PROVIDER_KEYS, USER_PROVIDER_KEYS } from "./provider-keys.state"; @@ -27,7 +26,6 @@ describe("encrypted provider keys", () => { describe("derived decrypted provider keys", () => { const encryptService = mock(); - const cryptoService = mock(); const userPrivateKey = makeStaticByteArray(64, 0) as UserPrivateKey; const sut = USER_PROVIDER_KEYS; @@ -59,9 +57,8 @@ describe("derived decrypted provider keys", () => { encryptService.rsaDecrypt.mockResolvedValueOnce(decryptedProviderKeys["provider-id-1"].key); encryptService.rsaDecrypt.mockResolvedValueOnce(decryptedProviderKeys["provider-id-2"].key); - cryptoService.getPrivateKey.mockResolvedValueOnce(userPrivateKey); - const result = await sut.derive(encryptedProviderKeys, { encryptService, cryptoService }); + const result = await sut.derive([encryptedProviderKeys, userPrivateKey], { encryptService }); expect(result).toEqual(decryptedProviderKeys); }); @@ -69,7 +66,7 @@ describe("derived decrypted provider keys", () => { it("should handle null input values", async () => { const encryptedProviderKeys: Record = null; - const result = await sut.derive(encryptedProviderKeys, { encryptService, cryptoService }); + const result = await sut.derive([encryptedProviderKeys, userPrivateKey], { encryptService }); expect(result).toEqual({}); }); diff --git a/libs/common/src/platform/services/key-state/provider-keys.state.ts b/libs/common/src/platform/services/key-state/provider-keys.state.ts index 776fdc77d8..dfda71be21 100644 --- a/libs/common/src/platform/services/key-state/provider-keys.state.ts +++ b/libs/common/src/platform/services/key-state/provider-keys.state.ts @@ -1,10 +1,9 @@ import { ProviderId } from "../../../types/guid"; -import { ProviderKey } from "../../../types/key"; +import { ProviderKey, UserPrivateKey } from "../../../types/key"; import { EncryptService } from "../../abstractions/encrypt.service"; import { EncString, EncryptedString } from "../../models/domain/enc-string"; import { SymmetricCryptoKey } from "../../models/domain/symmetric-crypto-key"; -import { CRYPTO_DISK, DeriveDefinition, UserKeyDefinition } from "../../state"; -import { CryptoService } from "../crypto.service"; +import { CRYPTO_DISK, CRYPTO_MEMORY, DeriveDefinition, UserKeyDefinition } from "../../state"; export const USER_ENCRYPTED_PROVIDER_KEYS = UserKeyDefinition.record( CRYPTO_DISK, @@ -15,11 +14,11 @@ export const USER_ENCRYPTED_PROVIDER_KEYS = UserKeyDefinition.record, +export const USER_PROVIDER_KEYS = new DeriveDefinition< + [Record, UserPrivateKey], Record, - { encryptService: EncryptService; cryptoService: CryptoService } // TODO: This should depend on an active user private key observable directly ->(USER_ENCRYPTED_PROVIDER_KEYS, { + { encryptService: EncryptService } +>(CRYPTO_MEMORY, "providerKeys", { deserializer: (obj) => { const result: Record = {}; for (const providerId of Object.keys(obj ?? {}) as ProviderId[]) { @@ -27,14 +26,13 @@ export const USER_PROVIDER_KEYS = DeriveDefinition.from< } return result; }, - derive: async (from, { encryptService, cryptoService }) => { + derive: async ([encryptedProviderKeys, privateKey], { encryptService }) => { const result: Record = {}; - for (const providerId of Object.keys(from ?? {}) as ProviderId[]) { + for (const providerId of Object.keys(encryptedProviderKeys ?? {}) as ProviderId[]) { if (result[providerId] != null) { continue; } - const encrypted = new EncString(from[providerId]); - const privateKey = await cryptoService.getPrivateKey(); + const encrypted = new EncString(encryptedProviderKeys[providerId]); const decrypted = await encryptService.rsaDecrypt(encrypted, privateKey); const providerKey = new SymmetricCryptoKey(decrypted) as ProviderKey; diff --git a/libs/common/src/platform/services/key-state/user-key.state.spec.ts b/libs/common/src/platform/services/key-state/user-key.state.spec.ts index 5c5c5ac70c..63273f1c79 100644 --- a/libs/common/src/platform/services/key-state/user-key.state.spec.ts +++ b/libs/common/src/platform/services/key-state/user-key.state.spec.ts @@ -1,7 +1,6 @@ import { mock } from "jest-mock-extended"; import { makeStaticByteArray } from "../../../../spec"; -import { UserId } from "../../../types/guid"; import { UserKey, UserPrivateKey, UserPublicKey } from "../../../types/key"; import { CryptoFunctionService } from "../../abstractions/crypto-function.service"; import { EncryptService } from "../../abstractions/encrypt.service"; @@ -70,7 +69,6 @@ describe("User public key", () => { describe("Derived decrypted private key", () => { const sut = USER_PRIVATE_KEY; - const userId = "userId" as UserId; const userKey = mock(); const encryptedPrivateKey = makeEncString().encryptedString; const decryptedPrivateKey = makeStaticByteArray(64, 1); @@ -88,37 +86,31 @@ describe("Derived decrypted private key", () => { }); it("should derive decrypted private key", async () => { - const getUserKey = jest.fn(async () => userKey); const encryptService = mock(); encryptService.decryptToBytes.mockResolvedValue(decryptedPrivateKey); - const result = await sut.derive([userId, encryptedPrivateKey], { + const result = await sut.derive([encryptedPrivateKey, userKey], { encryptService, - getUserKey, }); expect(result).toEqual(decryptedPrivateKey); }); - it("should handle null input values", async () => { - const getUserKey = jest.fn(async () => userKey); + it("should handle null encryptedPrivateKey", async () => { const encryptService = mock(); - const result = await sut.derive([userId, null], { + const result = await sut.derive([null, userKey], { encryptService, - getUserKey, }); expect(result).toEqual(null); }); - it("should handle null user key", async () => { - const getUserKey = jest.fn(async () => null); + it("should handle null userKey", async () => { const encryptService = mock(); - const result = await sut.derive([userId, encryptedPrivateKey], { + const result = await sut.derive([encryptedPrivateKey, null], { encryptService, - getUserKey, }); expect(result).toEqual(null); diff --git a/libs/common/src/platform/services/key-state/user-key.state.ts b/libs/common/src/platform/services/key-state/user-key.state.ts index 3df3b2044b..c2b84d6a24 100644 --- a/libs/common/src/platform/services/key-state/user-key.state.ts +++ b/libs/common/src/platform/services/key-state/user-key.state.ts @@ -1,4 +1,3 @@ -import { UserId } from "../../../types/guid"; import { UserPrivateKey, UserPublicKey, UserKey } from "../../../types/key"; import { CryptoFunctionService } from "../../abstractions/crypto-function.service"; import { EncryptService } from "../../abstractions/encrypt.service"; @@ -24,20 +23,14 @@ export const USER_ENCRYPTED_PRIVATE_KEY = new UserKeyDefinition }, ); -export const USER_PRIVATE_KEY = DeriveDefinition.fromWithUserId< - EncryptedString, +export const USER_PRIVATE_KEY = new DeriveDefinition< + [EncryptedString, UserKey], UserPrivateKey, - // TODO: update cryptoService to user key directly - { encryptService: EncryptService; getUserKey: (userId: UserId) => Promise } ->(USER_ENCRYPTED_PRIVATE_KEY, { + { encryptService: EncryptService } +>(CRYPTO_MEMORY, "privateKey", { deserializer: (obj) => new Uint8Array(Object.values(obj)) as UserPrivateKey, - derive: async ([userId, encPrivateKeyString], { encryptService, getUserKey }) => { - if (encPrivateKeyString == null) { - return null; - } - - const userKey = await getUserKey(userId); - if (userKey == null) { + derive: async ([encPrivateKeyString, userKey], { encryptService }) => { + if (encPrivateKeyString == null || userKey == null) { return null; } @@ -64,6 +57,7 @@ export const USER_PUBLIC_KEY = DeriveDefinition.from< return (await cryptoFunctionService.rsaExtractPublicKey(privateKey)) as UserPublicKey; }, }); + export const USER_KEY = new UserKeyDefinition(CRYPTO_MEMORY, "userKey", { deserializer: (obj) => SymmetricCryptoKey.fromJSON(obj) as UserKey, clearOn: ["logout", "lock"],