using the import result directly in the UI, will allow separate error displays per reason

This commit is contained in:
Adam Brown 2022-09-04 13:36:48 +01:00
parent 6a3c594481
commit 71af573d06
5 changed files with 89 additions and 61 deletions

View File

@ -44,6 +44,7 @@ import app.dapk.st.design.components.SettingsTextRow
import app.dapk.st.design.components.Spider import app.dapk.st.design.components.Spider
import app.dapk.st.design.components.SpiderPage import app.dapk.st.design.components.SpiderPage
import app.dapk.st.design.components.TextRow import app.dapk.st.design.components.TextRow
import app.dapk.st.matrix.crypto.ImportResult
import app.dapk.st.navigator.Navigator import app.dapk.st.navigator.Navigator
import app.dapk.st.settings.SettingsEvent.* import app.dapk.st.settings.SettingsEvent.*
import app.dapk.st.settings.eventlogger.EventLogActivity import app.dapk.st.settings.eventlogger.EventLogActivity
@ -137,10 +138,10 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
} }
} }
is LceWithProgress.Content -> { is ImportResult.Success -> {
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) { Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
Column(horizontalAlignment = Alignment.CenterHorizontally) { Column(horizontalAlignment = Alignment.CenterHorizontally) {
Text(text = "Successfully imported ${it.importProgress.value} keys") Text(text = "Successfully imported ${it.importProgress.totalImportedKeysCount} keys")
Spacer(modifier = Modifier.height(12.dp)) Spacer(modifier = Modifier.height(12.dp))
Button(onClick = { navigator.navigate.upToHome() }) { Button(onClick = { navigator.navigate.upToHome() }) {
Text(text = "Close".uppercase()) Text(text = "Close".uppercase())
@ -149,7 +150,7 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
} }
} }
is LceWithProgress.Error -> { is ImportResult.Error -> {
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) { Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
Column(horizontalAlignment = Alignment.CenterHorizontally) { Column(horizontalAlignment = Alignment.CenterHorizontally) {
Text(text = "Import failed") Text(text = "Import failed")
@ -160,10 +161,10 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
} }
} }
is LceWithProgress.Loading -> { is ImportResult.Update -> {
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) { Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
Column(horizontalAlignment = Alignment.CenterHorizontally) { Column(horizontalAlignment = Alignment.CenterHorizontally) {
Text(text = "Imported ${it.importProgress.progress} keys") Text(text = "Imported ${it.importProgress.importedKeysCount} keys")
Spacer(modifier = Modifier.height(12.dp)) Spacer(modifier = Modifier.height(12.dp))
CircularProgressIndicator(Modifier.wrapContentSize()) CircularProgressIndicator(Modifier.wrapContentSize())
} }

View File

@ -5,6 +5,7 @@ import app.dapk.st.core.Lce
import app.dapk.st.core.LceWithProgress import app.dapk.st.core.LceWithProgress
import app.dapk.st.design.components.Route import app.dapk.st.design.components.Route
import app.dapk.st.design.components.SpiderPage import app.dapk.st.design.components.SpiderPage
import app.dapk.st.matrix.crypto.ImportResult
import app.dapk.st.push.Registrar import app.dapk.st.push.Registrar
internal data class SettingsScreenState( internal data class SettingsScreenState(
@ -16,7 +17,7 @@ internal sealed interface Page {
object Security : Page object Security : Page
data class ImportRoomKey( data class ImportRoomKey(
val selectedFile: NamedUri? = null, val selectedFile: NamedUri? = null,
val importProgress: LceWithProgress<Long>? = null, val importProgress: ImportResult? = null,
) : Page ) : Page
data class PushProviders( data class PushProviders(

View File

@ -2,7 +2,6 @@ package app.dapk.st.settings
import android.content.ContentResolver import android.content.ContentResolver
import android.net.Uri import android.net.Uri
import android.util.Log
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import app.dapk.st.core.Lce import app.dapk.st.core.Lce
import app.dapk.st.core.LceWithProgress import app.dapk.st.core.LceWithProgress
@ -125,21 +124,21 @@ internal class SettingsViewModel(
} }
fun importFromFileKeys(file: Uri, passphrase: String) { fun importFromFileKeys(file: Uri, passphrase: String) {
updatePageState<Page.ImportRoomKey> { copy(importProgress = LceWithProgress.Loading(0L)) } updatePageState<Page.ImportRoomKey> { copy(importProgress = ImportResult.Update(0)) }
viewModelScope.launch { viewModelScope.launch {
with(cryptoService) { with(cryptoService) {
contentResolver.openInputStream(file)?.importRoomKeys(passphrase) contentResolver.openInputStream(file)?.importRoomKeys(passphrase)
?.onEach { ?.onEach {
updatePageState<Page.ImportRoomKey> { copy(importProgress = it) }
when (it) { when (it) {
is ImportResult.Error -> { is ImportResult.Error -> {
updatePageState<Page.ImportRoomKey> { copy(importProgress = LceWithProgress.Error(it.cause)) } // do nothing
} }
is ImportResult.Update -> { is ImportResult.Update -> {
updatePageState<Page.ImportRoomKey> { copy(importProgress = LceWithProgress.Loading(it.importedKeysCount)) } // do nothing
} }
is ImportResult.Success -> { is ImportResult.Success -> {
syncService.forceManualRefresh(it.roomIds.toList()) syncService.forceManualRefresh(it.roomIds.toList())
updatePageState<Page.ImportRoomKey> { copy(importProgress = LceWithProgress.Content(it.totalImportedKeysCount)) }
} }
} }
} }

View File

@ -163,6 +163,14 @@ fun interface RoomMembersProvider {
sealed interface ImportResult { sealed interface ImportResult {
data class Success(val roomIds: Set<RoomId>, val totalImportedKeysCount: Long) : ImportResult data class Success(val roomIds: Set<RoomId>, val totalImportedKeysCount: Long) : ImportResult
data class Error(val cause: Throwable) : ImportResult data class Error(val cause: Type) : ImportResult {
sealed interface Type {
data class Unknown(val cause: Throwable): Type
object NoKeysFound: Type
object UnexpectedDecryptionOutput: Type
}
}
data class Update(val importedKeysCount: Long) : ImportResult data class Update(val importedKeysCount: Long) : ImportResult
} }

View File

@ -2,17 +2,18 @@ package app.dapk.st.matrix.crypto.internal
import app.dapk.st.core.Base64 import app.dapk.st.core.Base64
import app.dapk.st.core.CoroutineDispatchers import app.dapk.st.core.CoroutineDispatchers
import app.dapk.st.core.withIoContext
import app.dapk.st.matrix.common.AlgorithmName import app.dapk.st.matrix.common.AlgorithmName
import app.dapk.st.matrix.common.RoomId import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.common.SessionId import app.dapk.st.matrix.common.SessionId
import app.dapk.st.matrix.common.SharedRoomKey import app.dapk.st.matrix.common.SharedRoomKey
import app.dapk.st.matrix.crypto.ImportResult import app.dapk.st.matrix.crypto.ImportResult
import kotlinx.coroutines.flow.* import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.serialization.SerialName import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
import java.io.IOException
import java.io.InputStream import java.io.InputStream
import java.nio.charset.Charset import java.nio.charset.Charset
import javax.crypto.Cipher import javax.crypto.Cipher
@ -32,30 +33,27 @@ class RoomKeyImporter(
suspend fun InputStream.importRoomKeys(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit): Flow<ImportResult> { suspend fun InputStream.importRoomKeys(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit): Flow<ImportResult> {
return flow { return flow {
runCatching { this@importRoomKeys.import(password, onChunk, this) }
.onFailure {
when (it) {
is ImportException -> emit(ImportResult.Error(it.type))
else -> emit(ImportResult.Error(ImportResult.Error.Type.Unknown(it)))
}
}
}.flowOn(dispatchers.io)
}
private suspend fun InputStream.import(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit, collector: FlowCollector<ImportResult>) {
var importedKeysCount = 0L var importedKeysCount = 0L
val roomIds = mutableSetOf<RoomId>() val roomIds = mutableSetOf<RoomId>()
val decryptCipher = Cipher.getInstance("AES/CTR/NoPadding")
this@importRoomKeys.bufferedReader().use { this.bufferedReader().use {
with(JsonAccumulator()) { with(JsonAccumulator()) {
it.useLines { sequence -> it.useLines { sequence ->
sequence sequence
.filterNot { it == HEADER_LINE || it == TRAILER_LINE || it.isEmpty() } .filterNot { it == HEADER_LINE || it == TRAILER_LINE || it.isEmpty() }
.chunked(2) .chunked(5)
.withIndex() .decrypt(password)
.map { (index, it) ->
val line = it.joinToString(separator = "").replace("\n", "")
val toByteArray = base64.decode(line)
if (index == 0) {
decryptCipher.initialize(toByteArray, password)
toByteArray.copyOfRange(37, toByteArray.size).decrypt(decryptCipher).also {
if (!it.startsWith("[{")) {
throw IllegalArgumentException("Unable to decrypt, assumed invalid password")
}
}
} else {
toByteArray.decrypt(decryptCipher)
}
}
.accumulateJson() .accumulateJson()
.map { decoded -> .map { decoded ->
roomIds.add(decoded.roomId) roomIds.add(decoded.roomId)
@ -71,17 +69,36 @@ class RoomKeyImporter(
.forEach { .forEach {
onChunk(it) onChunk(it)
importedKeysCount += it.size importedKeysCount += it.size
emit(ImportResult.Update(importedKeysCount)) collector.emit(ImportResult.Update(importedKeysCount))
} }
} }
} }
if (roomIds.isEmpty()) { when {
emit(ImportResult.Error(IOException("Found no rooms to import in the file"))) roomIds.isEmpty() -> collector.emit(ImportResult.Error(ImportResult.Error.Type.NoKeysFound))
else -> collector.emit(ImportResult.Success(roomIds, importedKeysCount))
}
}
}
private fun Sequence<List<String>>.decrypt(password: String): Sequence<String> {
val decryptCipher = Cipher.getInstance("AES/CTR/NoPadding")
return this.withIndex().map { (index, it) ->
val line = it.joinToString(separator = "").replace("\n", "")
val toByteArray = base64.decode(line)
if (index == 0) {
decryptCipher.initialize(toByteArray, password)
toByteArray
.copyOfRange(37, toByteArray.size)
.decrypt(decryptCipher)
.also {
if (!it.startsWith("[{")) {
throw ImportException(ImportResult.Error.Type.UnexpectedDecryptionOutput)
}
}
} else { } else {
emit(ImportResult.Success(roomIds, importedKeysCount)) toByteArray.decrypt(decryptCipher)
} }
} }
}.flowOn(dispatchers.io)
} }
private fun Cipher.initialize(payload: ByteArray, passphrase: String) { private fun Cipher.initialize(payload: ByteArray, passphrase: String) {
@ -145,6 +162,8 @@ private data class ElementMegolmExportObject(
@SerialName("algorithm") val algorithmName: AlgorithmName, @SerialName("algorithm") val algorithmName: AlgorithmName,
) )
private class ImportException(val type: ImportResult.Error.Type) : Throwable()
private class JsonAccumulator { private class JsonAccumulator {
private var jsonSegment = "" private var jsonSegment = ""