Clean ensure olm, fix unwedging, better logs

This commit is contained in:
Valere 2022-03-01 16:18:43 +01:00
parent 2d9beb67b4
commit f238739438
5 changed files with 165 additions and 121 deletions

View File

@ -434,6 +434,14 @@ internal class DefaultCryptoService @Inject constructor(
val currentCount = syncResponse.deviceOneTimeKeysCount.signedCurve25519 ?: 0
oneTimeKeysUploader.updateOneTimeKeyCount(currentCount)
}
// unwedge if needed
try {
eventDecryptor.unwedgeDevicesIfNeeded()
} catch (failure: Throwable) {
Timber.tag(loggerTag.value).w("unwedgeDevicesIfNeeded failed")
}
// There is a limit of to_device events returned per sync.
// If we are in a case of such limited to_device sync we can't try to generate/upload
// new otk now, because there might be some pending olm pre-key to_device messages that would fail if we rotate

View File

@ -18,17 +18,15 @@ package org.matrix.android.sdk.internal.crypto
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import org.matrix.android.sdk.api.MatrixCallback
import org.matrix.android.sdk.api.MatrixCoroutineDispatchers
import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.api.session.crypto.MXCryptoError
import org.matrix.android.sdk.api.session.events.model.Event
import org.matrix.android.sdk.api.session.events.model.EventType
import org.matrix.android.sdk.api.session.events.model.toModel
import org.matrix.android.sdk.internal.crypto.actions.EnsureOlmSessionsForDevicesAction
import org.matrix.android.sdk.internal.crypto.actions.MessageEncrypter
import org.matrix.android.sdk.internal.crypto.model.CryptoDeviceInfo
import org.matrix.android.sdk.internal.crypto.model.MXOlmSessionResult
import org.matrix.android.sdk.internal.crypto.model.MXUsersDevicesMap
import org.matrix.android.sdk.internal.crypto.model.event.OlmEventContent
import org.matrix.android.sdk.internal.crypto.store.IMXCryptoStore
@ -40,6 +38,8 @@ import javax.inject.Inject
private const val SEND_TO_DEVICE_RETRY_COUNT = 3
private val loggerTag = LoggerTag("CryptoSyncHandler", LoggerTag.CRYPTO)
@SessionScope
internal class EventDecryptor @Inject constructor(
private val cryptoCoroutineScope: CoroutineScope,
@ -47,13 +47,22 @@ internal class EventDecryptor @Inject constructor(
private val roomDecryptorProvider: RoomDecryptorProvider,
private val messageEncrypter: MessageEncrypter,
private val sendToDeviceTask: SendToDeviceTask,
private val deviceListManager: DeviceListManager,
private val ensureOlmSessionsForDevicesAction: EnsureOlmSessionsForDevicesAction,
private val cryptoStore: IMXCryptoStore
) {
// The date of the last time we forced establishment
// of a new session for each user:device.
private val lastNewSessionForcedDates = MXUsersDevicesMap<Long>()
/**
* Rate limit unwedge attempt, should we persist that?
*/
private val lastNewSessionForcedDates = mutableMapOf<WedgedDeviceInfo, Long>()
data class WedgedDeviceInfo(
val userId: String,
val senderKey: String?
)
private val wedgedDevices = mutableListOf<WedgedDeviceInfo>()
/**
* Decrypt an event
@ -94,35 +103,29 @@ internal class EventDecryptor @Inject constructor(
private suspend fun internalDecryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
val eventContent = event.content
if (eventContent == null) {
Timber.e("## CRYPTO | decryptEvent : empty event content")
Timber.e("decryptEvent : empty event content")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_ENCRYPTED_MESSAGE, MXCryptoError.BAD_ENCRYPTED_MESSAGE_REASON)
} else {
val algorithm = eventContent["algorithm"]?.toString()
val alg = roomDecryptorProvider.getOrCreateRoomDecryptor(event.roomId, algorithm)
if (alg == null) {
val reason = String.format(MXCryptoError.UNABLE_TO_DECRYPT_REASON, event.eventId, algorithm)
Timber.e("## CRYPTO | decryptEvent() : $reason")
Timber.e("decryptEvent() : $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, reason)
} else {
try {
return alg.decryptEvent(event, timeline)
} catch (mxCryptoError: MXCryptoError) {
Timber.v("## CRYPTO | internalDecryptEvent : Failed to decrypt ${event.eventId} reason: $mxCryptoError")
Timber.d("internalDecryptEvent : Failed to decrypt ${event.eventId} reason: $mxCryptoError")
if (algorithm == MXCRYPTO_ALGORITHM_OLM) {
if (mxCryptoError is MXCryptoError.Base &&
mxCryptoError.errorType == MXCryptoError.ErrorType.BAD_ENCRYPTED_MESSAGE) {
// need to find sending device
cryptoCoroutineScope.launch(coroutineDispatchers.crypto) {
val olmContent = event.content.toModel<OlmEventContent>()
cryptoStore.getUserDevices(event.senderId ?: "")
?.values
?.firstOrNull { it.identityKey() == olmContent?.senderKey }
?.let {
markOlmSessionForUnwedging(event.senderId ?: "", it)
}
?: run {
Timber.i("## CRYPTO | internalDecryptEvent() : Failed to find sender crypto device for unwedging")
}
val olmContent = event.content.toModel<OlmEventContent>()
if (event.senderId != null && olmContent?.senderKey != null) {
markOlmSessionForUnwedging(event.senderId, olmContent.senderKey)
} else {
Timber.tag(loggerTag.value).d("Can't mark as wedge malformed")
}
}
}
@ -132,53 +135,87 @@ internal class EventDecryptor @Inject constructor(
}
}
// coroutineDispatchers.crypto scope
private fun markOlmSessionForUnwedging(senderId: String, deviceInfo: CryptoDeviceInfo) {
val deviceKey = deviceInfo.identityKey()
private fun markOlmSessionForUnwedging(senderId: String, senderKey: String) {
val info = WedgedDeviceInfo(senderId, senderKey)
if (!wedgedDevices.contains(info)) {
Timber.tag(loggerTag.value).d("Marking device from $senderId key:$senderKey as wedged")
wedgedDevices.add(info)
}
}
val lastForcedDate = lastNewSessionForcedDates.getObject(senderId, deviceKey) ?: 0
// coroutineDispatchers.crypto scope
suspend fun unwedgeDevicesIfNeeded() {
// handle wedged devices
// Some olm decryption have failed and some device are wedged
// we should force start a new session for those
Timber.tag(loggerTag.value).d("Unwedging: ${wedgedDevices.size} are wedged")
// get the one that should be retried according to rate limit
val now = System.currentTimeMillis()
if (now - lastForcedDate < DefaultCryptoService.CRYPTO_MIN_FORCE_SESSION_PERIOD_MILLIS) {
Timber.w("## CRYPTO | markOlmSessionForUnwedging: New session already forced with device at $lastForcedDate. Not forcing another")
val toUnwedge = wedgedDevices.filter {
val lastForcedDate = lastNewSessionForcedDates[it] ?: 0
if (now - lastForcedDate < DefaultCryptoService.CRYPTO_MIN_FORCE_SESSION_PERIOD_MILLIS) {
Timber.tag(loggerTag.value).d("Unwedging, New session for $it already forced with device at $lastForcedDate")
return@filter false
}
// let's already mark that we tried now
lastNewSessionForcedDates[it] = now
true
}
if (toUnwedge.isEmpty()) {
Timber.tag(loggerTag.value).d("Nothing to unwedge")
return
}
Timber.tag(loggerTag.value).d("Unwedging, trying to create new session for ${toUnwedge.size} devices")
Timber.i("## CRYPTO | markOlmSessionForUnwedging from $senderId:${deviceInfo.deviceId}")
lastNewSessionForcedDates.setObject(senderId, deviceKey, now)
toUnwedge
.chunked(100) // safer to chunk if we ever have lots of wedged devices
.forEach { wedgedList ->
// lets download keys if needed
deviceListManager.downloadKeys(wedgedList.map { it.userId }, false)
// offload this from crypto thread (?)
cryptoCoroutineScope.launch(coroutineDispatchers.computation) {
runCatching { ensureOlmSessionsForDevicesAction.handle(mapOf(senderId to listOf(deviceInfo)), force = true) }.fold(
onSuccess = { sendDummyToDevice(ensured = it, deviceInfo, senderId) },
onFailure = {
Timber.e("## CRYPTO | markOlmSessionForUnwedging() : failed to ensure device info ${senderId}${deviceInfo.deviceId}")
}
)
}
}
// find the matching devices
wedgedList.groupBy { it.userId }
.map { groupedByUser ->
val userId = groupedByUser.key
val wedgeSenderKeysForUser = groupedByUser.value.map { it.senderKey }
val knownDevices = cryptoStore.getUserDevices(userId)?.values.orEmpty()
userId to wedgeSenderKeysForUser.mapNotNull { senderKey ->
knownDevices.firstOrNull { it.identityKey() == senderKey }
}
}
.toMap()
.let { deviceList ->
try {
// force creating new outbound session and mark them as most recent to
// be used for next encryption (dummy)
val sessionToUse = ensureOlmSessionsForDevicesAction.handle(deviceList, true)
Timber.tag(loggerTag.value).d("Unwedging, found ${sessionToUse.map.size} to send dummy to")
private suspend fun sendDummyToDevice(ensured: MXUsersDevicesMap<MXOlmSessionResult>, deviceInfo: CryptoDeviceInfo, senderId: String) {
Timber.i("## CRYPTO | markOlmSessionForUnwedging() : ensureOlmSessionsForDevicesAction isEmpty:${ensured.isEmpty}")
// Now send a blank message on that session so the other side knows about it.
// (The keyshare request is sent in the clear so that won't do)
val payloadJson = mapOf(
"type" to EventType.DUMMY
)
val sendToDeviceMap = MXUsersDevicesMap<Any>()
sessionToUse.map.values
.flatMap { it.values }
.map { it.deviceInfo }
.forEach { deviceInfo ->
Timber.tag(loggerTag.value).v("encrypting dummy to ${deviceInfo.deviceId}")
val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo))
sendToDeviceMap.setObject(deviceInfo.userId, deviceInfo.deviceId, encodedPayload)
}
// Now send a blank message on that session so the other side knows about it.
// (The keyshare request is sent in the clear so that won't do)
// We send this first such that, as long as the toDevice messages arrive in the
// same order we sent them, the other end will get this first, set up the new session,
// then get the keyshare request and send the key over this new session (because it
// is the session it has most recently received a message on).
val payloadJson = mapOf<String, Any>("type" to EventType.DUMMY)
val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo))
val sendToDeviceMap = MXUsersDevicesMap<Any>()
sendToDeviceMap.setObject(senderId, deviceInfo.deviceId, encodedPayload)
Timber.i("## CRYPTO | markOlmSessionForUnwedging() : sending dummy to $senderId:${deviceInfo.deviceId}")
withContext(coroutineDispatchers.io) {
val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, sendToDeviceMap)
try {
sendToDeviceTask.executeRetry(sendToDeviceParams, remainingRetry = SEND_TO_DEVICE_RETRY_COUNT)
} catch (failure: Throwable) {
Timber.e(failure, "## CRYPTO | markOlmSessionForUnwedging() : failed to send dummy to $senderId:${deviceInfo.deviceId}")
}
}
// now let's send that
val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, sendToDeviceMap)
sendToDeviceTask.executeRetry(sendToDeviceParams, remainingRetry = SEND_TO_DEVICE_RETRY_COUNT)
} catch (failure: Throwable) {
deviceList.flatMap { it.value }.joinToString { it.shortDebugString() }.let {
Timber.tag(loggerTag.value).e(failure, "## Failed to unwedge devices: $it}")
}
}
}
}
}
}

View File

@ -22,8 +22,8 @@ import org.matrix.android.sdk.internal.crypto.model.CryptoDeviceInfo
import org.matrix.android.sdk.internal.crypto.model.MXKey
import org.matrix.android.sdk.internal.crypto.model.MXOlmSessionResult
import org.matrix.android.sdk.internal.crypto.model.MXUsersDevicesMap
import org.matrix.android.sdk.internal.crypto.model.toDebugString
import org.matrix.android.sdk.internal.crypto.tasks.ClaimOneTimeKeysForUsersDeviceTask
import org.matrix.android.sdk.internal.session.SessionScope
import timber.log.Timber
import javax.inject.Inject
@ -31,89 +31,84 @@ private const val ONE_TIME_KEYS_RETRY_COUNT = 3
private val loggerTag = LoggerTag("EnsureOlmSessionsForDevicesAction", LoggerTag.CRYPTO)
@SessionScope
internal class EnsureOlmSessionsForDevicesAction @Inject constructor(
private val olmDevice: MXOlmDevice,
private val oneTimeKeysForUsersDeviceTask: ClaimOneTimeKeysForUsersDeviceTask) {
/**
* We want to synchronize a bit here, because we are iterating to check existing olm session and
* also adding some
*/
@Synchronized
suspend fun handle(devicesByUser: Map<String, List<CryptoDeviceInfo>>, force: Boolean = false): MXUsersDevicesMap<MXOlmSessionResult> {
val devicesWithoutSession = ArrayList<CryptoDeviceInfo>()
val deviceList = devicesByUser.flatMap { it.value }
Timber.tag(loggerTag.value)
.d("ensure olm forced:$force for ${deviceList.joinToString { it.shortDebugString() }}")
val results = MXUsersDevicesMap<MXOlmSessionResult>()
for ((userId, deviceList) in devicesByUser) {
for (deviceInfo in deviceList) {
val devicesToCreateSessionWith = mutableListOf<CryptoDeviceInfo>()
if (force) {
// we take all devices and will query otk for them
devicesToCreateSessionWith.addAll(deviceList)
} else {
// only peek devices without active session
deviceList.forEach { deviceInfo ->
val deviceId = deviceInfo.deviceId
val key = deviceInfo.identityKey()
if (key == null) {
Timber.w("## CRYPTO | Ignoring device (${deviceInfo.userId}|$deviceId) without identity key")
continue
val userId = deviceInfo.userId
val key = deviceInfo.identityKey() ?: return@forEach Unit.also {
Timber.tag(loggerTag.value).w("Ignoring device ${deviceInfo.shortDebugString()} without identity key")
}
// is there a session that as been already used?
val sessionId = olmDevice.getSessionId(key)
if (sessionId.isNullOrEmpty() || force) {
Timber.tag(loggerTag.value).d("Found no existing olm session (${deviceInfo.userId}|$deviceId) (force=$force)")
devicesWithoutSession.add(deviceInfo)
if (sessionId.isNullOrEmpty()) {
Timber.tag(loggerTag.value).d("Found no existing olm session ${deviceInfo.shortDebugString()} add to claim list")
devicesToCreateSessionWith.add(deviceInfo)
} else {
Timber.tag(loggerTag.value).d("using olm session $sessionId for (${deviceInfo.userId}|$deviceId)")
val olmSessionResult = MXOlmSessionResult(deviceInfo, sessionId)
results.setObject(userId, deviceId, olmSessionResult)
}
val olmSessionResult = MXOlmSessionResult(deviceInfo, sessionId)
results.setObject(userId, deviceId, olmSessionResult)
}
}
Timber.tag(loggerTag.value).d("Devices without olm session (count:${devicesWithoutSession.size}) :" +
" ${devicesWithoutSession.joinToString { "${it.userId}|${it.deviceId}" }}")
if (devicesWithoutSession.size == 0) {
if (devicesToCreateSessionWith.isEmpty()) {
// no session to create
return results
}
// Prepare the request for claiming one-time keys
val usersDevicesToClaim = MXUsersDevicesMap<String>()
val oneTimeKeyAlgorithm = MXKey.KEY_SIGNED_CURVE_25519_TYPE
for (device in devicesWithoutSession) {
usersDevicesToClaim.setObject(device.userId, device.deviceId, oneTimeKeyAlgorithm)
val usersDevicesToClaim = MXUsersDevicesMap<String>().apply {
devicesToCreateSessionWith.forEach {
setObject(it.userId, it.deviceId, MXKey.KEY_SIGNED_CURVE_25519_TYPE)
}
}
// TODO: this has a race condition - if we try to send another message
// while we are claiming a key, we will end up claiming two and setting up
// two sessions.
//
// That should eventually resolve itself, but it's poor form.
Timber.tag(loggerTag.value).i("claimOneTimeKeysForUsersDevices() : ${usersDevicesToClaim.toDebugString()}")
// Let's now claim one time keys
val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(usersDevicesToClaim)
val oneTimeKeys = oneTimeKeysForUsersDeviceTask.executeRetry(claimParams, remainingRetry = ONE_TIME_KEYS_RETRY_COUNT)
Timber.tag(loggerTag.value).v("claimOneTimeKeysForUsersDevices() : keysClaimResponse.oneTimeKeys: $oneTimeKeys")
for ((userId, deviceInfos) in devicesByUser) {
for (deviceInfo in deviceInfos) {
var oneTimeKey: MXKey? = null
val deviceIds = oneTimeKeys.getUserDeviceIds(userId)
if (null != deviceIds) {
for (deviceId in deviceIds) {
val olmSessionResult = results.getObject(userId, deviceId)
if (olmSessionResult?.sessionId != null && !force) {
// We already have a result for this device
continue
}
val key = oneTimeKeys.getObject(userId, deviceId)
if (key?.type == oneTimeKeyAlgorithm) {
oneTimeKey = key
}
if (oneTimeKey == null) {
Timber.tag(loggerTag.value).d("No one time key for $userId|$deviceId")
continue
}
// Update the result for this device in results
olmSessionResult?.sessionId = verifyKeyAndStartSession(oneTimeKey, userId, deviceInfo)
}
val oneTimeKeys = oneTimeKeysForUsersDeviceTask.executeRetry(claimParams, ONE_TIME_KEYS_RETRY_COUNT)
// let now start olm session using the new otks
devicesToCreateSessionWith.forEach { deviceInfo ->
val userId = deviceInfo.userId
val deviceId = deviceInfo.deviceId
// Did we get an OTK
val oneTimeKey = oneTimeKeys.getObject(userId, deviceId)
if (oneTimeKey == null) {
Timber.tag(loggerTag.value).d("No otk for ${deviceInfo.shortDebugString()}")
} else if (oneTimeKey.type != MXKey.KEY_SIGNED_CURVE_25519_TYPE) {
Timber.tag(loggerTag.value).d("Bad otk type (${oneTimeKey.type}) for ${deviceInfo.shortDebugString()}")
} else {
val olmSessionId = verifyKeyAndStartSession(oneTimeKey, userId, deviceInfo)
if (olmSessionId != null) {
val olmSessionResult = MXOlmSessionResult(deviceInfo, olmSessionId)
results.setObject(userId, deviceId, olmSessionResult)
} else {
Timber
.tag(loggerTag.value)
.d("## CRYPTO | cant unwedge failed to create outbound ${deviceInfo.shortDebugString()}")
}
}
}
return results
}

View File

@ -70,6 +70,8 @@ data class CryptoDeviceInfo(
keys?.let { map["keys"] = it }
return map
}
fun shortDebugString() = "$userId|$deviceId"
}
internal fun CryptoDeviceInfo.toRest(): DeviceKeys {

View File

@ -93,7 +93,9 @@ internal class CryptoSyncHandler @Inject constructor(private val cryptoService:
)
return true
} else {
// should not happen
// Could happen for to device events
// None of the known session could decrypt the message
// In this case unwedging process might have been started (rate limited)
Timber.e("## CRYPTO | ERROR NULL DECRYPTION RESULT from ${event.senderId}")
}
}