use mutex on suspend and not synchronized

This commit is contained in:
Valere 2022-03-01 19:16:25 +01:00
parent 49d33f3a4b
commit 6546f98858
1 changed files with 69 additions and 65 deletions

View File

@ -16,6 +16,8 @@
package org.matrix.android.sdk.internal.crypto.actions package org.matrix.android.sdk.internal.crypto.actions
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import org.matrix.android.sdk.api.MatrixCoroutineDispatchers import org.matrix.android.sdk.api.MatrixCoroutineDispatchers
import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.api.logger.LoggerTag
@ -39,82 +41,84 @@ internal class EnsureOlmSessionsForDevicesAction @Inject constructor(
private val coroutineDispatchers: MatrixCoroutineDispatchers, private val coroutineDispatchers: MatrixCoroutineDispatchers,
private val oneTimeKeysForUsersDeviceTask: ClaimOneTimeKeysForUsersDeviceTask) { private val oneTimeKeysForUsersDeviceTask: ClaimOneTimeKeysForUsersDeviceTask) {
private val ensureMutex = Mutex()
/** /**
* We want to synchronize a bit here, because we are iterating to check existing olm session and * We want to synchronize a bit here, because we are iterating to check existing olm session and
* also adding some * also adding some
*/ */
@Synchronized
suspend fun handle(devicesByUser: Map<String, List<CryptoDeviceInfo>>, force: Boolean = false): MXUsersDevicesMap<MXOlmSessionResult> { suspend fun handle(devicesByUser: Map<String, List<CryptoDeviceInfo>>, force: Boolean = false): MXUsersDevicesMap<MXOlmSessionResult> {
val deviceList = devicesByUser.flatMap { it.value } ensureMutex.withLock {
Timber.tag(loggerTag.value) val results = MXUsersDevicesMap<MXOlmSessionResult>()
.d("ensure olm forced:$force for ${deviceList.joinToString { it.shortDebugString() }}") val deviceList = devicesByUser.flatMap { it.value }
val results = MXUsersDevicesMap<MXOlmSessionResult>() Timber.tag(loggerTag.value)
val devicesToCreateSessionWith = mutableListOf<CryptoDeviceInfo>() .d("ensure olm forced:$force for ${deviceList.joinToString { it.shortDebugString() }}")
if (force) { val devicesToCreateSessionWith = mutableListOf<CryptoDeviceInfo>()
// we take all devices and will query otk for them if (force) {
devicesToCreateSessionWith.addAll(deviceList) // we take all devices and will query otk for them
} else { devicesToCreateSessionWith.addAll(deviceList)
// only peek devices without active session } else {
deviceList.forEach { deviceInfo -> // only peek devices without active session
val deviceId = deviceInfo.deviceId deviceList.forEach { deviceInfo ->
val userId = deviceInfo.userId val deviceId = deviceInfo.deviceId
val key = deviceInfo.identityKey() ?: return@forEach Unit.also { val userId = deviceInfo.userId
Timber.tag(loggerTag.value).w("Ignoring device ${deviceInfo.shortDebugString()} without identity key") val key = deviceInfo.identityKey() ?: return@forEach Unit.also {
} Timber.tag(loggerTag.value).w("Ignoring device ${deviceInfo.shortDebugString()} without identity key")
}
// is there a session that as been already used? // is there a session that as been already used?
val sessionId = olmDevice.getSessionId(key) val sessionId = olmDevice.getSessionId(key)
if (sessionId.isNullOrEmpty()) { if (sessionId.isNullOrEmpty()) {
Timber.tag(loggerTag.value).d("Found no existing olm session ${deviceInfo.shortDebugString()} add to claim list") Timber.tag(loggerTag.value).d("Found no existing olm session ${deviceInfo.shortDebugString()} add to claim list")
devicesToCreateSessionWith.add(deviceInfo) devicesToCreateSessionWith.add(deviceInfo)
} else { } else {
Timber.tag(loggerTag.value).d("using olm session $sessionId for (${deviceInfo.userId}|$deviceId)") Timber.tag(loggerTag.value).d("using olm session $sessionId for (${deviceInfo.userId}|$deviceId)")
val olmSessionResult = MXOlmSessionResult(deviceInfo, sessionId) val olmSessionResult = MXOlmSessionResult(deviceInfo, sessionId)
results.setObject(userId, deviceId, olmSessionResult) results.setObject(userId, deviceId, olmSessionResult)
}
} }
} }
}
if (devicesToCreateSessionWith.isEmpty()) { if (devicesToCreateSessionWith.isEmpty()) {
// no session to create // no session to create
return results
}
val usersDevicesToClaim = MXUsersDevicesMap<String>().apply {
devicesToCreateSessionWith.forEach {
setObject(it.userId, it.deviceId, MXKey.KEY_SIGNED_CURVE_25519_TYPE)
}
}
// Let's now claim one time keys
val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(usersDevicesToClaim)
val oneTimeKeys = withContext(coroutineDispatchers.io) {
oneTimeKeysForUsersDeviceTask.executeRetry(claimParams, ONE_TIME_KEYS_RETRY_COUNT)
}
// let now start olm session using the new otks
devicesToCreateSessionWith.forEach { deviceInfo ->
val userId = deviceInfo.userId
val deviceId = deviceInfo.deviceId
// Did we get an OTK
val oneTimeKey = oneTimeKeys.getObject(userId, deviceId)
if (oneTimeKey == null) {
Timber.tag(loggerTag.value).d("No otk for ${deviceInfo.shortDebugString()}")
} else if (oneTimeKey.type != MXKey.KEY_SIGNED_CURVE_25519_TYPE) {
Timber.tag(loggerTag.value).d("Bad otk type (${oneTimeKey.type}) for ${deviceInfo.shortDebugString()}")
} else {
val olmSessionId = verifyKeyAndStartSession(oneTimeKey, userId, deviceInfo)
if (olmSessionId != null) {
val olmSessionResult = MXOlmSessionResult(deviceInfo, olmSessionId)
results.setObject(userId, deviceId, olmSessionResult)
} else {
Timber
.tag(loggerTag.value)
.d("## CRYPTO | cant unwedge failed to create outbound ${deviceInfo.shortDebugString()}")
}
}
}
return results return results
} }
val usersDevicesToClaim = MXUsersDevicesMap<String>().apply {
devicesToCreateSessionWith.forEach {
setObject(it.userId, it.deviceId, MXKey.KEY_SIGNED_CURVE_25519_TYPE)
}
}
// Let's now claim one time keys
val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(usersDevicesToClaim)
val oneTimeKeys = withContext(coroutineDispatchers.io) {
oneTimeKeysForUsersDeviceTask.executeRetry(claimParams, ONE_TIME_KEYS_RETRY_COUNT)
}
// let now start olm session using the new otks
devicesToCreateSessionWith.forEach { deviceInfo ->
val userId = deviceInfo.userId
val deviceId = deviceInfo.deviceId
// Did we get an OTK
val oneTimeKey = oneTimeKeys.getObject(userId, deviceId)
if (oneTimeKey == null) {
Timber.tag(loggerTag.value).d("No otk for ${deviceInfo.shortDebugString()}")
} else if (oneTimeKey.type != MXKey.KEY_SIGNED_CURVE_25519_TYPE) {
Timber.tag(loggerTag.value).d("Bad otk type (${oneTimeKey.type}) for ${deviceInfo.shortDebugString()}")
} else {
val olmSessionId = verifyKeyAndStartSession(oneTimeKey, userId, deviceInfo)
if (olmSessionId != null) {
val olmSessionResult = MXOlmSessionResult(deviceInfo, olmSessionId)
results.setObject(userId, deviceId, olmSessionResult)
} else {
Timber
.tag(loggerTag.value)
.d("## CRYPTO | cant unwedge failed to create outbound ${deviceInfo.shortDebugString()}")
}
}
}
return results
} }
private fun verifyKeyAndStartSession(oneTimeKey: MXKey, userId: String, deviceInfo: CryptoDeviceInfo): String? { private fun verifyKeyAndStartSession(oneTimeKey: MXKey, userId: String, deviceInfo: CryptoDeviceInfo): String? {