diff --git a/Development Assets/DevelopmentModels.swift b/Development Assets/DevelopmentModels.swift index 6b521fb..67a14c5 100644 --- a/Development Assets/DevelopmentModels.swift +++ b/Development Assets/DevelopmentModels.swift @@ -56,7 +56,16 @@ extension IdentityDatabase { } extension Identity { - static let development = try! IdentityDatabase.development.identity(id: devIdentityID)! + static let development: Identity = { + var identity: Identity? + + IdentityDatabase.development.identityObservation(id: devIdentityID) + .assertNoFailure() + .sink(receiveValue: { identity = $0 }) + .store(in: &cancellables) + + return identity! + }() } extension SceneViewModel { diff --git a/Shared/Extensions/Publisher+Extensions.swift b/Shared/Extensions/Publisher+Extensions.swift index a52c111..35a3e30 100644 --- a/Shared/Extensions/Publisher+Extensions.swift +++ b/Shared/Extensions/Publisher+Extensions.swift @@ -4,12 +4,12 @@ import Foundation import Combine extension Publisher { - func assignErrorsToAlertItem( + func assignErrorsToAlertItem( to keyPath: ReferenceWritableKeyPath, on object: Root) -> AnyPublisher { - self.catch { error -> AnyPublisher in + self.catch { [weak object] error -> AnyPublisher in DispatchQueue.main.async { - object[keyPath: keyPath] = AlertItem(error: error) + object?[keyPath: keyPath] = AlertItem(error: error) } return Empty().eraseToAnyPublisher() diff --git a/Shared/Model/IdentityDatabase.swift b/Shared/Model/IdentityDatabase.swift index 4e6e30a..536a8ba 100644 --- a/Shared/Model/IdentityDatabase.swift +++ b/Shared/Model/IdentityDatabase.swift @@ -26,16 +26,12 @@ struct IdentityDatabase { } extension IdentityDatabase { - func createIdentity(id: String, url: URL) -> AnyPublisher { - databaseQueue.writePublisher { - try StoredIdentity(id: id, url: url, instanceURI: nil).save($0) - - return Identity(id: id, url: url, instance: nil, account: nil) - } - .eraseToAnyPublisher() + func createIdentity(id: String, url: URL) -> AnyPublisher { + databaseQueue.writePublisher(updates: StoredIdentity(id: id, url: url, instanceURI: nil).save) + .eraseToAnyPublisher() } - func updateInstance(_ instance: Instance, forIdentityID identityID: String) -> AnyPublisher { + func updateInstance(_ instance: Instance, forIdentityID identityID: String) -> AnyPublisher { databaseQueue.writePublisher { try Identity.Instance( uri: instance.uri, @@ -46,15 +42,13 @@ extension IdentityDatabase { try StoredIdentity .filter(Column("id") == identityID) .updateAll($0, Column("instanceURI").set(to: instance.uri)) - - return try Self.fetchIdentity(id: identityID, db: $0) } .eraseToAnyPublisher() } - func updateAccount(_ account: Account, forIdentityID identityID: String) -> AnyPublisher { - databaseQueue.writePublisher { - try Identity.Account( + func updateAccount(_ account: Account, forIdentityID identityID: String) -> AnyPublisher { + databaseQueue.writePublisher( + updates: Identity.Account( id: account.id, identityID: identityID, username: account.username, @@ -63,15 +57,26 @@ extension IdentityDatabase { avatarStatic: account.avatarStatic, header: account.header, headerStatic: account.headerStatic) - .save($0) - - return try Self.fetchIdentity(id: identityID, db: $0) - } - .eraseToAnyPublisher() + .save) + .eraseToAnyPublisher() } - func identity(id: String) throws -> Identity? { - try databaseQueue.read { try Self.fetchIdentity(id: id, db: $0) } + func identityObservation(id: String) -> AnyPublisher { + ValueObservation.tracking( + StoredIdentity + .filter(Column("id") == id) + .including(optional: StoredIdentity.instance) + .including(optional: StoredIdentity.account) + .asRequest(of: IdentityResult.self) + .fetchOne) + .removeDuplicates() + .publisher(in: databaseQueue, scheduling: .immediate) + .map { + guard let result = $0 else { return nil } + + return Identity(result: result) + } + .eraseToAnyPublisher() } } @@ -112,19 +117,6 @@ private extension IdentityDatabase { try migrator.migrate(writer) } - - private static func fetchIdentity(id: String, db: Database) throws -> Identity? { - if let result = try StoredIdentity - .filter(Column("id") == id) - .including(optional: StoredIdentity.instance) - .including(optional: StoredIdentity.account) - .asRequest(of: IdentityResult.self) - .fetchOne(db) { - return Identity(result: result) - } - - return nil - } } private struct StoredIdentity: Codable, Hashable, TableRecord, FetchableRecord, PersistableRecord { diff --git a/Shared/View Models/AddIdentityViewModel.swift b/Shared/View Models/AddIdentityViewModel.swift index 32ec2b5..71718fe 100644 --- a/Shared/View Models/AddIdentityViewModel.swift +++ b/Shared/View Models/AddIdentityViewModel.swift @@ -8,15 +8,13 @@ class AddIdentityViewModel: ObservableObject { @Published var urlFieldText = "" @Published var alertItem: AlertItem? @Published private(set) var loading = false - private(set) var addedIdentity: AnyPublisher + @Published private(set) var addedIdentityID: String? private let networkClient: HTTPClient private let identityDatabase: IdentityDatabase private let secrets: Secrets private let webAuthenticationSessionType: WebAuthenticationSessionType.Type private let webAuthenticationSessionContextProvider = WebAuthenticationSessionContextProvider() - private let addedIdentityInput = PassthroughSubject() - private var cancellables = Set() init( networkClient: HTTPClient, @@ -27,7 +25,6 @@ class AddIdentityViewModel: ObservableObject { self.identityDatabase = identityDatabase self.secrets = secrets self.webAuthenticationSessionType = webAuthenticationSessionType - addedIdentity = addedIdentityInput.eraseToAnyPublisher() } func goTapped() { @@ -65,11 +62,12 @@ class AddIdentityViewModel: ObservableObject { identityDatabase: identityDatabase, secrets: secrets) .assignErrorsToAlertItem(to: \.alertItem, on: self) + .receive(on: RunLoop.main) .handleEvents( receiveSubscription: { [weak self] _ in self?.loading = true }, receiveCompletion: { [weak self] _ in self?.loading = false }) - .sink(receiveValue: addedIdentityInput.send) - .store(in: &cancellables) + .map { $0 as String? } + .assign(to: &$addedIdentityID) } } @@ -196,13 +194,14 @@ private extension Publisher where Output == AccessToken { id: String, instanceURL: URL, identityDatabase: IdentityDatabase, - secrets: Secrets) -> AnyPublisher { + secrets: Secrets) -> AnyPublisher { tryMap { accessToken -> (String, URL) in try secrets.set(accessToken.accessToken, forItem: .accessToken, forIdentityID: id) return (id, instanceURL) } .flatMap(identityDatabase.createIdentity) + .map { id } .eraseToAnyPublisher() } } diff --git a/Shared/View Models/SceneViewModel.swift b/Shared/View Models/SceneViewModel.swift index 510f63a..4565352 100644 --- a/Shared/View Models/SceneViewModel.swift +++ b/Shared/View Models/SceneViewModel.swift @@ -4,21 +4,7 @@ import Foundation import Combine class SceneViewModel: ObservableObject { - @Published private(set) var identity: Identity? { - didSet { - if let identity = identity { - recentIdentityID = identity.id - networkClient.instanceURL = identity.url - - do { - networkClient.accessToken = try secrets.item(.accessToken, forIdentityID: identity.id) - } catch { - alertItem = AlertItem(error: error) - } - } - } - } - + @Published private(set) var identity: Identity? @Published var alertItem: AlertItem? @Published var presentingSettings = false var selectedTopLevelNavigation: TopLevelNavigation? = .timelines @@ -39,8 +25,7 @@ class SceneViewModel: ObservableObject { self.userDefaults = userDefaults if let recentIdentityID = recentIdentityID { - identity = try? identityDatabase.identity(id: recentIdentityID) - refreshIdentity() + changeIdentity(id: recentIdentityID) } } } @@ -54,7 +39,7 @@ extension SceneViewModel { .map { ($0, identity.id) } .flatMap(identityDatabase.updateAccount) .assignErrorsToAlertItem(to: \.alertItem, on: self) - .assign(to: \.identity, on: self) + .sink(receiveValue: {}) .store(in: &cancellables) } @@ -62,7 +47,7 @@ extension SceneViewModel { .map { ($0, identity.id) } .flatMap(identityDatabase.updateInstance) .assignErrorsToAlertItem(to: \.alertItem, on: self) - .assign(to: \.identity, on: self) + .sink(receiveValue: {}) .store(in: &cancellables) } @@ -72,8 +57,9 @@ extension SceneViewModel { identityDatabase: identityDatabase, secrets: secrets) - addAccountViewModel.addedIdentity - .sink(receiveValue: addIdentity(_:)) + addAccountViewModel.$addedIdentityID + .compactMap { $0 } + .sink(receiveValue: changeIdentity(id:)) .store(in: &cancellables) return addAccountViewModel @@ -88,8 +74,23 @@ private extension SceneViewModel { set { userDefaults.set(newValue, forKey: Self.recentIdentityIDKey) } } - private func addIdentity(_ identity: Identity) { - self.identity = identity + private func changeIdentity(id: String) { + identityDatabase.identityObservation(id: id) + .assignErrorsToAlertItem(to: \.alertItem, on: self) + .handleEvents(receiveOutput: { [weak self] in + guard let self = self, let identity = $0 else { return } + + self.recentIdentityID = identity.id + self.networkClient.instanceURL = identity.url + + do { + self.networkClient.accessToken = try self.secrets.item(.accessToken, forIdentityID: identity.id) + } catch { + self.alertItem = AlertItem(error: error) + } + }) + .assign(to: &$identity) + refreshIdentity() } } diff --git a/Tests/View Models/AddIdentityViewModelTests.swift b/Tests/View Models/AddIdentityViewModelTests.swift index c05fcfd..f84b55b 100644 --- a/Tests/View Models/AddIdentityViewModelTests.swift +++ b/Tests/View Models/AddIdentityViewModelTests.swift @@ -23,23 +23,26 @@ class AddIdentityViewModelTests: XCTestCase { identityDatabase: identityDatabase, secrets: secrets, webAuthenticationSessionType: SuccessfulStubbingWebAuthenticationSession.self) - let recorder = sut.addedIdentity.record() + let addedIDRecorder = sut.$addedIdentityID.record() + _ = try wait(for: addedIDRecorder.next(), timeout: 1) sut.urlFieldText = "https://mastodon.social" sut.goTapped() - let addedIdentity = try wait(for: recorder.next(), timeout: 1) + let addedIdentityID = try wait(for: addedIDRecorder.next(), timeout: 1)! + let identityRecorder = identityDatabase.identityObservation(id: addedIdentityID).record() + let addedIdentity = try wait(for: identityRecorder.next(), timeout: 1)! - XCTAssertEqual(try identityDatabase.identity(id: addedIdentity.id), addedIdentity) + XCTAssertEqual(addedIdentity.id, addedIdentityID) XCTAssertEqual(addedIdentity.url, URL(string: "https://mastodon.social")!) XCTAssertEqual( - try secrets.item(.clientID, forIdentityID: addedIdentity.id) as String?, + try secrets.item(.clientID, forIdentityID: addedIdentityID) as String?, "AUTHORIZATION_CLIENT_ID_STUB_VALUE") XCTAssertEqual( - try secrets.item(.clientSecret, forIdentityID: addedIdentity.id) as String?, + try secrets.item(.clientSecret, forIdentityID: addedIdentityID) as String?, "AUTHORIZATION_CLIENT_SECRET_STUB_VALUE") XCTAssertEqual( - try secrets.item(.accessToken, forIdentityID: addedIdentity.id) as String?, + try secrets.item(.accessToken, forIdentityID: addedIdentityID) as String?, "ACCESS_TOKEN_STUB_VALUE") } @@ -49,14 +52,16 @@ class AddIdentityViewModelTests: XCTestCase { identityDatabase: identityDatabase, secrets: secrets, webAuthenticationSessionType: SuccessfulStubbingWebAuthenticationSession.self) - let recorder = sut.addedIdentity.record() + let addedIDRecorder = sut.$addedIdentityID.record() + _ = try wait(for: addedIDRecorder.next(), timeout: 1) sut.urlFieldText = "mastodon.social" sut.goTapped() - let addedIdentity = try wait(for: recorder.next(), timeout: 1) + let addedIdentityID = try wait(for: addedIDRecorder.next(), timeout: 1)! + let identityRecorder = identityDatabase.identityObservation(id: addedIdentityID).record() + let addedIdentity = try wait(for: identityRecorder.next(), timeout: 1)! - XCTAssertEqual(try identityDatabase.identity(id: addedIdentity.id), addedIdentity) XCTAssertEqual(addedIdentity.url, URL(string: "https://mastodon.social")!) }