Merge pull request #117 from ouchadam/feature/import-keys-improvements

Import keys improvements
This commit is contained in:
Adam Brown 2022-09-04 14:18:34 +01:00 committed by GitHub
commit 7fc8060d34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 307 additions and 147 deletions

View File

@ -1,5 +1,9 @@
package app.dapk.st.core
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onStart
enum class AppLogTag(val key: String) {
NOTIFICATION("notification"),
PERFORMANCE("performance"),
@ -26,3 +30,13 @@ suspend fun <T> logP(area: String, block: suspend () -> T): T {
log(AppLogTag.PERFORMANCE, "$area: took $timeTaken ms")
}
}
fun <T> Flow<T>.logP(area: String): Flow<T> {
var start = -1L
return this
.onStart { start = System.currentTimeMillis() }
.onCompletion {
val timeTaken = System.currentTimeMillis() - start
log(AppLogTag.PERFORMANCE, "$area: took $timeTaken ms")
}
}

View File

@ -6,3 +6,9 @@ sealed interface Lce<T> {
data class Content<T>(val value: T) : Lce<T>
}
sealed interface LceWithProgress<T> {
data class Loading<T>(val progress: T) : LceWithProgress<T>
data class Error<T>(val cause: Throwable) : LceWithProgress<T>
data class Content<T>(val value: T) : LceWithProgress<T>
}

View File

@ -47,6 +47,10 @@ class OlmPersistenceWrapper(
olmPersistence.persist(sessionId, SerializedObject(inboundGroupSession.serialize()))
}
override suspend fun transaction(action: suspend () -> Unit) {
olmPersistence.startTransaction { action() }
}
override suspend fun readInbound(sessionId: SessionId): OlmInboundGroupSession? {
return olmPersistence.readInbound(sessionId)?.value?.deserialize()
}

View File

