porting media decryption to the chat engine

This commit is contained in:
Adam Brown 2022-10-09 19:38:33 +01:00
parent e61dea7ba7
commit 8bbf7258be
10 changed files with 51 additions and 43 deletions

View File

@ -179,7 +179,6 @@ internal class FeatureModules internal constructor(
MessengerModule(
matrixModules.engine,
context,
base64,
storeModule.value.messageStore(),
)
}
@ -466,10 +465,8 @@ internal class MatrixModules(
val push by unsafeLazy { matrix.pushService() }
val sync by unsafeLazy { matrix.syncService() }
val message by unsafeLazy { matrix.messageService() }
val room by unsafeLazy { matrix.roomService() }
val profile by unsafeLazy { matrix.profileService() }
val crypto by unsafeLazy { matrix.cryptoService() }
}
internal class DomainModules(

View File

@ -16,9 +16,19 @@ interface ChatEngine {
suspend fun me(forceRefresh: Boolean): Me
suspend fun refresh(roomIds: List<RoomId>)
suspend fun InputStream.importRoomKeys(password: String): Flow<ImportResult>
suspend fun send(message: SendMessage, room: RoomOverview)
fun mediaDecrypter(): MediaDecrypter
}
interface MediaDecrypter {
fun decrypt(input: InputStream, k: String, iv: String): Collector
fun interface Collector {
fun collect(partial: (ByteArray) -> Unit)
}
}

View File

@ -2,10 +2,9 @@ package app.dapk.st.messenger
import android.content.Context
import android.os.Environment
import app.dapk.st.core.Base64
import app.dapk.st.engine.MediaDecrypter
import app.dapk.st.engine.RoomEvent
import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.crypto.MediaDecrypter
import app.dapk.st.matrix.sync.RoomEvent
import coil.ImageLoader
import coil.decode.DataSource
import coil.decode.ImageSource
@ -20,9 +19,11 @@ import okio.BufferedSource
import okio.Path.Companion.toOkioPath
import java.io.File
class DecryptingFetcherFactory(private val context: Context, base64: Base64, private val roomId: RoomId) : Fetcher.Factory<RoomEvent.Image> {
private val mediaDecrypter = MediaDecrypter(base64)
class DecryptingFetcherFactory(
private val context: Context,
private val roomId: RoomId,
private val mediaDecrypter: MediaDecrypter,
) : Fetcher.Factory<RoomEvent.Image> {
override fun create(data: RoomEvent.Image, options: Options, imageLoader: ImageLoader): Fetcher {
return DecryptingFetcher(data, context, mediaDecrypter, roomId)

View File

@ -1,7 +1,6 @@
package app.dapk.st.messenger
import android.content.Context
import app.dapk.st.core.Base64
import app.dapk.st.core.ProvidableModule
import app.dapk.st.domain.application.message.MessageOptionsStore
import app.dapk.st.engine.ChatEngine
@ -10,7 +9,6 @@ import app.dapk.st.matrix.common.RoomId
class MessengerModule(
private val chatEngine: ChatEngine,
private val context: Context,
private val base64: Base64,
private val messageOptionsStore: MessageOptionsStore,
) : ProvidableModule {
@ -21,5 +19,5 @@ class MessengerModule(
)
}
internal fun decryptingFetcherFactory(roomId: RoomId) = DecryptingFetcherFactory(context, base64, roomId)
internal fun decryptingFetcherFactory(roomId: RoomId) = DecryptingFetcherFactory(context, roomId, chatEngine.mediaDecrypter())
}

View File

@ -147,19 +147,6 @@ internal class SettingsViewModel(
fileStream.importRoomKeys(passphrase)
.onEach {
updatePageState<Page.ImportRoomKey> { copy(importProgress = it) }
when (it) {
is ImportResult.Error -> {
// do nothing
}
is ImportResult.Update -> {
// do nothing
}
is ImportResult.Success -> {
chatEngine.refresh(it.roomIds.toList())
}
}
}
.launchIn(viewModelScope)
},

View File

@ -10,10 +10,7 @@ import app.dapk.st.matrix.auth.DeviceDisplayNameGenerator
import app.dapk.st.matrix.auth.authService
import app.dapk.st.matrix.auth.installAuthService
import app.dapk.st.matrix.common.*
import app.dapk.st.matrix.crypto.RoomMembersProvider
import app.dapk.st.matrix.crypto.Verification
import app.dapk.st.matrix.crypto.cryptoService
import app.dapk.st.matrix.crypto.installCryptoService
import app.dapk.st.matrix.crypto.*
import app.dapk.st.matrix.device.KnownDeviceStore
import app.dapk.st.matrix.device.deviceService
import app.dapk.st.matrix.device.installEncryptionService
@ -30,6 +27,7 @@ import app.dapk.st.olm.OlmStore
import app.dapk.st.olm.OlmWrapper
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onEach
import java.io.InputStream
import java.time.Clock
@ -38,6 +36,7 @@ class MatrixEngine internal constructor(
private val matrix: Lazy<MatrixClient>,
private val timelineUseCase: Lazy<ReadMarkingTimeline>,
private val sendMessageUseCase: Lazy<SendMessageUseCase>,
private val matrixMediaDecrypter: Lazy<MatrixMediaDecrypter>,
) : ChatEngine {
override fun directory() = directoryUseCase.value.state()
@ -57,14 +56,18 @@ class MatrixEngine internal constructor(
return matrix.value.profileService().me(forceRefresh).engine()
}
override suspend fun refresh(roomIds: List<RoomId>) {
matrix.value.syncService().forceManualRefresh(roomIds)
}
override suspend fun InputStream.importRoomKeys(password: String): Flow<ImportResult> {
return with(matrix.value.cryptoService()) {
importRoomKeys(password).map { it.engine() }
importRoomKeys(password).map { it.engine() }.onEach {
when (it) {
is ImportResult.Error,
is ImportResult.Update -> {
// do nothing
}
is ImportResult.Success -> matrix.value.syncService().forceManualRefresh(it.roomIds)
}
}
}
}
@ -72,6 +75,17 @@ class MatrixEngine internal constructor(
sendMessageUseCase.value.send(message, room)
}
override fun mediaDecrypter(): MediaDecrypter {
val mediaDecrypter = matrixMediaDecrypter.value
return object : MediaDecrypter {
override fun decrypt(input: InputStream, k: String, iv: String): MediaDecrypter.Collector {
return MediaDecrypter.Collector {
mediaDecrypter.decrypt(input, k, iv).collect(it)
}
}
}
}
class Factory {
fun create(
@ -138,8 +152,9 @@ class MatrixEngine internal constructor(
SendMessageUseCase(matrix.messageService(), LocalIdFactory(), imageContentReader, Clock.systemUTC())
}
return MatrixEngine(directoryUseCase, lazyMatrix, timelineUseCase, sendMessageUseCase)
val mediaDecrypter = unsafeLazy { MatrixMediaDecrypter(base64) }
return MatrixEngine(directoryUseCase, lazyMatrix, timelineUseCase, sendMessageUseCase, mediaDecrypter)
}
}

View File

@ -12,7 +12,7 @@ private const val CIPHER_ALGORITHM = "AES/CTR/NoPadding"
private const val SECRET_KEY_SPEC_ALGORITHM = "AES"
private const val MESSAGE_DIGEST_ALGORITHM = "SHA-256"
class MediaDecrypter(private val base64: Base64) {
class MatrixMediaDecrypter(private val base64: Base64) {
fun decrypt(input: InputStream, k: String, iv: String): Collector {
val key = base64.decode(k.replace('-', '+').replace('_', '/'))

View File

@ -21,7 +21,7 @@ interface SyncService : MatrixService {
fun startSyncing(): Flow<Unit>
fun events(roomId: RoomId? = null): Flow<List<SyncEvent>>
suspend fun observeEvent(eventId: EventId): Flow<EventId>
suspend fun forceManualRefresh(roomIds: List<RoomId>)
suspend fun forceManualRefresh(roomIds: Set<RoomId>)
@JvmInline
value class FilterId(val value: String)

View File

@ -110,7 +110,7 @@ internal class DefaultSyncService(
override fun room(roomId: RoomId) = roomStore.latest(roomId)
override fun events(roomId: RoomId?) = roomId?.let { syncEventsFlow.map { it.filter { it.roomId == roomId } }.distinctUntilChanged() } ?: syncEventsFlow
override suspend fun observeEvent(eventId: EventId) = roomStore.observeEvent(eventId)
override suspend fun forceManualRefresh(roomIds: List<RoomId>) {
override suspend fun forceManualRefresh(roomIds: Set<RoomId>) {
coroutineDispatchers.withIoContext {
roomIds.map {
async {

View File

@ -7,7 +7,7 @@ import TestUser
import app.dapk.st.core.extensions.ifNull
import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.common.RoomMember
import app.dapk.st.matrix.crypto.MediaDecrypter
import app.dapk.st.matrix.crypto.MatrixMediaDecrypter
import app.dapk.st.matrix.message.MessageService
import app.dapk.st.matrix.message.messageService
import app.dapk.st.matrix.sync.RoomEvent
@ -153,7 +153,7 @@ class MatrixTestScope(private val testScope: TestScope) {
null -> output.readBytes().md5Hash()
else -> {
val byteStream = ByteArrayOutputStream()
MediaDecrypter(this.base64).decrypt(output.inputStream(), keys.k, keys.iv).collect {
MatrixMediaDecrypter(this.base64).decrypt(output.inputStream(), keys.k, keys.iv).collect {
byteStream.write(it)
}
byteStream.toByteArray().md5Hash()