[PM-6404] Initial Clear Events Code (#8029)

* Add New KeyDefinitionOption

* Add New Services

* Add WebStorageServiceProvider Tests

* Update Error Message

* Add `UserKeyDefinition`

* Fix Deserialization Helpers

* Fix KeyDefinition

* Add `UserKeyDefinition`

* Fix Deserialization Helpers

* Fix KeyDefinition

* Move `ClearEvent`

* Cleanup

* Fix Imports

* Remove `updateMock`

* Call Super in Web Implementation

* Use Better Type to Avoid Casting

* Better Error Docs

* Move StorageKey Creation to Function

* Throw Aggregated Error for Failures
This commit is contained in:
Justin Baur 2024-02-27 15:58:31 -06:00 committed by GitHub
parent 929b5ebec3
commit 87c75e5ac8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 553 additions and 10 deletions

View File

@ -0,0 +1,36 @@
// eslint-disable-next-line import/no-restricted-paths
import { StateEventRegistrarService } from "@bitwarden/common/platform/state/state-event-registrar.service";
import { CachedServices, FactoryOptions, factory } from "./factory-options";
import {
GlobalStateProviderInitOptions,
globalStateProviderFactory,
} from "./global-state-provider.factory";
import {
StorageServiceProviderInitOptions,
storageServiceProviderFactory,
} from "./storage-service-provider.factory";
type StateEventRegistrarServiceFactoryOptions = FactoryOptions;
export type StateEventRegistrarServiceInitOptions = StateEventRegistrarServiceFactoryOptions &
GlobalStateProviderInitOptions &
StorageServiceProviderInitOptions;
export async function stateEventRegistrarServiceFactory(
cache: {
stateEventRegistrarService?: StateEventRegistrarService;
} & CachedServices,
opts: StateEventRegistrarServiceInitOptions,
): Promise<StateEventRegistrarService> {
return factory(
cache,
"stateEventRegistrarService",
opts,
async () =>
new StateEventRegistrarService(
await globalStateProviderFactory(cache, opts),
await storageServiceProviderFactory(cache, opts),
),
);
}

View File

@ -0,0 +1,33 @@
import { StorageServiceProvider } from "@bitwarden/common/platform/services/storage-service.provider";
import { CachedServices, FactoryOptions, factory } from "./factory-options";
import {
DiskStorageServiceInitOptions,
MemoryStorageServiceInitOptions,
observableDiskStorageServiceFactory,
observableMemoryStorageServiceFactory,
} from "./storage-service.factory";
type StorageServiceProviderFactoryOptions = FactoryOptions;
export type StorageServiceProviderInitOptions = StorageServiceProviderFactoryOptions &
MemoryStorageServiceInitOptions &
DiskStorageServiceInitOptions;
export async function storageServiceProviderFactory(
cache: {
storageServiceProvider?: StorageServiceProvider;
} & CachedServices,
opts: StorageServiceProviderInitOptions,
): Promise<StorageServiceProvider> {
return factory(
cache,
"storageServiceProvider",
opts,
async () =>
new StorageServiceProvider(
await observableDiskStorageServiceFactory(cache, opts),
await observableMemoryStorageServiceFactory(cache, opts),
),
);
}

View File

@ -0,0 +1,63 @@
import { mock } from "jest-mock-extended";
import {
AbstractStorageService,
ObservableStorageService,
} from "@bitwarden/common/platform/abstractions/storage.service";
import { PossibleLocation } from "@bitwarden/common/platform/services/storage-service.provider";
import {
ClientLocations,
StorageLocation,
// eslint-disable-next-line import/no-restricted-paths
} from "@bitwarden/common/platform/state/state-definition";
import { WebStorageServiceProvider } from "./web-storage-service.provider";
describe("WebStorageServiceProvider", () => {
const mockDiskStorage = mock<AbstractStorageService & ObservableStorageService>();
const mockMemoryStorage = mock<AbstractStorageService & ObservableStorageService>();
const mockDiskLocalStorage = mock<AbstractStorageService & ObservableStorageService>();
const sut = new WebStorageServiceProvider(
mockDiskStorage,
mockMemoryStorage,
mockDiskLocalStorage,
);
describe("get", () => {
const getTests = [
{
input: { default: "disk", overrides: {} },
expected: "disk",
},
{
input: { default: "memory", overrides: {} },
expected: "memory",
},
{
input: { default: "disk", overrides: { web: "disk-local" } },
expected: "disk-local",
},
{
input: { default: "disk", overrides: { web: "memory" } },
expected: "memory",
},
{
input: { default: "memory", overrides: { web: "disk" } },
expected: "disk",
},
] satisfies {
input: { default: StorageLocation; overrides: Partial<ClientLocations> };
expected: PossibleLocation;
}[];
it.each(getTests)("computes properly based on %s", ({ input, expected: expectedLocation }) => {
const [actualLocation] = sut.get(input.default, input.overrides);
expect(actualLocation).toStrictEqual(expectedLocation);
});
it("throws on unsupported option", () => {
expect(() => sut.get("blah" as any, {})).toThrow();
});
});
});

View File

@ -0,0 +1,37 @@
import {
AbstractStorageService,
ObservableStorageService,
} from "@bitwarden/common/platform/abstractions/storage.service";
import {
PossibleLocation,
StorageServiceProvider,
} from "@bitwarden/common/platform/services/storage-service.provider";
import {
ClientLocations,
// eslint-disable-next-line import/no-restricted-paths
} from "@bitwarden/common/platform/state/state-definition";
export class WebStorageServiceProvider extends StorageServiceProvider {
constructor(
diskStorageService: AbstractStorageService & ObservableStorageService,
memoryStorageService: AbstractStorageService & ObservableStorageService,
private readonly diskLocalStorageService: AbstractStorageService & ObservableStorageService,
) {
super(diskStorageService, memoryStorageService);
}
override get(
defaultLocation: PossibleLocation,
overrides: Partial<ClientLocations>,
): [location: PossibleLocation, service: AbstractStorageService & ObservableStorageService] {
const location = overrides["web"] ?? defaultLocation;
switch (location) {
case "disk-local":
return ["disk-local", this.diskLocalStorageService];
default:
// Pass in computed location to super because they could have
// overriden default "disk" with web "memory".
return super.get(location, overrides);
}
}
}

View File

@ -41,10 +41,10 @@ export class FakeGlobalState<T> implements GlobalState<T> {
this.stateSubject.next(initialValue ?? null); this.stateSubject.next(initialValue ?? null);
} }
update: <TCombine>( async update<TCombine>(
configureState: (state: T, dependency: TCombine) => T, configureState: (state: T, dependency: TCombine) => T,
options?: StateUpdateOptions<T, TCombine>, options?: StateUpdateOptions<T, TCombine>,
) => Promise<T> = jest.fn(async (configureState, options) => { ): Promise<T> {
options = populateOptionsWithDefault(options); options = populateOptionsWithDefault(options);
if (this.stateSubject["_buffer"].length == 0) { if (this.stateSubject["_buffer"].length == 0) {
// throw a more helpful not initialized error // throw a more helpful not initialized error
@ -64,9 +64,7 @@ export class FakeGlobalState<T> implements GlobalState<T> {
this.stateSubject.next(newState); this.stateSubject.next(newState);
this.nextMock(newState); this.nextMock(newState);
return newState; return newState;
}); }
updateMock = this.update as jest.MockedFunction<typeof this.update>;
/** Tracks update values resolved by `FakeState.update` */ /** Tracks update values resolved by `FakeState.update` */
nextMock = jest.fn<void, [T]>(); nextMock = jest.fn<void, [T]>();
@ -128,8 +126,6 @@ export class FakeSingleUserState<T> implements SingleUserState<T> {
return newState; return newState;
} }
updateMock = this.update as jest.MockedFunction<typeof this.update>;
/** Tracks update values resolved by `FakeState.update` */ /** Tracks update values resolved by `FakeState.update` */
nextMock = jest.fn<void, [T]>(); nextMock = jest.fn<void, [T]>();
private _keyDefinition: UserKeyDefinition<T> | null = null; private _keyDefinition: UserKeyDefinition<T> | null = null;
@ -191,8 +187,6 @@ export class FakeActiveUserState<T> implements ActiveUserState<T> {
return [this.userId, newState]; return [this.userId, newState];
} }
updateMock = this.update as jest.MockedFunction<typeof this.update>;
/** Tracks update values resolved by `FakeState.update` */ /** Tracks update values resolved by `FakeState.update` */
nextMock = jest.fn<void, [[UserId, T]]>(); nextMock = jest.fn<void, [[UserId, T]]>();

View File

@ -0,0 +1,28 @@
import { mock } from "jest-mock-extended";
import { AbstractStorageService, ObservableStorageService } from "../abstractions/storage.service";
import { StorageServiceProvider } from "./storage-service.provider";
describe("StorageServiceProvider", () => {
const mockDiskStorage = mock<AbstractStorageService & ObservableStorageService>();
const mockMemoryStorage = mock<AbstractStorageService & ObservableStorageService>();
const sut = new StorageServiceProvider(mockDiskStorage, mockMemoryStorage);
describe("get", () => {
it("gets disk service when default location is disk", () => {
const [computedLocation, computedService] = sut.get("disk", {});
expect(computedLocation).toBe("disk");
expect(computedService).toStrictEqual(mockDiskStorage);
});
it("gets memory service when default location is memory", () => {
const [computedLocation, computedService] = sut.get("memory", {});
expect(computedLocation).toBe("memory");
expect(computedService).toStrictEqual(mockMemoryStorage);
});
});
});

View File

@ -0,0 +1,39 @@
import { AbstractStorageService, ObservableStorageService } from "../abstractions/storage.service";
// eslint-disable-next-line import/no-restricted-paths
import { ClientLocations, StorageLocation } from "../state/state-definition";
export type PossibleLocation = StorageLocation | ClientLocations[keyof ClientLocations];
/**
* A provider for getting client specific computed storage locations and services.
*/
export class StorageServiceProvider {
constructor(
protected readonly diskStorageService: AbstractStorageService & ObservableStorageService,
protected readonly memoryStorageService: AbstractStorageService & ObservableStorageService,
) {}
/**
* Computes the location and corresponding service for a given client.
*
* **NOTE** The default implementation does not respect client overrides and if clients
* have special overrides they are responsible for implementing this service.
* @param defaultLocation The default location to use if no client specific override is preferred.
* @param overrides Client specific overrides
* @returns The computed storage location and corresponding storage service to use to get/store state.
* @throws If there is no configured storage service for the given inputs.
*/
get(
defaultLocation: PossibleLocation,
overrides: Partial<ClientLocations>,
): [location: PossibleLocation, service: AbstractStorageService & ObservableStorageService] {
switch (defaultLocation) {
case "disk":
return [defaultLocation, this.diskStorageService];
case "memory":
return [defaultLocation, this.memoryStorageService];
default:
throw new Error(`Unexpected location: ${defaultLocation}`);
}
}
}

View File

@ -9,5 +9,6 @@ export { ActiveUserStateProvider, SingleUserStateProvider } from "./user-state.p
export { KeyDefinition } from "./key-definition"; export { KeyDefinition } from "./key-definition";
export { StateUpdateOptions } from "./state-update-options"; export { StateUpdateOptions } from "./state-update-options";
export { UserKeyDefinition } from "./user-key-definition"; export { UserKeyDefinition } from "./user-key-definition";
export { StateEventRunnerService } from "./state-event-runner.service";
export * from "./state-definitions"; export * from "./state-definitions";

View File

@ -65,6 +65,8 @@ export const VAULT_FILTER_DISK = new StateDefinition("vaultFilter", "disk", {
web: "disk-local", web: "disk-local",
}); });
export const CLEAR_EVENT_DISK = new StateDefinition("clearEvent", "disk");
export const NEW_WEB_LAYOUT_BANNER_DISK = new StateDefinition("newWebLayoutBanner", "disk", { export const NEW_WEB_LAYOUT_BANNER_DISK = new StateDefinition("newWebLayoutBanner", "disk", {
web: "disk-local", web: "disk-local",
}); });

