protect race on prekey + logs

This commit is contained in:
Valere 2022-02-28 23:25:20 +01:00
parent 9b3c5d2153
commit ade16a0aa1
1 changed files with 50 additions and 22 deletions

View File

@ -16,6 +16,8 @@
package org.matrix.android.sdk.internal.crypto.algorithms.olm package org.matrix.android.sdk.internal.crypto.algorithms.olm
import kotlinx.coroutines.sync.withLock
import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.api.session.crypto.MXCryptoError import org.matrix.android.sdk.api.session.crypto.MXCryptoError
import org.matrix.android.sdk.api.session.events.model.Event import org.matrix.android.sdk.api.session.events.model.Event
import org.matrix.android.sdk.api.session.events.model.toModel import org.matrix.android.sdk.api.session.events.model.toModel
@ -30,6 +32,8 @@ import org.matrix.android.sdk.internal.di.MoshiProvider
import org.matrix.android.sdk.internal.util.convertFromUTF8 import org.matrix.android.sdk.internal.util.convertFromUTF8
import timber.log.Timber import timber.log.Timber
private val loggerTag = LoggerTag("MXOlmDecryption", LoggerTag.CRYPTO)
internal class MXOlmDecryption( internal class MXOlmDecryption(
// The olm device interface // The olm device interface
private val olmDevice: MXOlmDevice, private val olmDevice: MXOlmDevice,
@ -40,25 +44,25 @@ internal class MXOlmDecryption(
@Throws(MXCryptoError::class) @Throws(MXCryptoError::class)
override suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult { override suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult {
val olmEventContent = event.content.toModel<OlmEventContent>() ?: run { val olmEventContent = event.content.toModel<OlmEventContent>() ?: run {
Timber.e("## decryptEvent() : bad event format") Timber.tag(loggerTag.value).e("## decryptEvent() : bad event format")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_EVENT_FORMAT, throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_EVENT_FORMAT,
MXCryptoError.BAD_EVENT_FORMAT_TEXT_REASON) MXCryptoError.BAD_EVENT_FORMAT_TEXT_REASON)
} }
val cipherText = olmEventContent.ciphertext ?: run { val cipherText = olmEventContent.ciphertext ?: run {
Timber.e("## decryptEvent() : missing cipher text") Timber.tag(loggerTag.value).e("## decryptEvent() : missing cipher text")
throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_CIPHER_TEXT, throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_CIPHER_TEXT,
MXCryptoError.MISSING_CIPHER_TEXT_REASON) MXCryptoError.MISSING_CIPHER_TEXT_REASON)
} }
val senderKey = olmEventContent.senderKey ?: run { val senderKey = olmEventContent.senderKey ?: run {
Timber.e("## decryptEvent() : missing sender key") Timber.tag(loggerTag.value).e("## decryptEvent() : missing sender key")
throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_SENDER_KEY, throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_SENDER_KEY,
MXCryptoError.MISSING_SENDER_KEY_TEXT_REASON) MXCryptoError.MISSING_SENDER_KEY_TEXT_REASON)
} }
val messageAny = cipherText[olmDevice.deviceCurve25519Key] ?: run { val messageAny = cipherText[olmDevice.deviceCurve25519Key] ?: run {
Timber.e("## decryptEvent() : our device ${olmDevice.deviceCurve25519Key} is not included in recipients") Timber.tag(loggerTag.value).e("## decryptEvent() : our device ${olmDevice.deviceCurve25519Key} is not included in recipients")
throw MXCryptoError.Base(MXCryptoError.ErrorType.NOT_INCLUDE_IN_RECIPIENTS, MXCryptoError.NOT_INCLUDED_IN_RECIPIENT_REASON) throw MXCryptoError.Base(MXCryptoError.ErrorType.NOT_INCLUDE_IN_RECIPIENTS, MXCryptoError.NOT_INCLUDED_IN_RECIPIENT_REASON)
} }
@ -69,7 +73,7 @@ internal class MXOlmDecryption(
val decryptedPayload = decryptMessage(message, senderKey) val decryptedPayload = decryptMessage(message, senderKey)
if (decryptedPayload == null) { if (decryptedPayload == null) {
Timber.e("## decryptEvent() Failed to decrypt Olm event (id= ${event.eventId} from $senderKey") Timber.tag(loggerTag.value).e("## decryptEvent() Failed to decrypt Olm event (id= ${event.eventId} from $senderKey")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_ENCRYPTED_MESSAGE, MXCryptoError.BAD_ENCRYPTED_MESSAGE_REASON) throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_ENCRYPTED_MESSAGE, MXCryptoError.BAD_ENCRYPTED_MESSAGE_REASON)
} }
val payloadString = convertFromUTF8(decryptedPayload) val payloadString = convertFromUTF8(decryptedPayload)
@ -78,30 +82,30 @@ internal class MXOlmDecryption(
val payload = adapter.fromJson(payloadString) val payload = adapter.fromJson(payloadString)
if (payload == null) { if (payload == null) {
Timber.e("## decryptEvent failed : null payload") Timber.tag(loggerTag.value).e("## decryptEvent failed : null payload")
throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, MXCryptoError.MISSING_CIPHER_TEXT_REASON) throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, MXCryptoError.MISSING_CIPHER_TEXT_REASON)
} }
val olmPayloadContent = OlmPayloadContent.fromJsonString(payloadString) ?: run { val olmPayloadContent = OlmPayloadContent.fromJsonString(payloadString) ?: run {
Timber.e("## decryptEvent() : bad olmPayloadContent format") Timber.tag(loggerTag.value).e("## decryptEvent() : bad olmPayloadContent format")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_DECRYPTED_FORMAT, MXCryptoError.BAD_DECRYPTED_FORMAT_TEXT_REASON) throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_DECRYPTED_FORMAT, MXCryptoError.BAD_DECRYPTED_FORMAT_TEXT_REASON)
} }
if (olmPayloadContent.recipient.isNullOrBlank()) { if (olmPayloadContent.recipient.isNullOrBlank()) {
val reason = String.format(MXCryptoError.ERROR_MISSING_PROPERTY_REASON, "recipient") val reason = String.format(MXCryptoError.ERROR_MISSING_PROPERTY_REASON, "recipient")
Timber.e("## decryptEvent() : $reason") Timber.tag(loggerTag.value).e("## decryptEvent() : $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_PROPERTY, reason) throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_PROPERTY, reason)
} }
if (olmPayloadContent.recipient != userId) { if (olmPayloadContent.recipient != userId) {
Timber.e("## decryptEvent() : Event ${event.eventId}:" + Timber.tag(loggerTag.value).e("## decryptEvent() : Event ${event.eventId}:" +
" Intended recipient ${olmPayloadContent.recipient} does not match our id $userId") " Intended recipient ${olmPayloadContent.recipient} does not match our id $userId")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_RECIPIENT, throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_RECIPIENT,
String.format(MXCryptoError.BAD_RECIPIENT_REASON, olmPayloadContent.recipient)) String.format(MXCryptoError.BAD_RECIPIENT_REASON, olmPayloadContent.recipient))
} }
val recipientKeys = olmPayloadContent.recipientKeys ?: run { val recipientKeys = olmPayloadContent.recipientKeys ?: run {
Timber.e("## decryptEvent() : Olm event (id=${event.eventId}) contains no 'recipient_keys'" + Timber.tag(loggerTag.value).e("## decryptEvent() : Olm event (id=${event.eventId}) contains no 'recipient_keys'" +
" property; cannot prevent unknown-key attack") " property; cannot prevent unknown-key attack")
throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_PROPERTY, throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_PROPERTY,
String.format(MXCryptoError.ERROR_MISSING_PROPERTY_REASON, "recipient_keys")) String.format(MXCryptoError.ERROR_MISSING_PROPERTY_REASON, "recipient_keys"))
@ -110,31 +114,31 @@ internal class MXOlmDecryption(
val ed25519 = recipientKeys["ed25519"] val ed25519 = recipientKeys["ed25519"]
if (ed25519 != olmDevice.deviceEd25519Key) { if (ed25519 != olmDevice.deviceEd25519Key) {
Timber.e("## decryptEvent() : Event ${event.eventId}: Intended recipient ed25519 key $ed25519 did not match ours") Timber.tag(loggerTag.value).e("## decryptEvent() : Event ${event.eventId}: Intended recipient ed25519 key $ed25519 did not match ours")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_RECIPIENT_KEY, throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_RECIPIENT_KEY,
MXCryptoError.BAD_RECIPIENT_KEY_REASON) MXCryptoError.BAD_RECIPIENT_KEY_REASON)
} }
if (olmPayloadContent.sender.isNullOrBlank()) { if (olmPayloadContent.sender.isNullOrBlank()) {
Timber.e("## decryptEvent() : Olm event (id=${event.eventId}) contains no 'sender' property; cannot prevent unknown-key attack") Timber.tag(loggerTag.value).e("## decryptEvent() : Olm event (id=${event.eventId}) contains no 'sender' property; cannot prevent unknown-key attack")
throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_PROPERTY, throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_PROPERTY,
String.format(MXCryptoError.ERROR_MISSING_PROPERTY_REASON, "sender")) String.format(MXCryptoError.ERROR_MISSING_PROPERTY_REASON, "sender"))
} }
if (olmPayloadContent.sender != event.senderId) { if (olmPayloadContent.sender != event.senderId) {
Timber.e("Event ${event.eventId}: original sender ${olmPayloadContent.sender} does not match reported sender ${event.senderId}") Timber.tag(loggerTag.value).e("Event ${event.eventId}: original sender ${olmPayloadContent.sender} does not match reported sender ${event.senderId}")
throw MXCryptoError.Base(MXCryptoError.ErrorType.FORWARDED_MESSAGE, throw MXCryptoError.Base(MXCryptoError.ErrorType.FORWARDED_MESSAGE,
String.format(MXCryptoError.FORWARDED_MESSAGE_REASON, olmPayloadContent.sender)) String.format(MXCryptoError.FORWARDED_MESSAGE_REASON, olmPayloadContent.sender))
} }
if (olmPayloadContent.roomId != event.roomId) { if (olmPayloadContent.roomId != event.roomId) {
Timber.e("## decryptEvent() : Event ${event.eventId}: original room ${olmPayloadContent.roomId} does not match reported room ${event.roomId}") Timber.tag(loggerTag.value).e("## decryptEvent() : Event ${event.eventId}: original room ${olmPayloadContent.roomId} does not match reported room ${event.roomId}")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_ROOM, throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_ROOM,
String.format(MXCryptoError.BAD_ROOM_REASON, olmPayloadContent.roomId)) String.format(MXCryptoError.BAD_ROOM_REASON, olmPayloadContent.roomId))
} }
val keys = olmPayloadContent.keys ?: run { val keys = olmPayloadContent.keys ?: run {
Timber.e("## decryptEvent failed : null keys") Timber.tag(loggerTag.value).e("## decryptEvent failed : null keys")
throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT,
MXCryptoError.MISSING_CIPHER_TEXT_REASON) MXCryptoError.MISSING_CIPHER_TEXT_REASON)
} }
@ -166,11 +170,32 @@ internal class MXOlmDecryption(
// Try each session in turn // Try each session in turn
// decryptionErrors = {}; // decryptionErrors = {};
val isPreKey = messageType == 0
// we want to synchronize on prekey if not we could end up create two olm sessions
// Not very clear but it looks like the js-sdk for consistency
return if (isPreKey) {
olmDevice.mutex.withLock {
reallyDecryptMessage(sessionIds, messageBody, messageType, theirDeviceIdentityKey)
}
} else {
reallyDecryptMessage(sessionIds, messageBody, messageType, theirDeviceIdentityKey)
}
}
private suspend fun reallyDecryptMessage(sessionIds: List<String>, messageBody: String, messageType: Int, theirDeviceIdentityKey: String): String? {
Timber.tag(loggerTag.value).d("decryptMessage() try to decrypt olm message type:$messageType from ${sessionIds.size} known sessions")
for (sessionId in sessionIds) { for (sessionId in sessionIds) {
val payload = olmDevice.decryptMessage(messageBody, messageType, sessionId, theirDeviceIdentityKey) val payload = try {
olmDevice.decryptMessage(messageBody, messageType, sessionId, theirDeviceIdentityKey)
} catch (throwable: Exception) {
// As we are trying one by one, we don't really care of the error here
Timber.tag(loggerTag.value).d("decryptMessage() failed with session $sessionId")
null
}
if (null != payload) { if (null != payload) {
Timber.v("## decryptMessage() : Decrypted Olm message from $theirDeviceIdentityKey with session $sessionId") Timber.tag(loggerTag.value).v("## decryptMessage() : Decrypted Olm message from $theirDeviceIdentityKey with session $sessionId")
return payload return payload
} else { } else {
val foundSession = olmDevice.matchesSession(theirDeviceIdentityKey, sessionId, messageType, messageBody) val foundSession = olmDevice.matchesSession(theirDeviceIdentityKey, sessionId, messageType, messageBody)
@ -178,7 +203,7 @@ internal class MXOlmDecryption(
if (foundSession) { if (foundSession) {
// Decryption failed, but it was a prekey message matching this // Decryption failed, but it was a prekey message matching this
// session, so it should have worked. // session, so it should have worked.
Timber.e("## decryptMessage() : Error decrypting prekey message with existing session id $sessionId:TODO") Timber.tag(loggerTag.value).e("## decryptMessage() : Error decrypting prekey message with existing session id $sessionId:TODO")
return null return null
} }
} }
@ -189,9 +214,9 @@ internal class MXOlmDecryption(
// didn't work. // didn't work.
if (sessionIds.isEmpty()) { if (sessionIds.isEmpty()) {
Timber.e("## decryptMessage() : No existing sessions") Timber.tag(loggerTag.value).e("## decryptMessage() : No existing sessions")
} else { } else {
Timber.e("## decryptMessage() : Error decrypting non-prekey message with existing sessions") Timber.tag(loggerTag.value).e("## decryptMessage() : Error decrypting non-prekey message with existing sessions")
} }
return null return null
@ -199,14 +224,17 @@ internal class MXOlmDecryption(
// prekey message which doesn't match any existing sessions: make a new // prekey message which doesn't match any existing sessions: make a new
// session. // session.
// XXXX Possible races here? if concurrent access for same prekey message, we might create 2 sessions?
Timber.tag(loggerTag.value).d("## decryptMessage() : Create inbound group session from prekey sender:$theirDeviceIdentityKey")
val res = olmDevice.createInboundSession(theirDeviceIdentityKey, messageType, messageBody) val res = olmDevice.createInboundSession(theirDeviceIdentityKey, messageType, messageBody)
if (null == res) { if (null == res) {
Timber.e("## decryptMessage() : Error decrypting non-prekey message with existing sessions") Timber.tag(loggerTag.value).e("## decryptMessage() : Error decrypting non-prekey message with existing sessions")
return null return null
} }
Timber.v("## decryptMessage() : Created new inbound Olm session get id ${res["session_id"]} with $theirDeviceIdentityKey") Timber.tag(loggerTag.value).v("## decryptMessage() : Created new inbound Olm session get id ${res["session_id"]} with $theirDeviceIdentityKey")
return res["payload"] return res["payload"]
} }