diff --git a/apps/browser/src/platform/background/service-factories/state-event-registrar-service.factory.ts b/apps/browser/src/platform/background/service-factories/state-event-registrar-service.factory.ts new file mode 100644 index 0000000000..ca203a810b --- /dev/null +++ b/apps/browser/src/platform/background/service-factories/state-event-registrar-service.factory.ts @@ -0,0 +1,36 @@ +// eslint-disable-next-line import/no-restricted-paths +import { StateEventRegistrarService } from "@bitwarden/common/platform/state/state-event-registrar.service"; + +import { CachedServices, FactoryOptions, factory } from "./factory-options"; +import { + GlobalStateProviderInitOptions, + globalStateProviderFactory, +} from "./global-state-provider.factory"; +import { + StorageServiceProviderInitOptions, + storageServiceProviderFactory, +} from "./storage-service-provider.factory"; + +type StateEventRegistrarServiceFactoryOptions = FactoryOptions; + +export type StateEventRegistrarServiceInitOptions = StateEventRegistrarServiceFactoryOptions & + GlobalStateProviderInitOptions & + StorageServiceProviderInitOptions; + +export async function stateEventRegistrarServiceFactory( + cache: { + stateEventRegistrarService?: StateEventRegistrarService; + } & CachedServices, + opts: StateEventRegistrarServiceInitOptions, +): Promise { + return factory( + cache, + "stateEventRegistrarService", + opts, + async () => + new StateEventRegistrarService( + await globalStateProviderFactory(cache, opts), + await storageServiceProviderFactory(cache, opts), + ), + ); +} diff --git a/apps/browser/src/platform/background/service-factories/storage-service-provider.factory.ts b/apps/browser/src/platform/background/service-factories/storage-service-provider.factory.ts new file mode 100644 index 0000000000..8a2ddeb9e8 --- /dev/null +++ b/apps/browser/src/platform/background/service-factories/storage-service-provider.factory.ts @@ -0,0 +1,33 @@ +import { StorageServiceProvider } from "@bitwarden/common/platform/services/storage-service.provider"; + +import { CachedServices, FactoryOptions, factory } from "./factory-options"; +import { + DiskStorageServiceInitOptions, + MemoryStorageServiceInitOptions, + observableDiskStorageServiceFactory, + observableMemoryStorageServiceFactory, +} from "./storage-service.factory"; + +type StorageServiceProviderFactoryOptions = FactoryOptions; + +export type StorageServiceProviderInitOptions = StorageServiceProviderFactoryOptions & + MemoryStorageServiceInitOptions & + DiskStorageServiceInitOptions; + +export async function storageServiceProviderFactory( + cache: { + storageServiceProvider?: StorageServiceProvider; + } & CachedServices, + opts: StorageServiceProviderInitOptions, +): Promise { + return factory( + cache, + "storageServiceProvider", + opts, + async () => + new StorageServiceProvider( + await observableDiskStorageServiceFactory(cache, opts), + await observableMemoryStorageServiceFactory(cache, opts), + ), + ); +} diff --git a/apps/web/src/app/platform/web-storage-service.provider.spec.ts b/apps/web/src/app/platform/web-storage-service.provider.spec.ts new file mode 100644 index 0000000000..53f047d137 --- /dev/null +++ b/apps/web/src/app/platform/web-storage-service.provider.spec.ts @@ -0,0 +1,63 @@ +import { mock } from "jest-mock-extended"; + +import { + AbstractStorageService, + ObservableStorageService, +} from "@bitwarden/common/platform/abstractions/storage.service"; +import { PossibleLocation } from "@bitwarden/common/platform/services/storage-service.provider"; +import { + ClientLocations, + StorageLocation, + // eslint-disable-next-line import/no-restricted-paths +} from "@bitwarden/common/platform/state/state-definition"; + +import { WebStorageServiceProvider } from "./web-storage-service.provider"; + +describe("WebStorageServiceProvider", () => { + const mockDiskStorage = mock(); + const mockMemoryStorage = mock(); + const mockDiskLocalStorage = mock(); + + const sut = new WebStorageServiceProvider( + mockDiskStorage, + mockMemoryStorage, + mockDiskLocalStorage, + ); + + describe("get", () => { + const getTests = [ + { + input: { default: "disk", overrides: {} }, + expected: "disk", + }, + { + input: { default: "memory", overrides: {} }, + expected: "memory", + }, + { + input: { default: "disk", overrides: { web: "disk-local" } }, + expected: "disk-local", + }, + { + input: { default: "disk", overrides: { web: "memory" } }, + expected: "memory", + }, + { + input: { default: "memory", overrides: { web: "disk" } }, + expected: "disk", + }, + ] satisfies { + input: { default: StorageLocation; overrides: Partial }; + expected: PossibleLocation; + }[]; + + it.each(getTests)("computes properly based on %s", ({ input, expected: expectedLocation }) => { + const [actualLocation] = sut.get(input.default, input.overrides); + expect(actualLocation).toStrictEqual(expectedLocation); + }); + + it("throws on unsupported option", () => { + expect(() => sut.get("blah" as any, {})).toThrow(); + }); + }); +}); diff --git a/apps/web/src/app/platform/web-storage-service.provider.ts b/apps/web/src/app/platform/web-storage-service.provider.ts new file mode 100644 index 0000000000..da9a851785 --- /dev/null +++ b/apps/web/src/app/platform/web-storage-service.provider.ts @@ -0,0 +1,37 @@ +import { + AbstractStorageService, + ObservableStorageService, +} from "@bitwarden/common/platform/abstractions/storage.service"; +import { + PossibleLocation, + StorageServiceProvider, +} from "@bitwarden/common/platform/services/storage-service.provider"; +import { + ClientLocations, + // eslint-disable-next-line import/no-restricted-paths +} from "@bitwarden/common/platform/state/state-definition"; + +export class WebStorageServiceProvider extends StorageServiceProvider { + constructor( + diskStorageService: AbstractStorageService & ObservableStorageService, + memoryStorageService: AbstractStorageService & ObservableStorageService, + private readonly diskLocalStorageService: AbstractStorageService & ObservableStorageService, + ) { + super(diskStorageService, memoryStorageService); + } + + override get( + defaultLocation: PossibleLocation, + overrides: Partial, + ): [location: PossibleLocation, service: AbstractStorageService & ObservableStorageService] { + const location = overrides["web"] ?? defaultLocation; + switch (location) { + case "disk-local": + return ["disk-local", this.diskLocalStorageService]; + default: + // Pass in computed location to super because they could have + // overriden default "disk" with web "memory". + return super.get(location, overrides); + } + } +} diff --git a/libs/common/spec/fake-state.ts b/libs/common/spec/fake-state.ts index 5ec891a851..6b03ef6ef8 100644 --- a/libs/common/spec/fake-state.ts +++ b/libs/common/spec/fake-state.ts @@ -41,10 +41,10 @@ export class FakeGlobalState implements GlobalState { this.stateSubject.next(initialValue ?? null); } - update: ( + async update( configureState: (state: T, dependency: TCombine) => T, options?: StateUpdateOptions, - ) => Promise = jest.fn(async (configureState, options) => { + ): Promise { options = populateOptionsWithDefault(options); if (this.stateSubject["_buffer"].length == 0) { // throw a more helpful not initialized error @@ -64,9 +64,7 @@ export class FakeGlobalState implements GlobalState { this.stateSubject.next(newState); this.nextMock(newState); return newState; - }); - - updateMock = this.update as jest.MockedFunction; + } /** Tracks update values resolved by `FakeState.update` */ nextMock = jest.fn(); @@ -128,8 +126,6 @@ export class FakeSingleUserState implements SingleUserState { return newState; } - updateMock = this.update as jest.MockedFunction; - /** Tracks update values resolved by `FakeState.update` */ nextMock = jest.fn(); private _keyDefinition: UserKeyDefinition | null = null; @@ -191,8 +187,6 @@ export class FakeActiveUserState implements ActiveUserState { return [this.userId, newState]; } - updateMock = this.update as jest.MockedFunction; - /** Tracks update values resolved by `FakeState.update` */ nextMock = jest.fn(); diff --git a/libs/common/src/platform/services/storage-service.provider.spec.ts b/libs/common/src/platform/services/storage-service.provider.spec.ts new file mode 100644 index 0000000000..35f45064d4 --- /dev/null +++ b/libs/common/src/platform/services/storage-service.provider.spec.ts @@ -0,0 +1,28 @@ +import { mock } from "jest-mock-extended"; + +import { AbstractStorageService, ObservableStorageService } from "../abstractions/storage.service"; + +import { StorageServiceProvider } from "./storage-service.provider"; + +describe("StorageServiceProvider", () => { + const mockDiskStorage = mock(); + const mockMemoryStorage = mock(); + + const sut = new StorageServiceProvider(mockDiskStorage, mockMemoryStorage); + + describe("get", () => { + it("gets disk service when default location is disk", () => { + const [computedLocation, computedService] = sut.get("disk", {}); + + expect(computedLocation).toBe("disk"); + expect(computedService).toStrictEqual(mockDiskStorage); + }); + + it("gets memory service when default location is memory", () => { + const [computedLocation, computedService] = sut.get("memory", {}); + + expect(computedLocation).toBe("memory"); + expect(computedService).toStrictEqual(mockMemoryStorage); + }); + }); +}); diff --git a/libs/common/src/platform/services/storage-service.provider.ts b/libs/common/src/platform/services/storage-service.provider.ts new file mode 100644 index 0000000000..c34487403c --- /dev/null +++ b/libs/common/src/platform/services/storage-service.provider.ts @@ -0,0 +1,39 @@ +import { AbstractStorageService, ObservableStorageService } from "../abstractions/storage.service"; +// eslint-disable-next-line import/no-restricted-paths +import { ClientLocations, StorageLocation } from "../state/state-definition"; + +export type PossibleLocation = StorageLocation | ClientLocations[keyof ClientLocations]; + +/** + * A provider for getting client specific computed storage locations and services. + */ +export class StorageServiceProvider { + constructor( + protected readonly diskStorageService: AbstractStorageService & ObservableStorageService, + protected readonly memoryStorageService: AbstractStorageService & ObservableStorageService, + ) {} + + /** + * Computes the location and corresponding service for a given client. + * + * **NOTE** The default implementation does not respect client overrides and if clients + * have special overrides they are responsible for implementing this service. + * @param defaultLocation The default location to use if no client specific override is preferred. + * @param overrides Client specific overrides + * @returns The computed storage location and corresponding storage service to use to get/store state. + * @throws If there is no configured storage service for the given inputs. + */ + get( + defaultLocation: PossibleLocation, + overrides: Partial, + ): [location: PossibleLocation, service: AbstractStorageService & ObservableStorageService] { + switch (defaultLocation) { + case "disk": + return [defaultLocation, this.diskStorageService]; + case "memory": + return [defaultLocation, this.memoryStorageService]; + default: + throw new Error(`Unexpected location: ${defaultLocation}`); + } + } +} diff --git a/libs/common/src/platform/state/index.ts b/libs/common/src/platform/state/index.ts index 72f3aa155f..b457e14c2f 100644 --- a/libs/common/src/platform/state/index.ts +++ b/libs/common/src/platform/state/index.ts @@ -9,5 +9,6 @@ export { ActiveUserStateProvider, SingleUserStateProvider } from "./user-state.p export { KeyDefinition } from "./key-definition"; export { StateUpdateOptions } from "./state-update-options"; export { UserKeyDefinition } from "./user-key-definition"; +export { StateEventRunnerService } from "./state-event-runner.service"; export * from "./state-definitions"; diff --git a/libs/common/src/platform/state/state-definitions.ts b/libs/common/src/platform/state/state-definitions.ts index 0962e7b37a..6b9761481d 100644 --- a/libs/common/src/platform/state/state-definitions.ts +++ b/libs/common/src/platform/state/state-definitions.ts @@ -65,6 +65,8 @@ export const VAULT_FILTER_DISK = new StateDefinition("vaultFilter", "disk", { web: "disk-local", }); +export const CLEAR_EVENT_DISK = new StateDefinition("clearEvent", "disk"); + export const NEW_WEB_LAYOUT_BANNER_DISK = new StateDefinition("newWebLayoutBanner", "disk", { web: "disk-local", }); diff --git a/libs/common/src/platform/state/state-event-registrar.service.spec.ts b/libs/common/src/platform/state/state-event-registrar.service.spec.ts new file mode 100644 index 0000000000..2fae985033 --- /dev/null +++ b/libs/common/src/platform/state/state-event-registrar.service.spec.ts @@ -0,0 +1,85 @@ +import { mock } from "jest-mock-extended"; + +import { FakeGlobalStateProvider } from "../../../spec"; +import { AbstractStorageService, ObservableStorageService } from "../abstractions/storage.service"; +import { StorageServiceProvider } from "../services/storage-service.provider"; + +import { StateDefinition } from "./state-definition"; +import { STATE_LOCK_EVENT, StateEventRegistrarService } from "./state-event-registrar.service"; +import { UserKeyDefinition } from "./user-key-definition"; + +describe("StateEventRegistrarService", () => { + const globalStateProvider = new FakeGlobalStateProvider(); + const lockState = globalStateProvider.getFake(STATE_LOCK_EVENT); + const storageServiceProvider = mock(); + + const sut = new StateEventRegistrarService(globalStateProvider, storageServiceProvider); + + describe("registerEvents", () => { + const fakeKeyDefinition = new UserKeyDefinition( + new StateDefinition("fakeState", "disk"), + "fakeKey", + { + deserializer: (s) => s, + clearOn: ["lock"], + }, + ); + + beforeEach(() => { + jest.resetAllMocks(); + }); + + it("adds event on null storage", async () => { + storageServiceProvider.get.mockReturnValue([ + "disk", + mock(), + ]); + + await sut.registerEvents(fakeKeyDefinition); + + expect(lockState.nextMock).toHaveBeenCalledWith([ + { + key: "fakeKey", + location: "disk", + state: "fakeState", + }, + ]); + }); + + it("adds event on empty array in storage", async () => { + lockState.stateSubject.next([]); + storageServiceProvider.get.mockReturnValue([ + "disk", + mock(), + ]); + + await sut.registerEvents(fakeKeyDefinition); + + expect(lockState.nextMock).toHaveBeenCalledWith([ + { + key: "fakeKey", + location: "disk", + state: "fakeState", + }, + ]); + }); + + it("doesn't add a duplicate", async () => { + lockState.stateSubject.next([ + { + key: "fakeKey", + location: "disk", + state: "fakeState", + }, + ]); + storageServiceProvider.get.mockReturnValue([ + "disk", + mock(), + ]); + + await sut.registerEvents(fakeKeyDefinition); + + expect(lockState.nextMock).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/libs/common/src/platform/state/state-event-registrar.service.ts b/libs/common/src/platform/state/state-event-registrar.service.ts new file mode 100644 index 0000000000..e74d46d3b7 --- /dev/null +++ b/libs/common/src/platform/state/state-event-registrar.service.ts @@ -0,0 +1,76 @@ +import { PossibleLocation, StorageServiceProvider } from "../services/storage-service.provider"; + +import { GlobalState } from "./global-state"; +import { GlobalStateProvider } from "./global-state.provider"; +import { KeyDefinition } from "./key-definition"; +import { CLEAR_EVENT_DISK } from "./state-definitions"; +import { ClearEvent, UserKeyDefinition } from "./user-key-definition"; + +export type StateEventInfo = { + state: string; + key: string; + location: PossibleLocation; +}; + +export const STATE_LOCK_EVENT = KeyDefinition.array(CLEAR_EVENT_DISK, "lock", { + deserializer: (e) => e, +}); + +export const STATE_LOGOUT_EVENT = KeyDefinition.array(CLEAR_EVENT_DISK, "logout", { + deserializer: (e) => e, +}); + +export class StateEventRegistrarService { + private readonly stateEventStateMap: { [Prop in ClearEvent]: GlobalState }; + + constructor( + globalStateProvider: GlobalStateProvider, + private storageServiceProvider: StorageServiceProvider, + ) { + this.stateEventStateMap = { + lock: globalStateProvider.get(STATE_LOCK_EVENT), + logout: globalStateProvider.get(STATE_LOGOUT_EVENT), + }; + } + + async registerEvents(keyDefinition: UserKeyDefinition) { + for (const clearEvent of keyDefinition.clearOn) { + const eventState = this.stateEventStateMap[clearEvent]; + // Determine the storage location for this + const [storageLocation] = this.storageServiceProvider.get( + keyDefinition.stateDefinition.defaultStorageLocation, + keyDefinition.stateDefinition.storageLocationOverrides, + ); + + const newEvent: StateEventInfo = { + state: keyDefinition.stateDefinition.name, + key: keyDefinition.key, + location: storageLocation, + }; + + // Only update the event state if the existing list doesn't have a matching entry + await eventState.update( + (existingTickets) => { + existingTickets ??= []; + existingTickets.push(newEvent); + return existingTickets; + }, + { + shouldUpdate: (currentTickets) => { + return ( + // If the current tickets are null, then it will for sure be added + currentTickets == null || + // If an existing match couldn't be found, we also need to add one + currentTickets.findIndex( + (e) => + e.state === newEvent.state && + e.key === newEvent.key && + e.location === newEvent.location, + ) === -1 + ); + }, + }, + ); + } + } +} diff --git a/libs/common/src/platform/state/state-event-runner.service.spec.ts b/libs/common/src/platform/state/state-event-runner.service.spec.ts new file mode 100644 index 0000000000..1c98099a51 --- /dev/null +++ b/libs/common/src/platform/state/state-event-runner.service.spec.ts @@ -0,0 +1,69 @@ +import { mock } from "jest-mock-extended"; + +import { FakeGlobalStateProvider } from "../../../spec"; +import { UserId } from "../../types/guid"; +import { AbstractStorageService, ObservableStorageService } from "../abstractions/storage.service"; +import { StorageServiceProvider } from "../services/storage-service.provider"; + +import { STATE_LOCK_EVENT } from "./state-event-registrar.service"; +import { StateEventRunnerService } from "./state-event-runner.service"; + +describe("EventRunnerService", () => { + const fakeGlobalStateProvider = new FakeGlobalStateProvider(); + const lockState = fakeGlobalStateProvider.getFake(STATE_LOCK_EVENT); + + const storageServiceProvider = mock(); + + const sut = new StateEventRunnerService(fakeGlobalStateProvider, storageServiceProvider); + + describe("handleEvent", () => { + it("does nothing if there are no events in state", async () => { + const mockStorageService = mock(); + storageServiceProvider.get.mockReturnValue(["disk", mockStorageService]); + + await sut.handleEvent("lock", "bff09d3c-762a-4551-9275-45b137b2f073" as UserId); + + expect(lockState.nextMock).not.toHaveBeenCalled(); + }); + + it("loops through and acts on all events", async () => { + const mockDiskStorageService = mock(); + const mockMemoryStorageService = mock(); + + lockState.stateSubject.next([ + { + state: "fakeState1", + key: "fakeKey1", + location: "disk", + }, + { + state: "fakeState2", + key: "fakeKey2", + location: "memory", + }, + ]); + + storageServiceProvider.get.mockImplementation((defaultLocation, overrides) => { + if (defaultLocation === "disk") { + return [defaultLocation, mockDiskStorageService]; + } else if (defaultLocation === "memory") { + return [defaultLocation, mockMemoryStorageService]; + } + }); + + mockMemoryStorageService.get.mockResolvedValue("something"); + + await sut.handleEvent("lock", "bff09d3c-762a-4551-9275-45b137b2f073" as UserId); + + expect(mockDiskStorageService.get).toHaveBeenCalledTimes(1); + expect(mockDiskStorageService.get).toHaveBeenCalledWith( + "user_bff09d3c-762a-4551-9275-45b137b2f073_fakeState1_fakeKey1", + ); + expect(mockMemoryStorageService.get).toHaveBeenCalledTimes(1); + expect(mockMemoryStorageService.get).toHaveBeenCalledWith( + "user_bff09d3c-762a-4551-9275-45b137b2f073_fakeState2_fakeKey2", + ); + expect(mockMemoryStorageService.remove).toHaveBeenCalledTimes(1); + }); + }); +}); diff --git a/libs/common/src/platform/state/state-event-runner.service.ts b/libs/common/src/platform/state/state-event-runner.service.ts new file mode 100644 index 0000000000..8fcc0710da --- /dev/null +++ b/libs/common/src/platform/state/state-event-runner.service.ts @@ -0,0 +1,80 @@ +import { firstValueFrom } from "rxjs"; + +import { UserId } from "../../types/guid"; +import { StorageServiceProvider } from "../services/storage-service.provider"; + +import { GlobalState } from "./global-state"; +import { GlobalStateProvider } from "./global-state.provider"; +import { StateDefinition, StorageLocation } from "./state-definition"; +import { + STATE_LOCK_EVENT, + STATE_LOGOUT_EVENT, + StateEventInfo, +} from "./state-event-registrar.service"; +import { ClearEvent, UserKeyDefinition } from "./user-key-definition"; + +export class StateEventRunnerService { + private readonly stateEventMap: { [Prop in ClearEvent]: GlobalState }; + + constructor( + globalStateProvider: GlobalStateProvider, + private storageServiceProvider: StorageServiceProvider, + ) { + this.stateEventMap = { + lock: globalStateProvider.get(STATE_LOCK_EVENT), + logout: globalStateProvider.get(STATE_LOGOUT_EVENT), + }; + } + + async handleEvent(event: ClearEvent, userId: UserId) { + let tickets = await firstValueFrom(this.stateEventMap[event].state$); + tickets ??= []; + + const failures: string[] = []; + + for (const ticket of tickets) { + try { + const [, service] = this.storageServiceProvider.get( + ticket.location, + {}, // The storage location is already the computed storage location for this client + ); + + const ticketStorageKey = this.storageKeyFor(userId, ticket); + + // Evaluate current value so we can avoid writing to state if we don't need to + const currentValue = await service.get(ticketStorageKey); + if (currentValue != null) { + await service.remove(ticketStorageKey); + } + } catch (err: unknown) { + let errorMessage = "Unknown Error"; + if (typeof err === "object" && "message" in err && typeof err.message === "string") { + errorMessage = err.message; + } + + failures.push( + `${errorMessage} in ${ticket.state} > ${ticket.key} located ${ticket.location}`, + ); + } + } + + if (failures.length > 0) { + // Throw aggregated error + throw new Error( + `One or more errors occurred while handling event '${event}' for user ${userId}.\n${failures.join("\n")}`, + ); + } + } + + private storageKeyFor(userId: UserId, ticket: StateEventInfo) { + const userKey = new UserKeyDefinition( + new StateDefinition(ticket.state, ticket.location as unknown as StorageLocation), + ticket.key, + { + deserializer: (v) => v, + clearOn: [], + }, + ); + return userKey.buildKey(userId); + } +} diff --git a/libs/common/src/platform/state/user-key-definition.ts b/libs/common/src/platform/state/user-key-definition.ts index 242ca6c1ca..99e3039e1e 100644 --- a/libs/common/src/platform/state/user-key-definition.ts +++ b/libs/common/src/platform/state/user-key-definition.ts @@ -6,7 +6,7 @@ import { array, record } from "./deserialization-helpers"; import { KeyDefinition, KeyDefinitionOptions } from "./key-definition"; import { StateDefinition } from "./state-definition"; -type ClearEvent = "lock" | "logout"; +export type ClearEvent = "lock" | "logout"; type UserKeyDefinitionOptions = KeyDefinitionOptions & { clearOn: ClearEvent[];