From 6546f9885805b4e3526386af0a8505b49d2520ed Mon Sep 17 00:00:00 2001 From: Valere Date: Tue, 1 Mar 2022 19:16:25 +0100 Subject: [PATCH] use mutex on suspend and not synchronized --- .../EnsureOlmSessionsForDevicesAction.kt | 134 +++++++++--------- 1 file changed, 69 insertions(+), 65 deletions(-) diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/actions/EnsureOlmSessionsForDevicesAction.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/actions/EnsureOlmSessionsForDevicesAction.kt index c7059ae4f8..87c176612d 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/actions/EnsureOlmSessionsForDevicesAction.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/actions/EnsureOlmSessionsForDevicesAction.kt @@ -16,6 +16,8 @@ package org.matrix.android.sdk.internal.crypto.actions +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.withContext import org.matrix.android.sdk.api.MatrixCoroutineDispatchers import org.matrix.android.sdk.api.logger.LoggerTag @@ -39,82 +41,84 @@ internal class EnsureOlmSessionsForDevicesAction @Inject constructor( private val coroutineDispatchers: MatrixCoroutineDispatchers, private val oneTimeKeysForUsersDeviceTask: ClaimOneTimeKeysForUsersDeviceTask) { + private val ensureMutex = Mutex() + /** * We want to synchronize a bit here, because we are iterating to check existing olm session and * also adding some */ - @Synchronized suspend fun handle(devicesByUser: Map>, force: Boolean = false): MXUsersDevicesMap { - val deviceList = devicesByUser.flatMap { it.value } - Timber.tag(loggerTag.value) - .d("ensure olm forced:$force for ${deviceList.joinToString { it.shortDebugString() }}") - val results = MXUsersDevicesMap() - val devicesToCreateSessionWith = mutableListOf() - if (force) { - // we take all devices and will query otk for them - devicesToCreateSessionWith.addAll(deviceList) - } else { - // only peek devices without active session - deviceList.forEach { deviceInfo -> - val deviceId = deviceInfo.deviceId - val userId = deviceInfo.userId - val key = deviceInfo.identityKey() ?: return@forEach Unit.also { - Timber.tag(loggerTag.value).w("Ignoring device ${deviceInfo.shortDebugString()} without identity key") - } + ensureMutex.withLock { + val results = MXUsersDevicesMap() + val deviceList = devicesByUser.flatMap { it.value } + Timber.tag(loggerTag.value) + .d("ensure olm forced:$force for ${deviceList.joinToString { it.shortDebugString() }}") + val devicesToCreateSessionWith = mutableListOf() + if (force) { + // we take all devices and will query otk for them + devicesToCreateSessionWith.addAll(deviceList) + } else { + // only peek devices without active session + deviceList.forEach { deviceInfo -> + val deviceId = deviceInfo.deviceId + val userId = deviceInfo.userId + val key = deviceInfo.identityKey() ?: return@forEach Unit.also { + Timber.tag(loggerTag.value).w("Ignoring device ${deviceInfo.shortDebugString()} without identity key") + } - // is there a session that as been already used? - val sessionId = olmDevice.getSessionId(key) - if (sessionId.isNullOrEmpty()) { - Timber.tag(loggerTag.value).d("Found no existing olm session ${deviceInfo.shortDebugString()} add to claim list") - devicesToCreateSessionWith.add(deviceInfo) - } else { - Timber.tag(loggerTag.value).d("using olm session $sessionId for (${deviceInfo.userId}|$deviceId)") - val olmSessionResult = MXOlmSessionResult(deviceInfo, sessionId) - results.setObject(userId, deviceId, olmSessionResult) + // is there a session that as been already used? + val sessionId = olmDevice.getSessionId(key) + if (sessionId.isNullOrEmpty()) { + Timber.tag(loggerTag.value).d("Found no existing olm session ${deviceInfo.shortDebugString()} add to claim list") + devicesToCreateSessionWith.add(deviceInfo) + } else { + Timber.tag(loggerTag.value).d("using olm session $sessionId for (${deviceInfo.userId}|$deviceId)") + val olmSessionResult = MXOlmSessionResult(deviceInfo, sessionId) + results.setObject(userId, deviceId, olmSessionResult) + } } } - } - if (devicesToCreateSessionWith.isEmpty()) { - // no session to create + if (devicesToCreateSessionWith.isEmpty()) { + // no session to create + return results + } + val usersDevicesToClaim = MXUsersDevicesMap().apply { + devicesToCreateSessionWith.forEach { + setObject(it.userId, it.deviceId, MXKey.KEY_SIGNED_CURVE_25519_TYPE) + } + } + + // Let's now claim one time keys + val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(usersDevicesToClaim) + val oneTimeKeys = withContext(coroutineDispatchers.io) { + oneTimeKeysForUsersDeviceTask.executeRetry(claimParams, ONE_TIME_KEYS_RETRY_COUNT) + } + + // let now start olm session using the new otks + devicesToCreateSessionWith.forEach { deviceInfo -> + val userId = deviceInfo.userId + val deviceId = deviceInfo.deviceId + // Did we get an OTK + val oneTimeKey = oneTimeKeys.getObject(userId, deviceId) + if (oneTimeKey == null) { + Timber.tag(loggerTag.value).d("No otk for ${deviceInfo.shortDebugString()}") + } else if (oneTimeKey.type != MXKey.KEY_SIGNED_CURVE_25519_TYPE) { + Timber.tag(loggerTag.value).d("Bad otk type (${oneTimeKey.type}) for ${deviceInfo.shortDebugString()}") + } else { + val olmSessionId = verifyKeyAndStartSession(oneTimeKey, userId, deviceInfo) + if (olmSessionId != null) { + val olmSessionResult = MXOlmSessionResult(deviceInfo, olmSessionId) + results.setObject(userId, deviceId, olmSessionResult) + } else { + Timber + .tag(loggerTag.value) + .d("## CRYPTO | cant unwedge failed to create outbound ${deviceInfo.shortDebugString()}") + } + } + } return results } - val usersDevicesToClaim = MXUsersDevicesMap().apply { - devicesToCreateSessionWith.forEach { - setObject(it.userId, it.deviceId, MXKey.KEY_SIGNED_CURVE_25519_TYPE) - } - } - - // Let's now claim one time keys - val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(usersDevicesToClaim) - val oneTimeKeys = withContext(coroutineDispatchers.io) { - oneTimeKeysForUsersDeviceTask.executeRetry(claimParams, ONE_TIME_KEYS_RETRY_COUNT) - } - - // let now start olm session using the new otks - devicesToCreateSessionWith.forEach { deviceInfo -> - val userId = deviceInfo.userId - val deviceId = deviceInfo.deviceId - // Did we get an OTK - val oneTimeKey = oneTimeKeys.getObject(userId, deviceId) - if (oneTimeKey == null) { - Timber.tag(loggerTag.value).d("No otk for ${deviceInfo.shortDebugString()}") - } else if (oneTimeKey.type != MXKey.KEY_SIGNED_CURVE_25519_TYPE) { - Timber.tag(loggerTag.value).d("Bad otk type (${oneTimeKey.type}) for ${deviceInfo.shortDebugString()}") - } else { - val olmSessionId = verifyKeyAndStartSession(oneTimeKey, userId, deviceInfo) - if (olmSessionId != null) { - val olmSessionResult = MXOlmSessionResult(deviceInfo, olmSessionId) - results.setObject(userId, deviceId, olmSessionResult) - } else { - Timber - .tag(loggerTag.value) - .d("## CRYPTO | cant unwedge failed to create outbound ${deviceInfo.shortDebugString()}") - } - } - } - - return results } private fun verifyKeyAndStartSession(oneTimeKey: MXKey, userId: String, deviceInfo: CryptoDeviceInfo): String? {