@ -12,6 +12,7 @@ interface OlmStore {
suspend fun read(): OlmAccount?
suspend fun persist(olmAccount: OlmAccount)
suspend fun transaction(action: suspend () -> Unit)
suspend fun readOutbound(roomId: RoomId): Pair<Long, OlmOutboundGroupSession>?
suspend fun persistOutbound(roomId: RoomId, creationTimestampUtc: Long, outboundGroupSession: OlmOutboundGroupSession)
suspend fun persistSession(identity: Curve25519, sessionId: SessionId, olmSession: OlmSession)

View File

@ -46,13 +46,15 @@ class OlmWrapper(
override suspend fun import(keys: List<SharedRoomKey>) {
interactWithOlm()
keys.forEach {
val inBound = when (it.isExported) {
true -> OlmInboundGroupSession.importSession(it.sessionKey)
false -> OlmInboundGroupSession(it.sessionKey)
olmStore.transaction {
keys.forEach {
val inBound = when (it.isExported) {
true -> OlmInboundGroupSession.importSession(it.sessionKey)
false -> OlmInboundGroupSession(it.sessionKey)
}
olmStore.persist(it.sessionId, inBound)
}
logger.crypto("import megolm ${it.sessionKey}")
olmStore.persist(it.sessionId, inBound)
}
}

View File

@ -4,79 +4,109 @@ import app.dapk.db.DapkDb
import app.dapk.db.model.DbCryptoAccount
import app.dapk.db.model.DbCryptoMegolmInbound
import app.dapk.db.model.DbCryptoMegolmOutbound
import app.dapk.st.core.CoroutineDispatchers
import app.dapk.st.core.withIoContext
import app.dapk.st.matrix.common.CredentialsStore
import app.dapk.st.matrix.common.Curve25519
import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.common.SessionId
import com.squareup.sqldelight.TransactionWithoutReturn
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
class OlmPersistence(
private val database: DapkDb,
private val credentialsStore: CredentialsStore,
private val dispatchers: CoroutineDispatchers,
) {
suspend fun read(): String? {
return database.cryptoQueries
.selectAccount(credentialsStore.credentials()!!.userId.value)
.executeAsOneOrNull()
return dispatchers.withIoContext {
database.cryptoQueries
.selectAccount(credentialsStore.credentials()!!.userId.value)
.executeAsOneOrNull()
}
}
suspend fun persist(olmAccount: SerializedObject) {
database.cryptoQueries.insertAccount(
DbCryptoAccount(
user_id = credentialsStore.credentials()!!.userId.value,
blob = olmAccount.value
dispatchers.withIoContext {
database.cryptoQueries.insertAccount(
DbCryptoAccount(
user_id = credentialsStore.credentials()!!.userId.value,
blob = olmAccount.value
)
)
)
}
}
suspend fun readOutbound(roomId: RoomId): Pair<Long, String>? {
return database.cryptoQueries
.selectMegolmOutbound(roomId.value)
.executeAsOneOrNull()?.let {
it.utcEpochMillis to it.blob
}
return dispatchers.withIoContext {
database.cryptoQueries
.selectMegolmOutbound(roomId.value)
.executeAsOneOrNull()?.let {
it.utcEpochMillis to it.blob
}
}
}
suspend fun persistOutbound(roomId: RoomId, creationTimestampUtc: Long, outboundGroupSession: SerializedObject) {
database.cryptoQueries.insertMegolmOutbound(
DbCryptoMegolmOutbound(
room_id = roomId.value,
blob = outboundGroupSession.value,
utcEpochMillis = creationTimestampUtc,
dispatchers.withIoContext {
database.cryptoQueries.insertMegolmOutbound(
DbCryptoMegolmOutbound(
room_id = roomId.value,
blob = outboundGroupSession.value,
utcEpochMillis = creationTimestampUtc,
)
)
)
}
}
suspend fun persistSession(identity: Curve25519, sessionId: SessionId, olmSession: SerializedObject) {
database.cryptoQueries.insertOlmSession(
identity_key = identity.value,
session_id = sessionId.value,
blob = olmSession.value,
)
withContext(dispatchers.io) {
database.cryptoQueries.insertOlmSession(
identity_key = identity.value,
session_id = sessionId.value,
blob = olmSession.value,
)
}
}
suspend fun readSessions(identities: List<Curve25519>): List<Pair<Curve25519, String>>? {
return database.cryptoQueries
.selectOlmSession(identities.map { it.value })
.executeAsList()
.map { Curve25519(it.identity_key) to it.blob }
.takeIf { it.isNotEmpty() }
return withContext(dispatchers.io) {
database.cryptoQueries
.selectOlmSession(identities.map { it.value })
.executeAsList()
.map { Curve25519(it.identity_key) to it.blob }
.takeIf { it.isNotEmpty() }
}
}
suspend fun startTransaction(action: suspend TransactionWithoutReturn.() -> Unit) {
val scope = CoroutineScope(dispatchers.io)
database.cryptoQueries.transaction {
scope.launch { action() }
}
}
suspend fun persist(sessionId: SessionId, inboundGroupSession: SerializedObject) {
database.cryptoQueries.insertMegolmInbound(
DbCryptoMegolmInbound(
session_id = sessionId.value,
blob = inboundGroupSession.value
withContext(dispatchers.io) {
database.cryptoQueries.insertMegolmInbound(
DbCryptoMegolmInbound(
session_id = sessionId.value,
blob = inboundGroupSession.value
)
)
)
}
}
suspend fun readInbound(sessionId: SessionId): SerializedObject? {
return database.cryptoQueries
.selectMegolmInbound(sessionId.value)
.executeAsOneOrNull()
?.let { SerializedObject((it)) }
return withContext(dispatchers.io) {
database.cryptoQueries
.selectMegolmInbound(sessionId.value)
.executeAsOneOrNull()
?.let { SerializedObject((it)) }
}
}
}

View File

@ -39,7 +39,7 @@ class StoreModule(
fun applicationStore() = ApplicationPreferences(preferences)
fun olmStore() = OlmPersistence(database, credentialsStore())
fun olmStore() = OlmPersistence(database, credentialsStore(), coroutineDispatchers)
fun knownDevicesStore() = DevicePersistence(database, KnownDevicesCache(), coroutineDispatchers)
fun profileStore(): ProfileStore = ProfilePersistence(preferences)

View File

@ -13,12 +13,14 @@ FROM dbEventLog;
selectLatestByLog:
SELECT id, tag, content, time(utcEpochSeconds,'unixepoch')
FROM dbEventLog
WHERE logParent = ?;
WHERE logParent = ?
ORDER BY utcEpochSeconds DESC;
selectLatestByLogFiltered:
SELECT id, tag, content, time(utcEpochSeconds,'unixepoch')
FROM dbEventLog
WHERE logParent = ? AND tag = ?;
WHERE logParent = ? AND tag = ?
ORDER BY utcEpochSeconds DESC;
insert:
INSERT INTO dbEventLog(tag, content, utcEpochSeconds, logParent)

View File

@ -31,11 +31,13 @@ import androidx.compose.ui.text.input.ImeAction
import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.text.input.PasswordVisualTransformation
import androidx.compose.ui.text.input.VisualTransformation
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import androidx.core.net.toUri
import app.dapk.st.core.Lce
import app.dapk.st.core.LceWithProgress
import app.dapk.st.core.StartObserving
import app.dapk.st.core.components.CenteredLoading
import app.dapk.st.core.components.Header
@ -43,6 +45,7 @@ import app.dapk.st.design.components.SettingsTextRow
import app.dapk.st.design.components.Spider
import app.dapk.st.design.components.SpiderPage
import app.dapk.st.design.components.TextRow
import app.dapk.st.matrix.crypto.ImportResult
import app.dapk.st.navigator.Navigator
import app.dapk.st.settings.SettingsEvent.*
import app.dapk.st.settings.eventlogger.EventLogActivity
@ -72,7 +75,7 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
PushProviders(viewModel, it)
}
item(Page.Routes.importRoomKeys) {
when (it.importProgress) {
when (val result = it.importProgress) {
null -> {
Box(
modifier = Modifier.fillMaxSize(),
@ -136,10 +139,11 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
}
}
is Lce.Content -> {
is ImportResult.Success -> {
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
Text(text = "Import success")
Text(text = "Successfully imported ${result.totalImportedKeysCount} keys")
Spacer(modifier = Modifier.height(12.dp))
Button(onClick = { navigator.navigate.upToHome() }) {
Text(text = "Close".uppercase())
}
@ -147,10 +151,18 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
}
}
is Lce.Error -> {
is ImportResult.Error -> {
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
Text(text = "Import failed")
val message = when(val type = result.cause) {
ImportResult.Error.Type.NoKeysFound -> "No keys found in the file"
ImportResult.Error.Type.UnexpectedDecryptionOutput -> "Unable to decrypt file, double check your passphrase"
is ImportResult.Error.Type.Unknown -> "${type.cause::class.java.simpleName}: ${type.cause.message}"
ImportResult.Error.Type.UnableToOpenFile -> "Unable to open file"
}
Text(text = "Import failed\n$message", textAlign = TextAlign.Center)
Spacer(modifier = Modifier.height(12.dp))
Button(onClick = { navigator.navigate.upToHome() }) {
Text(text = "Close".uppercase())
}
@ -158,7 +170,15 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
}
}
is Lce.Loading -> CenteredLoading()
is ImportResult.Update -> {
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
Text(text = "Importing ${result.importedKeysCount} keys...")
Spacer(modifier = Modifier.height(12.dp))
CircularProgressIndicator(Modifier.wrapContentSize())
}
}
}
}
}
}

View File

@ -2,8 +2,10 @@ package app.dapk.st.settings
import android.net.Uri
import app.dapk.st.core.Lce
import app.dapk.st.core.LceWithProgress
import app.dapk.st.design.components.Route
import app.dapk.st.design.components.SpiderPage
import app.dapk.st.matrix.crypto.ImportResult
import app.dapk.st.push.Registrar
internal data class SettingsScreenState(
@ -15,7 +17,7 @@ internal sealed interface Page {
object Security : Page
data class ImportRoomKey(
val selectedFile: NamedUri? = null,
val importProgress: Lce<Unit>? = null,
val importProgress: ImportResult? = null,
) : Page
data class PushProviders(

View File

@ -7,6 +7,7 @@ import app.dapk.st.core.Lce
import app.dapk.st.design.components.SpiderPage
import app.dapk.st.domain.StoreCleaner
import app.dapk.st.matrix.crypto.CryptoService
import app.dapk.st.matrix.crypto.ImportResult
import app.dapk.st.matrix.sync.SyncService
import app.dapk.st.push.PushTokenRegistrars
import app.dapk.st.push.Registrar
@ -15,6 +16,8 @@ import app.dapk.st.settings.SettingsEvent.*
import app.dapk.st.viewmodel.DapkViewModel
import app.dapk.st.viewmodel.MutableStateFactory
import app.dapk.st.viewmodel.defaultStateFactory
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.launch
private const val PRIVACY_POLICY_URL = "https://ouchadam.github.io/small-talk/privacy/"
@ -120,17 +123,34 @@ internal class SettingsViewModel(
}
fun importFromFileKeys(file: Uri, passphrase: String) {
updatePageState<Page.ImportRoomKey> { copy(importProgress = Lce.Loading()) }
updatePageState<Page.ImportRoomKey> { copy(importProgress = ImportResult.Update(0)) }
viewModelScope.launch {
kotlin.runCatching {
with(cryptoService) {
val roomsToRefresh = contentResolver.openInputStream(file)?.importRoomKeys(passphrase)
roomsToRefresh?.let { syncService.forceManualRefresh(roomsToRefresh) }
}
}.fold(
onSuccess = { updatePageState<Page.ImportRoomKey> { copy(importProgress = Lce.Content(Unit)) } },
onFailure = { updatePageState<Page.ImportRoomKey> { copy(importProgress = Lce.Error(it)) } }
)
with(cryptoService) {
runCatching { contentResolver.openInputStream(file)!! }
.fold(
onSuccess = { fileStream ->
fileStream.importRoomKeys(passphrase)
.onEach {
updatePageState<Page.ImportRoomKey> { copy(importProgress = it) }
when (it) {
is ImportResult.Error -> {
// do nothing
}
is ImportResult.Update -> {
// do nothing
}
is ImportResult.Success -> {
syncService.forceManualRefresh(it.roomIds.toList())
}
}
}
.launchIn(viewModelScope)
},
onFailure = {
updatePageState<Page.ImportRoomKey> { copy(importProgress = ImportResult.Error(ImportResult.Error.Type.UnableToOpenFile)) }
}
)
}
}
}

View File

@ -3,6 +3,7 @@ package app.dapk.st.settings
import ViewModelTest
import app.dapk.st.core.Lce
import app.dapk.st.design.components.SpiderPage
import app.dapk.st.matrix.crypto.ImportResult
import fake.*
import fixture.FakeStoreCleaner
import fixture.aRoomId
@ -10,6 +11,7 @@ import internalfake.FakeSettingsItemFactory
import internalfake.FakeUriFilenameResolver
import internalfixture.aImportRoomKeysPage
import internalfixture.aSettingTextItem
import kotlinx.coroutines.flow.flowOf
import org.junit.Test
private const val APP_PRIVACY_POLICY_URL = "https://ouchadam.github.io/small-talk/privacy/"
@ -21,6 +23,8 @@ private val A_IMPORT_ROOM_KEYS_PAGE_WITH_SELECTION = aImportRoomKeysPage(
state = Page.ImportRoomKey(selectedFile = NamedUri(A_FILENAME, A_URI.instance))
)
private val A_LIST_OF_ROOM_IDS = listOf(aRoomId())
private val AN_IMPORT_SUCCESS = ImportResult.Success(A_LIST_OF_ROOM_IDS.toSet(), totalImportedKeysCount = 5)
private val AN_IMPORT_FILE_ERROR = ImportResult.Error(ImportResult.Error.Type.UnableToOpenFile)
private val AN_INPUT_STREAM = FakeInputStream()
private const val A_PASSPHRASE = "passphrase"
private val AN_ERROR = RuntimeException()
@ -166,15 +170,15 @@ internal class SettingsViewModelTest {
fun `given success when importing room keys, then emits progress`() = runViewModelTest {
fakeSyncService.expectUnit { it.forceManualRefresh(A_LIST_OF_ROOM_IDS) }
fakeContentResolver.givenFile(A_URI.instance).returns(AN_INPUT_STREAM.instance)
fakeCryptoService.givenImportKeys(AN_INPUT_STREAM.instance, A_PASSPHRASE).returns(A_LIST_OF_ROOM_IDS)
fakeCryptoService.givenImportKeys(AN_INPUT_STREAM.instance, A_PASSPHRASE).returns(flowOf(AN_IMPORT_SUCCESS))
viewModel
.test(initialState = SettingsScreenState(A_IMPORT_ROOM_KEYS_PAGE_WITH_SELECTION))
.importFromFileKeys(A_URI.instance, A_PASSPHRASE)
assertStates<SettingsScreenState>(
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = Lce.Loading()) }) },
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = Lce.Content(Unit)) }) },
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = ImportResult.Update(0L)) }) },
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = AN_IMPORT_SUCCESS) }) },
)
assertNoEvents<SettingsEvent>()
verifyExpects()
@ -189,8 +193,8 @@ internal class SettingsViewModelTest {
.importFromFileKeys(A_URI.instance, A_PASSPHRASE)
assertStates<SettingsScreenState>(
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = Lce.Loading()) }) },
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = Lce.Error(AN_ERROR)) }) },
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = ImportResult.Update(0L)) }) },
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = AN_IMPORT_FILE_ERROR) }) },
)
assertNoEvents<SettingsEvent>()
}

