porting media decryption to the chat engine

This commit is contained in:
Adam Brown 2022-10-09 19:38:33 +01:00
parent e61dea7ba7
commit 8bbf7258be
10 changed files with 51 additions and 43 deletions

View File

@ -179,7 +179,6 @@ internal class FeatureModules internal constructor(
MessengerModule( MessengerModule(
matrixModules.engine, matrixModules.engine,
context, context,
base64,
storeModule.value.messageStore(), storeModule.value.messageStore(),
) )
} }
@ -466,10 +465,8 @@ internal class MatrixModules(
val push by unsafeLazy { matrix.pushService() } val push by unsafeLazy { matrix.pushService() }
val sync by unsafeLazy { matrix.syncService() } val sync by unsafeLazy { matrix.syncService() }
val message by unsafeLazy { matrix.messageService() }
val room by unsafeLazy { matrix.roomService() } val room by unsafeLazy { matrix.roomService() }
val profile by unsafeLazy { matrix.profileService() } val profile by unsafeLazy { matrix.profileService() }
val crypto by unsafeLazy { matrix.cryptoService() }
} }
internal class DomainModules( internal class DomainModules(

View File

@ -16,9 +16,19 @@ interface ChatEngine {
suspend fun me(forceRefresh: Boolean): Me suspend fun me(forceRefresh: Boolean): Me
suspend fun refresh(roomIds: List<RoomId>)
suspend fun InputStream.importRoomKeys(password: String): Flow<ImportResult> suspend fun InputStream.importRoomKeys(password: String): Flow<ImportResult>
suspend fun send(message: SendMessage, room: RoomOverview) suspend fun send(message: SendMessage, room: RoomOverview)
fun mediaDecrypter(): MediaDecrypter
} }
interface MediaDecrypter {
fun decrypt(input: InputStream, k: String, iv: String): Collector
fun interface Collector {
fun collect(partial: (ByteArray) -> Unit)
}
}

View File

@ -2,10 +2,9 @@ package app.dapk.st.messenger
import android.content.Context import android.content.Context
import android.os.Environment import android.os.Environment
import app.dapk.st.core.Base64 import app.dapk.st.engine.MediaDecrypter
import app.dapk.st.engine.RoomEvent
import app.dapk.st.matrix.common.RoomId import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.crypto.MediaDecrypter
import app.dapk.st.matrix.sync.RoomEvent
import coil.ImageLoader import coil.ImageLoader
import coil.decode.DataSource import coil.decode.DataSource
import coil.decode.ImageSource import coil.decode.ImageSource
@ -20,9 +19,11 @@ import okio.BufferedSource
import okio.Path.Companion.toOkioPath import okio.Path.Companion.toOkioPath
import java.io.File import java.io.File
class DecryptingFetcherFactory(private val context: Context, base64: Base64, private val roomId: RoomId) : Fetcher.Factory<RoomEvent.Image> { class DecryptingFetcherFactory(
private val context: Context,
private val mediaDecrypter = MediaDecrypter(base64) private val roomId: RoomId,
private val mediaDecrypter: MediaDecrypter,
) : Fetcher.Factory<RoomEvent.Image> {
override fun create(data: RoomEvent.Image, options: Options, imageLoader: ImageLoader): Fetcher { override fun create(data: RoomEvent.Image, options: Options, imageLoader: ImageLoader): Fetcher {
return DecryptingFetcher(data, context, mediaDecrypter, roomId) return DecryptingFetcher(data, context, mediaDecrypter, roomId)

View File

@ -1,7 +1,6 @@
package app.dapk.st.messenger package app.dapk.st.messenger
import android.content.Context import android.content.Context
import app.dapk.st.core.Base64
import app.dapk.st.core.ProvidableModule import app.dapk.st.core.ProvidableModule
import app.dapk.st.domain.application.message.MessageOptionsStore import app.dapk.st.domain.application.message.MessageOptionsStore
import app.dapk.st.engine.ChatEngine import app.dapk.st.engine.ChatEngine
@ -10,7 +9,6 @@ import app.dapk.st.matrix.common.RoomId
class MessengerModule( class MessengerModule(
private val chatEngine: ChatEngine, private val chatEngine: ChatEngine,
private val context: Context, private val context: Context,
private val base64: Base64,
private val messageOptionsStore: MessageOptionsStore, private val messageOptionsStore: MessageOptionsStore,
) : ProvidableModule { ) : ProvidableModule {
@ -21,5 +19,5 @@ class MessengerModule(
) )
} }
internal fun decryptingFetcherFactory(roomId: RoomId) = DecryptingFetcherFactory(context, base64, roomId) internal fun decryptingFetcherFactory(roomId: RoomId) = DecryptingFetcherFactory(context, roomId, chatEngine.mediaDecrypter())
} }

View File

@ -147,19 +147,6 @@ internal class SettingsViewModel(
fileStream.importRoomKeys(passphrase) fileStream.importRoomKeys(passphrase)
.onEach { .onEach {
updatePageState<Page.ImportRoomKey> { copy(importProgress = it) } updatePageState<Page.ImportRoomKey> { copy(importProgress = it) }
when (it) {
is ImportResult.Error -> {
// do nothing
}
is ImportResult.Update -> {
// do nothing
}
is ImportResult.Success -> {
chatEngine.refresh(it.roomIds.toList())
}
}
} }
.launchIn(viewModelScope) .launchIn(viewModelScope)
}, },

View File

@ -10,10 +10,7 @@ import app.dapk.st.matrix.auth.DeviceDisplayNameGenerator
import app.dapk.st.matrix.auth.authService import app.dapk.st.matrix.auth.authService
import app.dapk.st.matrix.auth.installAuthService import app.dapk.st.matrix.auth.installAuthService
import app.dapk.st.matrix.common.* import app.dapk.st.matrix.common.*
import app.dapk.st.matrix.crypto.RoomMembersProvider import app.dapk.st.matrix.crypto.*
import app.dapk.st.matrix.crypto.Verification
import app.dapk.st.matrix.crypto.cryptoService
import app.dapk.st.matrix.crypto.installCryptoService
import app.dapk.st.matrix.device.KnownDeviceStore import app.dapk.st.matrix.device.KnownDeviceStore
import app.dapk.st.matrix.device.deviceService import app.dapk.st.matrix.device.deviceService
import app.dapk.st.matrix.device.installEncryptionService import app.dapk.st.matrix.device.installEncryptionService
@ -30,6 +27,7 @@ import app.dapk.st.olm.OlmStore
import app.dapk.st.olm.OlmWrapper import app.dapk.st.olm.OlmWrapper
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onEach
import java.io.InputStream import java.io.InputStream
import java.time.Clock import java.time.Clock
@ -38,6 +36,7 @@ class MatrixEngine internal constructor(
private val matrix: Lazy<MatrixClient>, private val matrix: Lazy<MatrixClient>,
private val timelineUseCase: Lazy<ReadMarkingTimeline>, private val timelineUseCase: Lazy<ReadMarkingTimeline>,
private val sendMessageUseCase: Lazy<SendMessageUseCase>, private val sendMessageUseCase: Lazy<SendMessageUseCase>,
private val matrixMediaDecrypter: Lazy<MatrixMediaDecrypter>,
) : ChatEngine { ) : ChatEngine {
override fun directory() = directoryUseCase.value.state() override fun directory() = directoryUseCase.value.state()
@ -57,14 +56,18 @@ class MatrixEngine internal constructor(
return matrix.value.profileService().me(forceRefresh).engine() return matrix.value.profileService().me(forceRefresh).engine()
} }
override suspend fun refresh(roomIds: List<RoomId>) {
matrix.value.syncService().forceManualRefresh(roomIds)
}
override suspend fun InputStream.importRoomKeys(password: String): Flow<ImportResult> { override suspend fun InputStream.importRoomKeys(password: String): Flow<ImportResult> {
return with(matrix.value.cryptoService()) { return with(matrix.value.cryptoService()) {
importRoomKeys(password).map { it.engine() } importRoomKeys(password).map { it.engine() }.onEach {
when (it) {
is ImportResult.Error,
is ImportResult.Update -> {
// do nothing
}
is ImportResult.Success -> matrix.value.syncService().forceManualRefresh(it.roomIds)
}
}
} }
} }
@ -72,6 +75,17 @@ class MatrixEngine internal constructor(
sendMessageUseCase.value.send(message, room) sendMessageUseCase.value.send(message, room)
} }
override fun mediaDecrypter(): MediaDecrypter {
val mediaDecrypter = matrixMediaDecrypter.value
return object : MediaDecrypter {
override fun decrypt(input: InputStream, k: String, iv: String): MediaDecrypter.Collector {
return MediaDecrypter.Collector {
mediaDecrypter.decrypt(input, k, iv).collect(it)
}
}
}
}
class Factory { class Factory {
fun create( fun create(
@ -138,8 +152,9 @@ class MatrixEngine internal constructor(
SendMessageUseCase(matrix.messageService(), LocalIdFactory(), imageContentReader, Clock.systemUTC()) SendMessageUseCase(matrix.messageService(), LocalIdFactory(), imageContentReader, Clock.systemUTC())
} }
return MatrixEngine(directoryUseCase, lazyMatrix, timelineUseCase, sendMessageUseCase) val mediaDecrypter = unsafeLazy { MatrixMediaDecrypter(base64) }
return MatrixEngine(directoryUseCase, lazyMatrix, timelineUseCase, sendMessageUseCase, mediaDecrypter)
} }
} }

