Merge pull request #117 from ouchadam/feature/import-keys-improvements

Import keys improvements
This commit is contained in:
Adam Brown 2022-09-04 14:18:34 +01:00 committed by GitHub
commit 7fc8060d34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 307 additions and 147 deletions

View File

@ -1,5 +1,9 @@
package app.dapk.st.core package app.dapk.st.core
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onStart
enum class AppLogTag(val key: String) { enum class AppLogTag(val key: String) {
NOTIFICATION("notification"), NOTIFICATION("notification"),
PERFORMANCE("performance"), PERFORMANCE("performance"),
@ -26,3 +30,13 @@ suspend fun <T> logP(area: String, block: suspend () -> T): T {
log(AppLogTag.PERFORMANCE, "$area: took $timeTaken ms") log(AppLogTag.PERFORMANCE, "$area: took $timeTaken ms")
} }
} }
fun <T> Flow<T>.logP(area: String): Flow<T> {
var start = -1L
return this
.onStart { start = System.currentTimeMillis() }
.onCompletion {
val timeTaken = System.currentTimeMillis() - start
log(AppLogTag.PERFORMANCE, "$area: took $timeTaken ms")
}
}

View File

@ -6,3 +6,9 @@ sealed interface Lce<T> {
data class Content<T>(val value: T) : Lce<T> data class Content<T>(val value: T) : Lce<T>
} }
sealed interface LceWithProgress<T> {
data class Loading<T>(val progress: T) : LceWithProgress<T>
data class Error<T>(val cause: Throwable) : LceWithProgress<T>
data class Content<T>(val value: T) : LceWithProgress<T>
}

View File

@ -47,6 +47,10 @@ class OlmPersistenceWrapper(
olmPersistence.persist(sessionId, SerializedObject(inboundGroupSession.serialize())) olmPersistence.persist(sessionId, SerializedObject(inboundGroupSession.serialize()))
} }
override suspend fun transaction(action: suspend () -> Unit) {
olmPersistence.startTransaction { action() }
}
override suspend fun readInbound(sessionId: SessionId): OlmInboundGroupSession? { override suspend fun readInbound(sessionId: SessionId): OlmInboundGroupSession? {
return olmPersistence.readInbound(sessionId)?.value?.deserialize() return olmPersistence.readInbound(sessionId)?.value?.deserialize()
} }

View File

