Merge pull request #117 from ouchadam/feature/import-keys-improvements
Import keys improvements
This commit is contained in:
commit
7fc8060d34
|
@ -1,5 +1,9 @@
|
|||
package app.dapk.st.core
|
||||
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.onCompletion
|
||||
import kotlinx.coroutines.flow.onStart
|
||||
|
||||
enum class AppLogTag(val key: String) {
|
||||
NOTIFICATION("notification"),
|
||||
PERFORMANCE("performance"),
|
||||
|
@ -26,3 +30,13 @@ suspend fun <T> logP(area: String, block: suspend () -> T): T {
|
|||
log(AppLogTag.PERFORMANCE, "$area: took $timeTaken ms")
|
||||
}
|
||||
}
|
||||
|
||||
fun <T> Flow<T>.logP(area: String): Flow<T> {
|
||||
var start = -1L
|
||||
return this
|
||||
.onStart { start = System.currentTimeMillis() }
|
||||
.onCompletion {
|
||||
val timeTaken = System.currentTimeMillis() - start
|
||||
log(AppLogTag.PERFORMANCE, "$area: took $timeTaken ms")
|
||||
}
|
||||
}
|
|
@ -6,3 +6,9 @@ sealed interface Lce<T> {
|
|||
data class Content<T>(val value: T) : Lce<T>
|
||||
}
|
||||
|
||||
sealed interface LceWithProgress<T> {
|
||||
data class Loading<T>(val progress: T) : LceWithProgress<T>
|
||||
data class Error<T>(val cause: Throwable) : LceWithProgress<T>
|
||||
data class Content<T>(val value: T) : LceWithProgress<T>
|
||||
}
|
||||
|
||||
|
|
|
@ -47,6 +47,10 @@ class OlmPersistenceWrapper(
|
|||
olmPersistence.persist(sessionId, SerializedObject(inboundGroupSession.serialize()))
|
||||
}
|
||||
|
||||
override suspend fun transaction(action: suspend () -> Unit) {
|
||||
olmPersistence.startTransaction { action() }
|
||||
}
|
||||
|
||||
override suspend fun readInbound(sessionId: SessionId): OlmInboundGroupSession? {
|
||||
return olmPersistence.readInbound(sessionId)?.value?.deserialize()
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ interface OlmStore {
|
|||
suspend fun read(): OlmAccount?
|
||||
suspend fun persist(olmAccount: OlmAccount)
|
||||
|
||||
suspend fun transaction(action: suspend () -> Unit)
|
||||
suspend fun readOutbound(roomId: RoomId): Pair<Long, OlmOutboundGroupSession>?
|
||||
suspend fun persistOutbound(roomId: RoomId, creationTimestampUtc: Long, outboundGroupSession: OlmOutboundGroupSession)
|
||||
suspend fun persistSession(identity: Curve25519, sessionId: SessionId, olmSession: OlmSession)
|
||||
|
|
|
@ -46,13 +46,15 @@ class OlmWrapper(
|
|||
|
||||
override suspend fun import(keys: List<SharedRoomKey>) {
|
||||
interactWithOlm()
|
||||
keys.forEach {
|
||||
val inBound = when (it.isExported) {
|
||||
true -> OlmInboundGroupSession.importSession(it.sessionKey)
|
||||
false -> OlmInboundGroupSession(it.sessionKey)
|
||||
|
||||
olmStore.transaction {
|
||||
keys.forEach {
|
||||
val inBound = when (it.isExported) {
|
||||
true -> OlmInboundGroupSession.importSession(it.sessionKey)
|
||||
false -> OlmInboundGroupSession(it.sessionKey)
|
||||
}
|
||||
olmStore.persist(it.sessionId, inBound)
|
||||
}
|
||||
logger.crypto("import megolm ${it.sessionKey}")
|
||||
olmStore.persist(it.sessionId, inBound)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,79 +4,109 @@ import app.dapk.db.DapkDb
|
|||
import app.dapk.db.model.DbCryptoAccount
|
||||
import app.dapk.db.model.DbCryptoMegolmInbound
|
||||
import app.dapk.db.model.DbCryptoMegolmOutbound
|
||||
import app.dapk.st.core.CoroutineDispatchers
|
||||
import app.dapk.st.core.withIoContext
|
||||
import app.dapk.st.matrix.common.CredentialsStore
|
||||
import app.dapk.st.matrix.common.Curve25519
|
||||
import app.dapk.st.matrix.common.RoomId
|
||||
import app.dapk.st.matrix.common.SessionId
|
||||
import com.squareup.sqldelight.TransactionWithoutReturn
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
|
||||
class OlmPersistence(
|
||||
private val database: DapkDb,
|
||||
private val credentialsStore: CredentialsStore,
|
||||
private val dispatchers: CoroutineDispatchers,
|
||||
) {
|
||||
|
||||
suspend fun read(): String? {
|
||||
return database.cryptoQueries
|
||||
.selectAccount(credentialsStore.credentials()!!.userId.value)
|
||||
.executeAsOneOrNull()
|
||||
return dispatchers.withIoContext {
|
||||
database.cryptoQueries
|
||||
.selectAccount(credentialsStore.credentials()!!.userId.value)
|
||||
.executeAsOneOrNull()
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun persist(olmAccount: SerializedObject) {
|
||||
database.cryptoQueries.insertAccount(
|
||||
DbCryptoAccount(
|
||||
user_id = credentialsStore.credentials()!!.userId.value,
|
||||
blob = olmAccount.value
|
||||
dispatchers.withIoContext {
|
||||
database.cryptoQueries.insertAccount(
|
||||
DbCryptoAccount(
|
||||
user_id = credentialsStore.credentials()!!.userId.value,
|
||||
blob = olmAccount.value
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun readOutbound(roomId: RoomId): Pair<Long, String>? {
|
||||
return database.cryptoQueries
|
||||
.selectMegolmOutbound(roomId.value)
|
||||
.executeAsOneOrNull()?.let {
|
||||
it.utcEpochMillis to it.blob
|
||||
}
|
||||
return dispatchers.withIoContext {
|
||||
database.cryptoQueries
|
||||
.selectMegolmOutbound(roomId.value)
|
||||
.executeAsOneOrNull()?.let {
|
||||
it.utcEpochMillis to it.blob
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun persistOutbound(roomId: RoomId, creationTimestampUtc: Long, outboundGroupSession: SerializedObject) {
|
||||
database.cryptoQueries.insertMegolmOutbound(
|
||||
DbCryptoMegolmOutbound(
|
||||
room_id = roomId.value,
|
||||
blob = outboundGroupSession.value,
|
||||
utcEpochMillis = creationTimestampUtc,
|
||||
dispatchers.withIoContext {
|
||||
database.cryptoQueries.insertMegolmOutbound(
|
||||
DbCryptoMegolmOutbound(
|
||||
room_id = roomId.value,
|
||||
blob = outboundGroupSession.value,
|
||||
utcEpochMillis = creationTimestampUtc,
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun persistSession(identity: Curve25519, sessionId: SessionId, olmSession: SerializedObject) {
|
||||
database.cryptoQueries.insertOlmSession(
|
||||
identity_key = identity.value,
|
||||
session_id = sessionId.value,
|
||||
blob = olmSession.value,
|
||||
)
|
||||
withContext(dispatchers.io) {
|
||||
database.cryptoQueries.insertOlmSession(
|
||||
identity_key = identity.value,
|
||||
session_id = sessionId.value,
|
||||
blob = olmSession.value,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun readSessions(identities: List<Curve25519>): List<Pair<Curve25519, String>>? {
|
||||
return database.cryptoQueries
|
||||
.selectOlmSession(identities.map { it.value })
|
||||
.executeAsList()
|
||||
.map { Curve25519(it.identity_key) to it.blob }
|
||||
.takeIf { it.isNotEmpty() }
|
||||
return withContext(dispatchers.io) {
|
||||
database.cryptoQueries
|
||||
.selectOlmSession(identities.map { it.value })
|
||||
.executeAsList()
|
||||
.map { Curve25519(it.identity_key) to it.blob }
|
||||
.takeIf { it.isNotEmpty() }
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun startTransaction(action: suspend TransactionWithoutReturn.() -> Unit) {
|
||||
val scope = CoroutineScope(dispatchers.io)
|
||||
database.cryptoQueries.transaction {
|
||||
scope.launch { action() }
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun persist(sessionId: SessionId, inboundGroupSession: SerializedObject) {
|
||||
database.cryptoQueries.insertMegolmInbound(
|
||||
DbCryptoMegolmInbound(
|
||||
session_id = sessionId.value,
|
||||
blob = inboundGroupSession.value
|
||||
withContext(dispatchers.io) {
|
||||
database.cryptoQueries.insertMegolmInbound(
|
||||
DbCryptoMegolmInbound(
|
||||
session_id = sessionId.value,
|
||||
blob = inboundGroupSession.value
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun readInbound(sessionId: SessionId): SerializedObject? {
|
||||
return database.cryptoQueries
|
||||
.selectMegolmInbound(sessionId.value)
|
||||
.executeAsOneOrNull()
|
||||
?.let { SerializedObject((it)) }
|
||||
return withContext(dispatchers.io) {
|
||||
database.cryptoQueries
|
||||
.selectMegolmInbound(sessionId.value)
|
||||
.executeAsOneOrNull()
|
||||
?.let { SerializedObject((it)) }
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ class StoreModule(
|
|||
|
||||
fun applicationStore() = ApplicationPreferences(preferences)
|
||||
|
||||
fun olmStore() = OlmPersistence(database, credentialsStore())
|
||||
fun olmStore() = OlmPersistence(database, credentialsStore(), coroutineDispatchers)
|
||||
fun knownDevicesStore() = DevicePersistence(database, KnownDevicesCache(), coroutineDispatchers)
|
||||
|
||||
fun profileStore(): ProfileStore = ProfilePersistence(preferences)
|
||||
|
|
|
@ -13,12 +13,14 @@ FROM dbEventLog;
|
|||
selectLatestByLog:
|
||||
SELECT id, tag, content, time(utcEpochSeconds,'unixepoch')
|
||||
FROM dbEventLog
|
||||
WHERE logParent = ?;
|
||||
WHERE logParent = ?
|
||||
ORDER BY utcEpochSeconds DESC;
|
||||
|
||||
selectLatestByLogFiltered:
|
||||
SELECT id, tag, content, time(utcEpochSeconds,'unixepoch')
|
||||
FROM dbEventLog
|
||||
WHERE logParent = ? AND tag = ?;
|
||||
WHERE logParent = ? AND tag = ?
|
||||
ORDER BY utcEpochSeconds DESC;
|
||||
|
||||
insert:
|
||||
INSERT INTO dbEventLog(tag, content, utcEpochSeconds, logParent)
|
||||
|
|
|
@ -31,11 +31,13 @@ import androidx.compose.ui.text.input.ImeAction
|
|||
import androidx.compose.ui.text.input.KeyboardType
|
||||
import androidx.compose.ui.text.input.PasswordVisualTransformation
|
||||
import androidx.compose.ui.text.input.VisualTransformation
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.text.style.TextOverflow
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.unit.sp
|
||||
import androidx.core.net.toUri
|
||||
import app.dapk.st.core.Lce
|
||||
import app.dapk.st.core.LceWithProgress
|
||||
import app.dapk.st.core.StartObserving
|
||||
import app.dapk.st.core.components.CenteredLoading
|
||||
import app.dapk.st.core.components.Header
|
||||
|
@ -43,6 +45,7 @@ import app.dapk.st.design.components.SettingsTextRow
|
|||
import app.dapk.st.design.components.Spider
|
||||
import app.dapk.st.design.components.SpiderPage
|
||||
import app.dapk.st.design.components.TextRow
|
||||
import app.dapk.st.matrix.crypto.ImportResult
|
||||
import app.dapk.st.navigator.Navigator
|
||||
import app.dapk.st.settings.SettingsEvent.*
|
||||
import app.dapk.st.settings.eventlogger.EventLogActivity
|
||||
|
@ -72,7 +75,7 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
|
|||
PushProviders(viewModel, it)
|
||||
}
|
||||
item(Page.Routes.importRoomKeys) {
|
||||
when (it.importProgress) {
|
||||
when (val result = it.importProgress) {
|
||||
null -> {
|
||||
Box(
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
|
@ -136,10 +139,11 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
|
|||
}
|
||||
}
|
||||
|
||||
is Lce.Content -> {
|
||||
is ImportResult.Success -> {
|
||||
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
|
||||
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
||||
Text(text = "Import success")
|
||||
Text(text = "Successfully imported ${result.totalImportedKeysCount} keys")
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
Button(onClick = { navigator.navigate.upToHome() }) {
|
||||
Text(text = "Close".uppercase())
|
||||
}
|
||||
|
@ -147,10 +151,18 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
|
|||
}
|
||||
}
|
||||
|
||||
is Lce.Error -> {
|
||||
is ImportResult.Error -> {
|
||||
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
|
||||
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
||||
Text(text = "Import failed")
|
||||
val message = when(val type = result.cause) {
|
||||
ImportResult.Error.Type.NoKeysFound -> "No keys found in the file"
|
||||
ImportResult.Error.Type.UnexpectedDecryptionOutput -> "Unable to decrypt file, double check your passphrase"
|
||||
is ImportResult.Error.Type.Unknown -> "${type.cause::class.java.simpleName}: ${type.cause.message}"
|
||||
ImportResult.Error.Type.UnableToOpenFile -> "Unable to open file"
|
||||
}
|
||||
|
||||
Text(text = "Import failed\n$message", textAlign = TextAlign.Center)
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
Button(onClick = { navigator.navigate.upToHome() }) {
|
||||
Text(text = "Close".uppercase())
|
||||
}
|
||||
|
@ -158,7 +170,15 @@ internal fun SettingsScreen(viewModel: SettingsViewModel, onSignOut: () -> Unit,
|
|||
}
|
||||
}
|
||||
|
||||
is Lce.Loading -> CenteredLoading()
|
||||
is ImportResult.Update -> {
|
||||
Box(Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
|
||||
Column(horizontalAlignment = Alignment.CenterHorizontally) {
|
||||
Text(text = "Importing ${result.importedKeysCount} keys...")
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
CircularProgressIndicator(Modifier.wrapContentSize())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,8 +2,10 @@ package app.dapk.st.settings
|
|||
|
||||
import android.net.Uri
|
||||
import app.dapk.st.core.Lce
|
||||
import app.dapk.st.core.LceWithProgress
|
||||
import app.dapk.st.design.components.Route
|
||||
import app.dapk.st.design.components.SpiderPage
|
||||
import app.dapk.st.matrix.crypto.ImportResult
|
||||
import app.dapk.st.push.Registrar
|
||||
|
||||
internal data class SettingsScreenState(
|
||||
|
@ -15,7 +17,7 @@ internal sealed interface Page {
|
|||
object Security : Page
|
||||
data class ImportRoomKey(
|
||||
val selectedFile: NamedUri? = null,
|
||||
val importProgress: Lce<Unit>? = null,
|
||||
val importProgress: ImportResult? = null,
|
||||
) : Page
|
||||
|
||||
data class PushProviders(
|
||||
|
|
|
@ -7,6 +7,7 @@ import app.dapk.st.core.Lce
|
|||
import app.dapk.st.design.components.SpiderPage
|
||||
import app.dapk.st.domain.StoreCleaner
|
||||
import app.dapk.st.matrix.crypto.CryptoService
|
||||
import app.dapk.st.matrix.crypto.ImportResult
|
||||
import app.dapk.st.matrix.sync.SyncService
|
||||
import app.dapk.st.push.PushTokenRegistrars
|
||||
import app.dapk.st.push.Registrar
|
||||
|
@ -15,6 +16,8 @@ import app.dapk.st.settings.SettingsEvent.*
|
|||
import app.dapk.st.viewmodel.DapkViewModel
|
||||
import app.dapk.st.viewmodel.MutableStateFactory
|
||||
import app.dapk.st.viewmodel.defaultStateFactory
|
||||
import kotlinx.coroutines.flow.launchIn
|
||||
import kotlinx.coroutines.flow.onEach
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
private const val PRIVACY_POLICY_URL = "https://ouchadam.github.io/small-talk/privacy/"
|
||||
|
@ -120,17 +123,34 @@ internal class SettingsViewModel(
|
|||
}
|
||||
|
||||
fun importFromFileKeys(file: Uri, passphrase: String) {
|
||||
updatePageState<Page.ImportRoomKey> { copy(importProgress = Lce.Loading()) }
|
||||
updatePageState<Page.ImportRoomKey> { copy(importProgress = ImportResult.Update(0)) }
|
||||
viewModelScope.launch {
|
||||
kotlin.runCatching {
|
||||
with(cryptoService) {
|
||||
val roomsToRefresh = contentResolver.openInputStream(file)?.importRoomKeys(passphrase)
|
||||
roomsToRefresh?.let { syncService.forceManualRefresh(roomsToRefresh) }
|
||||
}
|
||||
}.fold(
|
||||
onSuccess = { updatePageState<Page.ImportRoomKey> { copy(importProgress = Lce.Content(Unit)) } },
|
||||
onFailure = { updatePageState<Page.ImportRoomKey> { copy(importProgress = Lce.Error(it)) } }
|
||||
)
|
||||
with(cryptoService) {
|
||||
runCatching { contentResolver.openInputStream(file)!! }
|
||||
.fold(
|
||||
onSuccess = { fileStream ->
|
||||
fileStream.importRoomKeys(passphrase)
|
||||
.onEach {
|
||||
updatePageState<Page.ImportRoomKey> { copy(importProgress = it) }
|
||||
when (it) {
|
||||
is ImportResult.Error -> {
|
||||
// do nothing
|
||||
}
|
||||
is ImportResult.Update -> {
|
||||
// do nothing
|
||||
}
|
||||
is ImportResult.Success -> {
|
||||
syncService.forceManualRefresh(it.roomIds.toList())
|
||||
}
|
||||
}
|
||||
}
|
||||
.launchIn(viewModelScope)
|
||||
},
|
||||
onFailure = {
|
||||
updatePageState<Page.ImportRoomKey> { copy(importProgress = ImportResult.Error(ImportResult.Error.Type.UnableToOpenFile)) }
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ package app.dapk.st.settings
|
|||
import ViewModelTest
|
||||
import app.dapk.st.core.Lce
|
||||
import app.dapk.st.design.components.SpiderPage
|
||||
import app.dapk.st.matrix.crypto.ImportResult
|
||||
import fake.*
|
||||
import fixture.FakeStoreCleaner
|
||||
import fixture.aRoomId
|
||||
|
@ -10,6 +11,7 @@ import internalfake.FakeSettingsItemFactory
|
|||
import internalfake.FakeUriFilenameResolver
|
||||
import internalfixture.aImportRoomKeysPage
|
||||
import internalfixture.aSettingTextItem
|
||||
import kotlinx.coroutines.flow.flowOf
|
||||
import org.junit.Test
|
||||
|
||||
private const val APP_PRIVACY_POLICY_URL = "https://ouchadam.github.io/small-talk/privacy/"
|
||||
|
@ -21,6 +23,8 @@ private val A_IMPORT_ROOM_KEYS_PAGE_WITH_SELECTION = aImportRoomKeysPage(
|
|||
state = Page.ImportRoomKey(selectedFile = NamedUri(A_FILENAME, A_URI.instance))
|
||||
)
|
||||
private val A_LIST_OF_ROOM_IDS = listOf(aRoomId())
|
||||
private val AN_IMPORT_SUCCESS = ImportResult.Success(A_LIST_OF_ROOM_IDS.toSet(), totalImportedKeysCount = 5)
|
||||
private val AN_IMPORT_FILE_ERROR = ImportResult.Error(ImportResult.Error.Type.UnableToOpenFile)
|
||||
private val AN_INPUT_STREAM = FakeInputStream()
|
||||
private const val A_PASSPHRASE = "passphrase"
|
||||
private val AN_ERROR = RuntimeException()
|
||||
|
@ -166,15 +170,15 @@ internal class SettingsViewModelTest {
|
|||
fun `given success when importing room keys, then emits progress`() = runViewModelTest {
|
||||
fakeSyncService.expectUnit { it.forceManualRefresh(A_LIST_OF_ROOM_IDS) }
|
||||
fakeContentResolver.givenFile(A_URI.instance).returns(AN_INPUT_STREAM.instance)
|
||||
fakeCryptoService.givenImportKeys(AN_INPUT_STREAM.instance, A_PASSPHRASE).returns(A_LIST_OF_ROOM_IDS)
|
||||
fakeCryptoService.givenImportKeys(AN_INPUT_STREAM.instance, A_PASSPHRASE).returns(flowOf(AN_IMPORT_SUCCESS))
|
||||
|
||||
viewModel
|
||||
.test(initialState = SettingsScreenState(A_IMPORT_ROOM_KEYS_PAGE_WITH_SELECTION))
|
||||
.importFromFileKeys(A_URI.instance, A_PASSPHRASE)
|
||||
|
||||
assertStates<SettingsScreenState>(
|
||||
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = Lce.Loading()) }) },
|
||||
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = Lce.Content(Unit)) }) },
|
||||
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = ImportResult.Update(0L)) }) },
|
||||
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = AN_IMPORT_SUCCESS) }) },
|
||||
)
|
||||
assertNoEvents<SettingsEvent>()
|
||||
verifyExpects()
|
||||
|
@ -189,8 +193,8 @@ internal class SettingsViewModelTest {
|
|||
.importFromFileKeys(A_URI.instance, A_PASSPHRASE)
|
||||
|
||||
assertStates<SettingsScreenState>(
|
||||
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = Lce.Loading()) }) },
|
||||
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = Lce.Error(AN_ERROR)) }) },
|
||||
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = ImportResult.Update(0L)) }) },
|
||||
{ copy(page = page.updateState<Page.ImportRoomKey> { copy(importProgress = AN_IMPORT_FILE_ERROR) }) },
|
||||
)
|
||||
assertNoEvents<SettingsEvent>()
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ interface CryptoService : MatrixService {
|
|||
suspend fun encrypt(roomId: RoomId, credentials: DeviceCredentials, messageJson: JsonString): Crypto.EncryptionResult
|
||||
suspend fun decrypt(encryptedPayload: EncryptedMessageContent): DecryptionResult
|
||||
suspend fun importRoomKeys(keys: List<SharedRoomKey>)
|
||||
suspend fun InputStream.importRoomKeys(password: String): List<RoomId>
|
||||
suspend fun InputStream.importRoomKeys(password: String): Flow<ImportResult>
|
||||
|
||||
suspend fun maybeCreateMoreKeys(serverKeyCount: ServerKeyCount)
|
||||
suspend fun updateOlmSession(userIds: List<UserId>, syncToken: SyncToken?)
|
||||
|
@ -159,4 +159,19 @@ fun MatrixServiceProvider.cryptoService(): CryptoService = this.getService(key =
|
|||
|
||||
fun interface RoomMembersProvider {
|
||||
suspend fun userIdsForRoom(roomId: RoomId): List<UserId>
|
||||
}
|
||||
}
|
||||
|
||||
sealed interface ImportResult {
|
||||
data class Success(val roomIds: Set<RoomId>, val totalImportedKeysCount: Long) : ImportResult
|
||||
data class Error(val cause: Type) : ImportResult {
|
||||
|
||||
sealed interface Type {
|
||||
data class Unknown(val cause: Throwable): Type
|
||||
object NoKeysFound: Type
|
||||
object UnexpectedDecryptionOutput: Type
|
||||
object UnableToOpenFile: Type
|
||||
}
|
||||
|
||||
}
|
||||
data class Update(val importedKeysCount: Long) : ImportResult
|
||||
}
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
package app.dapk.st.matrix.crypto.internal
|
||||
|
||||
import app.dapk.st.core.logP
|
||||
import app.dapk.st.matrix.common.*
|
||||
import app.dapk.st.matrix.crypto.Crypto
|
||||
import app.dapk.st.matrix.crypto.CryptoService
|
||||
import app.dapk.st.matrix.crypto.ImportResult
|
||||
import app.dapk.st.matrix.crypto.Verification
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import java.io.InputStream
|
||||
|
@ -47,9 +49,11 @@ internal class DefaultCryptoService(
|
|||
verificationHandler.onUserVerificationAction(verificationAction)
|
||||
}
|
||||
|
||||
override suspend fun InputStream.importRoomKeys(password: String): List<RoomId> {
|
||||
override suspend fun InputStream.importRoomKeys(password: String): Flow<ImportResult> {
|
||||
return with(roomKeyImporter) {
|
||||
importRoomKeys(password) { importRoomKeys(it) }
|
||||
importRoomKeys(password) {
|
||||
importRoomKeys(it)
|
||||
}.logP("import room keys")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,6 @@ internal class OlmCrypto(
|
|||
) {
|
||||
|
||||
suspend fun importRoomKeys(keys: List<SharedRoomKey>) {
|
||||
logger.crypto("import room keys : ${keys.size}")
|
||||
olm.import(keys)
|
||||
}
|
||||
|
||||
|
|
|
@ -2,15 +2,18 @@ package app.dapk.st.matrix.crypto.internal
|
|||
|
||||
import app.dapk.st.core.Base64
|
||||
import app.dapk.st.core.CoroutineDispatchers
|
||||
import app.dapk.st.core.withIoContext
|
||||
import app.dapk.st.matrix.common.AlgorithmName
|
||||
import app.dapk.st.matrix.common.RoomId
|
||||
import app.dapk.st.matrix.common.SessionId
|
||||
import app.dapk.st.matrix.common.SharedRoomKey
|
||||
import app.dapk.st.matrix.crypto.ImportResult
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.FlowCollector
|
||||
import kotlinx.coroutines.flow.flow
|
||||
import kotlinx.coroutines.flow.flowOn
|
||||
import kotlinx.serialization.SerialName
|
||||
import kotlinx.serialization.Serializable
|
||||
import kotlinx.serialization.json.Json
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
import java.nio.charset.Charset
|
||||
import javax.crypto.Cipher
|
||||
|
@ -28,53 +31,29 @@ class RoomKeyImporter(
|
|||
private val dispatchers: CoroutineDispatchers,
|
||||
) {
|
||||
|
||||
suspend fun InputStream.importRoomKeys(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit): List<RoomId> {
|
||||
return dispatchers.withIoContext {
|
||||
val decryptCipher = Cipher.getInstance("AES/CTR/NoPadding")
|
||||
var jsonSegment = ""
|
||||
|
||||
fun <T> Sequence<T>.accumulateJson() = this.mapNotNull {
|
||||
val withLatest = jsonSegment + it
|
||||
try {
|
||||
when (val objectRange = withLatest.findClosingIndex()) {
|
||||
null -> {
|
||||
jsonSegment = withLatest
|
||||
null
|
||||
}
|
||||
else -> {
|
||||
val string = withLatest.substring(objectRange)
|
||||
importJson.decodeFromString(ElementMegolmExportObject.serializer(), string).also {
|
||||
jsonSegment = withLatest.replace(string, "").removePrefix(",")
|
||||
}
|
||||
}
|
||||
suspend fun InputStream.importRoomKeys(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit): Flow<ImportResult> {
|
||||
return flow {
|
||||
runCatching { this@importRoomKeys.import(password, onChunk, this) }
|
||||
.onFailure {
|
||||
when (it) {
|
||||
is ImportException -> emit(ImportResult.Error(it.type))
|
||||
else -> emit(ImportResult.Error(ImportResult.Error.Type.Unknown(it)))
|
||||
}
|
||||
} catch (error: Throwable) {
|
||||
jsonSegment = withLatest
|
||||
null
|
||||
}
|
||||
}
|
||||
}.flowOn(dispatchers.io)
|
||||
}
|
||||
|
||||
this@importRoomKeys.bufferedReader().use {
|
||||
val roomIds = mutableSetOf<RoomId>()
|
||||
private suspend fun InputStream.import(password: String, onChunk: suspend (List<SharedRoomKey>) -> Unit, collector: FlowCollector<ImportResult>) {
|
||||
var importedKeysCount = 0L
|
||||
val roomIds = mutableSetOf<RoomId>()
|
||||
|
||||
this.bufferedReader().use {
|
||||
with(JsonAccumulator()) {
|
||||
it.useLines { sequence ->
|
||||
sequence
|
||||
.filterNot { it == HEADER_LINE || it == TRAILER_LINE || it.isEmpty() }
|
||||
.chunked(2)
|
||||
.withIndex()
|
||||
.map { (index, it) ->
|
||||
val line = it.joinToString(separator = "").replace("\n", "")
|
||||
val toByteArray = base64.decode(line)
|
||||
if (index == 0) {
|
||||
decryptCipher.initialize(toByteArray, password)
|
||||
toByteArray.copyOfRange(37, toByteArray.size).decrypt(decryptCipher).also {
|
||||
if (!it.startsWith("[{")) {
|
||||
throw IllegalArgumentException("Unable to decrypt, assumed invalid password")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
toByteArray.decrypt(decryptCipher)
|
||||
}
|
||||
}
|
||||
.chunked(5)
|
||||
.decrypt(password)
|
||||
.accumulateJson()
|
||||
.map { decoded ->
|
||||
roomIds.add(decoded.roomId)
|
||||
|
@ -86,13 +65,39 @@ class RoomKeyImporter(
|
|||
isExported = true,
|
||||
)
|
||||
}
|
||||
.chunked(50)
|
||||
.forEach { onChunk(it) }
|
||||
}
|
||||
roomIds.toList().ifEmpty {
|
||||
throw IOException("Found no rooms to import in the file")
|
||||
.chunked(500)
|
||||
.forEach {
|
||||
onChunk(it)
|
||||
importedKeysCount += it.size
|
||||
collector.emit(ImportResult.Update(importedKeysCount))
|
||||
}
|
||||
}
|
||||
}
|
||||
when {
|
||||
roomIds.isEmpty() -> collector.emit(ImportResult.Error(ImportResult.Error.Type.NoKeysFound))
|
||||
else -> collector.emit(ImportResult.Success(roomIds, importedKeysCount))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun Sequence<List<String>>.decrypt(password: String): Sequence<String> {
|
||||
val decryptCipher = Cipher.getInstance("AES/CTR/NoPadding")
|
||||
return this.withIndex().map { (index, it) ->
|
||||
val line = it.joinToString(separator = "").replace("\n", "")
|
||||
val toByteArray = base64.decode(line)
|
||||
if (index == 0) {
|
||||
decryptCipher.initialize(toByteArray, password)
|
||||
toByteArray
|
||||
.copyOfRange(37, toByteArray.size)
|
||||
.decrypt(decryptCipher)
|
||||
.also {
|
||||
if (!it.startsWith("[{")) {
|
||||
throw ImportException(ImportResult.Error.Type.UnexpectedDecryptionOutput)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
toByteArray.decrypt(decryptCipher)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -149,28 +154,6 @@ class RoomKeyImporter(
|
|||
|
||||
private fun Byte.toUnsignedInt() = toInt() and 0xff
|
||||
|
||||
private fun String.findClosingIndex(): IntRange? {
|
||||
var opens = 0
|
||||
var openIndex = -1
|
||||
this.forEachIndexed { index, c ->
|
||||
when {
|
||||
c == '{' -> {
|
||||
if (opens == 0) {
|
||||
openIndex = index
|
||||
}
|
||||
opens++
|
||||
}
|
||||
c == '}' -> {
|
||||
opens--
|
||||
if (opens == 0) {
|
||||
return IntRange(openIndex, index)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
@Serializable
|
||||
private data class ElementMegolmExportObject(
|
||||
@SerialName("room_id") val roomId: RoomId,
|
||||
|
@ -178,3 +161,53 @@ private data class ElementMegolmExportObject(
|
|||
@SerialName("session_id") val sessionId: SessionId,
|
||||
@SerialName("algorithm") val algorithmName: AlgorithmName,
|
||||
)
|
||||
|
||||
private class ImportException(val type: ImportResult.Error.Type) : Throwable()
|
||||
|
||||
private class JsonAccumulator {
|
||||
|
||||
private var jsonSegment = ""
|
||||
|
||||
fun <T> Sequence<T>.accumulateJson() = this.mapNotNull {
|
||||
val withLatest = jsonSegment + it
|
||||
try {
|
||||
when (val objectRange = withLatest.findClosingIndex()) {
|
||||
null -> {
|
||||
jsonSegment = withLatest
|
||||
null
|
||||
}
|
||||
else -> {
|
||||
val string = withLatest.substring(objectRange)
|
||||
importJson.decodeFromString(ElementMegolmExportObject.serializer(), string).also {
|
||||
jsonSegment = withLatest.replace(string, "").removePrefix(",")
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error: Throwable) {
|
||||
jsonSegment = withLatest
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
private fun String.findClosingIndex(): IntRange? {
|
||||
var opens = 0
|
||||
var openIndex = -1
|
||||
this.forEachIndexed { index, c ->
|
||||
when {
|
||||
c == '{' -> {
|
||||
if (opens == 0) {
|
||||
openIndex = index
|
||||
}
|
||||
opens++
|
||||
}
|
||||
c == '}' -> {
|
||||
opens--
|
||||
if (opens == 0) {
|
||||
return IntRange(openIndex, index)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
}
|
|
@ -4,11 +4,13 @@ import app.dapk.st.matrix.common.HomeServerUrl
|
|||
import app.dapk.st.matrix.common.RoomId
|
||||
import app.dapk.st.matrix.common.RoomMember
|
||||
import app.dapk.st.matrix.common.UserId
|
||||
import app.dapk.st.matrix.crypto.ImportResult
|
||||
import app.dapk.st.matrix.crypto.Verification
|
||||
import app.dapk.st.matrix.crypto.cryptoService
|
||||
import app.dapk.st.matrix.room.roomService
|
||||
import app.dapk.st.matrix.sync.syncService
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kotlinx.coroutines.flow.onEach
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import org.amshove.kluent.shouldBeEqualTo
|
||||
|
@ -97,10 +99,13 @@ class SmokeTest {
|
|||
val stream = loadResourceStream("element-keys.txt")
|
||||
|
||||
val result = with(cryptoService) {
|
||||
stream.importRoomKeys(password = "aaaaaa")
|
||||
stream.importRoomKeys(password = "aaaaaa").first { it is ImportResult.Success }
|
||||
}
|
||||
|
||||
result shouldBeEqualTo listOf(RoomId(value = "!qOSENTtFUuCEKJSVzl:matrix.org"))
|
||||
result shouldBeEqualTo ImportResult.Success(
|
||||
roomIds = setOf(RoomId(value = "!qOSENTtFUuCEKJSVzl:matrix.org")),
|
||||
totalImportedKeysCount = 28,
|
||||
)
|
||||
}
|
||||
|
||||
private fun testTextMessaging(isEncrypted: Boolean) = testAfterInitialSync { alice, bob ->
|
||||
|
|
|
@ -150,7 +150,6 @@ const incrementVersionFile = async (github, branchName) => {
|
|||
name: updatedVersionName,
|
||||
}
|
||||
|
||||
|
||||
const encodedContentUpdate = Buffer.from(JSON.stringify(updatedVersionFile, null, 2)).toString('base64')
|
||||
await github.rest.repos.createOrUpdateFileContents({
|
||||
owner: config.owner,
|
||||
|
|
Loading…
Reference in New Issue