View File

@ -12,7 +12,7 @@ private const val CIPHER_ALGORITHM = "AES/CTR/NoPadding"
private const val SECRET_KEY_SPEC_ALGORITHM = "AES" private const val SECRET_KEY_SPEC_ALGORITHM = "AES"
private const val MESSAGE_DIGEST_ALGORITHM = "SHA-256" private const val MESSAGE_DIGEST_ALGORITHM = "SHA-256"
class MediaDecrypter(private val base64: Base64) { class MatrixMediaDecrypter(private val base64: Base64) {
fun decrypt(input: InputStream, k: String, iv: String): Collector { fun decrypt(input: InputStream, k: String, iv: String): Collector {
val key = base64.decode(k.replace('-', '+').replace('_', '/')) val key = base64.decode(k.replace('-', '+').replace('_', '/'))

View File

@ -21,7 +21,7 @@ interface SyncService : MatrixService {
fun startSyncing(): Flow<Unit> fun startSyncing(): Flow<Unit>
fun events(roomId: RoomId? = null): Flow<List<SyncEvent>> fun events(roomId: RoomId? = null): Flow<List<SyncEvent>>
suspend fun observeEvent(eventId: EventId): Flow<EventId> suspend fun observeEvent(eventId: EventId): Flow<EventId>
suspend fun forceManualRefresh(roomIds: List<RoomId>) suspend fun forceManualRefresh(roomIds: Set<RoomId>)
@JvmInline @JvmInline
value class FilterId(val value: String) value class FilterId(val value: String)

View File

@ -110,7 +110,7 @@ internal class DefaultSyncService(
override fun room(roomId: RoomId) = roomStore.latest(roomId) override fun room(roomId: RoomId) = roomStore.latest(roomId)
override fun events(roomId: RoomId?) = roomId?.let { syncEventsFlow.map { it.filter { it.roomId == roomId } }.distinctUntilChanged() } ?: syncEventsFlow override fun events(roomId: RoomId?) = roomId?.let { syncEventsFlow.map { it.filter { it.roomId == roomId } }.distinctUntilChanged() } ?: syncEventsFlow
override suspend fun observeEvent(eventId: EventId) = roomStore.observeEvent(eventId) override suspend fun observeEvent(eventId: EventId) = roomStore.observeEvent(eventId)
override suspend fun forceManualRefresh(roomIds: List<RoomId>) { override suspend fun forceManualRefresh(roomIds: Set<RoomId>) {
coroutineDispatchers.withIoContext { coroutineDispatchers.withIoContext {
roomIds.map { roomIds.map {
async { async {

View File

@ -7,7 +7,7 @@ import TestUser
import app.dapk.st.core.extensions.ifNull import app.dapk.st.core.extensions.ifNull
import app.dapk.st.matrix.common.RoomId import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.common.RoomMember import app.dapk.st.matrix.common.RoomMember
import app.dapk.st.matrix.crypto.MediaDecrypter import app.dapk.st.matrix.crypto.MatrixMediaDecrypter
import app.dapk.st.matrix.message.MessageService import app.dapk.st.matrix.message.MessageService
import app.dapk.st.matrix.message.messageService import app.dapk.st.matrix.message.messageService
import app.dapk.st.matrix.sync.RoomEvent import app.dapk.st.matrix.sync.RoomEvent
@ -153,7 +153,7 @@ class MatrixTestScope(private val testScope: TestScope) {
null -> output.readBytes().md5Hash() null -> output.readBytes().md5Hash()
else -> { else -> {
val byteStream = ByteArrayOutputStream() val byteStream = ByteArrayOutputStream()
MediaDecrypter(this.base64).decrypt(output.inputStream(), keys.k, keys.iv).collect { MatrixMediaDecrypter(this.base64).decrypt(output.inputStream(), keys.k, keys.iv).collect {
byteStream.write(it) byteStream.write(it)
} }
byteStream.toByteArray().md5Hash() byteStream.toByteArray().md5Hash()