From 8bbf7258bea7b8560e48f51d8b62f9d27dde904f Mon Sep 17 00:00:00 2001 From: Adam Brown Date: Sun, 9 Oct 2022 19:38:33 +0100 Subject: [PATCH] porting media decryption to the chat engine --- .../kotlin/app/dapk/st/graph/AppModule.kt | 3 -- .../kotlin/app/dapk/st/engine/ChatEngine.kt | 14 ++++++- .../dapk/st/messenger/DecryptingFetcher.kt | 13 ++++--- .../app/dapk/st/messenger/MessengerModule.kt | 4 +- .../app/dapk/st/settings/SettingsViewModel.kt | 13 ------- .../kotlin/app/dapk/st/engine/MatrixEngine.kt | 37 +++++++++++++------ ...iaDecrypter.kt => MatrixMediaDecrypter.kt} | 2 +- .../app/dapk/st/matrix/sync/SyncService.kt | 2 +- .../sync/internal/DefaultSyncService.kt | 2 +- test-harness/src/test/kotlin/test/Test.kt | 4 +- 10 files changed, 51 insertions(+), 43 deletions(-) rename matrix/services/crypto/src/main/kotlin/app/dapk/st/matrix/crypto/{MediaDecrypter.kt => MatrixMediaDecrypter.kt} (96%) diff --git a/app/src/main/kotlin/app/dapk/st/graph/AppModule.kt b/app/src/main/kotlin/app/dapk/st/graph/AppModule.kt index e904613..87c181d 100644 --- a/app/src/main/kotlin/app/dapk/st/graph/AppModule.kt +++ b/app/src/main/kotlin/app/dapk/st/graph/AppModule.kt @@ -179,7 +179,6 @@ internal class FeatureModules internal constructor( MessengerModule( matrixModules.engine, context, - base64, storeModule.value.messageStore(), ) } @@ -466,10 +465,8 @@ internal class MatrixModules( val push by unsafeLazy { matrix.pushService() } val sync by unsafeLazy { matrix.syncService() } - val message by unsafeLazy { matrix.messageService() } val room by unsafeLazy { matrix.roomService() } val profile by unsafeLazy { matrix.profileService() } - val crypto by unsafeLazy { matrix.cryptoService() } } internal class DomainModules( diff --git a/chat-engine/src/main/kotlin/app/dapk/st/engine/ChatEngine.kt b/chat-engine/src/main/kotlin/app/dapk/st/engine/ChatEngine.kt index a47c76e..1d0544f 100644 --- a/chat-engine/src/main/kotlin/app/dapk/st/engine/ChatEngine.kt +++ b/chat-engine/src/main/kotlin/app/dapk/st/engine/ChatEngine.kt @@ -16,9 +16,19 @@ interface ChatEngine { suspend fun me(forceRefresh: Boolean): Me - suspend fun refresh(roomIds: List) - suspend fun InputStream.importRoomKeys(password: String): Flow suspend fun send(message: SendMessage, room: RoomOverview) + + fun mediaDecrypter(): MediaDecrypter } + +interface MediaDecrypter { + + fun decrypt(input: InputStream, k: String, iv: String): Collector + + fun interface Collector { + fun collect(partial: (ByteArray) -> Unit) + } + +} \ No newline at end of file diff --git a/features/messenger/src/main/kotlin/app/dapk/st/messenger/DecryptingFetcher.kt b/features/messenger/src/main/kotlin/app/dapk/st/messenger/DecryptingFetcher.kt index a345ffc..d77e319 100644 --- a/features/messenger/src/main/kotlin/app/dapk/st/messenger/DecryptingFetcher.kt +++ b/features/messenger/src/main/kotlin/app/dapk/st/messenger/DecryptingFetcher.kt @@ -2,10 +2,9 @@ package app.dapk.st.messenger import android.content.Context import android.os.Environment -import app.dapk.st.core.Base64 +import app.dapk.st.engine.MediaDecrypter +import app.dapk.st.engine.RoomEvent import app.dapk.st.matrix.common.RoomId -import app.dapk.st.matrix.crypto.MediaDecrypter -import app.dapk.st.matrix.sync.RoomEvent import coil.ImageLoader import coil.decode.DataSource import coil.decode.ImageSource @@ -20,9 +19,11 @@ import okio.BufferedSource import okio.Path.Companion.toOkioPath import java.io.File -class DecryptingFetcherFactory(private val context: Context, base64: Base64, private val roomId: RoomId) : Fetcher.Factory { - - private val mediaDecrypter = MediaDecrypter(base64) +class DecryptingFetcherFactory( + private val context: Context, + private val roomId: RoomId, + private val mediaDecrypter: MediaDecrypter, +) : Fetcher.Factory { override fun create(data: RoomEvent.Image, options: Options, imageLoader: ImageLoader): Fetcher { return DecryptingFetcher(data, context, mediaDecrypter, roomId) diff --git a/features/messenger/src/main/kotlin/app/dapk/st/messenger/MessengerModule.kt b/features/messenger/src/main/kotlin/app/dapk/st/messenger/MessengerModule.kt index 85b10ac..e46f37a 100644 --- a/features/messenger/src/main/kotlin/app/dapk/st/messenger/MessengerModule.kt +++ b/features/messenger/src/main/kotlin/app/dapk/st/messenger/MessengerModule.kt @@ -1,7 +1,6 @@ package app.dapk.st.messenger import android.content.Context -import app.dapk.st.core.Base64 import app.dapk.st.core.ProvidableModule import app.dapk.st.domain.application.message.MessageOptionsStore import app.dapk.st.engine.ChatEngine @@ -10,7 +9,6 @@ import app.dapk.st.matrix.common.RoomId class MessengerModule( private val chatEngine: ChatEngine, private val context: Context, - private val base64: Base64, private val messageOptionsStore: MessageOptionsStore, ) : ProvidableModule { @@ -21,5 +19,5 @@ class MessengerModule( ) } - internal fun decryptingFetcherFactory(roomId: RoomId) = DecryptingFetcherFactory(context, base64, roomId) + internal fun decryptingFetcherFactory(roomId: RoomId) = DecryptingFetcherFactory(context, roomId, chatEngine.mediaDecrypter()) } \ No newline at end of file diff --git a/features/settings/src/main/kotlin/app/dapk/st/settings/SettingsViewModel.kt b/features/settings/src/main/kotlin/app/dapk/st/settings/SettingsViewModel.kt index a05783d..b74daf8 100644 --- a/features/settings/src/main/kotlin/app/dapk/st/settings/SettingsViewModel.kt +++ b/features/settings/src/main/kotlin/app/dapk/st/settings/SettingsViewModel.kt @@ -147,19 +147,6 @@ internal class SettingsViewModel( fileStream.importRoomKeys(passphrase) .onEach { updatePageState { copy(importProgress = it) } - when (it) { - is ImportResult.Error -> { - // do nothing - } - - is ImportResult.Update -> { - // do nothing - } - - is ImportResult.Success -> { - chatEngine.refresh(it.roomIds.toList()) - } - } } .launchIn(viewModelScope) }, diff --git a/matrix-chat-engine/src/main/kotlin/app/dapk/st/engine/MatrixEngine.kt b/matrix-chat-engine/src/main/kotlin/app/dapk/st/engine/MatrixEngine.kt index 6312713..305aa56 100644 --- a/matrix-chat-engine/src/main/kotlin/app/dapk/st/engine/MatrixEngine.kt +++ b/matrix-chat-engine/src/main/kotlin/app/dapk/st/engine/MatrixEngine.kt @@ -10,10 +10,7 @@ import app.dapk.st.matrix.auth.DeviceDisplayNameGenerator import app.dapk.st.matrix.auth.authService import app.dapk.st.matrix.auth.installAuthService import app.dapk.st.matrix.common.* -import app.dapk.st.matrix.crypto.RoomMembersProvider -import app.dapk.st.matrix.crypto.Verification -import app.dapk.st.matrix.crypto.cryptoService -import app.dapk.st.matrix.crypto.installCryptoService +import app.dapk.st.matrix.crypto.* import app.dapk.st.matrix.device.KnownDeviceStore import app.dapk.st.matrix.device.deviceService import app.dapk.st.matrix.device.installEncryptionService @@ -30,6 +27,7 @@ import app.dapk.st.olm.OlmStore import app.dapk.st.olm.OlmWrapper import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.onEach import java.io.InputStream import java.time.Clock @@ -38,6 +36,7 @@ class MatrixEngine internal constructor( private val matrix: Lazy, private val timelineUseCase: Lazy, private val sendMessageUseCase: Lazy, + private val matrixMediaDecrypter: Lazy, ) : ChatEngine { override fun directory() = directoryUseCase.value.state() @@ -57,14 +56,18 @@ class MatrixEngine internal constructor( return matrix.value.profileService().me(forceRefresh).engine() } - override suspend fun refresh(roomIds: List) { - matrix.value.syncService().forceManualRefresh(roomIds) - - } - override suspend fun InputStream.importRoomKeys(password: String): Flow { return with(matrix.value.cryptoService()) { - importRoomKeys(password).map { it.engine() } + importRoomKeys(password).map { it.engine() }.onEach { + when (it) { + is ImportResult.Error, + is ImportResult.Update -> { + // do nothing + } + + is ImportResult.Success -> matrix.value.syncService().forceManualRefresh(it.roomIds) + } + } } } @@ -72,6 +75,17 @@ class MatrixEngine internal constructor( sendMessageUseCase.value.send(message, room) } + override fun mediaDecrypter(): MediaDecrypter { + val mediaDecrypter = matrixMediaDecrypter.value + return object : MediaDecrypter { + override fun decrypt(input: InputStream, k: String, iv: String): MediaDecrypter.Collector { + return MediaDecrypter.Collector { + mediaDecrypter.decrypt(input, k, iv).collect(it) + } + } + } + } + class Factory { fun create( @@ -138,8 +152,9 @@ class MatrixEngine internal constructor( SendMessageUseCase(matrix.messageService(), LocalIdFactory(), imageContentReader, Clock.systemUTC()) } - return MatrixEngine(directoryUseCase, lazyMatrix, timelineUseCase, sendMessageUseCase) + val mediaDecrypter = unsafeLazy { MatrixMediaDecrypter(base64) } + return MatrixEngine(directoryUseCase, lazyMatrix, timelineUseCase, sendMessageUseCase, mediaDecrypter) } } diff --git a/matrix/services/crypto/src/main/kotlin/app/dapk/st/matrix/crypto/MediaDecrypter.kt b/matrix/services/crypto/src/main/kotlin/app/dapk/st/matrix/crypto/MatrixMediaDecrypter.kt similarity index 96% rename from matrix/services/crypto/src/main/kotlin/app/dapk/st/matrix/crypto/MediaDecrypter.kt rename to matrix/services/crypto/src/main/kotlin/app/dapk/st/matrix/crypto/MatrixMediaDecrypter.kt index df513d2..65dde9e 100644 --- a/matrix/services/crypto/src/main/kotlin/app/dapk/st/matrix/crypto/MediaDecrypter.kt +++ b/matrix/services/crypto/src/main/kotlin/app/dapk/st/matrix/crypto/MatrixMediaDecrypter.kt @@ -12,7 +12,7 @@ private const val CIPHER_ALGORITHM = "AES/CTR/NoPadding" private const val SECRET_KEY_SPEC_ALGORITHM = "AES" private const val MESSAGE_DIGEST_ALGORITHM = "SHA-256" -class MediaDecrypter(private val base64: Base64) { +class MatrixMediaDecrypter(private val base64: Base64) { fun decrypt(input: InputStream, k: String, iv: String): Collector { val key = base64.decode(k.replace('-', '+').replace('_', '/')) diff --git a/matrix/services/sync/src/main/kotlin/app/dapk/st/matrix/sync/SyncService.kt b/matrix/services/sync/src/main/kotlin/app/dapk/st/matrix/sync/SyncService.kt index 20ddf80..7ad567c 100644 --- a/matrix/services/sync/src/main/kotlin/app/dapk/st/matrix/sync/SyncService.kt +++ b/matrix/services/sync/src/main/kotlin/app/dapk/st/matrix/sync/SyncService.kt @@ -21,7 +21,7 @@ interface SyncService : MatrixService { fun startSyncing(): Flow fun events(roomId: RoomId? = null): Flow> suspend fun observeEvent(eventId: EventId): Flow - suspend fun forceManualRefresh(roomIds: List) + suspend fun forceManualRefresh(roomIds: Set) @JvmInline value class FilterId(val value: String) diff --git a/matrix/services/sync/src/main/kotlin/app/dapk/st/matrix/sync/internal/DefaultSyncService.kt b/matrix/services/sync/src/main/kotlin/app/dapk/st/matrix/sync/internal/DefaultSyncService.kt index 779c2f0..aa4c0c7 100644 --- a/matrix/services/sync/src/main/kotlin/app/dapk/st/matrix/sync/internal/DefaultSyncService.kt +++ b/matrix/services/sync/src/main/kotlin/app/dapk/st/matrix/sync/internal/DefaultSyncService.kt @@ -110,7 +110,7 @@ internal class DefaultSyncService( override fun room(roomId: RoomId) = roomStore.latest(roomId) override fun events(roomId: RoomId?) = roomId?.let { syncEventsFlow.map { it.filter { it.roomId == roomId } }.distinctUntilChanged() } ?: syncEventsFlow override suspend fun observeEvent(eventId: EventId) = roomStore.observeEvent(eventId) - override suspend fun forceManualRefresh(roomIds: List) { + override suspend fun forceManualRefresh(roomIds: Set) { coroutineDispatchers.withIoContext { roomIds.map { async { diff --git a/test-harness/src/test/kotlin/test/Test.kt b/test-harness/src/test/kotlin/test/Test.kt index 7ff5604..e0403a2 100644 --- a/test-harness/src/test/kotlin/test/Test.kt +++ b/test-harness/src/test/kotlin/test/Test.kt @@ -7,7 +7,7 @@ import TestUser import app.dapk.st.core.extensions.ifNull import app.dapk.st.matrix.common.RoomId import app.dapk.st.matrix.common.RoomMember -import app.dapk.st.matrix.crypto.MediaDecrypter +import app.dapk.st.matrix.crypto.MatrixMediaDecrypter import app.dapk.st.matrix.message.MessageService import app.dapk.st.matrix.message.messageService import app.dapk.st.matrix.sync.RoomEvent @@ -153,7 +153,7 @@ class MatrixTestScope(private val testScope: TestScope) { null -> output.readBytes().md5Hash() else -> { val byteStream = ByteArrayOutputStream() - MediaDecrypter(this.base64).decrypt(output.inputStream(), keys.k, keys.iv).collect { + MatrixMediaDecrypter(this.base64).decrypt(output.inputStream(), keys.k, keys.iv).collect { byteStream.write(it) } byteStream.toByteArray().md5Hash()