[PM-8155] Keep crypto derive dependencies in lockstep (#9191)

* Keep derive dependencies in lockstep

This reduces emissions in general due to updates of multiple inputs and removes decryption errors due to partially updated dependencies

* Fix provider encrypted org keys

* Fix provider state test types

* Type fixes
This commit is contained in:
Matt Gibson 2024-05-15 17:40:16 -04:00 committed by GitHub
parent c19a640557
commit 4ccf920da8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 112 additions and 103 deletions

View File

@ -1,11 +1,11 @@
import { CryptoService } from "../../../platform/abstractions/crypto.service";
import { EncryptService } from "../../../platform/abstractions/encrypt.service";
import { EncString } from "../../../platform/models/domain/enc-string";
import { SymmetricCryptoKey } from "../../../platform/models/domain/symmetric-crypto-key";
import { OrgKey } from "../../../types/key";
import { OrgKey, UserPrivateKey } from "../../../types/key";
import { EncryptedOrganizationKeyData } from "../data/encrypted-organization-key.data";
export abstract class BaseEncryptedOrganizationKey {
decrypt: (cryptoService: CryptoService) => Promise<SymmetricCryptoKey>;
abstract get encryptedOrganizationKey(): EncString;
static fromData(data: EncryptedOrganizationKeyData) {
switch (data.type) {
@ -19,22 +19,26 @@ export abstract class BaseEncryptedOrganizationKey {
return null;
}
}
static isProviderEncrypted(
key: EncryptedOrganizationKey | ProviderEncryptedOrganizationKey,
): key is ProviderEncryptedOrganizationKey {
return key.toData().type === "provider";
}
}
export class EncryptedOrganizationKey implements BaseEncryptedOrganizationKey {
constructor(private key: string) {}
async decrypt(cryptoService: CryptoService) {
const activeUserPrivateKey = await cryptoService.getPrivateKey();
if (activeUserPrivateKey == null) {
throw new Error("Active user does not have a private key, cannot decrypt organization key.");
}
const decValue = await cryptoService.rsaDecrypt(this.key, activeUserPrivateKey);
async decrypt(encryptService: EncryptService, privateKey: UserPrivateKey) {
const decValue = await encryptService.rsaDecrypt(this.encryptedOrganizationKey, privateKey);
return new SymmetricCryptoKey(decValue) as OrgKey;
}
get encryptedOrganizationKey() {
return new EncString(this.key);
}
toData(): EncryptedOrganizationKeyData {
return {
type: "organization",
@ -49,12 +53,18 @@ export class ProviderEncryptedOrganizationKey implements BaseEncryptedOrganizati
private providerId: string,
) {}
async decrypt(cryptoService: CryptoService) {
const providerKey = await cryptoService.getProviderKey(this.providerId);
const decValue = await cryptoService.decryptToBytes(new EncString(this.key), providerKey);
async decrypt(encryptService: EncryptService, providerKeys: Record<string, SymmetricCryptoKey>) {
const decValue = await encryptService.decryptToBytes(
new EncString(this.key),
providerKeys[this.providerId],
);
return new SymmetricCryptoKey(decValue) as OrgKey;
}
get encryptedOrganizationKey() {
return new EncString(this.key);
}
toData(): EncryptedOrganizationKeyData {
return {
type: "provider",

View File

@ -1,5 +1,5 @@
import * as bigInt from "big-integer";
import { Observable, filter, firstValueFrom, map } from "rxjs";
import { Observable, filter, firstValueFrom, map, zip } from "rxjs";
import { PinServiceAbstraction } from "../../../../auth/src/common/abstractions";
import { EncryptedOrganizationKeyData } from "../../admin-console/models/data/encrypted-organization-key.data";
@ -97,13 +97,12 @@ export class CryptoService implements CryptoServiceAbstraction {
// User Asymmetric Key Pair
this.activeUserEncryptedPrivateKeyState = stateProvider.getActive(USER_ENCRYPTED_PRIVATE_KEY);
this.activeUserPrivateKeyState = stateProvider.getDerived(
this.activeUserEncryptedPrivateKeyState.combinedState$.pipe(
filter(([_userId, key]) => key != null),
zip(this.activeUserEncryptedPrivateKeyState.state$, this.activeUserKey$).pipe(
filter(([, userKey]) => !!userKey),
),
USER_PRIVATE_KEY,
{
encryptService: this.encryptService,
getUserKey: (userId) => this.getUserKey(userId),
},
);
this.activeUserPrivateKey$ = this.activeUserPrivateKeyState.state$; // may be null
@ -116,27 +115,34 @@ export class CryptoService implements CryptoServiceAbstraction {
);
this.activeUserPublicKey$ = this.activeUserPublicKeyState.state$; // may be null
// Organization keys
this.activeUserEncryptedOrgKeysState = stateProvider.getActive(
USER_ENCRYPTED_ORGANIZATION_KEYS,
);
this.activeUserOrgKeysState = stateProvider.getDerived(
this.activeUserEncryptedOrgKeysState.state$.pipe(filter((keys) => keys != null)),
USER_ORGANIZATION_KEYS,
{ cryptoService: this },
);
this.activeUserOrgKeys$ = this.activeUserOrgKeysState.state$; // null handled by `derive` function
// Provider keys
this.activeUserEncryptedProviderKeysState = stateProvider.getActive(
USER_ENCRYPTED_PROVIDER_KEYS,
);
this.activeUserProviderKeysState = stateProvider.getDerived(
this.activeUserEncryptedProviderKeysState.state$.pipe(filter((keys) => keys != null)),
zip(
this.activeUserEncryptedProviderKeysState.state$.pipe(filter((keys) => keys != null)),
this.activeUserPrivateKey$,
).pipe(filter(([, privateKey]) => !!privateKey)),
USER_PROVIDER_KEYS,
{ encryptService: this.encryptService, cryptoService: this },
{ encryptService: this.encryptService },
);
this.activeUserProviderKeys$ = this.activeUserProviderKeysState.state$; // null handled by `derive` function
// Organization keys
this.activeUserEncryptedOrgKeysState = stateProvider.getActive(
USER_ENCRYPTED_ORGANIZATION_KEYS,
);
this.activeUserOrgKeysState = stateProvider.getDerived(
zip(
this.activeUserEncryptedOrgKeysState.state$.pipe(filter((keys) => keys != null)),
this.activeUserPrivateKey$,
this.activeUserProviderKeys$,
).pipe(filter(([, privateKey]) => !!privateKey)),
USER_ORGANIZATION_KEYS,
{ encryptService: this.encryptService },
);
this.activeUserOrgKeys$ = this.activeUserOrgKeysState.state$; // null handled by `derive` function
}
async setUserKey(key: UserKey, userId?: UserId): Promise<void> {
@ -656,17 +662,14 @@ export class CryptoService implements CryptoServiceAbstraction {
}
try {
const [userId, encPrivateKey] = await firstValueFrom(
this.activeUserEncryptedPrivateKeyState.combinedState$,
);
const encPrivateKey = await firstValueFrom(this.activeUserEncryptedPrivateKeyState.state$);
if (encPrivateKey == null) {
return false;
}
// Can decrypt private key
const privateKey = await USER_PRIVATE_KEY.derive([userId, encPrivateKey], {
const privateKey = await USER_PRIVATE_KEY.derive([encPrivateKey, key], {
encryptService: this.encryptService,
getUserKey: () => Promise.resolve(key),
});
if (privateKey == null) {

View File

@ -1,8 +1,8 @@
import { mock } from "jest-mock-extended";
import { makeEncString, makeStaticByteArray } from "../../../../spec";
import { OrgKey } from "../../../types/key";
import { CryptoService } from "../../abstractions/crypto.service";
import { OrgKey, UserPrivateKey } from "../../../types/key";
import { EncryptService } from "../../abstractions/encrypt.service";
import { SymmetricCryptoKey } from "../../models/domain/symmetric-crypto-key";
import { USER_ENCRYPTED_ORGANIZATION_KEYS, USER_ORGANIZATION_KEYS } from "./org-keys.state";
@ -30,7 +30,8 @@ describe("encrypted org keys", () => {
});
describe("derived decrypted org keys", () => {
const cryptoService = mock<CryptoService>();
const encryptService = mock<EncryptService>();
const userPrivateKey = makeStaticByteArray(64, 3) as UserPrivateKey;
const sut = USER_ORGANIZATION_KEYS;
afterEach(() => {
@ -65,15 +66,11 @@ describe("derived decrypted org keys", () => {
"org-id-2": new SymmetricCryptoKey(makeStaticByteArray(64, 2)) as OrgKey,
};
const userPrivateKey = makeStaticByteArray(64, 3);
cryptoService.getPrivateKey.mockResolvedValue(userPrivateKey);
// TODO: How to not have to mock these decryptions. They are internal concerns of EncryptedOrganizationKey
cryptoService.rsaDecrypt.mockResolvedValueOnce(decryptedOrgKeys["org-id-1"].key);
cryptoService.rsaDecrypt.mockResolvedValueOnce(decryptedOrgKeys["org-id-2"].key);
encryptService.rsaDecrypt.mockResolvedValueOnce(decryptedOrgKeys["org-id-1"].key);
encryptService.rsaDecrypt.mockResolvedValueOnce(decryptedOrgKeys["org-id-2"].key);
const result = await sut.derive(encryptedOrgKeys, { cryptoService });
const result = await sut.derive([encryptedOrgKeys, userPrivateKey, {}], { encryptService });
expect(result).toEqual(decryptedOrgKeys);
});
@ -92,16 +89,23 @@ describe("derived decrypted org keys", () => {
},
};
const providerKeys = {
"provider-id-1": new SymmetricCryptoKey(makeStaticByteArray(64, 1)),
"provider-id-2": new SymmetricCryptoKey(makeStaticByteArray(64, 2)),
};
const decryptedOrgKeys = {
"org-id-1": new SymmetricCryptoKey(makeStaticByteArray(64, 1)) as OrgKey,
"org-id-2": new SymmetricCryptoKey(makeStaticByteArray(64, 2)) as OrgKey,
};
// TODO: How to not have to mock these decryptions. They are internal concerns of ProviderEncryptedOrganizationKey
cryptoService.decryptToBytes.mockResolvedValueOnce(decryptedOrgKeys["org-id-1"].key);
cryptoService.decryptToBytes.mockResolvedValueOnce(decryptedOrgKeys["org-id-2"].key);
encryptService.decryptToBytes.mockResolvedValueOnce(decryptedOrgKeys["org-id-1"].key);
encryptService.decryptToBytes.mockResolvedValueOnce(decryptedOrgKeys["org-id-2"].key);
const result = await sut.derive(encryptedOrgKeys, { cryptoService });
const result = await sut.derive([encryptedOrgKeys, userPrivateKey, providerKeys], {
encryptService,
});
expect(result).toEqual(decryptedOrgKeys);
});

View File

@ -1,10 +1,10 @@
import { EncryptedOrganizationKeyData } from "../../../admin-console/models/data/encrypted-organization-key.data";
import { BaseEncryptedOrganizationKey } from "../../../admin-console/models/domain/encrypted-organization-key";
import { OrganizationId } from "../../../types/guid";
import { OrgKey } from "../../../types/key";
import { CryptoService } from "../../abstractions/crypto.service";
import { OrganizationId, ProviderId } from "../../../types/guid";
import { OrgKey, ProviderKey, UserPrivateKey } from "../../../types/key";
import { EncryptService } from "../../abstractions/encrypt.service";
import { SymmetricCryptoKey } from "../../models/domain/symmetric-crypto-key";
import { CRYPTO_DISK, DeriveDefinition, UserKeyDefinition } from "../../state";
import { CRYPTO_DISK, CRYPTO_MEMORY, DeriveDefinition, UserKeyDefinition } from "../../state";
export const USER_ENCRYPTED_ORGANIZATION_KEYS = UserKeyDefinition.record<
EncryptedOrganizationKeyData,
@ -14,11 +14,15 @@ export const USER_ENCRYPTED_ORGANIZATION_KEYS = UserKeyDefinition.record<
clearOn: ["logout"],
});
export const USER_ORGANIZATION_KEYS = DeriveDefinition.from<
Record<OrganizationId, EncryptedOrganizationKeyData>,
export const USER_ORGANIZATION_KEYS = new DeriveDefinition<
[
Record<OrganizationId, EncryptedOrganizationKeyData>,
UserPrivateKey,
Record<ProviderId, ProviderKey>,
],
Record<OrganizationId, OrgKey>,
{ cryptoService: CryptoService }
>(USER_ENCRYPTED_ORGANIZATION_KEYS, {
{ encryptService: EncryptService }
>(CRYPTO_MEMORY, "organizationKeys", {
deserializer: (obj) => {
const result: Record<OrganizationId, OrgKey> = {};
for (const orgId of Object.keys(obj ?? {}) as OrganizationId[]) {
@ -26,14 +30,21 @@ export const USER_ORGANIZATION_KEYS = DeriveDefinition.from<
}
return result;
},
derive: async (from, { cryptoService }) => {
derive: async ([encryptedOrgKeys, privateKey, providerKeys], { encryptService }) => {
const result: Record<OrganizationId, OrgKey> = {};
for (const orgId of Object.keys(from ?? {}) as OrganizationId[]) {
for (const orgId of Object.keys(encryptedOrgKeys ?? {}) as OrganizationId[]) {
if (result[orgId] != null) {
continue;
}
const encrypted = BaseEncryptedOrganizationKey.fromData(from[orgId]);
const decrypted = await encrypted.decrypt(cryptoService);
const encrypted = BaseEncryptedOrganizationKey.fromData(encryptedOrgKeys[orgId]);
let decrypted: OrgKey;
if (BaseEncryptedOrganizationKey.isProviderEncrypted(encrypted)) {
decrypted = await encrypted.decrypt(encryptService, providerKeys);
} else {
decrypted = await encrypted.decrypt(encryptService, privateKey);
}
result[orgId] = decrypted;
}

View File

@ -6,7 +6,6 @@ import { ProviderKey, UserPrivateKey } from "../../../types/key";
import { EncryptService } from "../../abstractions/encrypt.service";
import { EncryptedString } from "../../models/domain/enc-string";
import { SymmetricCryptoKey } from "../../models/domain/symmetric-crypto-key";
import { CryptoService } from "../crypto.service";
import { USER_ENCRYPTED_PROVIDER_KEYS, USER_PROVIDER_KEYS } from "./provider-keys.state";
@ -27,7 +26,6 @@ describe("encrypted provider keys", () => {
describe("derived decrypted provider keys", () => {
const encryptService = mock<EncryptService>();
const cryptoService = mock<CryptoService>();
const userPrivateKey = makeStaticByteArray(64, 0) as UserPrivateKey;
const sut = USER_PROVIDER_KEYS;
@ -59,9 +57,8 @@ describe("derived decrypted provider keys", () => {
encryptService.rsaDecrypt.mockResolvedValueOnce(decryptedProviderKeys["provider-id-1"].key);
encryptService.rsaDecrypt.mockResolvedValueOnce(decryptedProviderKeys["provider-id-2"].key);
cryptoService.getPrivateKey.mockResolvedValueOnce(userPrivateKey);
const result = await sut.derive(encryptedProviderKeys, { encryptService, cryptoService });
const result = await sut.derive([encryptedProviderKeys, userPrivateKey], { encryptService });
expect(result).toEqual(decryptedProviderKeys);
});
@ -69,7 +66,7 @@ describe("derived decrypted provider keys", () => {
it("should handle null input values", async () => {
const encryptedProviderKeys: Record<ProviderId, EncryptedString> = null;
const result = await sut.derive(encryptedProviderKeys, { encryptService, cryptoService });
const result = await sut.derive([encryptedProviderKeys, userPrivateKey], { encryptService });
expect(result).toEqual({});
});

View File

@ -1,10 +1,9 @@
import { ProviderId } from "../../../types/guid";
import { ProviderKey } from "../../../types/key";
import { ProviderKey, UserPrivateKey } from "../../../types/key";
import { EncryptService } from "../../abstractions/encrypt.service";
import { EncString, EncryptedString } from "../../models/domain/enc-string";
import { SymmetricCryptoKey } from "../../models/domain/symmetric-crypto-key";
import { CRYPTO_DISK, DeriveDefinition, UserKeyDefinition } from "../../state";
import { CryptoService } from "../crypto.service";
import { CRYPTO_DISK, CRYPTO_MEMORY, DeriveDefinition, UserKeyDefinition } from "../../state";
export const USER_ENCRYPTED_PROVIDER_KEYS = UserKeyDefinition.record<EncryptedString, ProviderId>(
CRYPTO_DISK,
@ -15,11 +14,11 @@ export const USER_ENCRYPTED_PROVIDER_KEYS = UserKeyDefinition.record<EncryptedSt
},
);
export const USER_PROVIDER_KEYS = DeriveDefinition.from<
Record<ProviderId, EncryptedString>,
export const USER_PROVIDER_KEYS = new DeriveDefinition<
[Record<ProviderId, EncryptedString>, UserPrivateKey],
Record<ProviderId, ProviderKey>,
{ encryptService: EncryptService; cryptoService: CryptoService } // TODO: This should depend on an active user private key observable directly
>(USER_ENCRYPTED_PROVIDER_KEYS, {
{ encryptService: EncryptService }
>(CRYPTO_MEMORY, "providerKeys", {
deserializer: (obj) => {
const result: Record<ProviderId, ProviderKey> = {};
for (const providerId of Object.keys(obj ?? {}) as ProviderId[]) {
@ -27,14 +26,13 @@ export const USER_PROVIDER_KEYS = DeriveDefinition.from<
}
return result;
},
derive: async (from, { encryptService, cryptoService }) => {
derive: async ([encryptedProviderKeys, privateKey], { encryptService }) => {
const result: Record<ProviderId, ProviderKey> = {};
for (const providerId of Object.keys(from ?? {}) as ProviderId[]) {
for (const providerId of Object.keys(encryptedProviderKeys ?? {}) as ProviderId[]) {
if (result[providerId] != null) {
continue;
}
const encrypted = new EncString(from[providerId]);
const privateKey = await cryptoService.getPrivateKey();
const encrypted = new EncString(encryptedProviderKeys[providerId]);
const decrypted = await encryptService.rsaDecrypt(encrypted, privateKey);
const providerKey = new SymmetricCryptoKey(decrypted) as ProviderKey;

View File

@ -1,7 +1,6 @@
import { mock } from "jest-mock-extended";
import { makeStaticByteArray } from "../../../../spec";
import { UserId } from "../../../types/guid";
import { UserKey, UserPrivateKey, UserPublicKey } from "../../../types/key";
import { CryptoFunctionService } from "../../abstractions/crypto-function.service";
import { EncryptService } from "../../abstractions/encrypt.service";
@ -70,7 +69,6 @@ describe("User public key", () => {
describe("Derived decrypted private key", () => {
const sut = USER_PRIVATE_KEY;
const userId = "userId" as UserId;
const userKey = mock<UserKey>();
const encryptedPrivateKey = makeEncString().encryptedString;
const decryptedPrivateKey = makeStaticByteArray(64, 1);
@ -88,37 +86,31 @@ describe("Derived decrypted private key", () => {
});
it("should derive decrypted private key", async () => {
const getUserKey = jest.fn(async () => userKey);
const encryptService = mock<EncryptService>();
encryptService.decryptToBytes.mockResolvedValue(decryptedPrivateKey);
const result = await sut.derive([userId, encryptedPrivateKey], {
const result = await sut.derive([encryptedPrivateKey, userKey], {
encryptService,
getUserKey,
});
expect(result).toEqual(decryptedPrivateKey);
});
it("should handle null input values", async () => {
const getUserKey = jest.fn(async () => userKey);
it("should handle null encryptedPrivateKey", async () => {
const encryptService = mock<EncryptService>();
const result = await sut.derive([userId, null], {
const result = await sut.derive([null, userKey], {
encryptService,
getUserKey,
});
expect(result).toEqual(null);
});
it("should handle null user key", async () => {
const getUserKey = jest.fn(async () => null);
it("should handle null userKey", async () => {
const encryptService = mock<EncryptService>();
const result = await sut.derive([userId, encryptedPrivateKey], {
const result = await sut.derive([encryptedPrivateKey, null], {
encryptService,
getUserKey,
});
expect(result).toEqual(null);

View File

@ -1,4 +1,3 @@
import { UserId } from "../../../types/guid";
import { UserPrivateKey, UserPublicKey, UserKey } from "../../../types/key";
import { CryptoFunctionService } from "../../abstractions/crypto-function.service";
import { EncryptService } from "../../abstractions/encrypt.service";
@ -24,20 +23,14 @@ export const USER_ENCRYPTED_PRIVATE_KEY = new UserKeyDefinition<EncryptedString>
},
);
export const USER_PRIVATE_KEY = DeriveDefinition.fromWithUserId<
EncryptedString,
export const USER_PRIVATE_KEY = new DeriveDefinition<
[EncryptedString, UserKey],
UserPrivateKey,
// TODO: update cryptoService to user key directly
{ encryptService: EncryptService; getUserKey: (userId: UserId) => Promise<UserKey> }
>(USER_ENCRYPTED_PRIVATE_KEY, {
{ encryptService: EncryptService }
>(CRYPTO_MEMORY, "privateKey", {
deserializer: (obj) => new Uint8Array(Object.values(obj)) as UserPrivateKey,
derive: async ([userId, encPrivateKeyString], { encryptService, getUserKey }) => {
if (encPrivateKeyString == null) {
return null;
}
const userKey = await getUserKey(userId);
if (userKey == null) {
derive: async ([encPrivateKeyString, userKey], { encryptService }) => {
if (encPrivateKeyString == null || userKey == null) {
return null;
}
@ -64,6 +57,7 @@ export const USER_PUBLIC_KEY = DeriveDefinition.from<
return (await cryptoFunctionService.rsaExtractPublicKey(privateKey)) as UserPublicKey;
},
});
export const USER_KEY = new UserKeyDefinition<UserKey>(CRYPTO_MEMORY, "userKey", {
deserializer: (obj) => SymmetricCryptoKey.fromJSON(obj) as UserKey,
clearOn: ["logout", "lock"],