View File

@ -18,7 +18,7 @@ interface CryptoService : MatrixService {
suspend fun encrypt(roomId: RoomId, credentials: DeviceCredentials, messageJson: JsonString): Crypto.EncryptionResult
suspend fun decrypt(encryptedPayload: EncryptedMessageContent): DecryptionResult
suspend fun importRoomKeys(keys: List<SharedRoomKey>)
suspend fun InputStream.importRoomKeys(password: String): List<RoomId>
suspend fun InputStream.importRoomKeys(password: String): Flow<ImportResult>
suspend fun maybeCreateMoreKeys(serverKeyCount: ServerKeyCount)
suspend fun updateOlmSession(userIds: List<UserId>, syncToken: SyncToken?)
@ -160,3 +160,18 @@ fun MatrixServiceProvider.cryptoService(): CryptoService = this.getService(key =
fun interface RoomMembersProvider {
suspend fun userIdsForRoom(roomId: RoomId): List<UserId>
}
sealed interface ImportResult {
data class Success(val roomIds: Set<RoomId>, val totalImportedKeysCount: Long) : ImportResult
data class Error(val cause: Type) : ImportResult {
sealed interface Type {
data class Unknown(val cause: Throwable): Type
object NoKeysFound: Type
object UnexpectedDecryptionOutput: Type
object UnableToOpenFile: Type
}
}
data class Update(val importedKeysCount: Long) : ImportResult
}

View File

@ -1,8 +1,10 @@
package app.dapk.st.matrix.crypto.internal
import app.dapk.st.core.logP
import app.dapk.st.matrix.common.*
import app.dapk.st.matrix.crypto.Crypto
import app.dapk.st.matrix.crypto.CryptoService
import app.dapk.st.matrix.crypto.ImportResult
import app.dapk.st.matrix.crypto.Verification
import kotlinx.coroutines.flow.Flow
import java.io.InputStream
@ -47,9 +49,11 @@ internal class DefaultCryptoService(
verificationHandler.onUserVerificationAction(verificationAction)
}
override suspend fun InputStream.importRoomKeys(password: String): List<RoomId> {
override suspend fun InputStream.importRoomKeys(password: String): Flow<ImportResult> {
return with(roomKeyImporter) {
importRoomKeys(password) { importRoomKeys(it) }
importRoomKeys(password) {
importRoomKeys(it)
}.logP("import room keys")
}
}
}

View File

@ -14,7 +14,6 @@ internal class OlmCrypto(
) {
suspend fun importRoomKeys(keys: List<SharedRoomKey>) {
logger.crypto("import room keys : ${keys.size}")
olm.import(keys)
}

View File

@ -2,15 +2,18 @@ package app.dapk.st.matrix.crypto.internal
import app.dapk.st.core.Base64
import app.dapk.st.core.CoroutineDispatchers
import app.dapk.st.core.withIoContext
import app.dapk.st.matrix.common.AlgorithmName
import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.common.SessionId
import app.dapk.st.matrix.common.SharedRoomKey
import app.dapk.st.matrix.crypto.ImportResult
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import java.io.IOException
import java.io.InputStream
import java.nio.charset.Charset
import javax.crypto.Cipher
@ -28,53 +31,29 @@ class RoomKeyImporter(
private val dispatchers: CoroutineDispatchers,
) {
suspend fun InputStream.importRoomKeys(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit): List<RoomId> {
return dispatchers.withIoContext {
val decryptCipher = Cipher.getInstance("AES/CTR/NoPadding")
var jsonSegment = ""
fun <T> Sequence<T>.accumulateJson() = this.mapNotNull {
val withLatest = jsonSegment + it
try {
when (val objectRange = withLatest.findClosingIndex()) {
null -> {
jsonSegment = withLatest
null
}
else -> {
val string = withLatest.substring(objectRange)
importJson.decodeFromString(ElementMegolmExportObject.serializer(), string).also {
jsonSegment = withLatest.replace(string, "").removePrefix(",")
}
}
suspend fun InputStream.importRoomKeys(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit): Flow<ImportResult> {
return flow {
runCatching { this@importRoomKeys.import(password, onChunk, this) }
.onFailure {
when (it) {
is ImportException -> emit(ImportResult.Error(it.type))
else -> emit(ImportResult.Error(ImportResult.Error.Type.Unknown(it)))
}
} catch (error: Throwable) {
jsonSegment = withLatest
null
}
}
}.flowOn(dispatchers.io)
}
this@importRoomKeys.bufferedReader().use {
val roomIds = mutableSetOf<RoomId>()
private suspend fun InputStream.import(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit, collector: FlowCollector<ImportResult>) {
var importedKeysCount = 0L
val roomIds = mutableSetOf<RoomId>()
this.bufferedReader().use {
with(JsonAccumulator()) {
it.useLines { sequence ->
sequence
.filterNot { it == HEADER_LINE || it == TRAILER_LINE || it.isEmpty() }
.chunked(2)
.withIndex()
.map { (index, it) ->
val line = it.joinToString(separator = "").replace("\n", "")
val toByteArray = base64.decode(line)
if (index == 0) {
decryptCipher.initialize(toByteArray, password)
toByteArray.copyOfRange(37, toByteArray.size).decrypt(decryptCipher).also {
if (!it.startsWith("[{")) {
throw IllegalArgumentException("Unable to decrypt, assumed invalid password")
}
}
} else {
toByteArray.decrypt(decryptCipher)
}
}
.chunked(5)
.decrypt(password)
.accumulateJson()
.map { decoded ->
roomIds.add(decoded.roomId)
@ -86,13 +65,39 @@ class RoomKeyImporter(
isExported = true,
)
}
.chunked(50)
.forEach { onChunk(it) }
}
roomIds.toList().ifEmpty {
throw IOException("Found no rooms to import in the file")
.chunked(500)
.forEach {
onChunk(it)
importedKeysCount += it.size
collector.emit(ImportResult.Update(importedKeysCount))
}
}
}
when {
roomIds.isEmpty() -> collector.emit(ImportResult.Error(ImportResult.Error.Type.NoKeysFound))
else -> collector.emit(ImportResult.Success(roomIds, importedKeysCount))
}
}
}
private fun Sequence<List<String>>.decrypt(password: String): Sequence<String> {
val decryptCipher = Cipher.getInstance("AES/CTR/NoPadding")
return this.withIndex().map { (index, it) ->
val line = it.joinToString(separator = "").replace("\n", "")
val toByteArray = base64.decode(line)
if (index == 0) {
decryptCipher.initialize(toByteArray, password)
toByteArray
.copyOfRange(37, toByteArray.size)
.decrypt(decryptCipher)
.also {
if (!it.startsWith("[{")) {
throw ImportException(ImportResult.Error.Type.UnexpectedDecryptionOutput)
}
}
} else {
toByteArray.decrypt(decryptCipher)
}
}
}
@ -149,28 +154,6 @@ class RoomKeyImporter(
private fun Byte.toUnsignedInt() = toInt() and 0xff
private fun String.findClosingIndex(): IntRange? {
var opens = 0
var openIndex = -1
this.forEachIndexed { index, c ->
when {
c == '{' -> {
if (opens == 0) {
openIndex = index
}
opens++
}
c == '}' -> {
opens--
if (opens == 0) {
return IntRange(openIndex, index)
}
}
}
}
return null
}
@Serializable
private data class ElementMegolmExportObject(
@SerialName("room_id") val roomId: RoomId,
@ -178,3 +161,53 @@ private data class ElementMegolmExportObject(
@SerialName("session_id") val sessionId: SessionId,
@SerialName("algorithm") val algorithmName: AlgorithmName,
)
private class ImportException(val type: ImportResult.Error.Type) : Throwable()
private class JsonAccumulator {
private var jsonSegment = ""
fun <T> Sequence<T>.accumulateJson() = this.mapNotNull {
val withLatest = jsonSegment + it
try {
when (val objectRange = withLatest.findClosingIndex()) {
null -> {
jsonSegment = withLatest
null
}
else -> {
val string = withLatest.substring(objectRange)
importJson.decodeFromString(ElementMegolmExportObject.serializer(), string).also {
jsonSegment = withLatest.replace(string, "").removePrefix(",")
}
}
}
} catch (error: Throwable) {
jsonSegment = withLatest
null
}
}
private fun String.findClosingIndex(): IntRange? {
var opens = 0
var openIndex = -1
this.forEachIndexed { index, c ->
when {
c == '{' -> {
if (opens == 0) {
openIndex = index
}
opens++
}
c == '}' -> {
opens--
if (opens == 0) {
return IntRange(openIndex, index)
}
}
}
}
return null
}
}

View File

@ -4,11 +4,13 @@ import app.dapk.st.matrix.common.HomeServerUrl
import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.common.RoomMember
import app.dapk.st.matrix.common.UserId
import app.dapk.st.matrix.crypto.ImportResult
import app.dapk.st.matrix.crypto.Verification
import app.dapk.st.matrix.crypto.cryptoService
import app.dapk.st.matrix.room.roomService
import app.dapk.st.matrix.sync.syncService
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.test.runTest
import org.amshove.kluent.shouldBeEqualTo
@ -97,10 +99,13 @@ class SmokeTest {
val stream = loadResourceStream("element-keys.txt")
val result = with(cryptoService) {
stream.importRoomKeys(password = "aaaaaa")
stream.importRoomKeys(password = "aaaaaa").first { it is ImportResult.Success }
}
result shouldBeEqualTo listOf(RoomId(value = "!qOSENTtFUuCEKJSVzl:matrix.org"))
result shouldBeEqualTo ImportResult.Success(
roomIds = setOf(RoomId(value = "!qOSENTtFUuCEKJSVzl:matrix.org")),
totalImportedKeysCount = 28,
)
}
private fun testTextMessaging(isEncrypted: Boolean) = testAfterInitialSync { alice, bob ->

View File

@ -150,7 +150,6 @@ const incrementVersionFile = async (github, branchName) => {
name: updatedVersionName,
}
const encodedContentUpdate = Buffer.from(JSON.stringify(updatedVersionFile, null, 2)).toString('base64')
await github.rest.repos.createOrUpdateFileContents({
owner: config.owner,