View File

@ -0,0 +1,85 @@
import { mock } from "jest-mock-extended";
import { FakeGlobalStateProvider } from "../../../spec";
import { AbstractStorageService, ObservableStorageService } from "../abstractions/storage.service";
import { StorageServiceProvider } from "../services/storage-service.provider";
import { StateDefinition } from "./state-definition";
import { STATE_LOCK_EVENT, StateEventRegistrarService } from "./state-event-registrar.service";
import { UserKeyDefinition } from "./user-key-definition";
describe("StateEventRegistrarService", () => {
const globalStateProvider = new FakeGlobalStateProvider();
const lockState = globalStateProvider.getFake(STATE_LOCK_EVENT);
const storageServiceProvider = mock<StorageServiceProvider>();
const sut = new StateEventRegistrarService(globalStateProvider, storageServiceProvider);
describe("registerEvents", () => {
const fakeKeyDefinition = new UserKeyDefinition<boolean>(
new StateDefinition("fakeState", "disk"),
"fakeKey",
{
deserializer: (s) => s,
clearOn: ["lock"],
},
);
beforeEach(() => {
jest.resetAllMocks();
});
it("adds event on null storage", async () => {
storageServiceProvider.get.mockReturnValue([
"disk",
mock<AbstractStorageService & ObservableStorageService>(),
]);
await sut.registerEvents(fakeKeyDefinition);
expect(lockState.nextMock).toHaveBeenCalledWith([
{
key: "fakeKey",
location: "disk",
state: "fakeState",
},
]);
});
it("adds event on empty array in storage", async () => {
lockState.stateSubject.next([]);
storageServiceProvider.get.mockReturnValue([
"disk",
mock<AbstractStorageService & ObservableStorageService>(),
]);
await sut.registerEvents(fakeKeyDefinition);
expect(lockState.nextMock).toHaveBeenCalledWith([
{
key: "fakeKey",
location: "disk",
state: "fakeState",
},
]);
});
it("doesn't add a duplicate", async () => {
lockState.stateSubject.next([
{
key: "fakeKey",
location: "disk",
state: "fakeState",
},
]);
storageServiceProvider.get.mockReturnValue([
"disk",
mock<AbstractStorageService & ObservableStorageService>(),
]);
await sut.registerEvents(fakeKeyDefinition);
expect(lockState.nextMock).not.toHaveBeenCalled();
});
});
});

