using the import result directly in the UI, will allow separate error displays per reason

This commit is contained in:
Adam Brown 2022-09-04 13:36:48 +01:00
parent 6a3c594481
commit 71af573d06
5 changed files with 89 additions and 61 deletions

View File

@ -44,6 +44,7 @@ import app.dapk.st.design.components.SettingsTextRow
import app.dapk.st.design.components.Spider
import app.dapk.st.design.components.SpiderPage
import app.dapk.st.design.components.TextRow
import app.dapk.st.matrix.crypto.ImportResult
import app.dapk.st.navigator.Navigator
import app.dapk.st.settings.SettingsEvent.*
import app.dapk.st.settings.eventlogger.EventLogActivity
@ -137,10 +138,10 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
}
}
is LceWithProgress.Content -> {
is ImportResult.Success -> {
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
Text(text = "Successfully imported ${it.importProgress.value} keys")
Text(text = "Successfully imported ${it.importProgress.totalImportedKeysCount} keys")
Spacer(modifier = Modifier.height(12.dp))
Button(onClick = { navigator.navigate.upToHome() }) {
Text(text = "Close".uppercase())
@ -149,7 +150,7 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
}
}
is LceWithProgress.Error -> {
is ImportResult.Error -> {
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
Text(text = "Import failed")
@ -160,10 +161,10 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
}
}
is LceWithProgress.Loading -> {
is ImportResult.Update -> {
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
Text(text = "Imported ${it.importProgress.progress} keys")
Text(text = "Imported ${it.importProgress.importedKeysCount} keys")
Spacer(modifier = Modifier.height(12.dp))
CircularProgressIndicator(Modifier.wrapContentSize())
}

View File

@ -5,6 +5,7 @@ import app.dapk.st.core.Lce
import app.dapk.st.core.LceWithProgress
import app.dapk.st.design.components.Route
import app.dapk.st.design.components.SpiderPage
import app.dapk.st.matrix.crypto.ImportResult
import app.dapk.st.push.Registrar
internal data class SettingsScreenState(
@ -16,7 +17,7 @@ internal sealed interface Page {
object Security : Page
data class ImportRoomKey(
val selectedFile: NamedUri? = null,
val importProgress: LceWithProgress<Long>? = null,
val importProgress: ImportResult? = null,
) : Page
data class PushProviders(

View File

@ -2,7 +2,6 @@ package app.dapk.st.settings
import android.content.ContentResolver
import android.net.Uri
import android.util.Log
import androidx.lifecycle.viewModelScope
import app.dapk.st.core.Lce
import app.dapk.st.core.LceWithProgress
@ -125,21 +124,21 @@ internal class SettingsViewModel(
}
fun importFromFileKeys(file: Uri, passphrase: String) {
updatePageState<Page.ImportRoomKey> { copy(importProgress = LceWithProgress.Loading(0L)) }
updatePageState<Page.ImportRoomKey> { copy(importProgress = ImportResult.Update(0)) }
viewModelScope.launch {
with(cryptoService) {
contentResolver.openInputStream(file)?.importRoomKeys(passphrase)
?.onEach {
updatePageState<Page.ImportRoomKey> { copy(importProgress = it) }
when (it) {
is ImportResult.Error -> {
updatePageState<Page.ImportRoomKey> { copy(importProgress = LceWithProgress.Error(it.cause)) }
// do nothing
}
is ImportResult.Update -> {
updatePageState<Page.ImportRoomKey> { copy(importProgress = LceWithProgress.Loading(it.importedKeysCount)) }
// do nothing
}
is ImportResult.Success -> {
syncService.forceManualRefresh(it.roomIds.toList())
updatePageState<Page.ImportRoomKey> { copy(importProgress = LceWithProgress.Content(it.totalImportedKeysCount)) }
}
}
}

View File

@ -163,6 +163,14 @@ fun interface RoomMembersProvider {
sealed interface ImportResult {
data class Success(val roomIds: Set<RoomId>, val totalImportedKeysCount: Long) : ImportResult
data class Error(val cause: Throwable) : ImportResult
data class Error(val cause: Type) : ImportResult {
sealed interface Type {
data class Unknown(val cause: Throwable): Type
object NoKeysFound: Type
object UnexpectedDecryptionOutput: Type
}
}
data class Update(val importedKeysCount: Long) : ImportResult
}

View File

@ -2,17 +2,18 @@ package app.dapk.st.matrix.crypto.internal
import app.dapk.st.core.Base64
import app.dapk.st.core.CoroutineDispatchers
import app.dapk.st.core.withIoContext
import app.dapk.st.matrix.common.AlgorithmName
import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.common.SessionId
import app.dapk.st.matrix.common.SharedRoomKey
import app.dapk.st.matrix.crypto.ImportResult
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import java.io.IOException
import java.io.InputStream
import java.nio.charset.Charset
import javax.crypto.Cipher
@ -32,56 +33,72 @@ class RoomKeyImporter(
suspend fun InputStream.importRoomKeys(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit): Flow<ImportResult> {
return flow {
var importedKeysCount = 0L
val roomIds = mutableSetOf<RoomId>()
val decryptCipher = Cipher.getInstance("AES/CTR/NoPadding")
this@importRoomKeys.bufferedReader().use {
with(JsonAccumulator()) {
it.useLines { sequence ->
sequence
.filterNot { it == HEADER_LINE || it == TRAILER_LINE || it.isEmpty() }
.chunked(2)
.withIndex()
.map { (index, it) ->
val line = it.joinToString(separator = "").replace("\n", "")
val toByteArray = base64.decode(line)
if (index == 0) {
decryptCipher.initialize(toByteArray, password)
toByteArray.copyOfRange(37, toByteArray.size).decrypt(decryptCipher).also {
if (!it.startsWith("[{")) {
throw IllegalArgumentException("Unable to decrypt, assumed invalid password")
}
}
} else {
toByteArray.decrypt(decryptCipher)
}
}
.accumulateJson()
.map { decoded ->
roomIds.add(decoded.roomId)
SharedRoomKey(
decoded.algorithmName,
decoded.roomId,
decoded.sessionId,
decoded.sessionKey,
isExported = true,
)
}
.chunked(500)
.forEach {
onChunk(it)
importedKeysCount += it.size
emit(ImportResult.Update(importedKeysCount))
}
runCatching { this@importRoomKeys.import(password, onChunk, this) }
.onFailure {
when (it) {
is ImportException -> emit(ImportResult.Error(it.type))
else -> emit(ImportResult.Error(ImportResult.Error.Type.Unknown(it)))
}
}
if (roomIds.isEmpty()) {
emit(ImportResult.Error(IOException("Found no rooms to import in the file")))
} else {
emit(ImportResult.Success(roomIds, importedKeysCount))
}.flowOn(dispatchers.io)
}
private suspend fun InputStream.import(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit, collector: FlowCollector<ImportResult>) {
var importedKeysCount = 0L
val roomIds = mutableSetOf<RoomId>()
this.bufferedReader().use {
with(JsonAccumulator()) {
it.useLines { sequence ->
sequence
.filterNot { it == HEADER_LINE || it == TRAILER_LINE || it.isEmpty() }
.chunked(5)
.decrypt(password)
.accumulateJson()
.map { decoded ->
roomIds.add(decoded.roomId)
SharedRoomKey(
decoded.algorithmName,
decoded.roomId,
decoded.sessionId,
decoded.sessionKey,
isExported = true,
)
}
.chunked(500)
.forEach {
onChunk(it)
importedKeysCount += it.size
collector.emit(ImportResult.Update(importedKeysCount))
}
}
}
}.flowOn(dispatchers.io)
when {
roomIds.isEmpty() -> collector.emit(ImportResult.Error(ImportResult.Error.Type.NoKeysFound))
else -> collector.emit(ImportResult.Success(roomIds, importedKeysCount))
}
}
}
private fun Sequence<List<String>>.decrypt(password: String): Sequence<String> {
val decryptCipher = Cipher.getInstance("AES/CTR/NoPadding")
return this.withIndex().map { (index, it) ->
val line = it.joinToString(separator = "").replace("\n", "")
val toByteArray = base64.decode(line)
if (index == 0) {
decryptCipher.initialize(toByteArray, password)
toByteArray
.copyOfRange(37, toByteArray.size)
.decrypt(decryptCipher)
.also {
if (!it.startsWith("[{")) {
throw ImportException(ImportResult.Error.Type.UnexpectedDecryptionOutput)
}
}
} else {
toByteArray.decrypt(decryptCipher)
}
}
}
private fun Cipher.initialize(payload: ByteArray, passphrase: String) {
@ -145,6 +162,8 @@ private data class ElementMegolmExportObject(
@SerialName("algorithm") val algorithmName: AlgorithmName,
)
private class ImportException(val type: ImportResult.Error.Type) : Throwable()
private class JsonAccumulator {
private var jsonSegment = ""