From f2387394385eb41b78cdbcc508078712eddceef8 Mon Sep 17 00:00:00 2001 From: Valere Date: Tue, 1 Mar 2022 16:18:43 +0100 Subject: [PATCH] Clean ensure olm, fix unwedging, better logs --- .../internal/crypto/DefaultCryptoService.kt | 8 + .../sdk/internal/crypto/EventDecryptor.kt | 157 +++++++++++------- .../EnsureOlmSessionsForDevicesAction.kt | 115 ++++++------- .../internal/crypto/model/CryptoDeviceInfo.kt | 2 + .../session/sync/handler/CryptoSyncHandler.kt | 4 +- 5 files changed, 165 insertions(+), 121 deletions(-) diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/DefaultCryptoService.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/DefaultCryptoService.kt index a66e0d4077..2667c25ea4 100755 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/DefaultCryptoService.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/DefaultCryptoService.kt @@ -434,6 +434,14 @@ internal class DefaultCryptoService @Inject constructor( val currentCount = syncResponse.deviceOneTimeKeysCount.signedCurve25519 ?: 0 oneTimeKeysUploader.updateOneTimeKeyCount(currentCount) } + + // unwedge if needed + try { + eventDecryptor.unwedgeDevicesIfNeeded() + } catch (failure: Throwable) { + Timber.tag(loggerTag.value).w("unwedgeDevicesIfNeeded failed") + } + // There is a limit of to_device events returned per sync. // If we are in a case of such limited to_device sync we can't try to generate/upload // new otk now, because there might be some pending olm pre-key to_device messages that would fail if we rotate diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/EventDecryptor.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/EventDecryptor.kt index 8a11e45740..9bf5cd594e 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/EventDecryptor.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/EventDecryptor.kt @@ -18,17 +18,15 @@ package org.matrix.android.sdk.internal.crypto import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.launch -import kotlinx.coroutines.withContext import org.matrix.android.sdk.api.MatrixCallback import org.matrix.android.sdk.api.MatrixCoroutineDispatchers +import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.api.session.crypto.MXCryptoError import org.matrix.android.sdk.api.session.events.model.Event import org.matrix.android.sdk.api.session.events.model.EventType import org.matrix.android.sdk.api.session.events.model.toModel import org.matrix.android.sdk.internal.crypto.actions.EnsureOlmSessionsForDevicesAction import org.matrix.android.sdk.internal.crypto.actions.MessageEncrypter -import org.matrix.android.sdk.internal.crypto.model.CryptoDeviceInfo -import org.matrix.android.sdk.internal.crypto.model.MXOlmSessionResult import org.matrix.android.sdk.internal.crypto.model.MXUsersDevicesMap import org.matrix.android.sdk.internal.crypto.model.event.OlmEventContent import org.matrix.android.sdk.internal.crypto.store.IMXCryptoStore @@ -40,6 +38,8 @@ import javax.inject.Inject private const val SEND_TO_DEVICE_RETRY_COUNT = 3 +private val loggerTag = LoggerTag("CryptoSyncHandler", LoggerTag.CRYPTO) + @SessionScope internal class EventDecryptor @Inject constructor( private val cryptoCoroutineScope: CoroutineScope, @@ -47,13 +47,22 @@ internal class EventDecryptor @Inject constructor( private val roomDecryptorProvider: RoomDecryptorProvider, private val messageEncrypter: MessageEncrypter, private val sendToDeviceTask: SendToDeviceTask, + private val deviceListManager: DeviceListManager, private val ensureOlmSessionsForDevicesAction: EnsureOlmSessionsForDevicesAction, private val cryptoStore: IMXCryptoStore ) { - // The date of the last time we forced establishment - // of a new session for each user:device. - private val lastNewSessionForcedDates = MXUsersDevicesMap() + /** + * Rate limit unwedge attempt, should we persist that? + */ + private val lastNewSessionForcedDates = mutableMapOf() + + data class WedgedDeviceInfo( + val userId: String, + val senderKey: String? + ) + + private val wedgedDevices = mutableListOf() /** * Decrypt an event @@ -94,35 +103,29 @@ internal class EventDecryptor @Inject constructor( private suspend fun internalDecryptEvent(event: Event, timeline: String): MXEventDecryptionResult { val eventContent = event.content if (eventContent == null) { - Timber.e("## CRYPTO | decryptEvent : empty event content") + Timber.e("decryptEvent : empty event content") throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_ENCRYPTED_MESSAGE, MXCryptoError.BAD_ENCRYPTED_MESSAGE_REASON) } else { val algorithm = eventContent["algorithm"]?.toString() val alg = roomDecryptorProvider.getOrCreateRoomDecryptor(event.roomId, algorithm) if (alg == null) { val reason = String.format(MXCryptoError.UNABLE_TO_DECRYPT_REASON, event.eventId, algorithm) - Timber.e("## CRYPTO | decryptEvent() : $reason") + Timber.e("decryptEvent() : $reason") throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, reason) } else { try { return alg.decryptEvent(event, timeline) } catch (mxCryptoError: MXCryptoError) { - Timber.v("## CRYPTO | internalDecryptEvent : Failed to decrypt ${event.eventId} reason: $mxCryptoError") + Timber.d("internalDecryptEvent : Failed to decrypt ${event.eventId} reason: $mxCryptoError") if (algorithm == MXCRYPTO_ALGORITHM_OLM) { if (mxCryptoError is MXCryptoError.Base && mxCryptoError.errorType == MXCryptoError.ErrorType.BAD_ENCRYPTED_MESSAGE) { // need to find sending device - cryptoCoroutineScope.launch(coroutineDispatchers.crypto) { - val olmContent = event.content.toModel() - cryptoStore.getUserDevices(event.senderId ?: "") - ?.values - ?.firstOrNull { it.identityKey() == olmContent?.senderKey } - ?.let { - markOlmSessionForUnwedging(event.senderId ?: "", it) - } - ?: run { - Timber.i("## CRYPTO | internalDecryptEvent() : Failed to find sender crypto device for unwedging") - } + val olmContent = event.content.toModel() + if (event.senderId != null && olmContent?.senderKey != null) { + markOlmSessionForUnwedging(event.senderId, olmContent.senderKey) + } else { + Timber.tag(loggerTag.value).d("Can't mark as wedge malformed") } } } @@ -132,53 +135,87 @@ internal class EventDecryptor @Inject constructor( } } - // coroutineDispatchers.crypto scope - private fun markOlmSessionForUnwedging(senderId: String, deviceInfo: CryptoDeviceInfo) { - val deviceKey = deviceInfo.identityKey() + private fun markOlmSessionForUnwedging(senderId: String, senderKey: String) { + val info = WedgedDeviceInfo(senderId, senderKey) + if (!wedgedDevices.contains(info)) { + Timber.tag(loggerTag.value).d("Marking device from $senderId key:$senderKey as wedged") + wedgedDevices.add(info) + } + } - val lastForcedDate = lastNewSessionForcedDates.getObject(senderId, deviceKey) ?: 0 + // coroutineDispatchers.crypto scope + suspend fun unwedgeDevicesIfNeeded() { + // handle wedged devices + // Some olm decryption have failed and some device are wedged + // we should force start a new session for those + Timber.tag(loggerTag.value).d("Unwedging: ${wedgedDevices.size} are wedged") + // get the one that should be retried according to rate limit val now = System.currentTimeMillis() - if (now - lastForcedDate < DefaultCryptoService.CRYPTO_MIN_FORCE_SESSION_PERIOD_MILLIS) { - Timber.w("## CRYPTO | markOlmSessionForUnwedging: New session already forced with device at $lastForcedDate. Not forcing another") + val toUnwedge = wedgedDevices.filter { + val lastForcedDate = lastNewSessionForcedDates[it] ?: 0 + if (now - lastForcedDate < DefaultCryptoService.CRYPTO_MIN_FORCE_SESSION_PERIOD_MILLIS) { + Timber.tag(loggerTag.value).d("Unwedging, New session for $it already forced with device at $lastForcedDate") + return@filter false + } + // let's already mark that we tried now + lastNewSessionForcedDates[it] = now + true + } + + if (toUnwedge.isEmpty()) { + Timber.tag(loggerTag.value).d("Nothing to unwedge") return } + Timber.tag(loggerTag.value).d("Unwedging, trying to create new session for ${toUnwedge.size} devices") - Timber.i("## CRYPTO | markOlmSessionForUnwedging from $senderId:${deviceInfo.deviceId}") - lastNewSessionForcedDates.setObject(senderId, deviceKey, now) + toUnwedge + .chunked(100) // safer to chunk if we ever have lots of wedged devices + .forEach { wedgedList -> + // lets download keys if needed + deviceListManager.downloadKeys(wedgedList.map { it.userId }, false) - // offload this from crypto thread (?) - cryptoCoroutineScope.launch(coroutineDispatchers.computation) { - runCatching { ensureOlmSessionsForDevicesAction.handle(mapOf(senderId to listOf(deviceInfo)), force = true) }.fold( - onSuccess = { sendDummyToDevice(ensured = it, deviceInfo, senderId) }, - onFailure = { - Timber.e("## CRYPTO | markOlmSessionForUnwedging() : failed to ensure device info ${senderId}${deviceInfo.deviceId}") - } - ) - } - } + // find the matching devices + wedgedList.groupBy { it.userId } + .map { groupedByUser -> + val userId = groupedByUser.key + val wedgeSenderKeysForUser = groupedByUser.value.map { it.senderKey } + val knownDevices = cryptoStore.getUserDevices(userId)?.values.orEmpty() + userId to wedgeSenderKeysForUser.mapNotNull { senderKey -> + knownDevices.firstOrNull { it.identityKey() == senderKey } + } + } + .toMap() + .let { deviceList -> + try { + // force creating new outbound session and mark them as most recent to + // be used for next encryption (dummy) + val sessionToUse = ensureOlmSessionsForDevicesAction.handle(deviceList, true) + Timber.tag(loggerTag.value).d("Unwedging, found ${sessionToUse.map.size} to send dummy to") - private suspend fun sendDummyToDevice(ensured: MXUsersDevicesMap, deviceInfo: CryptoDeviceInfo, senderId: String) { - Timber.i("## CRYPTO | markOlmSessionForUnwedging() : ensureOlmSessionsForDevicesAction isEmpty:${ensured.isEmpty}") + // Now send a blank message on that session so the other side knows about it. + // (The keyshare request is sent in the clear so that won't do) + val payloadJson = mapOf( + "type" to EventType.DUMMY + ) + val sendToDeviceMap = MXUsersDevicesMap() + sessionToUse.map.values + .flatMap { it.values } + .map { it.deviceInfo } + .forEach { deviceInfo -> + Timber.tag(loggerTag.value).v("encrypting dummy to ${deviceInfo.deviceId}") + val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo)) + sendToDeviceMap.setObject(deviceInfo.userId, deviceInfo.deviceId, encodedPayload) + } - // Now send a blank message on that session so the other side knows about it. - // (The keyshare request is sent in the clear so that won't do) - // We send this first such that, as long as the toDevice messages arrive in the - // same order we sent them, the other end will get this first, set up the new session, - // then get the keyshare request and send the key over this new session (because it - // is the session it has most recently received a message on). - val payloadJson = mapOf("type" to EventType.DUMMY) - - val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo)) - val sendToDeviceMap = MXUsersDevicesMap() - sendToDeviceMap.setObject(senderId, deviceInfo.deviceId, encodedPayload) - Timber.i("## CRYPTO | markOlmSessionForUnwedging() : sending dummy to $senderId:${deviceInfo.deviceId}") - withContext(coroutineDispatchers.io) { - val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, sendToDeviceMap) - try { - sendToDeviceTask.executeRetry(sendToDeviceParams, remainingRetry = SEND_TO_DEVICE_RETRY_COUNT) - } catch (failure: Throwable) { - Timber.e(failure, "## CRYPTO | markOlmSessionForUnwedging() : failed to send dummy to $senderId:${deviceInfo.deviceId}") - } - } + // now let's send that + val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, sendToDeviceMap) + sendToDeviceTask.executeRetry(sendToDeviceParams, remainingRetry = SEND_TO_DEVICE_RETRY_COUNT) + } catch (failure: Throwable) { + deviceList.flatMap { it.value }.joinToString { it.shortDebugString() }.let { + Timber.tag(loggerTag.value).e(failure, "## Failed to unwedge devices: $it}") + } + } + } + } } } diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/actions/EnsureOlmSessionsForDevicesAction.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/actions/EnsureOlmSessionsForDevicesAction.kt index ab2ed04dfb..58765a7043 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/actions/EnsureOlmSessionsForDevicesAction.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/actions/EnsureOlmSessionsForDevicesAction.kt @@ -22,8 +22,8 @@ import org.matrix.android.sdk.internal.crypto.model.CryptoDeviceInfo import org.matrix.android.sdk.internal.crypto.model.MXKey import org.matrix.android.sdk.internal.crypto.model.MXOlmSessionResult import org.matrix.android.sdk.internal.crypto.model.MXUsersDevicesMap -import org.matrix.android.sdk.internal.crypto.model.toDebugString import org.matrix.android.sdk.internal.crypto.tasks.ClaimOneTimeKeysForUsersDeviceTask +import org.matrix.android.sdk.internal.session.SessionScope import timber.log.Timber import javax.inject.Inject @@ -31,89 +31,84 @@ private const val ONE_TIME_KEYS_RETRY_COUNT = 3 private val loggerTag = LoggerTag("EnsureOlmSessionsForDevicesAction", LoggerTag.CRYPTO) +@SessionScope internal class EnsureOlmSessionsForDevicesAction @Inject constructor( private val olmDevice: MXOlmDevice, private val oneTimeKeysForUsersDeviceTask: ClaimOneTimeKeysForUsersDeviceTask) { + /** + * We want to synchronize a bit here, because we are iterating to check existing olm session and + * also adding some + */ + @Synchronized suspend fun handle(devicesByUser: Map>, force: Boolean = false): MXUsersDevicesMap { - val devicesWithoutSession = ArrayList() - + val deviceList = devicesByUser.flatMap { it.value } + Timber.tag(loggerTag.value) + .d("ensure olm forced:$force for ${deviceList.joinToString { it.shortDebugString() }}") val results = MXUsersDevicesMap() - - for ((userId, deviceList) in devicesByUser) { - for (deviceInfo in deviceList) { + val devicesToCreateSessionWith = mutableListOf() + if (force) { + // we take all devices and will query otk for them + devicesToCreateSessionWith.addAll(deviceList) + } else { + // only peek devices without active session + deviceList.forEach { deviceInfo -> val deviceId = deviceInfo.deviceId - val key = deviceInfo.identityKey() - if (key == null) { - Timber.w("## CRYPTO | Ignoring device (${deviceInfo.userId}|$deviceId) without identity key") - continue + val userId = deviceInfo.userId + val key = deviceInfo.identityKey() ?: return@forEach Unit.also { + Timber.tag(loggerTag.value).w("Ignoring device ${deviceInfo.shortDebugString()} without identity key") } + // is there a session that as been already used? val sessionId = olmDevice.getSessionId(key) - - if (sessionId.isNullOrEmpty() || force) { - Timber.tag(loggerTag.value).d("Found no existing olm session (${deviceInfo.userId}|$deviceId) (force=$force)") - devicesWithoutSession.add(deviceInfo) + if (sessionId.isNullOrEmpty()) { + Timber.tag(loggerTag.value).d("Found no existing olm session ${deviceInfo.shortDebugString()} add to claim list") + devicesToCreateSessionWith.add(deviceInfo) } else { Timber.tag(loggerTag.value).d("using olm session $sessionId for (${deviceInfo.userId}|$deviceId)") + val olmSessionResult = MXOlmSessionResult(deviceInfo, sessionId) + results.setObject(userId, deviceId, olmSessionResult) } - - val olmSessionResult = MXOlmSessionResult(deviceInfo, sessionId) - results.setObject(userId, deviceId, olmSessionResult) } } - Timber.tag(loggerTag.value).d("Devices without olm session (count:${devicesWithoutSession.size}) :" + - " ${devicesWithoutSession.joinToString { "${it.userId}|${it.deviceId}" }}") - if (devicesWithoutSession.size == 0) { + if (devicesToCreateSessionWith.isEmpty()) { + // no session to create return results } - - // Prepare the request for claiming one-time keys - val usersDevicesToClaim = MXUsersDevicesMap() - - val oneTimeKeyAlgorithm = MXKey.KEY_SIGNED_CURVE_25519_TYPE - - for (device in devicesWithoutSession) { - usersDevicesToClaim.setObject(device.userId, device.deviceId, oneTimeKeyAlgorithm) + val usersDevicesToClaim = MXUsersDevicesMap().apply { + devicesToCreateSessionWith.forEach { + setObject(it.userId, it.deviceId, MXKey.KEY_SIGNED_CURVE_25519_TYPE) + } } - // TODO: this has a race condition - if we try to send another message - // while we are claiming a key, we will end up claiming two and setting up - // two sessions. - // - // That should eventually resolve itself, but it's poor form. - - Timber.tag(loggerTag.value).i("claimOneTimeKeysForUsersDevices() : ${usersDevicesToClaim.toDebugString()}") - + // Let's now claim one time keys val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(usersDevicesToClaim) - val oneTimeKeys = oneTimeKeysForUsersDeviceTask.executeRetry(claimParams, remainingRetry = ONE_TIME_KEYS_RETRY_COUNT) - Timber.tag(loggerTag.value).v("claimOneTimeKeysForUsersDevices() : keysClaimResponse.oneTimeKeys: $oneTimeKeys") - for ((userId, deviceInfos) in devicesByUser) { - for (deviceInfo in deviceInfos) { - var oneTimeKey: MXKey? = null - val deviceIds = oneTimeKeys.getUserDeviceIds(userId) - if (null != deviceIds) { - for (deviceId in deviceIds) { - val olmSessionResult = results.getObject(userId, deviceId) - if (olmSessionResult?.sessionId != null && !force) { - // We already have a result for this device - continue - } - val key = oneTimeKeys.getObject(userId, deviceId) - if (key?.type == oneTimeKeyAlgorithm) { - oneTimeKey = key - } - if (oneTimeKey == null) { - Timber.tag(loggerTag.value).d("No one time key for $userId|$deviceId") - continue - } - // Update the result for this device in results - olmSessionResult?.sessionId = verifyKeyAndStartSession(oneTimeKey, userId, deviceInfo) - } + val oneTimeKeys = oneTimeKeysForUsersDeviceTask.executeRetry(claimParams, ONE_TIME_KEYS_RETRY_COUNT) + + // let now start olm session using the new otks + devicesToCreateSessionWith.forEach { deviceInfo -> + val userId = deviceInfo.userId + val deviceId = deviceInfo.deviceId + // Did we get an OTK + val oneTimeKey = oneTimeKeys.getObject(userId, deviceId) + if (oneTimeKey == null) { + Timber.tag(loggerTag.value).d("No otk for ${deviceInfo.shortDebugString()}") + } else if (oneTimeKey.type != MXKey.KEY_SIGNED_CURVE_25519_TYPE) { + Timber.tag(loggerTag.value).d("Bad otk type (${oneTimeKey.type}) for ${deviceInfo.shortDebugString()}") + } else { + val olmSessionId = verifyKeyAndStartSession(oneTimeKey, userId, deviceInfo) + if (olmSessionId != null) { + val olmSessionResult = MXOlmSessionResult(deviceInfo, olmSessionId) + results.setObject(userId, deviceId, olmSessionResult) + } else { + Timber + .tag(loggerTag.value) + .d("## CRYPTO | cant unwedge failed to create outbound ${deviceInfo.shortDebugString()}") } } } + return results } diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/model/CryptoDeviceInfo.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/model/CryptoDeviceInfo.kt index 5e7744853a..b3638dc414 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/model/CryptoDeviceInfo.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/model/CryptoDeviceInfo.kt @@ -70,6 +70,8 @@ data class CryptoDeviceInfo( keys?.let { map["keys"] = it } return map } + + fun shortDebugString() = "$userId|$deviceId" } internal fun CryptoDeviceInfo.toRest(): DeviceKeys { diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/sync/handler/CryptoSyncHandler.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/sync/handler/CryptoSyncHandler.kt index 28cfbc7342..9ae7b82777 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/sync/handler/CryptoSyncHandler.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/sync/handler/CryptoSyncHandler.kt @@ -93,7 +93,9 @@ internal class CryptoSyncHandler @Inject constructor(private val cryptoService: ) return true } else { - // should not happen + // Could happen for to device events + // None of the known session could decrypt the message + // In this case unwedging process might have been started (rate limited) Timber.e("## CRYPTO | ERROR NULL DECRYPTION RESULT from ${event.senderId}") } }