View File

@ -0,0 +1,76 @@
import { PossibleLocation, StorageServiceProvider } from "../services/storage-service.provider";
import { GlobalState } from "./global-state";
import { GlobalStateProvider } from "./global-state.provider";
import { KeyDefinition } from "./key-definition";
import { CLEAR_EVENT_DISK } from "./state-definitions";
import { ClearEvent, UserKeyDefinition } from "./user-key-definition";
export type StateEventInfo = {
state: string;
key: string;
location: PossibleLocation;
};
export const STATE_LOCK_EVENT = KeyDefinition.array<StateEventInfo>(CLEAR_EVENT_DISK, "lock", {
deserializer: (e) => e,
});
export const STATE_LOGOUT_EVENT = KeyDefinition.array<StateEventInfo>(CLEAR_EVENT_DISK, "logout", {
deserializer: (e) => e,
});
export class StateEventRegistrarService {
private readonly stateEventStateMap: { [Prop in ClearEvent]: GlobalState<StateEventInfo[]> };
constructor(
globalStateProvider: GlobalStateProvider,
private storageServiceProvider: StorageServiceProvider,
) {
this.stateEventStateMap = {
lock: globalStateProvider.get(STATE_LOCK_EVENT),
logout: globalStateProvider.get(STATE_LOGOUT_EVENT),
};
}
async registerEvents(keyDefinition: UserKeyDefinition<unknown>) {
for (const clearEvent of keyDefinition.clearOn) {
const eventState = this.stateEventStateMap[clearEvent];
// Determine the storage location for this
const [storageLocation] = this.storageServiceProvider.get(
keyDefinition.stateDefinition.defaultStorageLocation,
keyDefinition.stateDefinition.storageLocationOverrides,
);
const newEvent: StateEventInfo = {
state: keyDefinition.stateDefinition.name,
key: keyDefinition.key,
location: storageLocation,
};
// Only update the event state if the existing list doesn't have a matching entry
await eventState.update(
(existingTickets) => {
existingTickets ??= [];
existingTickets.push(newEvent);
return existingTickets;
},
{
shouldUpdate: (currentTickets) => {
return (
// If the current tickets are null, then it will for sure be added
currentTickets == null ||
// If an existing match couldn't be found, we also need to add one
currentTickets.findIndex(
(e) =>
e.state === newEvent.state &&
e.key === newEvent.key &&
e.location === newEvent.location,
) === -1
);
},
},
);
}
}
}

