Improve key decryption perf

This commit is contained in:
Valere 2021-11-26 13:56:36 +01:00
parent 1635c9730a
commit 69e4b6e8a4
1 changed files with 31 additions and 30 deletions

View File

@ -23,6 +23,8 @@ import androidx.annotation.WorkerThread
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
@ -486,42 +488,42 @@ internal class RustKeyBackupService @Inject constructor(
val data = getKeys(sessionId, roomId, keysVersionResult.version)
return withContext(coroutineDispatchers.computation) {
val sessionsData = ArrayList<MegolmSessionData>()
// Restore that data
var sessionsFromHsCount = 0
cryptoCoroutineScope.launch(Dispatchers.Main) {
withContext(Dispatchers.Main) {
stepProgressListener?.onStepProgress(StepProgressListener.Step.DecryptingKey(0, data.roomIdToRoomKeysBackupData.size))
}
var progressDecryptIndex = 0
// TODO this is quite long, could we add some concurrency here?
for ((roomIdLoop, backupData) in data.roomIdToRoomKeysBackupData) {
val roomIndex = progressDecryptIndex
progressDecryptIndex++
cryptoCoroutineScope.launch(Dispatchers.Main) {
stepProgressListener?.onStepProgress(StepProgressListener.Step.DecryptingKey(roomIndex, data.roomIdToRoomKeysBackupData.size))
}
for ((sessionIdLoop, keyBackupData) in backupData.sessionIdToKeyBackupData) {
sessionsFromHsCount++
val sessionData = decryptKeyBackupData(keyBackupData, sessionIdLoop, roomIdLoop, recoveryKey)
// rust is not very lax and will throw if field are missing,
// add a check
// TODO maybe could be done on rust side?
sessionData?.takeIf {
it.isValid().also {
if (!it) {
Timber.w("restoreKeysWithRecoveryKey: malformed sessionData $sessionData")
// Decrypting by chunk of 500 keys in parallel
// we loose proper progress report but tested 3x faster on big backup
val sessionsData = data.roomIdToRoomKeysBackupData
.mapValues {
it.value.sessionIdToKeyBackupData
}
.flatMap { flat ->
flat.value.entries.map { flat.key to it }
}
.chunked(500)
.map { slice ->
async {
slice.mapNotNull { pair ->
decryptKeyBackupData(pair.second.value, pair.second.key, pair.first, recoveryKey)
?.takeIf { sessionData ->
sessionData.isValid().also {
if (!it) {
Timber.w("restoreKeysWithRecoveryKey: malformed sessionData $sessionData")
}
}
}
}
}
}?.let {
sessionsData.add(it)
}
}
.awaitAll()
.flatten()
withContext(Dispatchers.Main) {
stepProgressListener?.onStepProgress(StepProgressListener.Step.DecryptingKey(data.roomIdToRoomKeysBackupData.size, data.roomIdToRoomKeysBackupData.size))
}
Timber.v("restoreKeysWithRecoveryKey: Decrypted ${sessionsData.size} keys out" +
" of $sessionsFromHsCount from the backup store on the homeserver")
" of ${data.roomIdToRoomKeysBackupData.size} rooms from the backup store on the homeserver")
// Do not trigger a backup for them if they come from the backup version we are using
val backUp = keysVersionResult.version != keysBackupVersion?.version
@ -534,7 +536,6 @@ internal class RustKeyBackupService @Inject constructor(
val progressListener = if (stepProgressListener != null) {
object : ProgressListener {
override fun onProgress(progress: Int, total: Int) {
// Note: no need to post to UI thread, importMegolmSessionsData() will do it
cryptoCoroutineScope.launch(Dispatchers.Main) {
stepProgressListener.onStepProgress(StepProgressListener.Step.ImportingKey(progress, total))
}