@ -12,6 +12,7 @@ interface OlmStore {
suspend fun read(): OlmAccount? suspend fun read(): OlmAccount?
suspend fun persist(olmAccount: OlmAccount) suspend fun persist(olmAccount: OlmAccount)
suspend fun transaction(action: suspend () -> Unit)
suspend fun readOutbound(roomId: RoomId): Pair<Long, OlmOutboundGroupSession>? suspend fun readOutbound(roomId: RoomId): Pair<Long, OlmOutboundGroupSession>?
suspend fun persistOutbound(roomId: RoomId, creationTimestampUtc: Long, outboundGroupSession: OlmOutboundGroupSession) suspend fun persistOutbound(roomId: RoomId, creationTimestampUtc: Long, outboundGroupSession: OlmOutboundGroupSession)
suspend fun persistSession(identity: Curve25519, sessionId: SessionId, olmSession: OlmSession) suspend fun persistSession(identity: Curve25519, sessionId: SessionId, olmSession: OlmSession)

View File

@ -46,13 +46,15 @@ class OlmWrapper(
override suspend fun import(keys: List<SharedRoomKey>) { override suspend fun import(keys: List<SharedRoomKey>) {
interactWithOlm() interactWithOlm()
keys.forEach {
val inBound = when (it.isExported) { olmStore.transaction {
true -> OlmInboundGroupSession.importSession(it.sessionKey) keys.forEach {
false -> OlmInboundGroupSession(it.sessionKey) val inBound = when (it.isExported) {
true -> OlmInboundGroupSession.importSession(it.sessionKey)
false -> OlmInboundGroupSession(it.sessionKey)
}
olmStore.persist(it.sessionId, inBound)
} }
logger.crypto("import megolm ${it.sessionKey}")
olmStore.persist(it.sessionId, inBound)
} }
} }

View File

@ -4,79 +4,109 @@ import app.dapk.db.DapkDb
import app.dapk.db.model.DbCryptoAccount import app.dapk.db.model.DbCryptoAccount
import app.dapk.db.model.DbCryptoMegolmInbound import app.dapk.db.model.DbCryptoMegolmInbound
import app.dapk.db.model.DbCryptoMegolmOutbound import app.dapk.db.model.DbCryptoMegolmOutbound
import app.dapk.st.core.CoroutineDispatchers
import app.dapk.st.core.withIoContext
import app.dapk.st.matrix.common.CredentialsStore import app.dapk.st.matrix.common.CredentialsStore
import app.dapk.st.matrix.common.Curve25519 import app.dapk.st.matrix.common.Curve25519
import app.dapk.st.matrix.common.RoomId import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.common.SessionId import app.dapk.st.matrix.common.SessionId
import com.squareup.sqldelight.TransactionWithoutReturn
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
class OlmPersistence( class OlmPersistence(
private val database: DapkDb, private val database: DapkDb,
private val credentialsStore: CredentialsStore, private val credentialsStore: CredentialsStore,
private val dispatchers: CoroutineDispatchers,
) { ) {
suspend fun read(): String? { suspend fun read(): String? {
return database.cryptoQueries return dispatchers.withIoContext {
.selectAccount(credentialsStore.credentials()!!.userId.value) database.cryptoQueries
.executeAsOneOrNull() .selectAccount(credentialsStore.credentials()!!.userId.value)
.executeAsOneOrNull()
}
} }
suspend fun persist(olmAccount: SerializedObject) { suspend fun persist(olmAccount: SerializedObject) {
database.cryptoQueries.insertAccount( dispatchers.withIoContext {
DbCryptoAccount( database.cryptoQueries.insertAccount(
user_id = credentialsStore.credentials()!!.userId.value, DbCryptoAccount(
blob = olmAccount.value user_id = credentialsStore.credentials()!!.userId.value,
blob = olmAccount.value
)
) )
) }
} }
suspend fun readOutbound(roomId: RoomId): Pair<Long, String>? { suspend fun readOutbound(roomId: RoomId): Pair<Long, String>? {
return database.cryptoQueries return dispatchers.withIoContext {
.selectMegolmOutbound(roomId.value) database.cryptoQueries
.executeAsOneOrNull()?.let { .selectMegolmOutbound(roomId.value)
it.utcEpochMillis to it.blob .executeAsOneOrNull()?.let {
} it.utcEpochMillis to it.blob
}
}
} }
suspend fun persistOutbound(roomId: RoomId, creationTimestampUtc: Long, outboundGroupSession: SerializedObject) { suspend fun persistOutbound(roomId: RoomId, creationTimestampUtc: Long, outboundGroupSession: SerializedObject) {
database.cryptoQueries.insertMegolmOutbound( dispatchers.withIoContext {
DbCryptoMegolmOutbound( database.cryptoQueries.insertMegolmOutbound(
room_id = roomId.value, DbCryptoMegolmOutbound(
blob = outboundGroupSession.value, room_id = roomId.value,
utcEpochMillis = creationTimestampUtc, blob = outboundGroupSession.value,
utcEpochMillis = creationTimestampUtc,
)
) )
) }
} }
suspend fun persistSession(identity: Curve25519, sessionId: SessionId, olmSession: SerializedObject) { suspend fun persistSession(identity: Curve25519, sessionId: SessionId, olmSession: SerializedObject) {
database.cryptoQueries.insertOlmSession( withContext(dispatchers.io) {
identity_key = identity.value, database.cryptoQueries.insertOlmSession(
session_id = sessionId.value, identity_key = identity.value,
blob = olmSession.value, session_id = sessionId.value,
) blob = olmSession.value,
)
}
} }
suspend fun readSessions(identities: List<Curve25519>): List<Pair<Curve25519, String>>? { suspend fun readSessions(identities: List<Curve25519>): List<Pair<Curve25519, String>>? {
return database.cryptoQueries return withContext(dispatchers.io) {
.selectOlmSession(identities.map { it.value }) database.cryptoQueries
.executeAsList() .selectOlmSession(identities.map { it.value })
.map { Curve25519(it.identity_key) to it.blob } .executeAsList()
.takeIf { it.isNotEmpty() } .map { Curve25519(it.identity_key) to it.blob }
.takeIf { it.isNotEmpty() }
}
}
suspend fun startTransaction(action: suspend TransactionWithoutReturn.() -> Unit) {
val scope = CoroutineScope(dispatchers.io)
database.cryptoQueries.transaction {
scope.launch { action() }
}
} }
suspend fun persist(sessionId: SessionId, inboundGroupSession: SerializedObject) { suspend fun persist(sessionId: SessionId, inboundGroupSession: SerializedObject) {
database.cryptoQueries.insertMegolmInbound( withContext(dispatchers.io) {
DbCryptoMegolmInbound( database.cryptoQueries.insertMegolmInbound(
session_id = sessionId.value, DbCryptoMegolmInbound(
blob = inboundGroupSession.value session_id = sessionId.value,
blob = inboundGroupSession.value
)
) )
) }
} }
suspend fun readInbound(sessionId: SessionId): SerializedObject? { suspend fun readInbound(sessionId: SessionId): SerializedObject? {
return database.cryptoQueries return withContext(dispatchers.io) {
.selectMegolmInbound(sessionId.value) database.cryptoQueries
.executeAsOneOrNull() .selectMegolmInbound(sessionId.value)
?.let { SerializedObject((it)) } .executeAsOneOrNull()
?.let { SerializedObject((it)) }
}
} }
} }

View File

@ -39,7 +39,7 @@ class StoreModule(
fun applicationStore() = ApplicationPreferences(preferences) fun applicationStore() = ApplicationPreferences(preferences)
fun olmStore() = OlmPersistence(database, credentialsStore()) fun olmStore() = OlmPersistence(database, credentialsStore(), coroutineDispatchers)
fun knownDevicesStore() = DevicePersistence(database, KnownDevicesCache(), coroutineDispatchers) fun knownDevicesStore() = DevicePersistence(database, KnownDevicesCache(), coroutineDispatchers)
fun profileStore(): ProfileStore = ProfilePersistence(preferences) fun profileStore(): ProfileStore = ProfilePersistence(preferences)

View File

@ -13,12 +13,14 @@ FROM dbEventLog;
selectLatestByLog: selectLatestByLog:
SELECT id, tag, content, time(utcEpochSeconds,'unixepoch') SELECT id, tag, content, time(utcEpochSeconds,'unixepoch')
FROM dbEventLog FROM dbEventLog
WHERE logParent = ?; WHERE logParent = ?
ORDER BY utcEpochSeconds DESC;
selectLatestByLogFiltered: selectLatestByLogFiltered:
SELECT id, tag, content, time(utcEpochSeconds,'unixepoch') SELECT id, tag, content, time(utcEpochSeconds,'unixepoch')
FROM dbEventLog FROM dbEventLog
WHERE logParent = ? AND tag = ?; WHERE logParent = ? AND tag = ?
ORDER BY utcEpochSeconds DESC;
insert: insert:
INSERT INTO dbEventLog(tag, content, utcEpochSeconds, logParent) INSERT INTO dbEventLog(tag, content, utcEpochSeconds, logParent)

View File

@ -31,11 +31,13 @@ import androidx.compose.ui.text.input.ImeAction
import androidx.compose.ui.text.input.KeyboardType import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.text.input.PasswordVisualTransformation import androidx.compose.ui.text.input.PasswordVisualTransformation
import androidx.compose.ui.text.input.VisualTransformation import androidx.compose.ui.text.input.VisualTransformation
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp import androidx.compose.ui.unit.sp
import androidx.core.net.toUri import androidx.core.net.toUri
import app.dapk.st.core.Lce import app.dapk.st.core.Lce
import app.dapk.st.core.LceWithProgress
import app.dapk.st.core.StartObserving import app.dapk.st.core.StartObserving
import app.dapk.st.core.components.CenteredLoading import app.dapk.st.core.components.CenteredLoading
import app.dapk.st.core.components.Header import app.dapk.st.core.components.Header
@ -43,6 +45,7 @@ import app.dapk.st.design.components.SettingsTextRow
import app.dapk.st.design.components.Spider import app.dapk.st.design.components.Spider
import app.dapk.st.design.components.SpiderPage import app.dapk.st.design.components.SpiderPage
import app.dapk.st.design.components.TextRow import app.dapk.st.design.components.TextRow
import app.dapk.st.matrix.crypto.ImportResult
import app.dapk.st.navigator.Navigator import app.dapk.st.navigator.Navigator
import app.dapk.st.settings.SettingsEvent.* import app.dapk.st.settings.SettingsEvent.*
import app.dapk.st.settings.eventlogger.EventLogActivity import app.dapk.st.settings.eventlogger.EventLogActivity
@ -72,7 +75,7 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
PushProviders(viewModel, it) PushProviders(viewModel, it)
} }
item(Page.Routes.importRoomKeys) { item(Page.Routes.importRoomKeys) {
when (it.importProgress) { when (val result = it.importProgress) {
null -> { null -> {
Box( Box(
modifier = Modifier.fillMaxSize(), modifier = Modifier.fillMaxSize(),
@ -136,10 +139,11 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
} }
} }
is Lce.Content -> { is ImportResult.Success -> {
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) { Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
Column(horizontalAlignment = Alignment.CenterHorizontally) { Column(horizontalAlignment = Alignment.CenterHorizontally) {
Text(text = "Import success") Text(text = "Successfully imported ${result.totalImportedKeysCount} keys")
Spacer(modifier = Modifier.height(12.dp))
Button(onClick = { navigator.navigate.upToHome() }) { Button(onClick = { navigator.navigate.upToHome() }) {
Text(text = "Close".uppercase()) Text(text = "Close".uppercase())
} }
@ -147,10 +151,18 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
} }
} }
is Lce.Error -> { is ImportResult.Error -> {
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) { Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
Column(horizontalAlignment = Alignment.CenterHorizontally) { Column(horizontalAlignment = Alignment.CenterHorizontally) {
Text(text = "Import failed") val message = when(val type = result.cause) {
ImportResult.Error.Type.NoKeysFound -> "No keys found in the file"
ImportResult.Error.Type.UnexpectedDecryptionOutput -> "Unable to decrypt file, double check your passphrase"
is ImportResult.Error.Type.Unknown -> "${type.cause::class.java.simpleName}: ${type.cause.message}"
ImportResult.Error.Type.UnableToOpenFile -> "Unable to open file"
}
Text(text = "Import failed\n$message", textAlign = TextAlign.Center)
Spacer(modifier = Modifier.height(12.dp))
Button(onClick = { navigator.navigate.upToHome() }) { Button(onClick = { navigator.navigate.upToHome() }) {
Text(text = "Close".uppercase()) Text(text = "Close".uppercase())
} }
@ -158,7 +170,15 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
} }
} }
is Lce.Loading -> CenteredLoading() is ImportResult.Update -> {
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
Text(text = "Importing ${result.importedKeysCount} keys...")
Spacer(modifier = Modifier.height(12.dp))
CircularProgressIndicator(Modifier.wrapContentSize())
}
}
}
} }
} }
} }

