use mutex on suspend and not synchronized
This commit is contained in:
parent
49d33f3a4b
commit
6546f98858
|
@ -16,6 +16,8 @@
|
|||
|
||||
package org.matrix.android.sdk.internal.crypto.actions
|
||||
|
||||
import kotlinx.coroutines.sync.Mutex
|
||||
import kotlinx.coroutines.sync.withLock
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.matrix.android.sdk.api.MatrixCoroutineDispatchers
|
||||
import org.matrix.android.sdk.api.logger.LoggerTag
|
||||
|
@ -39,82 +41,84 @@ internal class EnsureOlmSessionsForDevicesAction @Inject constructor(
|
|||
private val coroutineDispatchers: MatrixCoroutineDispatchers,
|
||||
private val oneTimeKeysForUsersDeviceTask: ClaimOneTimeKeysForUsersDeviceTask) {
|
||||
|
||||
private val ensureMutex = Mutex()
|
||||
|
||||
/**
|
||||
* We want to synchronize a bit here, because we are iterating to check existing olm session and
|
||||
* also adding some
|
||||
*/
|
||||
@Synchronized
|
||||
suspend fun handle(devicesByUser: Map<String, List<CryptoDeviceInfo>>, force: Boolean = false): MXUsersDevicesMap<MXOlmSessionResult> {
|
||||
val deviceList = devicesByUser.flatMap { it.value }
|
||||
Timber.tag(loggerTag.value)
|
||||
.d("ensure olm forced:$force for ${deviceList.joinToString { it.shortDebugString() }}")
|
||||
val results = MXUsersDevicesMap<MXOlmSessionResult>()
|
||||
val devicesToCreateSessionWith = mutableListOf<CryptoDeviceInfo>()
|
||||
if (force) {
|
||||
// we take all devices and will query otk for them
|
||||
devicesToCreateSessionWith.addAll(deviceList)
|
||||
} else {
|
||||
// only peek devices without active session
|
||||
deviceList.forEach { deviceInfo ->
|
||||
val deviceId = deviceInfo.deviceId
|
||||
val userId = deviceInfo.userId
|
||||
val key = deviceInfo.identityKey() ?: return@forEach Unit.also {
|
||||
Timber.tag(loggerTag.value).w("Ignoring device ${deviceInfo.shortDebugString()} without identity key")
|
||||
}
|
||||
ensureMutex.withLock {
|
||||
val results = MXUsersDevicesMap<MXOlmSessionResult>()
|
||||
val deviceList = devicesByUser.flatMap { it.value }
|
||||
Timber.tag(loggerTag.value)
|
||||
.d("ensure olm forced:$force for ${deviceList.joinToString { it.shortDebugString() }}")
|
||||
val devicesToCreateSessionWith = mutableListOf<CryptoDeviceInfo>()
|
||||
if (force) {
|
||||
// we take all devices and will query otk for them
|
||||
devicesToCreateSessionWith.addAll(deviceList)
|
||||
} else {
|
||||
// only peek devices without active session
|
||||
deviceList.forEach { deviceInfo ->
|
||||
val deviceId = deviceInfo.deviceId
|
||||
val userId = deviceInfo.userId
|
||||
val key = deviceInfo.identityKey() ?: return@forEach Unit.also {
|
||||
Timber.tag(loggerTag.value).w("Ignoring device ${deviceInfo.shortDebugString()} without identity key")
|
||||
}
|
||||
|
||||
// is there a session that as been already used?
|
||||
val sessionId = olmDevice.getSessionId(key)
|
||||
if (sessionId.isNullOrEmpty()) {
|
||||
Timber.tag(loggerTag.value).d("Found no existing olm session ${deviceInfo.shortDebugString()} add to claim list")
|
||||
devicesToCreateSessionWith.add(deviceInfo)
|
||||
} else {
|
||||
Timber.tag(loggerTag.value).d("using olm session $sessionId for (${deviceInfo.userId}|$deviceId)")
|
||||
val olmSessionResult = MXOlmSessionResult(deviceInfo, sessionId)
|
||||
results.setObject(userId, deviceId, olmSessionResult)
|
||||
// is there a session that as been already used?
|
||||
val sessionId = olmDevice.getSessionId(key)
|
||||
if (sessionId.isNullOrEmpty()) {
|
||||
Timber.tag(loggerTag.value).d("Found no existing olm session ${deviceInfo.shortDebugString()} add to claim list")
|
||||
devicesToCreateSessionWith.add(deviceInfo)
|
||||
} else {
|
||||
Timber.tag(loggerTag.value).d("using olm session $sessionId for (${deviceInfo.userId}|$deviceId)")
|
||||
val olmSessionResult = MXOlmSessionResult(deviceInfo, sessionId)
|
||||
results.setObject(userId, deviceId, olmSessionResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (devicesToCreateSessionWith.isEmpty()) {
|
||||
// no session to create
|
||||
if (devicesToCreateSessionWith.isEmpty()) {
|
||||
// no session to create
|
||||
return results
|
||||
}
|
||||
val usersDevicesToClaim = MXUsersDevicesMap<String>().apply {
|
||||
devicesToCreateSessionWith.forEach {
|
||||
setObject(it.userId, it.deviceId, MXKey.KEY_SIGNED_CURVE_25519_TYPE)
|
||||
}
|
||||
}
|
||||
|
||||
// Let's now claim one time keys
|
||||
val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(usersDevicesToClaim)
|
||||
val oneTimeKeys = withContext(coroutineDispatchers.io) {
|
||||
oneTimeKeysForUsersDeviceTask.executeRetry(claimParams, ONE_TIME_KEYS_RETRY_COUNT)
|
||||
}
|
||||
|
||||
// let now start olm session using the new otks
|
||||
devicesToCreateSessionWith.forEach { deviceInfo ->
|
||||
val userId = deviceInfo.userId
|
||||
val deviceId = deviceInfo.deviceId
|
||||
// Did we get an OTK
|
||||
val oneTimeKey = oneTimeKeys.getObject(userId, deviceId)
|
||||
if (oneTimeKey == null) {
|
||||
Timber.tag(loggerTag.value).d("No otk for ${deviceInfo.shortDebugString()}")
|
||||
} else if (oneTimeKey.type != MXKey.KEY_SIGNED_CURVE_25519_TYPE) {
|
||||
Timber.tag(loggerTag.value).d("Bad otk type (${oneTimeKey.type}) for ${deviceInfo.shortDebugString()}")
|
||||
} else {
|
||||
val olmSessionId = verifyKeyAndStartSession(oneTimeKey, userId, deviceInfo)
|
||||
if (olmSessionId != null) {
|
||||
val olmSessionResult = MXOlmSessionResult(deviceInfo, olmSessionId)
|
||||
results.setObject(userId, deviceId, olmSessionResult)
|
||||
} else {
|
||||
Timber
|
||||
.tag(loggerTag.value)
|
||||
.d("## CRYPTO | cant unwedge failed to create outbound ${deviceInfo.shortDebugString()}")
|
||||
}
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
val usersDevicesToClaim = MXUsersDevicesMap<String>().apply {
|
||||
devicesToCreateSessionWith.forEach {
|
||||
setObject(it.userId, it.deviceId, MXKey.KEY_SIGNED_CURVE_25519_TYPE)
|
||||
}
|
||||
}
|
||||
|
||||
// Let's now claim one time keys
|
||||
val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(usersDevicesToClaim)
|
||||
val oneTimeKeys = withContext(coroutineDispatchers.io) {
|
||||
oneTimeKeysForUsersDeviceTask.executeRetry(claimParams, ONE_TIME_KEYS_RETRY_COUNT)
|
||||
}
|
||||
|
||||
// let now start olm session using the new otks
|
||||
devicesToCreateSessionWith.forEach { deviceInfo ->
|
||||
val userId = deviceInfo.userId
|
||||
val deviceId = deviceInfo.deviceId
|
||||
// Did we get an OTK
|
||||
val oneTimeKey = oneTimeKeys.getObject(userId, deviceId)
|
||||
if (oneTimeKey == null) {
|
||||
Timber.tag(loggerTag.value).d("No otk for ${deviceInfo.shortDebugString()}")
|
||||
} else if (oneTimeKey.type != MXKey.KEY_SIGNED_CURVE_25519_TYPE) {
|
||||
Timber.tag(loggerTag.value).d("Bad otk type (${oneTimeKey.type}) for ${deviceInfo.shortDebugString()}")
|
||||
} else {
|
||||
val olmSessionId = verifyKeyAndStartSession(oneTimeKey, userId, deviceInfo)
|
||||
if (olmSessionId != null) {
|
||||
val olmSessionResult = MXOlmSessionResult(deviceInfo, olmSessionId)
|
||||
results.setObject(userId, deviceId, olmSessionResult)
|
||||
} else {
|
||||
Timber
|
||||
.tag(loggerTag.value)
|
||||
.d("## CRYPTO | cant unwedge failed to create outbound ${deviceInfo.shortDebugString()}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
private fun verifyKeyAndStartSession(oneTimeKey: MXKey, userId: String, deviceInfo: CryptoDeviceInfo): String? {
|
||||
|
|
Loading…
Reference in New Issue