From 10ea166b2a898468afb723d79e5cf79231e69b13 Mon Sep 17 00:00:00 2001 From: Valere Date: Fri, 25 Feb 2022 16:07:06 +0100 Subject: [PATCH] Extract olm cache store --- .../sdk/account/AccountCreationTest.kt | 152 +++++++++++ .../sdk/internal/crypto/SimpleE2EEChatTest.kt | 244 ++++++++++++++++++ .../sdk/internal/crypto/MXOlmDevice.kt | 22 +- .../sdk/internal/crypto/OlmSessionStore.kt | 152 +++++++++++ .../crypto/algorithms/olm/MXOlmDecryption.kt | 2 +- .../crypto/store/db/RealmCryptoStore.kt | 41 +-- 6 files changed, 575 insertions(+), 38 deletions(-) create mode 100644 matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/SimpleE2EEChatTest.kt create mode 100644 matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmSessionStore.kt diff --git a/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/account/AccountCreationTest.kt b/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/account/AccountCreationTest.kt index 486bc02769..6f25b24d9c 100644 --- a/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/account/AccountCreationTest.kt +++ b/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/account/AccountCreationTest.kt @@ -16,7 +16,16 @@ package org.matrix.android.sdk.account +import android.util.Log import androidx.test.filters.LargeTest +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.asCoroutineDispatcher +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch import org.junit.FixMethodOrder import org.junit.Ignore import org.junit.Test @@ -28,6 +37,9 @@ import org.matrix.android.sdk.common.CommonTestHelper import org.matrix.android.sdk.common.CryptoTestHelper import org.matrix.android.sdk.common.SessionTestParams import org.matrix.android.sdk.common.TestConstants +import java.util.concurrent.Executors +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.random.Random @RunWith(JUnit4::class) @FixMethodOrder(MethodSorters.JVM) @@ -62,4 +74,144 @@ class AccountCreationTest : InstrumentedTest { res.cleanUp(commonTestHelper) } + + @Test + fun testConcurrentDecrypt() { +// val res = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom() + + // ============================= + // ARRANGE + // ============================= + + val aliceSession = commonTestHelper.createAccount(TestConstants.USER_ALICE, SessionTestParams(true)) + val bobSession = commonTestHelper.createAccount(TestConstants.USER_BOB, SessionTestParams(true)) + cryptoTestHelper.initializeCrossSigning(bobSession) + val bobSession2 = commonTestHelper.logIntoAccount(bobSession.myUserId, SessionTestParams(true)) + + bobSession2.cryptoService().verificationService().markedLocallyAsManuallyVerified(bobSession.myUserId, bobSession.sessionParams.deviceId ?: "") + bobSession.cryptoService().verificationService().markedLocallyAsManuallyVerified(bobSession.myUserId, bobSession2.sessionParams.deviceId ?: "") + + val roomId = cryptoTestHelper.createDM(aliceSession, bobSession) + val roomAlicePOV = aliceSession.getRoom(roomId)!! + + // ============================= + // ACT + // ============================= + + val timelineEvent = commonTestHelper.sendTextMessage(roomAlicePOV, "Hello Bob", 1).first() + val secondEvent = commonTestHelper.sendTextMessage(roomAlicePOV, "Hello Bob 2", 1).first() + val thirdEvent = commonTestHelper.sendTextMessage(roomAlicePOV, "Hello Bob 3", 1).first() + val forthEvent = commonTestHelper.sendTextMessage(roomAlicePOV, "Hello Bob 4", 1).first() + + // await for bob unverified session to get the message + commonTestHelper.waitWithLatch { latch -> + commonTestHelper.retryPeriodicallyWithLatch(latch) { + bobSession.getRoom(roomId)?.getTimeLineEvent(forthEvent.eventId) != null + } + } + + val eventBobPOV = bobSession.getRoom(roomId)?.getTimeLineEvent(timelineEvent.eventId)!! + val secondEventBobPOV = bobSession.getRoom(roomId)?.getTimeLineEvent(secondEvent.eventId)!! + val thirdEventBobPOV = bobSession.getRoom(roomId)?.getTimeLineEvent(thirdEvent.eventId)!! + val forthEventBobPOV = bobSession.getRoom(roomId)?.getTimeLineEvent(forthEvent.eventId)!! + + // let's try to decrypt concurrently and check that we are not getting exceptions + val dispatcher = Executors + .newFixedThreadPool(100) + .asCoroutineDispatcher() + val coroutineScope = CoroutineScope(SupervisorJob() + dispatcher) + + val eventList = listOf(eventBobPOV, secondEventBobPOV, thirdEventBobPOV, forthEventBobPOV) + +// commonTestHelper.runBlockingTest { +// val export = bobSession.cryptoService().exportRoomKeys("foo") + +// } + val atomicAsError = AtomicBoolean() + val deff = mutableListOf>() +// for (i in 1..100) { +// GlobalScope.launch { +// val index = Random.nextInt(eventList.size) +// try { +// val event = eventList[index] +// bobSession.cryptoService().decryptEvent(event.root, "") +// Log.d("#TEST", "Decrypt Success $index :${Thread.currentThread().name}") +// } catch (failure: Throwable) { +// Log.d("#TEST", "Failed to decrypt $index :$failure") +// } +// } +// } + val cryptoService = bobSession.cryptoService() + + coroutineScope.launch { + for (spawn in 1..100) { + delay((Random.nextFloat() * 1000).toLong()) + aliceSession.cryptoService().requestRoomKeyForEvent(eventList.random().root) + } + } + + for (spawn in 1..8000) { + eventList.random().let { event -> + coroutineScope.async { + try { + cryptoService.decryptEvent(event.root, "") + Log.d("#TEST", "[$spawn] Decrypt Success ${event.eventId} :${Thread.currentThread().name}") + } catch (failure: Throwable) { + atomicAsError.set(true) + Log.e("#TEST", "Failed to decrypt $spawn/${event.eventId} :$failure") + } + }.let { + deff.add(it) + } + } +// coroutineScope.async { +// val index = Random.nextInt(eventList.size) +// try { +// val event = eventList[index] +// cryptoService.decryptEvent(event.root, "") +// for (other in eventList.indices) { +// if (other != index) { +// cryptoService.decryptEventAsync(eventList[other].root, "", object : MatrixCallback { +// override fun onFailure(failure: Throwable) { +// Log.e("#TEST", "Failed to decrypt $spawn/$index :$failure") +// } +// }) +// } +// } +// Log.d("#TEST", "[$spawn] Decrypt Success $index :${Thread.currentThread().name}") +// } catch (failure: Throwable) { +// Log.e("#TEST", "Failed to decrypt $spawn/$index :$failure") +// } +// }.let { +// deff.add(it) +// } + } + + coroutineScope.launch { + for (spawn in 1..100) { + delay((Random.nextFloat() * 1000).toLong()) + bobSession.cryptoService().requestRoomKeyForEvent(eventList.random().root) + } + } + + commonTestHelper.runBlockingTest(10 * 60_000) { + deff.awaitAll() + delay(10_000) + assert(!atomicAsError.get()) + // There should be no errors? +// deff.map { it.await() }.forEach { +// it.fold({ +// Log.d("#TEST", "Decrypt Success :${it}") +// }, { +// Log.d("#TEST", "Failed to decrypt :$it") +// }) +// val hasFailure = deff.any { it.await().exceptionOrNull() != null } +// assert(!hasFailure) +// } + + commonTestHelper.signOutAndClose(aliceSession) + commonTestHelper.signOutAndClose(bobSession) + commonTestHelper.signOutAndClose(bobSession2) + } + } } diff --git a/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/SimpleE2EEChatTest.kt b/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/SimpleE2EEChatTest.kt new file mode 100644 index 0000000000..d2b2495b76 --- /dev/null +++ b/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/SimpleE2EEChatTest.kt @@ -0,0 +1,244 @@ +/* + * Copyright (c) 2022 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.matrix.android.sdk.internal.crypto + +import android.util.Log +import androidx.test.filters.LargeTest +import kotlinx.coroutines.delay +import org.junit.Assert +import org.junit.FixMethodOrder +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.junit.runners.MethodSorters +import org.matrix.android.sdk.InstrumentedTest +import org.matrix.android.sdk.api.session.Session +import org.matrix.android.sdk.api.session.crypto.MXCryptoError +import org.matrix.android.sdk.api.session.events.model.EventType +import org.matrix.android.sdk.api.session.events.model.toModel +import org.matrix.android.sdk.api.session.room.Room +import org.matrix.android.sdk.api.session.room.failure.JoinRoomFailure +import org.matrix.android.sdk.api.session.room.model.Membership +import org.matrix.android.sdk.api.session.room.model.message.MessageContent +import org.matrix.android.sdk.api.session.room.send.SendState +import org.matrix.android.sdk.api.session.room.timeline.TimelineSettings +import org.matrix.android.sdk.common.CommonTestHelper +import org.matrix.android.sdk.common.CryptoTestHelper +import org.matrix.android.sdk.common.SessionTestParams + +@RunWith(JUnit4::class) +@FixMethodOrder(MethodSorters.JVM) +@LargeTest +class SimpleE2EEChatTest : InstrumentedTest { + + private val testHelper = CommonTestHelper(context()) + private val cryptoTestHelper = CryptoTestHelper(testHelper) + + /** + * Simple test that create an e2ee room. + * Some new members are added, and a message is sent. + * We check that the message is e2e and can be decrypted. + * + * Additional users join, we check that they can't decrypt history + * + * Alice sends a new message, then check that the new one can be decrypted + */ + @Test + fun testSendingE2EEMessages() { + val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true) + val aliceSession = cryptoTestData.firstSession + val e2eRoomID = cryptoTestData.roomId + + val aliceRoomPOV = aliceSession.getRoom(e2eRoomID)!! + + // add some more users and invite them + val otherAccounts = listOf("benoit", "valere", "ganfra") // , "adam", "manu") + .map { + testHelper.createAccount(it, SessionTestParams(true)) + } + + Log.v("#E2E TEST", "All accounts created") + // we want to invite them in the room + otherAccounts.forEach { + testHelper.runBlockingTest { + Log.v("#E2E TEST", "Alice invites ${it.myUserId}") + aliceRoomPOV.invite(it.myUserId) + } + } + + // All user should accept invite + otherAccounts.forEach { otherSession -> + waitForAndAcceptInviteInRoom(otherSession, e2eRoomID) + Log.v("#E2E TEST", "${otherSession.myUserId} joined room $e2eRoomID") + } + + // check that alice see them as joined (not really necessary?) + ensureMembersHaveJoined(aliceSession, otherAccounts, e2eRoomID) + + Log.v("#E2E TEST", "All users have joined the room") + + Log.v("#E2E TEST", "Alice is sending the message") + + val text = "This is my message" + val sentEventId: String? = sendMessageInRoom(aliceRoomPOV, text) + // val sentEvent = testHelper.sendTextMessage(aliceRoomPOV, "Hello all", 1).first() + Assert.assertTrue("Message should be sent", sentEventId != null) + + // All should be able to decrypt + otherAccounts.forEach { otherSession -> + testHelper.waitWithLatch { latch -> + testHelper.retryPeriodicallyWithLatch(latch) { + val timeLineEvent = otherSession.getRoom(e2eRoomID)?.getTimeLineEvent(sentEventId!!) + timeLineEvent != null && + timeLineEvent.isEncrypted() && + timeLineEvent.root.getClearType() == EventType.MESSAGE + } + } + } + + // Add a new user to the room, and check that he can't decrypt + val newAccount = listOf("adam") // , "adam", "manu") + .map { + testHelper.createAccount(it, SessionTestParams(true)) + } + + newAccount.forEach { + testHelper.runBlockingTest { + Log.v("#E2E TEST", "Alice invites ${it.myUserId}") + aliceRoomPOV.invite(it.myUserId) + } + } + + newAccount.forEach { + waitForAndAcceptInviteInRoom(it, e2eRoomID) + } + + ensureMembersHaveJoined(aliceSession, newAccount, e2eRoomID) + + // wait a bit + testHelper.runBlockingTest { + delay(3_000) + } + + // check that messages are encrypted (uisi) + newAccount.forEach { otherSession -> + testHelper.waitWithLatch { latch -> + testHelper.retryPeriodicallyWithLatch(latch) { + val timeLineEvent = otherSession.getRoom(e2eRoomID)?.getTimeLineEvent(sentEventId!!).also { + Log.v("#E2E TEST", "Event seen by new user ${it?.root?.getClearType()}|${it?.root?.mCryptoError}") + } + timeLineEvent != null && + timeLineEvent.root.getClearType() == EventType.ENCRYPTED && + timeLineEvent.root.mCryptoError == MXCryptoError.ErrorType.UNKNOWN_INBOUND_SESSION_ID + } + } + } + + // Let alice send a new message + Log.v("#E2E TEST", "Alice sends a new message") + + val secondMessage = "2 This is my message" + val secondSentEventId: String? = sendMessageInRoom(aliceRoomPOV, secondMessage) + + // new members should be able to decrypt it + newAccount.forEach { otherSession -> + testHelper.waitWithLatch { latch -> + testHelper.retryPeriodicallyWithLatch(latch) { + val timeLineEvent = otherSession.getRoom(e2eRoomID)?.getTimeLineEvent(secondSentEventId!!).also { + Log.v("#E2E TEST", "Second Event seen by new user ${it?.root?.getClearType()}|${it?.root?.mCryptoError}") + } + timeLineEvent != null && + timeLineEvent.root.getClearType() == EventType.MESSAGE && + secondMessage.equals(timeLineEvent.root.getClearContent().toModel()?.body) + } + } + } + + otherAccounts.forEach { + testHelper.signOutAndClose(it) + } + newAccount.forEach { testHelper.signOutAndClose(it) } + + cryptoTestData.cleanUp(testHelper) + } + + private fun sendMessageInRoom(aliceRoomPOV: Room, text: String): String? { + aliceRoomPOV.sendTextMessage(text) + var sentEventId: String? = null + testHelper.waitWithLatch(4 * 60_000) { + val timeline = aliceRoomPOV.createTimeline(null, TimelineSettings(60)) + timeline.start() + + testHelper.retryPeriodicallyWithLatch(it) { + val decryptedMsg = timeline.getSnapshot() + .filter { it.root.getClearType() == EventType.MESSAGE } + .also { + Log.v("#E2E TEST", "Timeline snapshot is ${it.map { "${it.root.type}|${it.root.sendState}" }.joinToString(",", "[", "]")}") + } + .filter { it.root.sendState == SendState.SYNCED } + .firstOrNull { it.root.getClearContent().toModel()?.body?.startsWith(text) == true } + sentEventId = decryptedMsg?.eventId + decryptedMsg != null + } + + timeline.dispose() + } + return sentEventId + } + + private fun ensureMembersHaveJoined(aliceSession: Session, otherAccounts: List, e2eRoomID: String) { + testHelper.waitWithLatch { + testHelper.retryPeriodicallyWithLatch(it) { + otherAccounts.map { + aliceSession.getRoomMember(it.myUserId, e2eRoomID)?.membership + }.all { + it == Membership.JOIN + } + } + } + } + + private fun waitForAndAcceptInviteInRoom(otherSession: Session, e2eRoomID: String) { + testHelper.waitWithLatch { + testHelper.retryPeriodicallyWithLatch(it) { + val roomSummary = otherSession.getRoomSummary(e2eRoomID) + (roomSummary != null && roomSummary.membership == Membership.INVITE).also { + if (it) { + Log.v("#E2E TEST", "${otherSession.myUserId} can see the invite from alice") + } + } + } + } + + testHelper.runBlockingTest(60_000) { + Log.v("#E2E TEST", "${otherSession.myUserId} tries to join room $e2eRoomID") + try { + otherSession.joinRoom(e2eRoomID) + } catch (ex: JoinRoomFailure.JoinedWithTimeout) { + // it's ok we will wait after + } + } + + Log.v("#E2E TEST", "${otherSession.myUserId} waiting for join echo ...") + testHelper.waitWithLatch { + testHelper.retryPeriodicallyWithLatch(it) { + val roomSummary = otherSession.getRoomSummary(e2eRoomID) + roomSummary != null && roomSummary.membership == Membership.JOIN + } + } + } +} diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/MXOlmDevice.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/MXOlmDevice.kt index e1a706df79..bfe986d1fd 100755 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/MXOlmDevice.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/MXOlmDevice.kt @@ -47,6 +47,7 @@ internal class MXOlmDevice @Inject constructor( * The store where crypto data is saved. */ private val store: IMXCryptoStore, + private val olmSessionStore: OlmSessionStore, private val inboundGroupSessionStore: InboundGroupSessionStore ) { @@ -190,6 +191,7 @@ internal class MXOlmDevice @Inject constructor( it.groupSession.releaseSession() } outboundGroupSessionCache.clear() + olmSessionStore.clear() } /** @@ -257,7 +259,8 @@ internal class MXOlmDevice @Inject constructor( // this session olmSessionWrapper.onMessageReceived() - store.storeSession(olmSessionWrapper, theirIdentityKey) + olmSessionStore.storeSession(olmSessionWrapper, theirIdentityKey) +// store.storeSession(olmSessionWrapper, theirIdentityKey) val sessionIdentifier = olmSession.sessionIdentifier() @@ -324,7 +327,7 @@ internal class MXOlmDevice @Inject constructor( // This counts as a received message: set last received message time to now olmSessionWrapper.onMessageReceived() - store.storeSession(olmSessionWrapper, theirDeviceIdentityKey) + olmSessionStore.storeSession(olmSessionWrapper, theirDeviceIdentityKey) } catch (e: Exception) { Timber.e(e, "## createInboundSession() : decryptMessage failed") } @@ -357,8 +360,8 @@ internal class MXOlmDevice @Inject constructor( * @param theirDeviceIdentityKey the Curve25519 identity key for the remote device. * @return a list of known session ids for the device. */ - fun getSessionIds(theirDeviceIdentityKey: String): List? { - return store.getDeviceSessionIds(theirDeviceIdentityKey) + fun getSessionIds(theirDeviceIdentityKey: String): List { + return olmSessionStore.getDeviceSessionIds(theirDeviceIdentityKey) } /** @@ -368,7 +371,7 @@ internal class MXOlmDevice @Inject constructor( * @return the session id, or null if no established session. */ fun getSessionId(theirDeviceIdentityKey: String): String? { - return store.getLastUsedSessionId(theirDeviceIdentityKey) + return olmSessionStore.getLastUsedSessionId(theirDeviceIdentityKey) } /** @@ -390,7 +393,8 @@ internal class MXOlmDevice @Inject constructor( // Timber.v("## encryptMessage() : payloadString: " + payloadString); olmMessage = olmSessionWrapper.olmSession.encryptMessage(payloadString) - store.storeSession(olmSessionWrapper, theirDeviceIdentityKey) +// store.storeSession(olmSessionWrapper, theirDeviceIdentityKey) + olmSessionStore.storeSession(olmSessionWrapper, theirDeviceIdentityKey) res = HashMap() res["body"] = olmMessage.mCipherText @@ -427,7 +431,8 @@ internal class MXOlmDevice @Inject constructor( try { payloadString = olmSessionWrapper.olmSession.decryptMessage(olmMessage) olmSessionWrapper.onMessageReceived() - store.storeSession(olmSessionWrapper, theirDeviceIdentityKey) + olmSessionStore.storeSession(olmSessionWrapper, theirDeviceIdentityKey) +// store.storeSession(olmSessionWrapper, theirDeviceIdentityKey) } catch (e: Exception) { Timber.e(e, "## decryptMessage() : decryptMessage failed") } @@ -819,7 +824,8 @@ internal class MXOlmDevice @Inject constructor( private fun getSessionForDevice(theirDeviceIdentityKey: String, sessionId: String): OlmSessionWrapper? { // sanity check return if (theirDeviceIdentityKey.isEmpty() || sessionId.isEmpty()) null else { - store.getDeviceSession(sessionId, theirDeviceIdentityKey) + olmSessionStore.getDeviceSession(sessionId, theirDeviceIdentityKey) +// store.getDeviceSession(sessionId, theirDeviceIdentityKey) } } diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmSessionStore.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmSessionStore.kt new file mode 100644 index 0000000000..6044095e7b --- /dev/null +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmSessionStore.kt @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2022 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.matrix.android.sdk.internal.crypto + +import org.matrix.android.sdk.api.extensions.tryOrNull +import org.matrix.android.sdk.internal.crypto.model.OlmSessionWrapper +import org.matrix.android.sdk.internal.crypto.store.IMXCryptoStore +import javax.inject.Inject + +/** + * Keep the used olm session in memory and load them from the data layer when needed + * Access is synchronized for thread safety + */ +internal class OlmSessionStore @Inject constructor(private val store: IMXCryptoStore) { + + /* + * map of device key to list of olm sessions (it is possible to have several active sessions with a device) + */ + private val olmSessions = HashMap>() + + /** + * Store a session between the logged-in user and another device. + * This will be called after creation but also after any use of the ratchet + * in order to persist the correct state for next run + * @param olmSessionWrapper the end-to-end session. + * @param deviceKey the public key of the other device. + */ + @Synchronized + fun storeSession(olmSessionWrapper: OlmSessionWrapper, deviceKey: String) { + // This could be a newly created session or one that was just created + // Anyhow we should persist ratchet state for futur app lifecycle + addNewSessionInCache(olmSessionWrapper, deviceKey) + store.storeSession(olmSessionWrapper, deviceKey) + } + + /** + * Retrieve the end-to-end session ids between the logged-in user and another + * device. + * + * @param deviceKey the public key of the other device. + * @return A set of sessionId, or empty if device is not known + */ + @Synchronized + fun getDeviceSessionIds(deviceKey: String): List { + return internalGetAllSessions(deviceKey) + } + + private fun internalGetAllSessions(deviceKey: String): MutableList { + // we need to get the persisted ids first + val persistedKnownSessions = store.getDeviceSessionIds(deviceKey) + .orEmpty() + .toMutableList() + // Do we have some in cache not yet persisted? + olmSessions.getOrPut(deviceKey) { mutableListOf() }.forEach { cached -> + tryOrNull("Olm session was released") { cached.olmSession.sessionIdentifier() }?.let { cachedSessionId -> + if (!persistedKnownSessions.contains(cachedSessionId)) { + persistedKnownSessions.add(cachedSessionId) + } + } + } + return persistedKnownSessions + } + + /** + * Retrieve an end-to-end session between the logged-in user and another + * device. + * + * @param sessionId the session Id. + * @param deviceKey the public key of the other device. + * @return The Base64 end-to-end session, or null if not found + */ + @Synchronized + fun getDeviceSession(sessionId: String, deviceKey: String): OlmSessionWrapper? { + // get from cache or load and add to cache + return internalGetSession(sessionId, deviceKey) + } + + /** + * Retrieve the last used sessionId, regarding `lastReceivedMessageTs`, or null if no session exist + * + * @param deviceKey the public key of the other device. + * @return last used sessionId, or null if not found + */ + @Synchronized + fun getLastUsedSessionId(deviceKey: String): String? { + // We want to avoid to load in memory old session if possible + val lastPersistedUsedSession = store.getLastUsedSessionId(deviceKey) + var candidate = lastPersistedUsedSession?.let { internalGetSession(it, deviceKey) } + // we should check if we have one in cache with a higher last message received? + olmSessions[deviceKey].orEmpty().forEach { inCache -> + if (inCache.lastReceivedMessageTs > (candidate?.lastReceivedMessageTs ?: 0L)) { + candidate = inCache + } + } + + return candidate?.olmSession?.sessionIdentifier() + } + + /** + * Release all sessions and clear cache + */ + @Synchronized + fun clear() { + olmSessions.entries.onEach { entry -> + entry.value.onEach { it.olmSession.releaseSession() } + } + olmSessions.clear() + } + + private fun internalGetSession(sessionId: String, deviceKey: String): OlmSessionWrapper? { + return getSessionInCache(sessionId, deviceKey) + ?: // deserialize from store + return store.getDeviceSession(sessionId, deviceKey)?.also { + addNewSessionInCache(it, deviceKey) + } + } + + private fun getSessionInCache(sessionId: String, deviceKey: String): OlmSessionWrapper? { + return olmSessions[deviceKey]?.firstOrNull { + it.olmSession.isReleased && it.olmSession.sessionIdentifier() == sessionId + } + } + + private fun addNewSessionInCache(session: OlmSessionWrapper, deviceKey: String) { + val sessionId = tryOrNull { session.olmSession.sessionIdentifier() } ?: return + olmSessions.getOrPut(deviceKey) { mutableListOf() }.let { + val existing = it.firstOrNull { tryOrNull { it.olmSession.sessionIdentifier() } == sessionId } + it.add(session) + // remove and release if was there but with different instance + if (existing != null && existing.olmSession != session.olmSession) { + // mm not sure when this could happen + // anyhow we should remove and release the one known + it.remove(existing) + existing.olmSession.releaseSession() + } + } + } +} diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/olm/MXOlmDecryption.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/olm/MXOlmDecryption.kt index f1bca4fbc6..eb93abfb61 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/olm/MXOlmDecryption.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/olm/MXOlmDecryption.kt @@ -154,7 +154,7 @@ internal class MXOlmDecryption( * @return payload, if decrypted successfully. */ private fun decryptMessage(message: JsonDict, theirDeviceIdentityKey: String): String? { - val sessionIds = olmDevice.getSessionIds(theirDeviceIdentityKey).orEmpty() + val sessionIds = olmDevice.getSessionIds(theirDeviceIdentityKey) val messageBody = message["body"] as? String ?: return null val messageType = when (val typeAsVoid = message["type"]) { diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/store/db/RealmCryptoStore.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/store/db/RealmCryptoStore.kt index a07827c033..c13ed77eb7 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/store/db/RealmCryptoStore.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/store/db/RealmCryptoStore.kt @@ -125,7 +125,7 @@ internal class RealmCryptoStore @Inject constructor( private var olmAccount: OlmAccount? = null // Cache for OlmSession, to release them properly - private val olmSessionsToRelease = HashMap() +// private val olmSessionsToRelease = HashMap() // Cache for InboundGroupSession, to release them properly private val inboundGroupSessionToRelease = HashMap() @@ -213,11 +213,6 @@ internal class RealmCryptoStore @Inject constructor( monarchyWriteAsyncExecutor.awaitTermination(1, TimeUnit.MINUTES) } - olmSessionsToRelease.forEach { - it.value.olmSession.releaseSession() - } - olmSessionsToRelease.clear() - inboundGroupSessionToRelease.forEach { it.value.olmInboundGroupSession?.releaseSession() } @@ -680,13 +675,6 @@ internal class RealmCryptoStore @Inject constructor( if (sessionIdentifier != null) { val key = OlmSessionEntity.createPrimaryKey(sessionIdentifier, deviceKey) - // Release memory of previously known session, if it is not the same one - if (olmSessionsToRelease[key]?.olmSession != olmSessionWrapper.olmSession) { - olmSessionsToRelease[key]?.olmSession?.releaseSession() - } - - olmSessionsToRelease[key] = olmSessionWrapper - doRealmTransaction(realmConfiguration) { val realmOlmSession = OlmSessionEntity().apply { primaryKey = key @@ -703,23 +691,18 @@ internal class RealmCryptoStore @Inject constructor( override fun getDeviceSession(sessionId: String, deviceKey: String): OlmSessionWrapper? { val key = OlmSessionEntity.createPrimaryKey(sessionId, deviceKey) - - // If not in cache (or not found), try to read it from realm - if (olmSessionsToRelease[key] == null) { - doRealmQueryAndCopy(realmConfiguration) { - it.where() - .equalTo(OlmSessionEntityFields.PRIMARY_KEY, key) - .findFirst() - } - ?.let { - val olmSession = it.getOlmSession() - if (olmSession != null && it.sessionId != null) { - olmSessionsToRelease[key] = OlmSessionWrapper(olmSession, it.lastReceivedMessageTs) - } - } + return doRealmQueryAndCopy(realmConfiguration) { + it.where() + .equalTo(OlmSessionEntityFields.PRIMARY_KEY, key) + .findFirst() } - - return olmSessionsToRelease[key] + ?.let { + val olmSession = it.getOlmSession() + if (olmSession != null && it.sessionId != null) { + return@let OlmSessionWrapper(olmSession, it.lastReceivedMessageTs) + } + null + } } override fun getLastUsedSessionId(deviceKey: String): String? {