View File

@ -2,8 +2,10 @@ package app.dapk.st.settings
import android.net.Uri import android.net.Uri
import app.dapk.st.core.Lce import app.dapk.st.core.Lce
import app.dapk.st.core.LceWithProgress
import app.dapk.st.design.components.Route import app.dapk.st.design.components.Route
import app.dapk.st.design.components.SpiderPage import app.dapk.st.design.components.SpiderPage
import app.dapk.st.matrix.crypto.ImportResult
import app.dapk.st.push.Registrar import app.dapk.st.push.Registrar
internal data class SettingsScreenState( internal data class SettingsScreenState(
@ -15,7 +17,7 @@ internal sealed interface Page {
object Security : Page object Security : Page
data class ImportRoomKey( data class ImportRoomKey(
val selectedFile: NamedUri? = null, val selectedFile: NamedUri? = null,
val importProgress: Lce<Unit>? = null, val importProgress: ImportResult? = null,
) : Page ) : Page
data class PushProviders( data class PushProviders(

View File

@ -7,6 +7,7 @@ import app.dapk.st.core.Lce
import app.dapk.st.design.components.SpiderPage import app.dapk.st.design.components.SpiderPage
import app.dapk.st.domain.StoreCleaner import app.dapk.st.domain.StoreCleaner
import app.dapk.st.matrix.crypto.CryptoService import app.dapk.st.matrix.crypto.CryptoService
import app.dapk.st.matrix.crypto.ImportResult
import app.dapk.st.matrix.sync.SyncService import app.dapk.st.matrix.sync.SyncService
import app.dapk.st.push.PushTokenRegistrars import app.dapk.st.push.PushTokenRegistrars
import app.dapk.st.push.Registrar import app.dapk.st.push.Registrar
@ -15,6 +16,8 @@ import app.dapk.st.settings.SettingsEvent.*
import app.dapk.st.viewmodel.DapkViewModel import app.dapk.st.viewmodel.DapkViewModel
import app.dapk.st.viewmodel.MutableStateFactory import app.dapk.st.viewmodel.MutableStateFactory
import app.dapk.st.viewmodel.defaultStateFactory import app.dapk.st.viewmodel.defaultStateFactory
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
private const val PRIVACY_POLICY_URL = "https://ouchadam.github.io/small-talk/privacy/" private const val PRIVACY_POLICY_URL = "https://ouchadam.github.io/small-talk/privacy/"
@ -120,17 +123,34 @@ internal class SettingsViewModel(
} }
fun importFromFileKeys(file: Uri, passphrase: String) { fun importFromFileKeys(file: Uri, passphrase: String) {
updatePageState<Page.ImportRoomKey> { copy(importProgress = Lce.Loading()) } updatePageState<Page.ImportRoomKey> { copy(importProgress = ImportResult.Update(0)) }
viewModelScope.launch { viewModelScope.launch {
kotlin.runCatching { with(cryptoService) {
with(cryptoService) { runCatching { contentResolver.openInputStream(file)!! }
val roomsToRefresh = contentResolver.openInputStream(file)?.importRoomKeys(passphrase) .fold(
roomsToRefresh?.let { syncService.forceManualRefresh(roomsToRefresh) } onSuccess = { fileStream ->
} fileStream.importRoomKeys(passphrase)
}.fold( .onEach {
onSuccess = { updatePageState<Page.ImportRoomKey> { copy(importProgress = Lce.Content(Unit)) } }, updatePageState<Page.ImportRoomKey> { copy(importProgress = it) }
onFailure = { updatePageState<Page.ImportRoomKey> { copy(importProgress = Lce.Error(it)) } } when (it) {
) is ImportResult.Error -> {
// do nothing
}
is ImportResult.Update -> {
// do nothing
}
is ImportResult.Success -> {
syncService.forceManualRefresh(it.roomIds.toList())
}
}
}
.launchIn(viewModelScope)
},
onFailure = {
updatePageState<Page.ImportRoomKey> { copy(importProgress = ImportResult.Error(ImportResult.Error.Type.UnableToOpenFile)) }
}
)
}
} }
} }

View File

@ -3,6 +3,7 @@ package app.dapk.st.settings
import ViewModelTest import ViewModelTest
import app.dapk.st.core.Lce import app.dapk.st.core.Lce
import app.dapk.st.design.components.SpiderPage import app.dapk.st.design.components.SpiderPage
import app.dapk.st.matrix.crypto.ImportResult
import fake.* import fake.*
import fixture.FakeStoreCleaner import fixture.FakeStoreCleaner
import fixture.aRoomId import fixture.aRoomId
@ -10,6 +11,7 @@ import internalfake.FakeSettingsItemFactory
import internalfake.FakeUriFilenameResolver import internalfake.FakeUriFilenameResolver
import internalfixture.aImportRoomKeysPage import internalfixture.aImportRoomKeysPage
import internalfixture.aSettingTextItem import internalfixture.aSettingTextItem
import kotlinx.coroutines.flow.flowOf
import org.junit.Test import org.junit.Test
private const val APP_PRIVACY_POLICY_URL = "https://ouchadam.github.io/small-talk/privacy/" private const val APP_PRIVACY_POLICY_URL = "https://ouchadam.github.io/small-talk/privacy/"
@ -21,6 +23,8 @@ private val A_IMPORT_ROOM_KEYS_PAGE_WITH_SELECTION = aImportRoomKeysPage(
state = Page.ImportRoomKey(selectedFile = NamedUri(A_FILENAME, A_URI.instance)) state = Page.ImportRoomKey(selectedFile = NamedUri(A_FILENAME, A_URI.instance))
) )
private val A_LIST_OF_ROOM_IDS = listOf(aRoomId()) private val A_LIST_OF_ROOM_IDS = listOf(aRoomId())
private val AN_IMPORT_SUCCESS = ImportResult.Success(A_LIST_OF_ROOM_IDS.toSet(), totalImportedKeysCount = 5)
private val AN_IMPORT_FILE_ERROR = ImportResult.Error(ImportResult.Error.Type.UnableToOpenFile)
private val AN_INPUT_STREAM = FakeInputStream() private val AN_INPUT_STREAM = FakeInputStream()
private const val A_PASSPHRASE = "passphrase" private const val A_PASSPHRASE = "passphrase"
private val AN_ERROR = RuntimeException() private val AN_ERROR = RuntimeException()
@ -166,15 +170,15 @@ internal class SettingsViewModelTest {
fun `given success when importing room keys, then emits progress`() = runViewModelTest { fun `given success when importing room keys, then emits progress`() = runViewModelTest {
fakeSyncService.expectUnit { it.forceManualRefresh(A_LIST_OF_ROOM_IDS) } fakeSyncService.expectUnit { it.forceManualRefresh(A_LIST_OF_ROOM_IDS) }
fakeContentResolver.givenFile(A_URI.instance).returns(AN_INPUT_STREAM.instance) fakeContentResolver.givenFile(A_URI.instance).returns(AN_INPUT_STREAM.instance)
fakeCryptoService.givenImportKeys(AN_INPUT_STREAM.instance, A_PASSPHRASE).returns(A_LIST_OF_ROOM_IDS) fakeCryptoService.givenImportKeys(AN_INPUT_STREAM.instance, A_PASSPHRASE).returns(flowOf(AN_IMPORT_SUCCESS))
viewModel viewModel
.test(initialState = SettingsScreenState(A_IMPORT_ROOM_KEYS_PAGE_WITH_SELECTION)) .test(initialState = SettingsScreenState(A_IMPORT_ROOM_KEYS_PAGE_WITH_SELECTION))
.importFromFileKeys(A_URI.instance, A_PASSPHRASE) .importFromFileKeys(A_URI.instance, A_PASSPHRASE)
assertStates<SettingsScreenState>( assertStates<SettingsScreenState>(
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = Lce.Loading()) }) }, { copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = ImportResult.Update(0L)) }) },
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = Lce.Content(Unit)) }) }, { copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = AN_IMPORT_SUCCESS) }) },
) )
assertNoEvents<SettingsEvent>() assertNoEvents<SettingsEvent>()
verifyExpects() verifyExpects()
@ -189,8 +193,8 @@ internal class SettingsViewModelTest {
.importFromFileKeys(A_URI.instance, A_PASSPHRASE) .importFromFileKeys(A_URI.instance, A_PASSPHRASE)
assertStates<SettingsScreenState>( assertStates<SettingsScreenState>(
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = Lce.Loading()) }) }, { copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = ImportResult.Update(0L)) }) },
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = Lce.Error(AN_ERROR)) }) }, { copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = AN_IMPORT_FILE_ERROR) }) },
) )
assertNoEvents<SettingsEvent>() assertNoEvents<SettingsEvent>()
} }

