using the import result directly in the UI, will allow separate error displays per reason
This commit is contained in:
parent
6a3c594481
commit
71af573d06
|
@ -44,6 +44,7 @@ import app.dapk.st.design.components.SettingsTextRow
|
||||||
import app.dapk.st.design.components.Spider
|
import app.dapk.st.design.components.Spider
|
||||||
import app.dapk.st.design.components.SpiderPage
|
import app.dapk.st.design.components.SpiderPage
|
||||||
import app.dapk.st.design.components.TextRow
|
import app.dapk.st.design.components.TextRow
|
||||||
|
import app.dapk.st.matrix.crypto.ImportResult
|
||||||
import app.dapk.st.navigator.Navigator
|
import app.dapk.st.navigator.Navigator
|
||||||
import app.dapk.st.settings.SettingsEvent.*
|
import app.dapk.st.settings.SettingsEvent.*
|
||||||
import app.dapk.st.settings.eventlogger.EventLogActivity
|
import app.dapk.st.settings.eventlogger.EventLogActivity
|
||||||
|
@ -137,10 +138,10 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
is LceWithProgress.Content -> {
|
is ImportResult.Success -> {
|
||||||
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
|
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
|
||||||
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
||||||
Text(text = "Successfully imported ${it.importProgress.value} keys")
|
Text(text = "Successfully imported ${it.importProgress.totalImportedKeysCount} keys")
|
||||||
Spacer(modifier = Modifier.height(12.dp))
|
Spacer(modifier = Modifier.height(12.dp))
|
||||||
Button(onClick = { navigator.navigate.upToHome() }) {
|
Button(onClick = { navigator.navigate.upToHome() }) {
|
||||||
Text(text = "Close".uppercase())
|
Text(text = "Close".uppercase())
|
||||||
|
@ -149,7 +150,7 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
is LceWithProgress.Error -> {
|
is ImportResult.Error -> {
|
||||||
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
|
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
|
||||||
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
||||||
Text(text = "Import failed")
|
Text(text = "Import failed")
|
||||||
|
@ -160,10 +161,10 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
is LceWithProgress.Loading -> {
|
is ImportResult.Update -> {
|
||||||
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
|
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
|
||||||
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
||||||
Text(text = "Imported ${it.importProgress.progress} keys")
|
Text(text = "Imported ${it.importProgress.importedKeysCount} keys")
|
||||||
Spacer(modifier = Modifier.height(12.dp))
|
Spacer(modifier = Modifier.height(12.dp))
|
||||||
CircularProgressIndicator(Modifier.wrapContentSize())
|
CircularProgressIndicator(Modifier.wrapContentSize())
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import app.dapk.st.core.Lce
|
||||||
import app.dapk.st.core.LceWithProgress
|
import app.dapk.st.core.LceWithProgress
|
||||||
import app.dapk.st.design.components.Route
|
import app.dapk.st.design.components.Route
|
||||||
import app.dapk.st.design.components.SpiderPage
|
import app.dapk.st.design.components.SpiderPage
|
||||||
|
import app.dapk.st.matrix.crypto.ImportResult
|
||||||
import app.dapk.st.push.Registrar
|
import app.dapk.st.push.Registrar
|
||||||
|
|
||||||
internal data class SettingsScreenState(
|
internal data class SettingsScreenState(
|
||||||
|
@ -16,7 +17,7 @@ internal sealed interface Page {
|
||||||
object Security : Page
|
object Security : Page
|
||||||
data class ImportRoomKey(
|
data class ImportRoomKey(
|
||||||
val selectedFile: NamedUri? = null,
|
val selectedFile: NamedUri? = null,
|
||||||
val importProgress: LceWithProgress<Long>? = null,
|
val importProgress: ImportResult? = null,
|
||||||
) : Page
|
) : Page
|
||||||
|
|
||||||
data class PushProviders(
|
data class PushProviders(
|
||||||
|
|
|
@ -2,7 +2,6 @@ package app.dapk.st.settings
|
||||||
|
|
||||||
import android.content.ContentResolver
|
import android.content.ContentResolver
|
||||||
import android.net.Uri
|
import android.net.Uri
|
||||||
import android.util.Log
|
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import app.dapk.st.core.Lce
|
import app.dapk.st.core.Lce
|
||||||
import app.dapk.st.core.LceWithProgress
|
import app.dapk.st.core.LceWithProgress
|
||||||
|
@ -125,21 +124,21 @@ internal class SettingsViewModel(
|
||||||
}
|
}
|
||||||
|
|
||||||
fun importFromFileKeys(file: Uri, passphrase: String) {
|
fun importFromFileKeys(file: Uri, passphrase: String) {
|
||||||
updatePageState<Page.ImportRoomKey> { copy(importProgress = LceWithProgress.Loading(0L)) }
|
updatePageState<Page.ImportRoomKey> { copy(importProgress = ImportResult.Update(0)) }
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
with(cryptoService) {
|
with(cryptoService) {
|
||||||
contentResolver.openInputStream(file)?.importRoomKeys(passphrase)
|
contentResolver.openInputStream(file)?.importRoomKeys(passphrase)
|
||||||
?.onEach {
|
?.onEach {
|
||||||
|
updatePageState<Page.ImportRoomKey> { copy(importProgress = it) }
|
||||||
when (it) {
|
when (it) {
|
||||||
is ImportResult.Error -> {
|
is ImportResult.Error -> {
|
||||||
updatePageState<Page.ImportRoomKey> { copy(importProgress = LceWithProgress.Error(it.cause)) }
|
// do nothing
|
||||||
}
|
}
|
||||||
is ImportResult.Update -> {
|
is ImportResult.Update -> {
|
||||||
updatePageState<Page.ImportRoomKey> { copy(importProgress = LceWithProgress.Loading(it.importedKeysCount)) }
|
// do nothing
|
||||||
}
|
}
|
||||||
is ImportResult.Success -> {
|
is ImportResult.Success -> {
|
||||||
syncService.forceManualRefresh(it.roomIds.toList())
|
syncService.forceManualRefresh(it.roomIds.toList())
|
||||||
updatePageState<Page.ImportRoomKey> { copy(importProgress = LceWithProgress.Content(it.totalImportedKeysCount)) }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -163,6 +163,14 @@ fun interface RoomMembersProvider {
|
||||||
|
|
||||||
sealed interface ImportResult {
|
sealed interface ImportResult {
|
||||||
data class Success(val roomIds: Set<RoomId>, val totalImportedKeysCount: Long) : ImportResult
|
data class Success(val roomIds: Set<RoomId>, val totalImportedKeysCount: Long) : ImportResult
|
||||||
data class Error(val cause: Throwable) : ImportResult
|
data class Error(val cause: Type) : ImportResult {
|
||||||
|
|
||||||
|
sealed interface Type {
|
||||||
|
data class Unknown(val cause: Throwable): Type
|
||||||
|
object NoKeysFound: Type
|
||||||
|
object UnexpectedDecryptionOutput: Type
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
data class Update(val importedKeysCount: Long) : ImportResult
|
data class Update(val importedKeysCount: Long) : ImportResult
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,17 +2,18 @@ package app.dapk.st.matrix.crypto.internal
|
||||||
|
|
||||||
import app.dapk.st.core.Base64
|
import app.dapk.st.core.Base64
|
||||||
import app.dapk.st.core.CoroutineDispatchers
|
import app.dapk.st.core.CoroutineDispatchers
|
||||||
import app.dapk.st.core.withIoContext
|
|
||||||
import app.dapk.st.matrix.common.AlgorithmName
|
import app.dapk.st.matrix.common.AlgorithmName
|
||||||
import app.dapk.st.matrix.common.RoomId
|
import app.dapk.st.matrix.common.RoomId
|
||||||
import app.dapk.st.matrix.common.SessionId
|
import app.dapk.st.matrix.common.SessionId
|
||||||
import app.dapk.st.matrix.common.SharedRoomKey
|
import app.dapk.st.matrix.common.SharedRoomKey
|
||||||
import app.dapk.st.matrix.crypto.ImportResult
|
import app.dapk.st.matrix.crypto.ImportResult
|
||||||
import kotlinx.coroutines.flow.*
|
import kotlinx.coroutines.flow.Flow
|
||||||
|
import kotlinx.coroutines.flow.FlowCollector
|
||||||
|
import kotlinx.coroutines.flow.flow
|
||||||
|
import kotlinx.coroutines.flow.flowOn
|
||||||
import kotlinx.serialization.SerialName
|
import kotlinx.serialization.SerialName
|
||||||
import kotlinx.serialization.Serializable
|
import kotlinx.serialization.Serializable
|
||||||
import kotlinx.serialization.json.Json
|
import kotlinx.serialization.json.Json
|
||||||
import java.io.IOException
|
|
||||||
import java.io.InputStream
|
import java.io.InputStream
|
||||||
import java.nio.charset.Charset
|
import java.nio.charset.Charset
|
||||||
import javax.crypto.Cipher
|
import javax.crypto.Cipher
|
||||||
|
@ -32,30 +33,27 @@ class RoomKeyImporter(
|
||||||
|
|
||||||
suspend fun InputStream.importRoomKeys(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit): Flow<ImportResult> {
|
suspend fun InputStream.importRoomKeys(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit): Flow<ImportResult> {
|
||||||
return flow {
|
return flow {
|
||||||
|
runCatching { this@importRoomKeys.import(password, onChunk, this) }
|
||||||
|
.onFailure {
|
||||||
|
when (it) {
|
||||||
|
is ImportException -> emit(ImportResult.Error(it.type))
|
||||||
|
else -> emit(ImportResult.Error(ImportResult.Error.Type.Unknown(it)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}.flowOn(dispatchers.io)
|
||||||
|
}
|
||||||
|
|
||||||
|
private suspend fun InputStream.import(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit, collector: FlowCollector<ImportResult>) {
|
||||||
var importedKeysCount = 0L
|
var importedKeysCount = 0L
|
||||||
val roomIds = mutableSetOf<RoomId>()
|
val roomIds = mutableSetOf<RoomId>()
|
||||||
val decryptCipher = Cipher.getInstance("AES/CTR/NoPadding")
|
|
||||||
this@importRoomKeys.bufferedReader().use {
|
this.bufferedReader().use {
|
||||||
with(JsonAccumulator()) {
|
with(JsonAccumulator()) {
|
||||||
it.useLines { sequence ->
|
it.useLines { sequence ->
|
||||||
sequence
|
sequence
|
||||||
.filterNot { it == HEADER_LINE || it == TRAILER_LINE || it.isEmpty() }
|
.filterNot { it == HEADER_LINE || it == TRAILER_LINE || it.isEmpty() }
|
||||||
.chunked(2)
|
.chunked(5)
|
||||||
.withIndex()
|
.decrypt(password)
|
||||||
.map { (index, it) ->
|
|
||||||
val line = it.joinToString(separator = "").replace("\n", "")
|
|
||||||
val toByteArray = base64.decode(line)
|
|
||||||
if (index == 0) {
|
|
||||||
decryptCipher.initialize(toByteArray, password)
|
|
||||||
toByteArray.copyOfRange(37, toByteArray.size).decrypt(decryptCipher).also {
|
|
||||||
if (!it.startsWith("[{")) {
|
|
||||||
throw IllegalArgumentException("Unable to decrypt, assumed invalid password")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
toByteArray.decrypt(decryptCipher)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.accumulateJson()
|
.accumulateJson()
|
||||||
.map { decoded ->
|
.map { decoded ->
|
||||||
roomIds.add(decoded.roomId)
|
roomIds.add(decoded.roomId)
|
||||||
|
@ -71,17 +69,36 @@ class RoomKeyImporter(
|
||||||
.forEach {
|
.forEach {
|
||||||
onChunk(it)
|
onChunk(it)
|
||||||
importedKeysCount += it.size
|
importedKeysCount += it.size
|
||||||
emit(ImportResult.Update(importedKeysCount))
|
collector.emit(ImportResult.Update(importedKeysCount))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (roomIds.isEmpty()) {
|
when {
|
||||||
emit(ImportResult.Error(IOException("Found no rooms to import in the file")))
|
roomIds.isEmpty() -> collector.emit(ImportResult.Error(ImportResult.Error.Type.NoKeysFound))
|
||||||
|
else -> collector.emit(ImportResult.Success(roomIds, importedKeysCount))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun Sequence<List<String>>.decrypt(password: String): Sequence<String> {
|
||||||
|
val decryptCipher = Cipher.getInstance("AES/CTR/NoPadding")
|
||||||
|
return this.withIndex().map { (index, it) ->
|
||||||
|
val line = it.joinToString(separator = "").replace("\n", "")
|
||||||
|
val toByteArray = base64.decode(line)
|
||||||
|
if (index == 0) {
|
||||||
|
decryptCipher.initialize(toByteArray, password)
|
||||||
|
toByteArray
|
||||||
|
.copyOfRange(37, toByteArray.size)
|
||||||
|
.decrypt(decryptCipher)
|
||||||
|
.also {
|
||||||
|
if (!it.startsWith("[{")) {
|
||||||
|
throw ImportException(ImportResult.Error.Type.UnexpectedDecryptionOutput)
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
emit(ImportResult.Success(roomIds, importedKeysCount))
|
toByteArray.decrypt(decryptCipher)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}.flowOn(dispatchers.io)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun Cipher.initialize(payload: ByteArray, passphrase: String) {
|
private fun Cipher.initialize(payload: ByteArray, passphrase: String) {
|
||||||
|
@ -145,6 +162,8 @@ private data class ElementMegolmExportObject(
|
||||||
@SerialName("algorithm") val algorithmName: AlgorithmName,
|
@SerialName("algorithm") val algorithmName: AlgorithmName,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
private class ImportException(val type: ImportResult.Error.Type) : Throwable()
|
||||||
|
|
||||||
private class JsonAccumulator {
|
private class JsonAccumulator {
|
||||||
|
|
||||||
private var jsonSegment = ""
|
private var jsonSegment = ""
|
||||||
|
|
Loading…
Reference in New Issue