View File

@ -0,0 +1,69 @@
import { mock } from "jest-mock-extended";
import { FakeGlobalStateProvider } from "../../../spec";
import { UserId } from "../../types/guid";
import { AbstractStorageService, ObservableStorageService } from "../abstractions/storage.service";
import { StorageServiceProvider } from "../services/storage-service.provider";
import { STATE_LOCK_EVENT } from "./state-event-registrar.service";
import { StateEventRunnerService } from "./state-event-runner.service";
describe("EventRunnerService", () => {
const fakeGlobalStateProvider = new FakeGlobalStateProvider();
const lockState = fakeGlobalStateProvider.getFake(STATE_LOCK_EVENT);
const storageServiceProvider = mock<StorageServiceProvider>();
const sut = new StateEventRunnerService(fakeGlobalStateProvider, storageServiceProvider);
describe("handleEvent", () => {
it("does nothing if there are no events in state", async () => {
const mockStorageService = mock<AbstractStorageService & ObservableStorageService>();
storageServiceProvider.get.mockReturnValue(["disk", mockStorageService]);
await sut.handleEvent("lock", "bff09d3c-762a-4551-9275-45b137b2f073" as UserId);
expect(lockState.nextMock).not.toHaveBeenCalled();
});
it("loops through and acts on all events", async () => {
const mockDiskStorageService = mock<AbstractStorageService & ObservableStorageService>();
const mockMemoryStorageService = mock<AbstractStorageService & ObservableStorageService>();
lockState.stateSubject.next([
{
state: "fakeState1",
key: "fakeKey1",
location: "disk",
},
{
state: "fakeState2",
key: "fakeKey2",
location: "memory",
},
]);
storageServiceProvider.get.mockImplementation((defaultLocation, overrides) => {
if (defaultLocation === "disk") {
return [defaultLocation, mockDiskStorageService];
} else if (defaultLocation === "memory") {
return [defaultLocation, mockMemoryStorageService];
}
});
mockMemoryStorageService.get.mockResolvedValue("something");
await sut.handleEvent("lock", "bff09d3c-762a-4551-9275-45b137b2f073" as UserId);
expect(mockDiskStorageService.get).toHaveBeenCalledTimes(1);
expect(mockDiskStorageService.get).toHaveBeenCalledWith(
"user_bff09d3c-762a-4551-9275-45b137b2f073_fakeState1_fakeKey1",
);
expect(mockMemoryStorageService.get).toHaveBeenCalledTimes(1);
expect(mockMemoryStorageService.get).toHaveBeenCalledWith(
"user_bff09d3c-762a-4551-9275-45b137b2f073_fakeState2_fakeKey2",
);
expect(mockMemoryStorageService.remove).toHaveBeenCalledTimes(1);
});
});
});