View File

@ -18,7 +18,7 @@ interface CryptoService : MatrixService {
suspend fun encrypt(roomId: RoomId, credentials: DeviceCredentials, messageJson: JsonString): Crypto.EncryptionResult suspend fun encrypt(roomId: RoomId, credentials: DeviceCredentials, messageJson: JsonString): Crypto.EncryptionResult
suspend fun decrypt(encryptedPayload: EncryptedMessageContent): DecryptionResult suspend fun decrypt(encryptedPayload: EncryptedMessageContent): DecryptionResult
suspend fun importRoomKeys(keys: List<SharedRoomKey>) suspend fun importRoomKeys(keys: List<SharedRoomKey>)
suspend fun InputStream.importRoomKeys(password: String): List<RoomId> suspend fun InputStream.importRoomKeys(password: String): Flow<ImportResult>
suspend fun maybeCreateMoreKeys(serverKeyCount: ServerKeyCount) suspend fun maybeCreateMoreKeys(serverKeyCount: ServerKeyCount)
suspend fun updateOlmSession(userIds: List<UserId>, syncToken: SyncToken?) suspend fun updateOlmSession(userIds: List<UserId>, syncToken: SyncToken?)
@ -159,4 +159,19 @@ fun MatrixServiceProvider.cryptoService(): CryptoService = this.getService(key =
fun interface RoomMembersProvider { fun interface RoomMembersProvider {
suspend fun userIdsForRoom(roomId: RoomId): List<UserId> suspend fun userIdsForRoom(roomId: RoomId): List<UserId>
} }
sealed interface ImportResult {
data class Success(val roomIds: Set<RoomId>, val totalImportedKeysCount: Long) : ImportResult
data class Error(val cause: Type) : ImportResult {
sealed interface Type {
data class Unknown(val cause: Throwable): Type
object NoKeysFound: Type
object UnexpectedDecryptionOutput: Type
object UnableToOpenFile: Type
}
}
data class Update(val importedKeysCount: Long) : ImportResult
}

View File

@ -1,8 +1,10 @@
package app.dapk.st.matrix.crypto.internal package app.dapk.st.matrix.crypto.internal
import app.dapk.st.core.logP
import app.dapk.st.matrix.common.* import app.dapk.st.matrix.common.*
import app.dapk.st.matrix.crypto.Crypto import app.dapk.st.matrix.crypto.Crypto
import app.dapk.st.matrix.crypto.CryptoService import app.dapk.st.matrix.crypto.CryptoService
import app.dapk.st.matrix.crypto.ImportResult
import app.dapk.st.matrix.crypto.Verification import app.dapk.st.matrix.crypto.Verification
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import java.io.InputStream import java.io.InputStream
@ -47,9 +49,11 @@ internal class DefaultCryptoService(
verificationHandler.onUserVerificationAction(verificationAction) verificationHandler.onUserVerificationAction(verificationAction)
} }
override suspend fun InputStream.importRoomKeys(password: String): List<RoomId> { override suspend fun InputStream.importRoomKeys(password: String): Flow<ImportResult> {
return with(roomKeyImporter) { return with(roomKeyImporter) {
importRoomKeys(password) { importRoomKeys(it) } importRoomKeys(password) {
importRoomKeys(it)
}.logP("import room keys")
} }
} }
} }

