Improve inbound group session cache + mutex

This commit is contained in:
Valere 2022-02-28 22:00:39 +01:00
parent c97de48474
commit 9b3c5d2153
7 changed files with 138 additions and 138 deletions

View File

@ -19,8 +19,10 @@ package org.matrix.android.sdk.internal.crypto
import android.util.LruCache import android.util.LruCache
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Mutex
import org.matrix.android.sdk.api.MatrixCoroutineDispatchers import org.matrix.android.sdk.api.MatrixCoroutineDispatchers
import org.matrix.android.sdk.api.extensions.tryOrNull import org.matrix.android.sdk.api.extensions.tryOrNull
import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.internal.crypto.model.OlmInboundGroupSessionWrapper2 import org.matrix.android.sdk.internal.crypto.model.OlmInboundGroupSessionWrapper2
import org.matrix.android.sdk.internal.crypto.store.IMXCryptoStore import org.matrix.android.sdk.internal.crypto.store.IMXCryptoStore
import timber.log.Timber import timber.log.Timber
@ -28,6 +30,14 @@ import java.util.Timer
import java.util.TimerTask import java.util.TimerTask
import javax.inject.Inject import javax.inject.Inject
data class InboundGroupSessionHolder(
val wrapper: OlmInboundGroupSessionWrapper2,
val mutex: Mutex = Mutex()
)
private val loggerTag = LoggerTag("InboundGroupSessionStore", LoggerTag.CRYPTO)
/** /**
* Allows to cache and batch store operations on inbound group session store. * Allows to cache and batch store operations on inbound group session store.
* Because it is used in the decrypt flow, that can be called quite rapidly * Because it is used in the decrypt flow, that can be called quite rapidly
@ -42,12 +52,13 @@ internal class InboundGroupSessionStore @Inject constructor(
val senderKey: String val senderKey: String
) )
private val sessionCache = object : LruCache<CacheKey, OlmInboundGroupSessionWrapper2>(30) { private val sessionCache = object : LruCache<CacheKey, InboundGroupSessionHolder>(100) {
override fun entryRemoved(evicted: Boolean, key: CacheKey?, oldValue: OlmInboundGroupSessionWrapper2?, newValue: OlmInboundGroupSessionWrapper2?) { override fun entryRemoved(evicted: Boolean, key: CacheKey?, oldValue: InboundGroupSessionHolder?, newValue: InboundGroupSessionHolder?) {
if (evicted && oldValue != null) { if (oldValue != null) {
cryptoCoroutineScope.launch(coroutineDispatchers.crypto) { cryptoCoroutineScope.launch(coroutineDispatchers.crypto) {
Timber.v("## Inbound: entryRemoved ${oldValue.roomId}-${oldValue.senderKey}") Timber.tag(loggerTag.value).v("## Inbound: entryRemoved ${oldValue.wrapper.roomId}-${oldValue.wrapper.senderKey}")
store.storeInboundGroupSessions(listOf(oldValue)) store.storeInboundGroupSessions(listOf(oldValue).map { it.wrapper })
oldValue.wrapper.olmInboundGroupSession?.releaseSession()
} }
} }
} }
@ -59,41 +70,50 @@ internal class InboundGroupSessionStore @Inject constructor(
private val dirtySession = mutableListOf<OlmInboundGroupSessionWrapper2>() private val dirtySession = mutableListOf<OlmInboundGroupSessionWrapper2>()
@Synchronized @Synchronized
fun getInboundGroupSession(sessionId: String, senderKey: String): OlmInboundGroupSessionWrapper2? { fun clear() {
synchronized(sessionCache) { sessionCache.evictAll()
val known = sessionCache[CacheKey(sessionId, senderKey)]
Timber.v("## Inbound: getInboundGroupSession in cache ${known != null}")
return known ?: store.getInboundGroupSession(sessionId, senderKey)?.also {
Timber.v("## Inbound: getInboundGroupSession cache populate ${it.roomId}")
sessionCache.put(CacheKey(sessionId, senderKey), it)
}
}
} }
@Synchronized @Synchronized
fun replaceGroupSession(old: OlmInboundGroupSessionWrapper2, new: OlmInboundGroupSessionWrapper2, sessionId: String, senderKey: String) { fun getInboundGroupSession(sessionId: String, senderKey: String): InboundGroupSessionHolder? {
Timber.v("## Replacing outdated session ${old.roomId}-${old.senderKey}") val known = sessionCache[CacheKey(sessionId, senderKey)]
dirtySession.remove(old) Timber.tag(loggerTag.value).v("## Inbound: getInboundGroupSession $sessionId in cache ${known != null}")
return known
?: store.getInboundGroupSession(sessionId, senderKey)?.also {
Timber.tag(loggerTag.value).v("## Inbound: getInboundGroupSession cache populate ${it.roomId}")
sessionCache.put(CacheKey(sessionId, senderKey), InboundGroupSessionHolder(it))
}?.let {
InboundGroupSessionHolder(it)
}
}
@Synchronized
fun replaceGroupSession(old: InboundGroupSessionHolder, new: InboundGroupSessionHolder, sessionId: String, senderKey: String) {
Timber.tag(loggerTag.value).v("## Replacing outdated session ${old.wrapper.roomId}-${old.wrapper.senderKey}")
dirtySession.remove(old.wrapper)
store.removeInboundGroupSession(sessionId, senderKey) store.removeInboundGroupSession(sessionId, senderKey)
sessionCache.remove(CacheKey(sessionId, senderKey)) sessionCache.remove(CacheKey(sessionId, senderKey))
// release removed session
old.wrapper.olmInboundGroupSession?.releaseSession()
internalStoreGroupSession(new, sessionId, senderKey) internalStoreGroupSession(new, sessionId, senderKey)
} }
@Synchronized @Synchronized
fun storeInBoundGroupSession(wrapper: OlmInboundGroupSessionWrapper2, sessionId: String, senderKey: String) { fun storeInBoundGroupSession(holder: InboundGroupSessionHolder, sessionId: String, senderKey: String) {
internalStoreGroupSession(wrapper, sessionId, senderKey) internalStoreGroupSession(holder, sessionId, senderKey)
} }
private fun internalStoreGroupSession(wrapper: OlmInboundGroupSessionWrapper2, sessionId: String, senderKey: String) { private fun internalStoreGroupSession(holder: InboundGroupSessionHolder, sessionId: String, senderKey: String) {
Timber.v("## Inbound: getInboundGroupSession mark as dirty ${wrapper.roomId}-${wrapper.senderKey}") Timber.tag(loggerTag.value).v("## Inbound: getInboundGroupSession mark as dirty ${holder.wrapper.roomId}-${holder.wrapper.senderKey}")
// We want to batch this a bit for performances // We want to batch this a bit for performances
dirtySession.add(wrapper) dirtySession.add(holder.wrapper)
if (sessionCache[CacheKey(sessionId, senderKey)] == null) { if (sessionCache[CacheKey(sessionId, senderKey)] == null) {
// first time seen, put it in memory cache while waiting for batch insert // first time seen, put it in memory cache while waiting for batch insert
// If it's already known, no need to update cache it's already there // If it's already known, no need to update cache it's already there
sessionCache.put(CacheKey(sessionId, senderKey), wrapper) sessionCache.put(CacheKey(sessionId, senderKey), holder)
} }
timerTask?.cancel() timerTask?.cancel()
@ -110,7 +130,7 @@ internal class InboundGroupSessionStore @Inject constructor(
val toSave = mutableListOf<OlmInboundGroupSessionWrapper2>().apply { addAll(dirtySession) } val toSave = mutableListOf<OlmInboundGroupSessionWrapper2>().apply { addAll(dirtySession) }
dirtySession.clear() dirtySession.clear()
cryptoCoroutineScope.launch(coroutineDispatchers.crypto) { cryptoCoroutineScope.launch(coroutineDispatchers.crypto) {
Timber.v("## Inbound: getInboundGroupSession batching save of ${dirtySession.size}") Timber.tag(loggerTag.value).v("## Inbound: getInboundGroupSession batching save of ${toSave.size}")
tryOrNull { tryOrNull {
store.storeInboundGroupSessions(toSave) store.storeInboundGroupSessions(toSave)
} }

View File

@ -208,6 +208,7 @@ internal class MXOlmDevice @Inject constructor(
it.groupSession.releaseSession() it.groupSession.releaseSession()
} }
outboundGroupSessionCache.clear() outboundGroupSessionCache.clear()
inboundGroupSessionStore.clear()
olmSessionStore.clear() olmSessionStore.clear()
} }
@ -585,7 +586,7 @@ internal class MXOlmDevice @Inject constructor(
if (sessionId.isNotEmpty() && payloadString.isNotEmpty()) { if (sessionId.isNotEmpty() && payloadString.isNotEmpty()) {
try { try {
return outboundGroupSessionCache[sessionId]!!.groupSession.encryptMessage(payloadString) return outboundGroupSessionCache[sessionId]!!.groupSession.encryptMessage(payloadString)
} catch (e: Exception) { } catch (e: Throwable) {
Timber.e(e, "## encryptGroupMessage() : failed") Timber.e(e, "## encryptGroupMessage() : failed")
} }
} }
@ -614,7 +615,8 @@ internal class MXOlmDevice @Inject constructor(
keysClaimed: Map<String, String>, keysClaimed: Map<String, String>,
exportFormat: Boolean): Boolean { exportFormat: Boolean): Boolean {
val candidateSession = OlmInboundGroupSessionWrapper2(sessionKey, exportFormat) val candidateSession = OlmInboundGroupSessionWrapper2(sessionKey, exportFormat)
val existingSession = tryOrNull { getInboundGroupSession(sessionId, senderKey, roomId) } val existingSessionHolder = tryOrNull { getInboundGroupSession(sessionId, senderKey, roomId) }
val existingSession = existingSessionHolder?.wrapper
// If we have an existing one we should check if the new one is not better // If we have an existing one we should check if the new one is not better
if (existingSession != null) { if (existingSession != null) {
Timber.d("## addInboundGroupSession() check if known session is better than candidate session") Timber.d("## addInboundGroupSession() check if known session is better than candidate session")
@ -666,9 +668,9 @@ internal class MXOlmDevice @Inject constructor(
candidateSession.forwardingCurve25519KeyChain = forwardingCurve25519KeyChain candidateSession.forwardingCurve25519KeyChain = forwardingCurve25519KeyChain
if (existingSession != null) { if (existingSession != null) {
inboundGroupSessionStore.replaceGroupSession(existingSession, candidateSession, sessionId, senderKey) inboundGroupSessionStore.replaceGroupSession(existingSessionHolder, InboundGroupSessionHolder(candidateSession), sessionId, senderKey)
} else { } else {
inboundGroupSessionStore.storeInBoundGroupSession(candidateSession, sessionId, senderKey) inboundGroupSessionStore.storeInBoundGroupSession(InboundGroupSessionHolder(candidateSession), sessionId, senderKey)
} }
return true return true
@ -715,7 +717,8 @@ internal class MXOlmDevice @Inject constructor(
continue continue
} }
val existingSession = tryOrNull { getInboundGroupSession(sessionId, senderKey, roomId) } val existingSessionHolder = tryOrNull { getInboundGroupSession(sessionId, senderKey, roomId) }
val existingSession = existingSessionHolder?.wrapper
if (existingSession == null) { if (existingSession == null) {
// Session does not already exist, add it // Session does not already exist, add it
@ -736,7 +739,7 @@ internal class MXOlmDevice @Inject constructor(
candidateOlmInboundGroupSession.releaseSession() candidateOlmInboundGroupSession.releaseSession()
} else { } else {
// update cache with better session // update cache with better session
inboundGroupSessionStore.replaceGroupSession(existingSession, candidateSessionToImport, sessionId, senderKey) inboundGroupSessionStore.replaceGroupSession(existingSessionHolder, InboundGroupSessionHolder(candidateSessionToImport), sessionId, senderKey)
sessions.add(candidateSessionToImport) sessions.add(candidateSessionToImport)
} }
} }
@ -748,18 +751,6 @@ internal class MXOlmDevice @Inject constructor(
return sessions return sessions
} }
/**
* Remove an inbound group session
*
* @param sessionId the session identifier.
* @param sessionKey base64-encoded secret key.
*/
fun removeInboundGroupSession(sessionId: String?, sessionKey: String?) {
if (null != sessionId && null != sessionKey) {
store.removeInboundGroupSession(sessionId, sessionKey)
}
}
/** /**
* Decrypt a received message with an inbound group session. * Decrypt a received message with an inbound group session.
* *
@ -771,17 +762,22 @@ internal class MXOlmDevice @Inject constructor(
* @return the decrypting result. Nil if the sessionId is unknown. * @return the decrypting result. Nil if the sessionId is unknown.
*/ */
@Throws(MXCryptoError::class) @Throws(MXCryptoError::class)
fun decryptGroupMessage(body: String, suspend fun decryptGroupMessage(body: String,
roomId: String, roomId: String,
timeline: String?, timeline: String?,
sessionId: String, sessionId: String,
senderKey: String): OlmDecryptionResult { senderKey: String): OlmDecryptionResult {
val session = getInboundGroupSession(sessionId, senderKey, roomId) val sessionHolder = getInboundGroupSession(sessionId, senderKey, roomId)
val wrapper = sessionHolder.wrapper
val inboundGroupSession = wrapper.olmInboundGroupSession
?: throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, "Session is null")
// Check that the room id matches the original one for the session. This stops // Check that the room id matches the original one for the session. This stops
// the HS pretending a message was targeting a different room. // the HS pretending a message was targeting a different room.
if (roomId == session.roomId) { if (roomId == wrapper.roomId) {
val decryptResult = try { val decryptResult = try {
session.olmInboundGroupSession!!.decryptMessage(body) sessionHolder.mutex.withLock {
inboundGroupSession.decryptMessage(body)
}
} catch (e: OlmException) { } catch (e: OlmException) {
Timber.e(e, "## decryptGroupMessage () : decryptMessage failed") Timber.e(e, "## decryptGroupMessage () : decryptMessage failed")
throw MXCryptoError.OlmError(e) throw MXCryptoError.OlmError(e)
@ -801,7 +797,7 @@ internal class MXOlmDevice @Inject constructor(
timelineSet.add(messageIndexKey) timelineSet.add(messageIndexKey)
} }
inboundGroupSessionStore.storeInBoundGroupSession(session, sessionId, senderKey) inboundGroupSessionStore.storeInBoundGroupSession(sessionHolder, sessionId, senderKey)
val payload = try { val payload = try {
val adapter = MoshiProvider.providesMoshi().adapter<JsonDict>(JSON_DICT_PARAMETERIZED_TYPE) val adapter = MoshiProvider.providesMoshi().adapter<JsonDict>(JSON_DICT_PARAMETERIZED_TYPE)
val payloadString = convertFromUTF8(decryptResult.mDecryptedMessage) val payloadString = convertFromUTF8(decryptResult.mDecryptedMessage)
@ -813,12 +809,12 @@ internal class MXOlmDevice @Inject constructor(
return OlmDecryptionResult( return OlmDecryptionResult(
payload, payload,
session.keysClaimed, wrapper.keysClaimed,
senderKey, senderKey,
session.forwardingCurve25519KeyChain wrapper.forwardingCurve25519KeyChain
) )
} else { } else {
val reason = String.format(MXCryptoError.INBOUND_SESSION_MISMATCH_ROOM_ID_REASON, roomId, session.roomId) val reason = String.format(MXCryptoError.INBOUND_SESSION_MISMATCH_ROOM_ID_REASON, roomId, wrapper.roomId)
Timber.e("## decryptGroupMessage() : $reason") Timber.e("## decryptGroupMessage() : $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.INBOUND_SESSION_MISMATCH_ROOM_ID, reason) throw MXCryptoError.Base(MXCryptoError.ErrorType.INBOUND_SESSION_MISMATCH_ROOM_ID, reason)
} }
@ -885,12 +881,13 @@ internal class MXOlmDevice @Inject constructor(
* @param senderKey the base64-encoded curve25519 key of the sender. * @param senderKey the base64-encoded curve25519 key of the sender.
* @return the inbound group session. * @return the inbound group session.
*/ */
fun getInboundGroupSession(sessionId: String?, senderKey: String?, roomId: String?): OlmInboundGroupSessionWrapper2 { fun getInboundGroupSession(sessionId: String?, senderKey: String?, roomId: String?): InboundGroupSessionHolder {
if (sessionId.isNullOrBlank() || senderKey.isNullOrBlank()) { if (sessionId.isNullOrBlank() || senderKey.isNullOrBlank()) {
throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_SENDER_KEY, MXCryptoError.ERROR_MISSING_PROPERTY_REASON) throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_SENDER_KEY, MXCryptoError.ERROR_MISSING_PROPERTY_REASON)
} }
val session = inboundGroupSessionStore.getInboundGroupSession(sessionId, senderKey) val holder = inboundGroupSessionStore.getInboundGroupSession(sessionId, senderKey)
val session = holder?.wrapper
if (session != null) { if (session != null) {
// Check that the room id matches the original one for the session. This stops // Check that the room id matches the original one for the session. This stops
@ -900,7 +897,7 @@ internal class MXOlmDevice @Inject constructor(
Timber.e("## getInboundGroupSession() : $errorDescription") Timber.e("## getInboundGroupSession() : $errorDescription")
throw MXCryptoError.Base(MXCryptoError.ErrorType.INBOUND_SESSION_MISMATCH_ROOM_ID, errorDescription) throw MXCryptoError.Base(MXCryptoError.ErrorType.INBOUND_SESSION_MISMATCH_ROOM_ID, errorDescription)
} else { } else {
return session return holder
} }
} else { } else {
Timber.w("## getInboundGroupSession() : Cannot retrieve inbound group session $sessionId") Timber.w("## getInboundGroupSession() : Cannot retrieve inbound group session $sessionId")

View File

@ -45,7 +45,7 @@ internal class OlmSessionStore @Inject constructor(private val store: IMXCryptoS
@Synchronized @Synchronized
fun storeSession(olmSessionWrapper: OlmSessionWrapper, deviceKey: String) { fun storeSession(olmSessionWrapper: OlmSessionWrapper, deviceKey: String) {
// This could be a newly created session or one that was just created // This could be a newly created session or one that was just created
// Anyhow we should persist ratchet state for futur app lifecycle // Anyhow we should persist ratchet state for future app lifecycle
addNewSessionInCache(olmSessionWrapper, deviceKey) addNewSessionInCache(olmSessionWrapper, deviceKey)
store.storeSession(olmSessionWrapper, deviceKey) store.storeSession(olmSessionWrapper, deviceKey)
} }

View File

@ -45,7 +45,7 @@ internal interface IMXGroupEncryption {
* *
* @return true in case of success * @return true in case of success
*/ */
suspend fun reshareKey(sessionId: String, suspend fun reshareKey(groupSessionId: String,
userId: String, userId: String,
deviceId: String, deviceId: String,
senderKey: String): Boolean senderKey: String): Boolean

View File

@ -19,6 +19,7 @@ package org.matrix.android.sdk.internal.crypto.algorithms.megolm
import dagger.Lazy import dagger.Lazy
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.withLock
import org.matrix.android.sdk.api.MatrixCoroutineDispatchers import org.matrix.android.sdk.api.MatrixCoroutineDispatchers
import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.api.session.crypto.MXCryptoError import org.matrix.android.sdk.api.session.crypto.MXCryptoError
@ -79,7 +80,7 @@ internal class MXMegolmDecryption(private val userId: String,
} }
@Throws(MXCryptoError::class) @Throws(MXCryptoError::class)
private fun decryptEvent(event: Event, timeline: String, requestKeysOnFail: Boolean): MXEventDecryptionResult { private suspend fun decryptEvent(event: Event, timeline: String, requestKeysOnFail: Boolean): MXEventDecryptionResult {
Timber.tag(loggerTag.value).v("decryptEvent ${event.eventId}, requestKeysOnFail:$requestKeysOnFail") Timber.tag(loggerTag.value).v("decryptEvent ${event.eventId}, requestKeysOnFail:$requestKeysOnFail")
if (event.roomId.isNullOrBlank()) { if (event.roomId.isNullOrBlank()) {
throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_FIELDS, MXCryptoError.MISSING_FIELDS_REASON) throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_FIELDS, MXCryptoError.MISSING_FIELDS_REASON)
@ -345,7 +346,23 @@ internal class MXMegolmDecryption(private val userId: String,
return return
} }
val userId = request.userId ?: return val userId = request.userId ?: return
cryptoCoroutineScope.launch(coroutineDispatchers.crypto) { cryptoCoroutineScope.launch(coroutineDispatchers.crypto) {
val body = request.requestBody
val sessionHolder = try {
olmDevice.getInboundGroupSession(body.sessionId, body.senderKey, body.roomId)
} catch (failure: Throwable) {
Timber.tag(loggerTag.value).e(failure, "shareKeysWithDevice: failed to get session for request $body")
return@launch
}
val export = sessionHolder.mutex.withLock {
sessionHolder.wrapper.exportKeys()
} ?: return@launch Unit.also {
Timber.tag(loggerTag.value).e("shareKeysWithDevice: failed to export group session ${body.sessionId}")
}
runCatching { deviceListManager.downloadKeys(listOf(userId), false) } runCatching { deviceListManager.downloadKeys(listOf(userId), false) }
.mapCatching { .mapCatching {
val deviceId = request.deviceId val deviceId = request.deviceId
@ -355,7 +372,6 @@ internal class MXMegolmDecryption(private val userId: String,
} else { } else {
val devicesByUser = mapOf(userId to listOf(deviceInfo)) val devicesByUser = mapOf(userId to listOf(deviceInfo))
val usersDeviceMap = ensureOlmSessionsForDevicesAction.handle(devicesByUser) val usersDeviceMap = ensureOlmSessionsForDevicesAction.handle(devicesByUser)
val body = request.requestBody
val olmSessionResult = usersDeviceMap.getObject(userId, deviceId) val olmSessionResult = usersDeviceMap.getObject(userId, deviceId)
if (olmSessionResult?.sessionId == null) { if (olmSessionResult?.sessionId == null) {
// no session with this device, probably because there // no session with this device, probably because there
@ -365,19 +381,10 @@ internal class MXMegolmDecryption(private val userId: String,
} }
Timber.tag(loggerTag.value).i("shareKeysWithDevice() : sharing session ${body.sessionId} with device $userId:$deviceId") Timber.tag(loggerTag.value).i("shareKeysWithDevice() : sharing session ${body.sessionId} with device $userId:$deviceId")
val payloadJson = mutableMapOf<String, Any>("type" to EventType.FORWARDED_ROOM_KEY) val payloadJson = mapOf(
runCatching { olmDevice.getInboundGroupSession(body.sessionId, body.senderKey, body.roomId) } "type" to EventType.FORWARDED_ROOM_KEY,
.fold( "content" to export
{ )
// TODO
payloadJson["content"] = it.exportKeys() ?: ""
},
{
// TODO
Timber.tag(loggerTag.value).e(it, "shareKeysWithDevice: failed to get session for request $body")
}
)
val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo)) val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo))
val sendToDeviceMap = MXUsersDevicesMap<Any>() val sendToDeviceMap = MXUsersDevicesMap<Any>()

View File

@ -18,6 +18,7 @@ package org.matrix.android.sdk.internal.crypto.algorithms.megolm
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.withLock
import org.matrix.android.sdk.api.MatrixCoroutineDispatchers import org.matrix.android.sdk.api.MatrixCoroutineDispatchers
import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.api.session.crypto.MXCryptoError import org.matrix.android.sdk.api.session.crypto.MXCryptoError
@ -432,20 +433,20 @@ internal class MXMegolmEncryption(
} }
} }
override suspend fun reshareKey(sessionId: String, override suspend fun reshareKey(groupSessionId: String,
userId: String, userId: String,
deviceId: String, deviceId: String,
senderKey: String): Boolean { senderKey: String): Boolean {
Timber.tag(loggerTag.value).i("process reshareKey for $sessionId to $userId:$deviceId") Timber.tag(loggerTag.value).i("process reshareKey for $groupSessionId to $userId:$deviceId")
val deviceInfo = cryptoStore.getUserDevice(userId, deviceId) ?: return false val deviceInfo = cryptoStore.getUserDevice(userId, deviceId) ?: return false
.also { Timber.tag(loggerTag.value).w("reshareKey: Device not found") } .also { Timber.tag(loggerTag.value).w("reshareKey: Device not found") }
// Get the chain index of the key we previously sent this device // Get the chain index of the key we previously sent this device
val wasSessionSharedWithUser = cryptoStore.getSharedSessionInfo(roomId, sessionId, deviceInfo) val wasSessionSharedWithUser = cryptoStore.getSharedSessionInfo(roomId, groupSessionId, deviceInfo)
if (!wasSessionSharedWithUser.found) { if (!wasSessionSharedWithUser.found) {
// This session was never shared with this user // This session was never shared with this user
// Send a room key with held // Send a room key with held
notifyKeyWithHeld(listOf(UserDevice(userId, deviceId)), sessionId, senderKey, WithHeldCode.UNAUTHORISED) notifyKeyWithHeld(listOf(UserDevice(userId, deviceId)), groupSessionId, senderKey, WithHeldCode.UNAUTHORISED)
Timber.tag(loggerTag.value).w("reshareKey: ERROR : Never shared megolm with this device") Timber.tag(loggerTag.value).w("reshareKey: ERROR : Never shared megolm with this device")
return false return false
} }
@ -456,42 +457,48 @@ internal class MXMegolmEncryption(
} }
val devicesByUser = mapOf(userId to listOf(deviceInfo)) val devicesByUser = mapOf(userId to listOf(deviceInfo))
val usersDeviceMap = ensureOlmSessionsForDevicesAction.handle(devicesByUser) val usersDeviceMap = try {
val olmSessionResult = usersDeviceMap.getObject(userId, deviceId) ensureOlmSessionsForDevicesAction.handle(devicesByUser)
olmSessionResult?.sessionId // no session with this device, probably because there were no one-time keys. } catch (failure: Throwable) {
// ensureOlmSessionsForDevicesAction has already done the logging, so just skip it. null
?: return false.also { }
Timber.tag(loggerTag.value).w("reshareKey: no session with this device, probably because there were no one-time keys") val olmSessionResult = usersDeviceMap?.getObject(userId, deviceId)
} if (olmSessionResult?.sessionId == null) {
return false.also {
Timber.tag(loggerTag.value).w("reshareKey: no session with this device, probably because there were no one-time keys")
}
}
Timber.tag(loggerTag.value).i(" reshareKey: $groupSessionId:$chainIndex with device $userId:$deviceId using session ${olmSessionResult.sessionId}")
Timber.tag(loggerTag.value).i(" reshareKey: sharing keys for session $senderKey|$sessionId:$chainIndex with device $userId:$deviceId") val sessionHolder = try {
olmDevice.getInboundGroupSession(groupSessionId, senderKey, roomId)
} catch (failure: Throwable) {
Timber.tag(loggerTag.value).e(failure, "shareKeysWithDevice: failed to get session $groupSessionId")
return false
}
val payloadJson = mutableMapOf<String, Any>("type" to EventType.FORWARDED_ROOM_KEY) val export = sessionHolder.mutex.withLock {
sessionHolder.wrapper.exportKeys()
} ?: return false.also {
Timber.tag(loggerTag.value).e("shareKeysWithDevice: failed to export group session ${groupSessionId}")
}
runCatching { olmDevice.getInboundGroupSession(sessionId, senderKey, roomId) } val payloadJson = mapOf(
.fold( "type" to EventType.FORWARDED_ROOM_KEY,
{ "content" to export
// TODO )
payloadJson["content"] = it.exportKeys(chainIndex.toLong()) ?: ""
},
{
// TODO
Timber.tag(loggerTag.value).e(it, "reshareKey: failed to get session $sessionId|$senderKey|$roomId")
}
)
val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo)) val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo))
val sendToDeviceMap = MXUsersDevicesMap<Any>() val sendToDeviceMap = MXUsersDevicesMap<Any>()
sendToDeviceMap.setObject(userId, deviceId, encodedPayload) sendToDeviceMap.setObject(userId, deviceId, encodedPayload)
Timber.tag(loggerTag.value).i("reshareKey() : sending session $sessionId to $userId:$deviceId") Timber.tag(loggerTag.value).i("reshareKey() : sending session $groupSessionId to $userId:$deviceId")
val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, sendToDeviceMap) val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, sendToDeviceMap)
return try { return try {
sendToDeviceTask.execute(sendToDeviceParams) sendToDeviceTask.execute(sendToDeviceParams)
Timber.tag(loggerTag.value).i("reshareKey() : successfully send <$sessionId> to $userId:$deviceId") Timber.tag(loggerTag.value).i("reshareKey() : successfully send <$groupSessionId> to $userId:$deviceId")
true true
} catch (failure: Throwable) { } catch (failure: Throwable) {
Timber.tag(loggerTag.value).e(failure, "reshareKey() : fail to send <$sessionId> to $userId:$deviceId") Timber.tag(loggerTag.value).e(failure, "reshareKey() : fail to send <$groupSessionId> to $userId:$deviceId")
false false
} }
} }

View File

@ -104,7 +104,6 @@ import timber.log.Timber
import java.util.concurrent.Executors import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import javax.inject.Inject import javax.inject.Inject
import kotlin.collections.set
@SessionScope @SessionScope
internal class RealmCryptoStore @Inject constructor( internal class RealmCryptoStore @Inject constructor(
@ -124,12 +123,6 @@ internal class RealmCryptoStore @Inject constructor(
// The olm account // The olm account
private var olmAccount: OlmAccount? = null private var olmAccount: OlmAccount? = null
// Cache for OlmSession, to release them properly
// private val olmSessionsToRelease = HashMap<String, OlmSessionWrapper>()
// Cache for InboundGroupSession, to release them properly
private val inboundGroupSessionToRelease = HashMap<String, OlmInboundGroupSessionWrapper2>()
private val newSessionListeners = ArrayList<NewSessionListener>() private val newSessionListeners = ArrayList<NewSessionListener>()
override fun addNewSessionListener(listener: NewSessionListener) { override fun addNewSessionListener(listener: NewSessionListener) {
@ -213,11 +206,6 @@ internal class RealmCryptoStore @Inject constructor(
monarchyWriteAsyncExecutor.awaitTermination(1, TimeUnit.MINUTES) monarchyWriteAsyncExecutor.awaitTermination(1, TimeUnit.MINUTES)
} }
inboundGroupSessionToRelease.forEach {
it.value.olmInboundGroupSession?.releaseSession()
}
inboundGroupSessionToRelease.clear()
olmAccount?.releaseAccount() olmAccount?.releaseAccount()
realmLocker?.close() realmLocker?.close()
@ -745,13 +733,6 @@ internal class RealmCryptoStore @Inject constructor(
if (sessionIdentifier != null) { if (sessionIdentifier != null) {
val key = OlmInboundGroupSessionEntity.createPrimaryKey(sessionIdentifier, session.senderKey) val key = OlmInboundGroupSessionEntity.createPrimaryKey(sessionIdentifier, session.senderKey)
// Release memory of previously known session, if it is not the same one
if (inboundGroupSessionToRelease[key] != session) {
inboundGroupSessionToRelease[key]?.olmInboundGroupSession?.releaseSession()
}
inboundGroupSessionToRelease[key] = session
val realmOlmInboundGroupSession = OlmInboundGroupSessionEntity().apply { val realmOlmInboundGroupSession = OlmInboundGroupSessionEntity().apply {
primaryKey = key primaryKey = key
sessionId = sessionIdentifier sessionId = sessionIdentifier
@ -768,20 +749,12 @@ internal class RealmCryptoStore @Inject constructor(
override fun getInboundGroupSession(sessionId: String, senderKey: String): OlmInboundGroupSessionWrapper2? { override fun getInboundGroupSession(sessionId: String, senderKey: String): OlmInboundGroupSessionWrapper2? {
val key = OlmInboundGroupSessionEntity.createPrimaryKey(sessionId, senderKey) val key = OlmInboundGroupSessionEntity.createPrimaryKey(sessionId, senderKey)
// If not in cache (or not found), try to read it from realm return doWithRealm(realmConfiguration) {
if (inboundGroupSessionToRelease[key] == null) { it.where<OlmInboundGroupSessionEntity>()
doWithRealm(realmConfiguration) { .equalTo(OlmInboundGroupSessionEntityFields.PRIMARY_KEY, key)
it.where<OlmInboundGroupSessionEntity>() .findFirst()
.equalTo(OlmInboundGroupSessionEntityFields.PRIMARY_KEY, key) ?.getInboundGroupSession()
.findFirst()
?.getInboundGroupSession()
}
?.let {
inboundGroupSessionToRelease[key] = it
}
} }
return inboundGroupSessionToRelease[key]
} }
override fun getCurrentOutboundGroupSessionForRoom(roomId: String): OutboundGroupSessionWrapper? { override fun getCurrentOutboundGroupSessionForRoom(roomId: String): OutboundGroupSessionWrapper? {
@ -837,10 +810,6 @@ internal class RealmCryptoStore @Inject constructor(
override fun removeInboundGroupSession(sessionId: String, senderKey: String) { override fun removeInboundGroupSession(sessionId: String, senderKey: String) {
val key = OlmInboundGroupSessionEntity.createPrimaryKey(sessionId, senderKey) val key = OlmInboundGroupSessionEntity.createPrimaryKey(sessionId, senderKey)
// Release memory of previously known session
inboundGroupSessionToRelease[key]?.olmInboundGroupSession?.releaseSession()
inboundGroupSessionToRelease.remove(key)
doRealmTransaction(realmConfiguration) { doRealmTransaction(realmConfiguration) {
it.where<OlmInboundGroupSessionEntity>() it.where<OlmInboundGroupSessionEntity>()
.equalTo(OlmInboundGroupSessionEntityFields.PRIMARY_KEY, key) .equalTo(OlmInboundGroupSessionEntityFields.PRIMARY_KEY, key)