View File

@ -0,0 +1,80 @@
import { firstValueFrom } from "rxjs";
import { UserId } from "../../types/guid";
import { StorageServiceProvider } from "../services/storage-service.provider";
import { GlobalState } from "./global-state";
import { GlobalStateProvider } from "./global-state.provider";
import { StateDefinition, StorageLocation } from "./state-definition";
import {
STATE_LOCK_EVENT,
STATE_LOGOUT_EVENT,
StateEventInfo,
} from "./state-event-registrar.service";
import { ClearEvent, UserKeyDefinition } from "./user-key-definition";
export class StateEventRunnerService {
private readonly stateEventMap: { [Prop in ClearEvent]: GlobalState<StateEventInfo[]> };
constructor(
globalStateProvider: GlobalStateProvider,
private storageServiceProvider: StorageServiceProvider,
) {
this.stateEventMap = {
lock: globalStateProvider.get(STATE_LOCK_EVENT),
logout: globalStateProvider.get(STATE_LOGOUT_EVENT),
};
}
async handleEvent(event: ClearEvent, userId: UserId) {
let tickets = await firstValueFrom(this.stateEventMap[event].state$);
tickets ??= [];
const failures: string[] = [];
for (const ticket of tickets) {
try {
const [, service] = this.storageServiceProvider.get(
ticket.location,
{}, // The storage location is already the computed storage location for this client
);
const ticketStorageKey = this.storageKeyFor(userId, ticket);
// Evaluate current value so we can avoid writing to state if we don't need to
const currentValue = await service.get(ticketStorageKey);
if (currentValue != null) {
await service.remove(ticketStorageKey);
}
} catch (err: unknown) {
let errorMessage = "Unknown Error";
if (typeof err === "object" && "message" in err && typeof err.message === "string") {
errorMessage = err.message;
}
failures.push(
`${errorMessage} in ${ticket.state} > ${ticket.key} located ${ticket.location}`,
);
}
}
if (failures.length > 0) {
// Throw aggregated error
throw new Error(
`One or more errors occurred while handling event '${event}' for user ${userId}.\n${failures.join("\n")}`,
);
}
}
private storageKeyFor(userId: UserId, ticket: StateEventInfo) {
const userKey = new UserKeyDefinition<unknown>(
new StateDefinition(ticket.state, ticket.location as unknown as StorageLocation),
ticket.key,
{
deserializer: (v) => v,
clearOn: [],
},
);
return userKey.buildKey(userId);
}
}

View File

@ -6,7 +6,7 @@ import { array, record } from "./deserialization-helpers";
import { KeyDefinition, KeyDefinitionOptions } from "./key-definition"; import { KeyDefinition, KeyDefinitionOptions } from "./key-definition";
import { StateDefinition } from "./state-definition"; import { StateDefinition } from "./state-definition";
type ClearEvent = "lock" | "logout"; export type ClearEvent = "lock" | "logout";
type UserKeyDefinitionOptions<T> = KeyDefinitionOptions<T> & { type UserKeyDefinitionOptions<T> = KeyDefinitionOptions<T> & {
clearOn: ClearEvent[]; clearOn: ClearEvent[];