mirror of
https://github.com/ouchadam/small-talk.git
synced 2025-02-07 06:44:36 +01:00
using the import result directly in the UI, will allow separate error displays per reason
This commit is contained in:
parent
6a3c594481
commit
71af573d06
@ -44,6 +44,7 @@ import app.dapk.st.design.components.SettingsTextRow
|
||||
import app.dapk.st.design.components.Spider
|
||||
import app.dapk.st.design.components.SpiderPage
|
||||
import app.dapk.st.design.components.TextRow
|
||||
import app.dapk.st.matrix.crypto.ImportResult
|
||||
import app.dapk.st.navigator.Navigator
|
||||
import app.dapk.st.settings.SettingsEvent.*
|
||||
import app.dapk.st.settings.eventlogger.EventLogActivity
|
||||
@ -137,10 +138,10 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
|
||||
}
|
||||
}
|
||||
|
||||
is LceWithProgress.Content -> {
|
||||
is ImportResult.Success -> {
|
||||
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
|
||||
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
||||
Text(text = "Successfully imported ${it.importProgress.value} keys")
|
||||
Text(text = "Successfully imported ${it.importProgress.totalImportedKeysCount} keys")
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
Button(onClick = { navigator.navigate.upToHome() }) {
|
||||
Text(text = "Close".uppercase())
|
||||
@ -149,7 +150,7 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
|
||||
}
|
||||
}
|
||||
|
||||
is LceWithProgress.Error -> {
|
||||
is ImportResult.Error -> {
|
||||
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
|
||||
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
||||
Text(text = "Import failed")
|
||||
@ -160,10 +161,10 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
|
||||
}
|
||||
}
|
||||
|
||||
is LceWithProgress.Loading -> {
|
||||
is ImportResult.Update -> {
|
||||
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
|
||||
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
||||
Text(text = "Imported ${it.importProgress.progress} keys")
|
||||
Text(text = "Imported ${it.importProgress.importedKeysCount} keys")
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
CircularProgressIndicator(Modifier.wrapContentSize())
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import app.dapk.st.core.Lce
|
||||
import app.dapk.st.core.LceWithProgress
|
||||
import app.dapk.st.design.components.Route
|
||||
import app.dapk.st.design.components.SpiderPage
|
||||
import app.dapk.st.matrix.crypto.ImportResult
|
||||
import app.dapk.st.push.Registrar
|
||||
|
||||
internal data class SettingsScreenState(
|
||||
@ -16,7 +17,7 @@ internal sealed interface Page {
|
||||
object Security : Page
|
||||
data class ImportRoomKey(
|
||||
val selectedFile: NamedUri? = null,
|
||||
val importProgress: LceWithProgress<Long>? = null,
|
||||
val importProgress: ImportResult? = null,
|
||||
) : Page
|
||||
|
||||
data class PushProviders(
|
||||
|
@ -2,7 +2,6 @@ package app.dapk.st.settings
|
||||
|
||||
import android.content.ContentResolver
|
||||
import android.net.Uri
|
||||
import android.util.Log
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import app.dapk.st.core.Lce
|
||||
import app.dapk.st.core.LceWithProgress
|
||||
@ -125,21 +124,21 @@ internal class SettingsViewModel(
|
||||
}
|
||||
|
||||
fun importFromFileKeys(file: Uri, passphrase: String) {
|
||||
updatePageState<Page.ImportRoomKey> { copy(importProgress = LceWithProgress.Loading(0L)) }
|
||||
updatePageState<Page.ImportRoomKey> { copy(importProgress = ImportResult.Update(0)) }
|
||||
viewModelScope.launch {
|
||||
with(cryptoService) {
|
||||
contentResolver.openInputStream(file)?.importRoomKeys(passphrase)
|
||||
?.onEach {
|
||||
updatePageState<Page.ImportRoomKey> { copy(importProgress = it) }
|
||||
when (it) {
|
||||
is ImportResult.Error -> {
|
||||
updatePageState<Page.ImportRoomKey> { copy(importProgress = LceWithProgress.Error(it.cause)) }
|
||||
// do nothing
|
||||
}
|
||||
is ImportResult.Update -> {
|
||||
updatePageState<Page.ImportRoomKey> { copy(importProgress = LceWithProgress.Loading(it.importedKeysCount)) }
|
||||
// do nothing
|
||||
}
|
||||
is ImportResult.Success -> {
|
||||
syncService.forceManualRefresh(it.roomIds.toList())
|
||||
updatePageState<Page.ImportRoomKey> { copy(importProgress = LceWithProgress.Content(it.totalImportedKeysCount)) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -163,6 +163,14 @@ fun interface RoomMembersProvider {
|
||||
|
||||
sealed interface ImportResult {
|
||||
data class Success(val roomIds: Set<RoomId>, val totalImportedKeysCount: Long) : ImportResult
|
||||
data class Error(val cause: Throwable) : ImportResult
|
||||
data class Error(val cause: Type) : ImportResult {
|
||||
|
||||
sealed interface Type {
|
||||
data class Unknown(val cause: Throwable): Type
|
||||
object NoKeysFound: Type
|
||||
object UnexpectedDecryptionOutput: Type
|
||||
}
|
||||
|
||||
}
|
||||
data class Update(val importedKeysCount: Long) : ImportResult
|
||||
}
|
||||
|
@ -2,17 +2,18 @@ package app.dapk.st.matrix.crypto.internal
|
||||
|
||||
import app.dapk.st.core.Base64
|
||||
import app.dapk.st.core.CoroutineDispatchers
|
||||
import app.dapk.st.core.withIoContext
|
||||
import app.dapk.st.matrix.common.AlgorithmName
|
||||
import app.dapk.st.matrix.common.RoomId
|
||||
import app.dapk.st.matrix.common.SessionId
|
||||
import app.dapk.st.matrix.common.SharedRoomKey
|
||||
import app.dapk.st.matrix.crypto.ImportResult
|
||||
import kotlinx.coroutines.flow.*
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.FlowCollector
|
||||
import kotlinx.coroutines.flow.flow
|
||||
import kotlinx.coroutines.flow.flowOn
|
||||
import kotlinx.serialization.SerialName
|
||||
import kotlinx.serialization.Serializable
|
||||
import kotlinx.serialization.json.Json
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
import java.nio.charset.Charset
|
||||
import javax.crypto.Cipher
|
||||
@ -32,56 +33,72 @@ class RoomKeyImporter(
|
||||
|
||||
suspend fun InputStream.importRoomKeys(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit): Flow<ImportResult> {
|
||||
return flow {
|
||||
var importedKeysCount = 0L
|
||||
val roomIds = mutableSetOf<RoomId>()
|
||||
val decryptCipher = Cipher.getInstance("AES/CTR/NoPadding")
|
||||
this@importRoomKeys.bufferedReader().use {
|
||||
with(JsonAccumulator()) {
|
||||
it.useLines { sequence ->
|
||||
sequence
|
||||
.filterNot { it == HEADER_LINE || it == TRAILER_LINE || it.isEmpty() }
|
||||
.chunked(2)
|
||||
.withIndex()
|
||||
.map { (index, it) ->
|
||||
val line = it.joinToString(separator = "").replace("\n", "")
|
||||
val toByteArray = base64.decode(line)
|
||||
if (index == 0) {
|
||||
decryptCipher.initialize(toByteArray, password)
|
||||
toByteArray.copyOfRange(37, toByteArray.size).decrypt(decryptCipher).also {
|
||||
if (!it.startsWith("[{")) {
|
||||
throw IllegalArgumentException("Unable to decrypt, assumed invalid password")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
toByteArray.decrypt(decryptCipher)
|
||||
}
|
||||
}
|
||||
.accumulateJson()
|
||||
.map { decoded ->
|
||||
roomIds.add(decoded.roomId)
|
||||
SharedRoomKey(
|
||||
decoded.algorithmName,
|
||||
decoded.roomId,
|
||||
decoded.sessionId,
|
||||
decoded.sessionKey,
|
||||
isExported = true,
|
||||
)
|
||||
}
|
||||
.chunked(500)
|
||||
.forEach {
|
||||
onChunk(it)
|
||||
importedKeysCount += it.size
|
||||
emit(ImportResult.Update(importedKeysCount))
|
||||
}
|
||||
runCatching { this@importRoomKeys.import(password, onChunk, this) }
|
||||
.onFailure {
|
||||
when (it) {
|
||||
is ImportException -> emit(ImportResult.Error(it.type))
|
||||
else -> emit(ImportResult.Error(ImportResult.Error.Type.Unknown(it)))
|
||||
}
|
||||
}
|
||||
if (roomIds.isEmpty()) {
|
||||
emit(ImportResult.Error(IOException("Found no rooms to import in the file")))
|
||||
} else {
|
||||
emit(ImportResult.Success(roomIds, importedKeysCount))
|
||||
}.flowOn(dispatchers.io)
|
||||
}
|
||||
|
||||
private suspend fun InputStream.import(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit, collector: FlowCollector<ImportResult>) {
|
||||
var importedKeysCount = 0L
|
||||
val roomIds = mutableSetOf<RoomId>()
|
||||
|
||||
this.bufferedReader().use {
|
||||
with(JsonAccumulator()) {
|
||||
it.useLines { sequence ->
|
||||
sequence
|
||||
.filterNot { it == HEADER_LINE || it == TRAILER_LINE || it.isEmpty() }
|
||||
.chunked(5)
|
||||
.decrypt(password)
|
||||
.accumulateJson()
|
||||
.map { decoded ->
|
||||
roomIds.add(decoded.roomId)
|
||||
SharedRoomKey(
|
||||
decoded.algorithmName,
|
||||
decoded.roomId,
|
||||
decoded.sessionId,
|
||||
decoded.sessionKey,
|
||||
isExported = true,
|
||||
)
|
||||
}
|
||||
.chunked(500)
|
||||
.forEach {
|
||||
onChunk(it)
|
||||
importedKeysCount += it.size
|
||||
collector.emit(ImportResult.Update(importedKeysCount))
|
||||
}
|
||||
}
|
||||
}
|
||||
}.flowOn(dispatchers.io)
|
||||
when {
|
||||
roomIds.isEmpty() -> collector.emit(ImportResult.Error(ImportResult.Error.Type.NoKeysFound))
|
||||
else -> collector.emit(ImportResult.Success(roomIds, importedKeysCount))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun Sequence<List<String>>.decrypt(password: String): Sequence<String> {
|
||||
val decryptCipher = Cipher.getInstance("AES/CTR/NoPadding")
|
||||
return this.withIndex().map { (index, it) ->
|
||||
val line = it.joinToString(separator = "").replace("\n", "")
|
||||
val toByteArray = base64.decode(line)
|
||||
if (index == 0) {
|
||||
decryptCipher.initialize(toByteArray, password)
|
||||
toByteArray
|
||||
.copyOfRange(37, toByteArray.size)
|
||||
.decrypt(decryptCipher)
|
||||
.also {
|
||||
if (!it.startsWith("[{")) {
|
||||
throw ImportException(ImportResult.Error.Type.UnexpectedDecryptionOutput)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
toByteArray.decrypt(decryptCipher)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun Cipher.initialize(payload: ByteArray, passphrase: String) {
|
||||
@ -145,6 +162,8 @@ private data class ElementMegolmExportObject(
|
||||
@SerialName("algorithm") val algorithmName: AlgorithmName,
|
||||
)
|
||||
|
||||
private class ImportException(val type: ImportResult.Error.Type) : Throwable()
|
||||
|
||||
private class JsonAccumulator {
|
||||
|
||||
private var jsonSegment = ""
|
||||
|
Loading…
x
Reference in New Issue
Block a user