View File

@ -14,7 +14,6 @@ internal class OlmCrypto(
) { ) {
suspend fun importRoomKeys(keys: List<SharedRoomKey>) { suspend fun importRoomKeys(keys: List<SharedRoomKey>) {
logger.crypto("import room keys : ${keys.size}")
olm.import(keys) olm.import(keys)
} }

View File

@ -2,15 +2,18 @@ package app.dapk.st.matrix.crypto.internal
import app.dapk.st.core.Base64 import app.dapk.st.core.Base64
import app.dapk.st.core.CoroutineDispatchers import app.dapk.st.core.CoroutineDispatchers
import app.dapk.st.core.withIoContext
import app.dapk.st.matrix.common.AlgorithmName import app.dapk.st.matrix.common.AlgorithmName
import app.dapk.st.matrix.common.RoomId import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.common.SessionId import app.dapk.st.matrix.common.SessionId
import app.dapk.st.matrix.common.SharedRoomKey import app.dapk.st.matrix.common.SharedRoomKey
import app.dapk.st.matrix.crypto.ImportResult
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.serialization.SerialName import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
import java.io.IOException
import java.io.InputStream import java.io.InputStream
import java.nio.charset.Charset import java.nio.charset.Charset
import javax.crypto.Cipher import javax.crypto.Cipher
@ -28,53 +31,29 @@ class RoomKeyImporter(
private val dispatchers: CoroutineDispatchers, private val dispatchers: CoroutineDispatchers,
) { ) {
suspend fun InputStream.importRoomKeys(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit): List<RoomId> { suspend fun InputStream.importRoomKeys(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit): Flow<ImportResult> {
return dispatchers.withIoContext { return flow {
val decryptCipher = Cipher.getInstance("AES/CTR/NoPadding") runCatching { this@importRoomKeys.import(password, onChunk, this) }
var jsonSegment = "" .onFailure {
when (it) {
fun <T> Sequence<T>.accumulateJson() = this.mapNotNull { is ImportException -> emit(ImportResult.Error(it.type))
val withLatest = jsonSegment + it else -> emit(ImportResult.Error(ImportResult.Error.Type.Unknown(it)))
try {
when (val objectRange = withLatest.findClosingIndex()) {
null -> {
jsonSegment = withLatest
null
}
else -> {
val string = withLatest.substring(objectRange)
importJson.decodeFromString(ElementMegolmExportObject.serializer(), string).also {
jsonSegment = withLatest.replace(string, "").removePrefix(",")
}
}
} }
} catch (error: Throwable) {
jsonSegment = withLatest
null
} }
} }.flowOn(dispatchers.io)
}
this@importRoomKeys.bufferedReader().use { private suspend fun InputStream.import(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit, collector: FlowCollector<ImportResult>) {
val roomIds = mutableSetOf<RoomId>() var importedKeysCount = 0L
val roomIds = mutableSetOf<RoomId>()
this.bufferedReader().use {
with(JsonAccumulator()) {
it.useLines { sequence -> it.useLines { sequence ->
sequence sequence
.filterNot { it == HEADER_LINE || it == TRAILER_LINE || it.isEmpty() } .filterNot { it == HEADER_LINE || it == TRAILER_LINE || it.isEmpty() }
.chunked(2) .chunked(5)
.withIndex() .decrypt(password)
.map { (index, it) ->
val line = it.joinToString(separator = "").replace("\n", "")
val toByteArray = base64.decode(line)
if (index == 0) {
decryptCipher.initialize(toByteArray, password)
toByteArray.copyOfRange(37, toByteArray.size).decrypt(decryptCipher).also {
if (!it.startsWith("[{")) {
throw IllegalArgumentException("Unable to decrypt, assumed invalid password")
}
}
} else {
toByteArray.decrypt(decryptCipher)
}
}
.accumulateJson() .accumulateJson()
.map { decoded -> .map { decoded ->
roomIds.add(decoded.roomId) roomIds.add(decoded.roomId)
@ -86,13 +65,39 @@ class RoomKeyImporter(
isExported = true, isExported = true,
) )
} }
.chunked(50) .chunked(500)
.forEach { onChunk(it) } .forEach {
} onChunk(it)
roomIds.toList().ifEmpty { importedKeysCount += it.size
throw IOException("Found no rooms to import in the file") collector.emit(ImportResult.Update(importedKeysCount))
}
} }
} }
when {
roomIds.isEmpty() -> collector.emit(ImportResult.Error(ImportResult.Error.Type.NoKeysFound))
else -> collector.emit(ImportResult.Success(roomIds, importedKeysCount))
}
}
}
private fun Sequence<List<String>>.decrypt(password: String): Sequence<String> {
val decryptCipher = Cipher.getInstance("AES/CTR/NoPadding")
return this.withIndex().map { (index, it) ->
val line = it.joinToString(separator = "").replace("\n", "")
val toByteArray = base64.decode(line)
if (index == 0) {
decryptCipher.initialize(toByteArray, password)
toByteArray
.copyOfRange(37, toByteArray.size)
.decrypt(decryptCipher)
.also {
if (!it.startsWith("[{")) {
throw ImportException(ImportResult.Error.Type.UnexpectedDecryptionOutput)
}
}
} else {
toByteArray.decrypt(decryptCipher)
}
} }
} }
@ -149,28 +154,6 @@ class RoomKeyImporter(
private fun Byte.toUnsignedInt() = toInt() and 0xff private fun Byte.toUnsignedInt() = toInt() and 0xff
private fun String.findClosingIndex(): IntRange? {
var opens = 0
var openIndex = -1
this.forEachIndexed { index, c ->
when {
c == '{' -> {
if (opens == 0) {
openIndex = index
}
opens++
}
c == '}' -> {
opens--
if (opens == 0) {
return IntRange(openIndex, index)
}
}
}
}
return null
}
@Serializable @Serializable
private data class ElementMegolmExportObject( private data class ElementMegolmExportObject(
@SerialName("room_id") val roomId: RoomId, @SerialName("room_id") val roomId: RoomId,
@ -178,3 +161,53 @@ private data class ElementMegolmExportObject(
@SerialName("session_id") val sessionId: SessionId, @SerialName("session_id") val sessionId: SessionId,
@SerialName("algorithm") val algorithmName: AlgorithmName, @SerialName("algorithm") val algorithmName: AlgorithmName,
) )
private class ImportException(val type: ImportResult.Error.Type) : Throwable()
private class JsonAccumulator {
private var jsonSegment = ""
fun <T> Sequence<T>.accumulateJson() = this.mapNotNull {
val withLatest = jsonSegment + it
try {
when (val objectRange = withLatest.findClosingIndex()) {
null -> {
jsonSegment = withLatest
null
}
else -> {
val string = withLatest.substring(objectRange)
importJson.decodeFromString(ElementMegolmExportObject.serializer(), string).also {
jsonSegment = withLatest.replace(string, "").removePrefix(",")
}
}
}
} catch (error: Throwable) {
jsonSegment = withLatest
null
}
}
private fun String.findClosingIndex(): IntRange? {
var opens = 0
var openIndex = -1
this.forEachIndexed { index, c ->
when {
c == '{' -> {
if (opens == 0) {
openIndex = index
}
opens++
}
c == '}' -> {
opens--
if (opens == 0) {
return IntRange(openIndex, index)
}
}
}
}
return null
}
}

View File

@ -4,11 +4,13 @@ import app.dapk.st.matrix.common.HomeServerUrl
import app.dapk.st.matrix.common.RoomId import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.common.RoomMember import app.dapk.st.matrix.common.RoomMember
import app.dapk.st.matrix.common.UserId import app.dapk.st.matrix.common.UserId
import app.dapk.st.matrix.crypto.ImportResult
import app.dapk.st.matrix.crypto.Verification import app.dapk.st.matrix.crypto.Verification
import app.dapk.st.matrix.crypto.cryptoService import app.dapk.st.matrix.crypto.cryptoService
import app.dapk.st.matrix.room.roomService import app.dapk.st.matrix.room.roomService
import app.dapk.st.matrix.sync.syncService import app.dapk.st.matrix.sync.syncService
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.test.runTest import kotlinx.coroutines.test.runTest
import org.amshove.kluent.shouldBeEqualTo import org.amshove.kluent.shouldBeEqualTo
@ -97,10 +99,13 @@ class SmokeTest {
val stream = loadResourceStream("element-keys.txt") val stream = loadResourceStream("element-keys.txt")
val result = with(cryptoService) { val result = with(cryptoService) {
stream.importRoomKeys(password = "aaaaaa") stream.importRoomKeys(password = "aaaaaa").first { it is ImportResult.Success }
} }
result shouldBeEqualTo listOf(RoomId(value = "!qOSENTtFUuCEKJSVzl:matrix.org")) result shouldBeEqualTo ImportResult.Success(
roomIds = setOf(RoomId(value = "!qOSENTtFUuCEKJSVzl:matrix.org")),
totalImportedKeysCount = 28,
)
} }
private fun testTextMessaging(isEncrypted: Boolean) = testAfterInitialSync { alice, bob -> private fun testTextMessaging(isEncrypted: Boolean) = testAfterInitialSync { alice, bob ->

View File

@ -150,7 +150,6 @@ const incrementVersionFile = async (github, branchName) => {
name: updatedVersionName, name: updatedVersionName,
} }
const encodedContentUpdate = Buffer.from(JSON.stringify(updatedVersionFile, null, 2)).toString('base64') const encodedContentUpdate = Buffer.from(JSON.stringify(updatedVersionFile, null, 2)).toString('base64')
await github.rest.repos.createOrUpdateFileContents({ await github.rest.repos.createOrUpdateFileContents({
owner: config.owner, owner: config.owner,