diff --git a/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/common/TestConstants.kt b/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/common/TestConstants.kt index 5c9b79361e..0f79896b2c 100644 --- a/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/common/TestConstants.kt +++ b/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/common/TestConstants.kt @@ -23,7 +23,7 @@ object TestConstants { const val TESTS_HOME_SERVER_URL = "http://10.0.2.2:8080" // Time out to use when waiting for server response. - private const val AWAIT_TIME_OUT_MILLIS = 30_000 + private const val AWAIT_TIME_OUT_MILLIS = 60_000 // Time out to use when waiting for server response, when the debugger is connected. 10 minutes private const val AWAIT_TIME_OUT_WITH_DEBUGGER_MILLIS = 10 * 60_000 diff --git a/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/E2eeSanityTests.kt b/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/E2eeSanityTests.kt new file mode 100644 index 0000000000..be76961b4c --- /dev/null +++ b/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/E2eeSanityTests.kt @@ -0,0 +1,648 @@ +/* + * Copyright 2022 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.matrix.android.sdk.internal.crypto + +import android.util.Log +import androidx.test.filters.LargeTest +import kotlinx.coroutines.delay +import org.amshove.kluent.fail +import org.amshove.kluent.internal.assertEquals +import org.junit.Assert +import org.junit.FixMethodOrder +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.junit.runners.MethodSorters +import org.matrix.android.sdk.InstrumentedTest +import org.matrix.android.sdk.api.session.Session +import org.matrix.android.sdk.api.session.crypto.MXCryptoError +import org.matrix.android.sdk.api.session.events.model.EventType +import org.matrix.android.sdk.api.session.events.model.toModel +import org.matrix.android.sdk.api.session.room.Room +import org.matrix.android.sdk.api.session.room.failure.JoinRoomFailure +import org.matrix.android.sdk.api.session.room.model.Membership +import org.matrix.android.sdk.api.session.room.model.message.MessageContent +import org.matrix.android.sdk.api.session.room.send.SendState +import org.matrix.android.sdk.api.session.room.timeline.TimelineSettings +import org.matrix.android.sdk.common.CommonTestHelper +import org.matrix.android.sdk.common.CryptoTestHelper +import org.matrix.android.sdk.common.SessionTestParams +import org.matrix.android.sdk.common.TestMatrixCallback +import org.matrix.android.sdk.internal.crypto.algorithms.olm.OlmDecryptionResult +import org.matrix.android.sdk.internal.crypto.keysbackup.model.MegolmBackupCreationInfo +import org.matrix.android.sdk.internal.crypto.keysbackup.model.rest.KeysVersion +import org.matrix.android.sdk.internal.crypto.keysbackup.model.rest.KeysVersionResult +import org.matrix.android.sdk.internal.crypto.model.ImportRoomKeysResult +import org.matrix.android.sdk.internal.crypto.model.event.EncryptedEventContent + +@RunWith(JUnit4::class) +@FixMethodOrder(MethodSorters.JVM) +@LargeTest +class E2eeSanityTests : InstrumentedTest { + + private val testHelper = CommonTestHelper(context()) + private val cryptoTestHelper = CryptoTestHelper(testHelper) + + /** + * Simple test that create an e2ee room. + * Some new members are added, and a message is sent. + * We check that the message is e2e and can be decrypted. + * + * Additional users join, we check that they can't decrypt history + * + * Alice sends a new message, then check that the new one can be decrypted + */ + @Test + fun testSendingE2EEMessages() { + val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true) + val aliceSession = cryptoTestData.firstSession + val e2eRoomID = cryptoTestData.roomId + + val aliceRoomPOV = aliceSession.getRoom(e2eRoomID)!! + + // add some more users and invite them + val otherAccounts = listOf("benoit", "valere", "ganfra") // , "adam", "manu") + .map { + testHelper.createAccount(it, SessionTestParams(true)) + } + + Log.v("#E2E TEST", "All accounts created") + // we want to invite them in the room + otherAccounts.forEach { + testHelper.runBlockingTest { + Log.v("#E2E TEST", "Alice invites ${it.myUserId}") + aliceRoomPOV.invite(it.myUserId) + } + } + + // All user should accept invite + otherAccounts.forEach { otherSession -> + waitForAndAcceptInviteInRoom(otherSession, e2eRoomID) + Log.v("#E2E TEST", "${otherSession.myUserId} joined room $e2eRoomID") + } + + // check that alice see them as joined (not really necessary?) + ensureMembersHaveJoined(aliceSession, otherAccounts, e2eRoomID) + + Log.v("#E2E TEST", "All users have joined the room") + + Log.v("#E2E TEST", "Alice is sending the message") + + val text = "This is my message" + val sentEventId: String? = sendMessageInRoom(aliceRoomPOV, text) + // val sentEvent = testHelper.sendTextMessage(aliceRoomPOV, "Hello all", 1).first() + Assert.assertTrue("Message should be sent", sentEventId != null) + + // All should be able to decrypt + otherAccounts.forEach { otherSession -> + testHelper.waitWithLatch { latch -> + testHelper.retryPeriodicallyWithLatch(latch) { + val timeLineEvent = otherSession.getRoom(e2eRoomID)?.getTimeLineEvent(sentEventId!!) + timeLineEvent != null && + timeLineEvent.isEncrypted() && + timeLineEvent.root.getClearType() == EventType.MESSAGE + } + } + } + + // Add a new user to the room, and check that he can't decrypt + val newAccount = listOf("adam") // , "adam", "manu") + .map { + testHelper.createAccount(it, SessionTestParams(true)) + } + + newAccount.forEach { + testHelper.runBlockingTest { + Log.v("#E2E TEST", "Alice invites ${it.myUserId}") + aliceRoomPOV.invite(it.myUserId) + } + } + + newAccount.forEach { + waitForAndAcceptInviteInRoom(it, e2eRoomID) + } + + ensureMembersHaveJoined(aliceSession, newAccount, e2eRoomID) + + // wait a bit + testHelper.runBlockingTest { + delay(3_000) + } + + // check that messages are encrypted (uisi) + newAccount.forEach { otherSession -> + testHelper.waitWithLatch { latch -> + testHelper.retryPeriodicallyWithLatch(latch) { + val timeLineEvent = otherSession.getRoom(e2eRoomID)?.getTimeLineEvent(sentEventId!!).also { + Log.v("#E2E TEST", "Event seen by new user ${it?.root?.getClearType()}|${it?.root?.mCryptoError}") + } + timeLineEvent != null && + timeLineEvent.root.getClearType() == EventType.ENCRYPTED && + timeLineEvent.root.mCryptoError == MXCryptoError.ErrorType.UNKNOWN_INBOUND_SESSION_ID + } + } + } + + // Let alice send a new message + Log.v("#E2E TEST", "Alice sends a new message") + + val secondMessage = "2 This is my message" + val secondSentEventId: String? = sendMessageInRoom(aliceRoomPOV, secondMessage) + + // new members should be able to decrypt it + newAccount.forEach { otherSession -> + testHelper.waitWithLatch { latch -> + testHelper.retryPeriodicallyWithLatch(latch) { + val timeLineEvent = otherSession.getRoom(e2eRoomID)?.getTimeLineEvent(secondSentEventId!!).also { + Log.v("#E2E TEST", "Second Event seen by new user ${it?.root?.getClearType()}|${it?.root?.mCryptoError}") + } + timeLineEvent != null && + timeLineEvent.root.getClearType() == EventType.MESSAGE && + secondMessage.equals(timeLineEvent.root.getClearContent().toModel()?.body) + } + } + } + + otherAccounts.forEach { + testHelper.signOutAndClose(it) + } + newAccount.forEach { testHelper.signOutAndClose(it) } + + cryptoTestData.cleanUp(testHelper) + } + + /** + * Quick test for basic keybackup + * 1. Create e2e between Alice and Bob + * 2. Alice sends 3 messages, using 3 different sessions + * 3. Ensure bob can decrypt + * 4. Create backup for bob and uplaod keys + * + * 5. Sign out alice and bob to ensure no gossiping will happen + * + * 6. Let bob sign in with a new session + * 7. Ensure history is UISI + * 8. Import backup + * 9. Check that new session can decrypt + */ + @Test + fun testBasicBackupImport() { + val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true) + val aliceSession = cryptoTestData.firstSession + val bobSession = cryptoTestData.secondSession!! + val e2eRoomID = cryptoTestData.roomId + + Log.v("#E2E TEST", "Create and start keybackup for bob ...") + val keysBackupService = bobSession.cryptoService().keysBackupService() + val keyBackupPassword = "FooBarBaz" + val megolmBackupCreationInfo = testHelper.doSync { + keysBackupService.prepareKeysBackupVersion(keyBackupPassword, null, it) + } + val version = testHelper.doSync { + keysBackupService.createKeysBackupVersion(megolmBackupCreationInfo, it) + } + Log.v("#E2E TEST", "... Key backup started and enabled for bob") + // Bob session should now have + + val aliceRoomPOV = aliceSession.getRoom(e2eRoomID)!! + + // let's send a few message to bob + val sentEventIds = mutableListOf() + val messagesText = listOf("1. Hello", "2. Bob", "3. Good morning") + messagesText.forEach { text -> + val sentEventId = sendMessageInRoom(aliceRoomPOV, text)!!.also { + sentEventIds.add(it) + } + + testHelper.waitWithLatch { latch -> + testHelper.retryPeriodicallyWithLatch(latch) { + val timeLineEvent = bobSession.getRoom(e2eRoomID)?.getTimeLineEvent(sentEventId) + timeLineEvent != null && + timeLineEvent.isEncrypted() && + timeLineEvent.root.getClearType() == EventType.MESSAGE + } + } + // we want more so let's discard the session + aliceSession.cryptoService().discardOutboundSession(e2eRoomID) + + testHelper.runBlockingTest { + delay(1_000) + } + } + Log.v("#E2E TEST", "Bob received all and can decrypt") + + // Let's wait a bit to be sure that bob has backed up the session + + Log.v("#E2E TEST", "Force key backup for Bob...") + testHelper.waitWithLatch { latch -> + keysBackupService.backupAllGroupSessions( + null, + TestMatrixCallback(latch, true) + ) + } + Log.v("#E2E TEST", "... Keybackup done for Bob") + + // Now lets logout both alice and bob to ensure that we won't have any gossiping + + val bobUserId = bobSession.myUserId + Log.v("#E2E TEST", "Logout alice and bob...") + testHelper.signOutAndClose(aliceSession) + testHelper.signOutAndClose(bobSession) + Log.v("#E2E TEST", "..Logout alice and bob...") + + testHelper.runBlockingTest { + delay(1_000) + } + + // Create a new session for bob + Log.v("#E2E TEST", "Create a new session for Bob") + val newBobSession = testHelper.logIntoAccount(bobUserId, SessionTestParams(true)) + + // check that bob can't currently decrypt + Log.v("#E2E TEST", "check that bob can't currently decrypt") + sentEventIds.forEach { sentEventId -> + testHelper.waitWithLatch { latch -> + testHelper.retryPeriodicallyWithLatch(latch) { + val timeLineEvent = newBobSession.getRoom(e2eRoomID)?.getTimeLineEvent(sentEventId)?.also { + Log.v("#E2E TEST", "Event seen by new user ${it.root.getClearType()}|${it.root.mCryptoError}") + } + timeLineEvent != null && + timeLineEvent.root.getClearType() == EventType.ENCRYPTED + } + } + } + // after initial sync events are not decrypted, so we have to try manually + ensureCannotDecrypt(sentEventIds, newBobSession, e2eRoomID, MXCryptoError.ErrorType.UNKNOWN_INBOUND_SESSION_ID) + + // Let's now import keys from backup + + newBobSession.cryptoService().keysBackupService().let { keysBackupService -> + val keyVersionResult = testHelper.doSync { + keysBackupService.getVersion(version.version, it) + } + + val importedResult = testHelper.doSync { + keysBackupService.restoreKeyBackupWithPassword(keyVersionResult!!, + keyBackupPassword, + null, + null, + null, it) + } + + assertEquals(3, importedResult.totalNumberOfKeys) + } + + // ensure bob can now decrypt + ensureCanDecrypt(sentEventIds, newBobSession, e2eRoomID, messagesText) + + testHelper.signOutAndClose(newBobSession) + } + + /** + * Check that a new verified session that was not supposed to get the keys initially will + * get them from an older one. + */ + @Test + fun testSimpleGossip() { + val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true) + val aliceSession = cryptoTestData.firstSession + val bobSession = cryptoTestData.secondSession!! + val e2eRoomID = cryptoTestData.roomId + + val aliceRoomPOV = aliceSession.getRoom(e2eRoomID)!! + + cryptoTestHelper.initializeCrossSigning(bobSession) + + // let's send a few message to bob + val sentEventIds = mutableListOf() + val messagesText = listOf("1. Hello", "2. Bob") + + Log.v("#E2E TEST", "Alice sends some messages") + messagesText.forEach { text -> + val sentEventId = sendMessageInRoom(aliceRoomPOV, text)!!.also { + sentEventIds.add(it) + } + + testHelper.waitWithLatch { latch -> + testHelper.retryPeriodicallyWithLatch(latch) { + val timeLineEvent = bobSession.getRoom(e2eRoomID)?.getTimeLineEvent(sentEventId) + timeLineEvent != null && + timeLineEvent.isEncrypted() && + timeLineEvent.root.getClearType() == EventType.MESSAGE + } + } + } + + // Ensure bob can decrypt + ensureIsDecrypted(sentEventIds, bobSession, e2eRoomID) + + // Let's now add a new bob session + // Create a new session for bob + Log.v("#E2E TEST", "Create a new session for Bob") + val newBobSession = testHelper.logIntoAccount(bobSession.myUserId, SessionTestParams(true)) + + // check that new bob can't currently decrypt + Log.v("#E2E TEST", "check that new bob can't currently decrypt") + + ensureCannotDecrypt(sentEventIds, newBobSession, e2eRoomID, MXCryptoError.ErrorType.UNKNOWN_INBOUND_SESSION_ID) + + // Try to request + sentEventIds.forEach { sentEventId -> + val event = newBobSession.getRoom(e2eRoomID)!!.getTimeLineEvent(sentEventId)!!.root + newBobSession.cryptoService().requestRoomKeyForEvent(event) + } + + // wait a bit + testHelper.runBlockingTest { + delay(10_000) + } + + // Ensure that new bob still can't decrypt (keys must have been withheld) + ensureCannotDecrypt(sentEventIds, newBobSession, e2eRoomID, MXCryptoError.ErrorType.KEYS_WITHHELD) + + // Now mark new bob session as verified + + bobSession.cryptoService().verificationService().markedLocallyAsManuallyVerified(newBobSession.myUserId, newBobSession.sessionParams.deviceId!!) + newBobSession.cryptoService().verificationService().markedLocallyAsManuallyVerified(bobSession.myUserId, bobSession.sessionParams.deviceId!!) + + // now let new session re-request + sentEventIds.forEach { sentEventId -> + val event = newBobSession.getRoom(e2eRoomID)!!.getTimeLineEvent(sentEventId)!!.root + newBobSession.cryptoService().reRequestRoomKeyForEvent(event) + } + + // wait a bit + testHelper.runBlockingTest { + delay(10_000) + } + + ensureCanDecrypt(sentEventIds, newBobSession, e2eRoomID, messagesText) + + cryptoTestData.cleanUp(testHelper) + testHelper.signOutAndClose(newBobSession) + } + + /** + * Test that if a better key is forwared (lower index, it is then used) + */ + @Test + fun testForwardBetterKey() { + val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true) + val aliceSession = cryptoTestData.firstSession + val bobSessionWithBetterKey = cryptoTestData.secondSession!! + val e2eRoomID = cryptoTestData.roomId + + val aliceRoomPOV = aliceSession.getRoom(e2eRoomID)!! + + cryptoTestHelper.initializeCrossSigning(bobSessionWithBetterKey) + + // let's send a few message to bob + var firstEventId: String + val firstMessage = "1. Hello" + + Log.v("#E2E TEST", "Alice sends some messages") + firstMessage.let { text -> + firstEventId = sendMessageInRoom(aliceRoomPOV, text)!! + + testHelper.waitWithLatch { latch -> + testHelper.retryPeriodicallyWithLatch(latch) { + val timeLineEvent = bobSessionWithBetterKey.getRoom(e2eRoomID)?.getTimeLineEvent(firstEventId) + timeLineEvent != null && + timeLineEvent.isEncrypted() && + timeLineEvent.root.getClearType() == EventType.MESSAGE + } + } + } + + // Ensure bob can decrypt + ensureIsDecrypted(listOf(firstEventId), bobSessionWithBetterKey, e2eRoomID) + + // Let's add a new unverified session from bob + val newBobSession = testHelper.logIntoAccount(bobSessionWithBetterKey.myUserId, SessionTestParams(true)) + + // check that new bob can't currently decrypt + Log.v("#E2E TEST", "check that new bob can't currently decrypt") + ensureCannotDecrypt(listOf(firstEventId), newBobSession, e2eRoomID, MXCryptoError.ErrorType.UNKNOWN_INBOUND_SESSION_ID) + + // Now let alice send a new message. this time the new bob session will be able to decrypt + var secondEventId: String + val secondMessage = "2. New Device?" + + Log.v("#E2E TEST", "Alice sends some messages") + secondMessage.let { text -> + secondEventId = sendMessageInRoom(aliceRoomPOV, text)!! + + testHelper.waitWithLatch { latch -> + testHelper.retryPeriodicallyWithLatch(latch) { + val timeLineEvent = newBobSession.getRoom(e2eRoomID)?.getTimeLineEvent(secondEventId) + timeLineEvent != null && + timeLineEvent.isEncrypted() && + timeLineEvent.root.getClearType() == EventType.MESSAGE + } + } + } + + // check that both messages have same sessionId (it's just that we don't have index 0) + val firstEventNewBobPov = newBobSession.getRoom(e2eRoomID)?.getTimeLineEvent(firstEventId) + val secondEventNewBobPov = newBobSession.getRoom(e2eRoomID)?.getTimeLineEvent(secondEventId) + + val firstSessionId = firstEventNewBobPov!!.root.content.toModel()!!.sessionId!! + val secondSessionId = secondEventNewBobPov!!.root.content.toModel()!!.sessionId!! + + Assert.assertTrue("Should be the same session id", firstSessionId == secondSessionId) + + // Confirm we can decrypt one but not the other + testHelper.runBlockingTest { + try { + newBobSession.cryptoService().decryptEvent(firstEventNewBobPov.root, "") + fail("Should not be able to decrypt event") + } catch (error: MXCryptoError) { + val errorType = (error as? MXCryptoError.Base)?.errorType + assertEquals(MXCryptoError.ErrorType.UNKNOWN_MESSAGE_INDEX, errorType) + } + } + + testHelper.runBlockingTest { + try { + newBobSession.cryptoService().decryptEvent(secondEventNewBobPov.root, "") + } catch (error: MXCryptoError) { + fail("Should be able to decrypt event") + } + } + + // Now let's verify bobs session, and re-request keys + bobSessionWithBetterKey.cryptoService() + .verificationService() + .markedLocallyAsManuallyVerified(newBobSession.myUserId, newBobSession.sessionParams.deviceId!!) + + newBobSession.cryptoService() + .verificationService() + .markedLocallyAsManuallyVerified(bobSessionWithBetterKey.myUserId, bobSessionWithBetterKey.sessionParams.deviceId!!) + + // now let new session request + newBobSession.cryptoService().requestRoomKeyForEvent(firstEventNewBobPov.root) + + // wait a bit + testHelper.runBlockingTest { + delay(10_000) + } + + // old session should have shared the key at earliest known index now + // we should be able to decrypt both + testHelper.runBlockingTest { + try { + newBobSession.cryptoService().decryptEvent(firstEventNewBobPov.root, "") + } catch (error: MXCryptoError) { + fail("Should be able to decrypt first event now $error") + } + } + testHelper.runBlockingTest { + try { + newBobSession.cryptoService().decryptEvent(secondEventNewBobPov.root, "") + } catch (error: MXCryptoError) { + fail("Should be able to decrypt event $error") + } + } + + cryptoTestData.cleanUp(testHelper) + testHelper.signOutAndClose(newBobSession) + } + + private fun sendMessageInRoom(aliceRoomPOV: Room, text: String): String? { + aliceRoomPOV.sendTextMessage(text) + var sentEventId: String? = null + testHelper.waitWithLatch(4 * 60_000) { + val timeline = aliceRoomPOV.createTimeline(null, TimelineSettings(60)) + timeline.start() + + testHelper.retryPeriodicallyWithLatch(it) { + val decryptedMsg = timeline.getSnapshot() + .filter { it.root.getClearType() == EventType.MESSAGE } + .also { + Log.v("#E2E TEST", "Timeline snapshot is ${it.map { "${it.root.type}|${it.root.sendState}" }.joinToString(",", "[", "]")}") + } + .filter { it.root.sendState == SendState.SYNCED } + .firstOrNull { it.root.getClearContent().toModel()?.body?.startsWith(text) == true } + sentEventId = decryptedMsg?.eventId + decryptedMsg != null + } + + timeline.dispose() + } + return sentEventId + } + + private fun ensureMembersHaveJoined(aliceSession: Session, otherAccounts: List, e2eRoomID: String) { + testHelper.waitWithLatch { + testHelper.retryPeriodicallyWithLatch(it) { + otherAccounts.map { + aliceSession.getRoomMember(it.myUserId, e2eRoomID)?.membership + }.all { + it == Membership.JOIN + } + } + } + } + + private fun waitForAndAcceptInviteInRoom(otherSession: Session, e2eRoomID: String) { + testHelper.waitWithLatch { + testHelper.retryPeriodicallyWithLatch(it) { + val roomSummary = otherSession.getRoomSummary(e2eRoomID) + (roomSummary != null && roomSummary.membership == Membership.INVITE).also { + if (it) { + Log.v("#E2E TEST", "${otherSession.myUserId} can see the invite from alice") + } + } + } + } + + testHelper.runBlockingTest(60_000) { + Log.v("#E2E TEST", "${otherSession.myUserId} tries to join room $e2eRoomID") + try { + otherSession.joinRoom(e2eRoomID) + } catch (ex: JoinRoomFailure.JoinedWithTimeout) { + // it's ok we will wait after + } + } + + Log.v("#E2E TEST", "${otherSession.myUserId} waiting for join echo ...") + testHelper.waitWithLatch { + testHelper.retryPeriodicallyWithLatch(it) { + val roomSummary = otherSession.getRoomSummary(e2eRoomID) + roomSummary != null && roomSummary.membership == Membership.JOIN + } + } + } + + private fun ensureCanDecrypt(sentEventIds: MutableList, session: Session, e2eRoomID: String, messagesText: List) { + sentEventIds.forEachIndexed { index, sentEventId -> + testHelper.waitWithLatch { latch -> + testHelper.retryPeriodicallyWithLatch(latch) { + val event = session.getRoom(e2eRoomID)!!.getTimeLineEvent(sentEventId)!!.root + testHelper.runBlockingTest { + try { + session.cryptoService().decryptEvent(event, "").let { result -> + event.mxDecryptionResult = OlmDecryptionResult( + payload = result.clearEvent, + senderKey = result.senderCurve25519Key, + keysClaimed = result.claimedEd25519Key?.let { mapOf("ed25519" to it) }, + forwardingCurve25519KeyChain = result.forwardingCurve25519KeyChain + ) + } + } catch (error: MXCryptoError) { + // nop + } + } + event.getClearType() == EventType.MESSAGE && + messagesText[index] == event.getClearContent()?.toModel()?.body + } + } + } + } + + private fun ensureIsDecrypted(sentEventIds: List, session: Session, e2eRoomID: String) { + testHelper.waitWithLatch { latch -> + sentEventIds.forEach { sentEventId -> + testHelper.retryPeriodicallyWithLatch(latch) { + val timeLineEvent = session.getRoom(e2eRoomID)?.getTimeLineEvent(sentEventId) + timeLineEvent != null && + timeLineEvent.isEncrypted() && + timeLineEvent.root.getClearType() == EventType.MESSAGE + } + } + } + } + + private fun ensureCannotDecrypt(sentEventIds: List, newBobSession: Session, e2eRoomID: String, expectedError: MXCryptoError.ErrorType?) { + sentEventIds.forEach { sentEventId -> + val event = newBobSession.getRoom(e2eRoomID)!!.getTimeLineEvent(sentEventId)!!.root + testHelper.runBlockingTest { + try { + newBobSession.cryptoService().decryptEvent(event, "") + fail("Should not be able to decrypt event") + } catch (error: MXCryptoError) { + val errorType = (error as? MXCryptoError.Base)?.errorType + if (expectedError == null) { + Assert.assertNotNull(errorType) + } else { + assertEquals(expectedError, errorType, "Message expected to be UISI") + } + } + } + } + } +} diff --git a/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/PreShareKeysTest.kt b/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/PreShareKeysTest.kt index a7a81bacf5..46c1dacf78 100644 --- a/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/PreShareKeysTest.kt +++ b/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/PreShareKeysTest.kt @@ -21,7 +21,6 @@ import androidx.test.ext.junit.runners.AndroidJUnit4 import org.junit.Assert.assertEquals import org.junit.Assert.assertNotNull import org.junit.FixMethodOrder -import org.junit.Ignore import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.MethodSorters @@ -41,7 +40,6 @@ class PreShareKeysTest : InstrumentedTest { private val cryptoTestHelper = CryptoTestHelper(testHelper) @Test - @Ignore("This test will be ignored until it is fixed") fun ensure_outbound_session_happy_path() { val testData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true) val e2eRoomID = testData.roomId @@ -92,7 +90,7 @@ class PreShareKeysTest : InstrumentedTest { // Just send a real message as test val sentEvent = testHelper.sendTextMessage(aliceSession.getRoom(e2eRoomID)!!, "Allo", 1).first() - assertEquals(megolmSessionId, sentEvent.root.content.toModel()?.sessionId, "Unexpected megolm session") + assertEquals("Unexpected megolm session", megolmSessionId, sentEvent.root.content.toModel()?.sessionId,) testHelper.waitWithLatch { latch -> testHelper.retryPeriodicallyWithLatch(latch) { bobSession.getRoom(e2eRoomID)?.getTimelineEvent(sentEvent.eventId)?.root?.getClearType() == EventType.MESSAGE diff --git a/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/UnwedgingTest.kt b/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/UnwedgingTest.kt index 0a8ce67680..fb5d58b127 100644 --- a/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/UnwedgingTest.kt +++ b/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/UnwedgingTest.kt @@ -21,7 +21,6 @@ import org.amshove.kluent.shouldBe import org.junit.Assert import org.junit.Before import org.junit.FixMethodOrder -import org.junit.Ignore import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.MethodSorters @@ -85,7 +84,6 @@ class UnwedgingTest : InstrumentedTest { * -> This is automatically fixed after SDKs restarted the olm session */ @Test - @Ignore("This test will be ignored until it is fixed") fun testUnwedging() { val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom() @@ -94,9 +92,7 @@ class UnwedgingTest : InstrumentedTest { val bobSession = cryptoTestData.secondSession!! val aliceCryptoStore = (aliceSession.cryptoService() as DefaultCryptoService).cryptoStoreForTesting - - // bobSession.cryptoService().setWarnOnUnknownDevices(false) - // aliceSession.cryptoService().setWarnOnUnknownDevices(false) + val olmDevice = (aliceSession.cryptoService() as DefaultCryptoService).olmDeviceForTest val roomFromBobPOV = bobSession.getRoom(aliceRoomId)!! val roomFromAlicePOV = aliceSession.getRoom(aliceRoomId)!! @@ -175,6 +171,7 @@ class UnwedgingTest : InstrumentedTest { Timber.i("## CRYPTO | testUnwedging: wedge the session now. Set crypto state like after the first message") aliceCryptoStore.storeSession(OlmSessionWrapper(deserializeFromRealm(oldSession)!!), bobSession.cryptoService().getMyDevice().identityKey()!!) + olmDevice.clearOlmSessionCache() Thread.sleep(6_000) // Force new session, and key share @@ -227,8 +224,10 @@ class UnwedgingTest : InstrumentedTest { testHelper.waitWithLatch { testHelper.retryPeriodicallyWithLatch(it) { // we should get back the key and be able to decrypt - val result = tryOrNull { - bobSession.cryptoService().decryptEvent(messagesReceivedByBob[0].root, "") + val result = testHelper.runBlockingTest { + tryOrNull { + bobSession.cryptoService().decryptEvent(messagesReceivedByBob[0].root, "") + } } Timber.i("## CRYPTO | testUnwedging: decrypt result ${result?.clearEvent}") result != null diff --git a/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/gossiping/KeyShareTests.kt b/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/gossiping/KeyShareTests.kt index 82aee454eb..cd20ab477c 100644 --- a/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/gossiping/KeyShareTests.kt +++ b/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/gossiping/KeyShareTests.kt @@ -97,7 +97,9 @@ class KeyShareTests : InstrumentedTest { assert(receivedEvent!!.isEncrypted()) try { - aliceSession2.cryptoService().decryptEvent(receivedEvent.root, "foo") + commonTestHelper.runBlockingTest { + aliceSession2.cryptoService().decryptEvent(receivedEvent.root, "foo") + } fail("should fail") } catch (failure: Throwable) { } @@ -152,7 +154,9 @@ class KeyShareTests : InstrumentedTest { } try { - aliceSession2.cryptoService().decryptEvent(receivedEvent.root, "foo") + commonTestHelper.runBlockingTest { + aliceSession2.cryptoService().decryptEvent(receivedEvent.root, "foo") + } fail("should fail") } catch (failure: Throwable) { } @@ -189,7 +193,9 @@ class KeyShareTests : InstrumentedTest { } try { - aliceSession2.cryptoService().decryptEvent(receivedEvent.root, "foo") + commonTestHelper.runBlockingTest { + aliceSession2.cryptoService().decryptEvent(receivedEvent.root, "foo") + } } catch (failure: Throwable) { fail("should have been able to decrypt") } @@ -384,7 +390,11 @@ class KeyShareTests : InstrumentedTest { val roomRoomBobPov = aliceSession.getRoom(roomId) val beforeJoin = roomRoomBobPov!!.getTimelineEvent(secondEventId) - var dRes = tryOrNull { bobSession.cryptoService().decryptEvent(beforeJoin!!.root, "") } + var dRes = tryOrNull { + commonTestHelper.runBlockingTest { + bobSession.cryptoService().decryptEvent(beforeJoin!!.root, "") + } + } assert(dRes == null) @@ -395,7 +405,11 @@ class KeyShareTests : InstrumentedTest { Thread.sleep(3_000) // With the bug the first session would have improperly reshare that key :/ - dRes = tryOrNull { bobSession.cryptoService().decryptEvent(beforeJoin.root, "") } + dRes = tryOrNull { + commonTestHelper.runBlockingTest { + bobSession.cryptoService().decryptEvent(beforeJoin.root, "") + } + } Log.d("#TEST", "KS: sgould not decrypt that ${beforeJoin.root.getClearContent().toModel()?.body}") assert(dRes?.clearEvent == null) } diff --git a/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/gossiping/WithHeldTests.kt b/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/gossiping/WithHeldTests.kt index 9fda21763a..65c65660b5 100644 --- a/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/gossiping/WithHeldTests.kt +++ b/matrix-sdk-android/src/androidTest/java/org/matrix/android/sdk/internal/crypto/gossiping/WithHeldTests.kt @@ -93,7 +93,9 @@ class WithHeldTests : InstrumentedTest { // Bob should not be able to decrypt because the keys is withheld try { // .. might need to wait a bit for stability? - bobUnverifiedSession.cryptoService().decryptEvent(eventBobPOV.root, "") + testHelper.runBlockingTest { + bobUnverifiedSession.cryptoService().decryptEvent(eventBobPOV.root, "") + } Assert.fail("This session should not be able to decrypt") } catch (failure: Throwable) { val type = (failure as MXCryptoError.Base).errorType @@ -118,7 +120,9 @@ class WithHeldTests : InstrumentedTest { // Previous message should still be undecryptable (partially withheld session) try { // .. might need to wait a bit for stability? - bobUnverifiedSession.cryptoService().decryptEvent(eventBobPOV.root, "") + testHelper.runBlockingTest { + bobUnverifiedSession.cryptoService().decryptEvent(eventBobPOV.root, "") + } Assert.fail("This session should not be able to decrypt") } catch (failure: Throwable) { val type = (failure as MXCryptoError.Base).errorType @@ -165,7 +169,9 @@ class WithHeldTests : InstrumentedTest { val eventBobPOV = bobSession.getRoom(testData.roomId)?.getTimelineEvent(eventId) try { // .. might need to wait a bit for stability? - bobSession.cryptoService().decryptEvent(eventBobPOV!!.root, "") + testHelper.runBlockingTest { + bobSession.cryptoService().decryptEvent(eventBobPOV!!.root, "") + } Assert.fail("This session should not be able to decrypt") } catch (failure: Throwable) { val type = (failure as MXCryptoError.Base).errorType @@ -233,7 +239,11 @@ class WithHeldTests : InstrumentedTest { testHelper.retryPeriodicallyWithLatch(latch) { val timeLineEvent = bobSecondSession.getRoom(testData.roomId)?.getTimelineEvent(eventId)?.also { // try to decrypt and force key request - tryOrNull { bobSecondSession.cryptoService().decryptEvent(it.root, "") } + tryOrNull { + testHelper.runBlockingTest { + bobSecondSession.cryptoService().decryptEvent(it.root, "") + } + } } sessionId = timeLineEvent?.root?.content?.toModel()?.sessionId timeLineEvent != null diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/session/crypto/CryptoService.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/session/crypto/CryptoService.kt index e3f00a24b6..65f69e17c9 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/session/crypto/CryptoService.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/session/crypto/CryptoService.kt @@ -121,7 +121,7 @@ interface CryptoService { fun discardOutboundSession(roomId: String) @Throws(MXCryptoError::class) - fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult + suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult fun decryptEventAsync(event: Event, timeline: String, callback: MatrixCallback) diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/DefaultCryptoService.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/DefaultCryptoService.kt index 0646e4d2b8..db44abc36f 100755 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/DefaultCryptoService.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/DefaultCryptoService.kt @@ -434,6 +434,14 @@ internal class DefaultCryptoService @Inject constructor( val currentCount = syncResponse.deviceOneTimeKeysCount.signedCurve25519 ?: 0 oneTimeKeysUploader.updateOneTimeKeyCount(currentCount) } + + // unwedge if needed + try { + eventDecryptor.unwedgeDevicesIfNeeded() + } catch (failure: Throwable) { + Timber.tag(loggerTag.value).w("unwedgeDevicesIfNeeded failed") + } + // There is a limit of to_device events returned per sync. // If we are in a case of such limited to_device sync we can't try to generate/upload // new otk now, because there might be some pending olm pre-key to_device messages that would fail if we rotate @@ -723,7 +731,7 @@ internal class DefaultCryptoService @Inject constructor( * @return the MXEventDecryptionResult data, or throw in case of error */ @Throws(MXCryptoError::class) - override fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult { + override suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult { return internalDecryptEvent(event, timeline) } @@ -746,7 +754,7 @@ internal class DefaultCryptoService @Inject constructor( * @return the MXEventDecryptionResult data, or null in case of error */ @Throws(MXCryptoError::class) - private fun internalDecryptEvent(event: Event, timeline: String): MXEventDecryptionResult { + private suspend fun internalDecryptEvent(event: Event, timeline: String): MXEventDecryptionResult { return eventDecryptor.decryptEvent(event, timeline) } @@ -1364,6 +1372,9 @@ internal class DefaultCryptoService @Inject constructor( @VisibleForTesting val cryptoStoreForTesting = cryptoStore + @VisibleForTesting + val olmDeviceForTest = olmDevice + companion object { const val CRYPTO_MIN_FORCE_SESSION_PERIOD_MILLIS = 3_600_000 // one hour } diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/EventDecryptor.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/EventDecryptor.kt index 57381eacfb..00efd3d6a8 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/EventDecryptor.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/EventDecryptor.kt @@ -21,14 +21,13 @@ import kotlinx.coroutines.launch import kotlinx.coroutines.withContext import org.matrix.android.sdk.api.MatrixCallback import org.matrix.android.sdk.api.MatrixCoroutineDispatchers +import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.api.session.crypto.MXCryptoError import org.matrix.android.sdk.api.session.events.model.Event import org.matrix.android.sdk.api.session.events.model.EventType import org.matrix.android.sdk.api.session.events.model.toModel import org.matrix.android.sdk.internal.crypto.actions.EnsureOlmSessionsForDevicesAction import org.matrix.android.sdk.internal.crypto.actions.MessageEncrypter -import org.matrix.android.sdk.internal.crypto.model.CryptoDeviceInfo -import org.matrix.android.sdk.internal.crypto.model.MXOlmSessionResult import org.matrix.android.sdk.internal.crypto.model.MXUsersDevicesMap import org.matrix.android.sdk.internal.crypto.model.event.OlmEventContent import org.matrix.android.sdk.internal.crypto.store.IMXCryptoStore @@ -40,6 +39,8 @@ import javax.inject.Inject private const val SEND_TO_DEVICE_RETRY_COUNT = 3 +private val loggerTag = LoggerTag("CryptoSyncHandler", LoggerTag.CRYPTO) + @SessionScope internal class EventDecryptor @Inject constructor( private val cryptoCoroutineScope: CoroutineScope, @@ -47,13 +48,22 @@ internal class EventDecryptor @Inject constructor( private val roomDecryptorProvider: RoomDecryptorProvider, private val messageEncrypter: MessageEncrypter, private val sendToDeviceTask: SendToDeviceTask, + private val deviceListManager: DeviceListManager, private val ensureOlmSessionsForDevicesAction: EnsureOlmSessionsForDevicesAction, private val cryptoStore: IMXCryptoStore ) { - // The date of the last time we forced establishment - // of a new session for each user:device. - private val lastNewSessionForcedDates = MXUsersDevicesMap() + /** + * Rate limit unwedge attempt, should we persist that? + */ + private val lastNewSessionForcedDates = mutableMapOf() + + data class WedgedDeviceInfo( + val userId: String, + val senderKey: String? + ) + + private val wedgedDevices = mutableListOf() /** * Decrypt an event @@ -63,7 +73,7 @@ internal class EventDecryptor @Inject constructor( * @return the MXEventDecryptionResult data, or throw in case of error */ @Throws(MXCryptoError::class) - fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult { + suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult { return internalDecryptEvent(event, timeline) } @@ -91,38 +101,32 @@ internal class EventDecryptor @Inject constructor( * @return the MXEventDecryptionResult data, or null in case of error */ @Throws(MXCryptoError::class) - private fun internalDecryptEvent(event: Event, timeline: String): MXEventDecryptionResult { + private suspend fun internalDecryptEvent(event: Event, timeline: String): MXEventDecryptionResult { val eventContent = event.content if (eventContent == null) { - Timber.e("## CRYPTO | decryptEvent : empty event content") + Timber.tag(loggerTag.value).e("decryptEvent : empty event content") throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_ENCRYPTED_MESSAGE, MXCryptoError.BAD_ENCRYPTED_MESSAGE_REASON) } else { val algorithm = eventContent["algorithm"]?.toString() val alg = roomDecryptorProvider.getOrCreateRoomDecryptor(event.roomId, algorithm) if (alg == null) { val reason = String.format(MXCryptoError.UNABLE_TO_DECRYPT_REASON, event.eventId, algorithm) - Timber.e("## CRYPTO | decryptEvent() : $reason") + Timber.tag(loggerTag.value).e("decryptEvent() : $reason") throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, reason) } else { try { return alg.decryptEvent(event, timeline) } catch (mxCryptoError: MXCryptoError) { - Timber.v("## CRYPTO | internalDecryptEvent : Failed to decrypt ${event.eventId} reason: $mxCryptoError") + Timber.tag(loggerTag.value).d("internalDecryptEvent : Failed to decrypt ${event.eventId} reason: $mxCryptoError") if (algorithm == MXCRYPTO_ALGORITHM_OLM) { if (mxCryptoError is MXCryptoError.Base && mxCryptoError.errorType == MXCryptoError.ErrorType.BAD_ENCRYPTED_MESSAGE) { // need to find sending device - cryptoCoroutineScope.launch(coroutineDispatchers.crypto) { - val olmContent = event.content.toModel() - cryptoStore.getUserDevices(event.senderId ?: "") - ?.values - ?.firstOrNull { it.identityKey() == olmContent?.senderKey } - ?.let { - markOlmSessionForUnwedging(event.senderId ?: "", it) - } - ?: run { - Timber.i("## CRYPTO | internalDecryptEvent() : Failed to find sender crypto device for unwedging") - } + val olmContent = event.content.toModel() + if (event.senderId != null && olmContent?.senderKey != null) { + markOlmSessionForUnwedging(event.senderId, olmContent.senderKey) + } else { + Timber.tag(loggerTag.value).d("Can't mark as wedge malformed") } } } @@ -132,53 +136,91 @@ internal class EventDecryptor @Inject constructor( } } - // coroutineDispatchers.crypto scope - private fun markOlmSessionForUnwedging(senderId: String, deviceInfo: CryptoDeviceInfo) { - val deviceKey = deviceInfo.identityKey() + private fun markOlmSessionForUnwedging(senderId: String, senderKey: String) { + val info = WedgedDeviceInfo(senderId, senderKey) + if (!wedgedDevices.contains(info)) { + Timber.tag(loggerTag.value).d("Marking device from $senderId key:$senderKey as wedged") + wedgedDevices.add(info) + } + } - val lastForcedDate = lastNewSessionForcedDates.getObject(senderId, deviceKey) ?: 0 + // coroutineDispatchers.crypto scope + suspend fun unwedgeDevicesIfNeeded() { + // handle wedged devices + // Some olm decryption have failed and some device are wedged + // we should force start a new session for those + Timber.tag(loggerTag.value).v("Unwedging: ${wedgedDevices.size} are wedged") + // get the one that should be retried according to rate limit val now = System.currentTimeMillis() - if (now - lastForcedDate < DefaultCryptoService.CRYPTO_MIN_FORCE_SESSION_PERIOD_MILLIS) { - Timber.w("## CRYPTO | markOlmSessionForUnwedging: New session already forced with device at $lastForcedDate. Not forcing another") + val toUnwedge = wedgedDevices.filter { + val lastForcedDate = lastNewSessionForcedDates[it] ?: 0 + if (now - lastForcedDate < DefaultCryptoService.CRYPTO_MIN_FORCE_SESSION_PERIOD_MILLIS) { + Timber.tag(loggerTag.value).d("Unwedging, New session for $it already forced with device at $lastForcedDate") + return@filter false + } + // let's already mark that we tried now + lastNewSessionForcedDates[it] = now + true + } + + if (toUnwedge.isEmpty()) { + Timber.tag(loggerTag.value).v("Nothing to unwedge") return } + Timber.tag(loggerTag.value).d("Unwedging, trying to create new session for ${toUnwedge.size} devices") - Timber.i("## CRYPTO | markOlmSessionForUnwedging from $senderId:${deviceInfo.deviceId}") - lastNewSessionForcedDates.setObject(senderId, deviceKey, now) - - // offload this from crypto thread (?) - cryptoCoroutineScope.launch(coroutineDispatchers.computation) { - runCatching { ensureOlmSessionsForDevicesAction.handle(mapOf(senderId to listOf(deviceInfo)), force = true) }.fold( - onSuccess = { sendDummyToDevice(ensured = it, deviceInfo, senderId) }, - onFailure = { - Timber.e("## CRYPTO | markOlmSessionForUnwedging() : failed to ensure device info ${senderId}${deviceInfo.deviceId}") + toUnwedge + .chunked(100) // safer to chunk if we ever have lots of wedged devices + .forEach { wedgedList -> + val groupedByUserId = wedgedList.groupBy { it.userId } + // lets download keys if needed + withContext(coroutineDispatchers.io) { + deviceListManager.downloadKeys(groupedByUserId.keys.toList(), false) } - ) - } - } - private suspend fun sendDummyToDevice(ensured: MXUsersDevicesMap, deviceInfo: CryptoDeviceInfo, senderId: String) { - Timber.i("## CRYPTO | markOlmSessionForUnwedging() : ensureOlmSessionsForDevicesAction isEmpty:${ensured.isEmpty}") + // find the matching devices + groupedByUserId + .map { groupedByUser -> + val userId = groupedByUser.key + val wedgeSenderKeysForUser = groupedByUser.value.map { it.senderKey } + val knownDevices = cryptoStore.getUserDevices(userId)?.values.orEmpty() + userId to wedgeSenderKeysForUser.mapNotNull { senderKey -> + knownDevices.firstOrNull { it.identityKey() == senderKey } + } + } + .toMap() + .let { deviceList -> + try { + // force creating new outbound session and mark them as most recent to + // be used for next encryption (dummy) + val sessionToUse = ensureOlmSessionsForDevicesAction.handle(deviceList, true) + Timber.tag(loggerTag.value).d("Unwedging, found ${sessionToUse.map.size} to send dummy to") - // Now send a blank message on that session so the other side knows about it. - // (The keyshare request is sent in the clear so that won't do) - // We send this first such that, as long as the toDevice messages arrive in the - // same order we sent them, the other end will get this first, set up the new session, - // then get the keyshare request and send the key over this new session (because it - // is the session it has most recently received a message on). - val payloadJson = mapOf("type" to EventType.DUMMY) + // Now send a dummy message on that session so the other side knows about it. + val payloadJson = mapOf( + "type" to EventType.DUMMY + ) + val sendToDeviceMap = MXUsersDevicesMap() + sessionToUse.map.values + .flatMap { it.values } + .map { it.deviceInfo } + .forEach { deviceInfo -> + Timber.tag(loggerTag.value).v("encrypting dummy to ${deviceInfo.deviceId}") + val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo)) + sendToDeviceMap.setObject(deviceInfo.userId, deviceInfo.deviceId, encodedPayload) + } - val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo)) - val sendToDeviceMap = MXUsersDevicesMap() - sendToDeviceMap.setObject(senderId, deviceInfo.deviceId, encodedPayload) - Timber.i("## CRYPTO | markOlmSessionForUnwedging() : sending dummy to $senderId:${deviceInfo.deviceId}") - withContext(coroutineDispatchers.io) { - val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, sendToDeviceMap) - try { - sendToDeviceTask.executeRetry(sendToDeviceParams, remainingRetry = SEND_TO_DEVICE_RETRY_COUNT) - } catch (failure: Throwable) { - Timber.e(failure, "## CRYPTO | markOlmSessionForUnwedging() : failed to send dummy to $senderId:${deviceInfo.deviceId}") - } - } + // now let's send that + val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, sendToDeviceMap) + withContext(coroutineDispatchers.io) { + sendToDeviceTask.executeRetry(sendToDeviceParams, remainingRetry = SEND_TO_DEVICE_RETRY_COUNT) + } + } catch (failure: Throwable) { + deviceList.flatMap { it.value }.joinToString { it.shortDebugString() }.let { + Timber.tag(loggerTag.value).e(failure, "## Failed to unwedge devices: $it}") + } + } + } + } } } diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/InboundGroupSessionStore.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/InboundGroupSessionStore.kt index e7a46750b0..34bef61c98 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/InboundGroupSessionStore.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/InboundGroupSessionStore.kt @@ -19,8 +19,10 @@ package org.matrix.android.sdk.internal.crypto import android.util.LruCache import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.launch +import kotlinx.coroutines.sync.Mutex import org.matrix.android.sdk.api.MatrixCoroutineDispatchers import org.matrix.android.sdk.api.extensions.tryOrNull +import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.internal.crypto.model.OlmInboundGroupSessionWrapper2 import org.matrix.android.sdk.internal.crypto.store.IMXCryptoStore import timber.log.Timber @@ -28,6 +30,13 @@ import java.util.Timer import java.util.TimerTask import javax.inject.Inject +data class InboundGroupSessionHolder( + val wrapper: OlmInboundGroupSessionWrapper2, + val mutex: Mutex = Mutex() +) + +private val loggerTag = LoggerTag("InboundGroupSessionStore", LoggerTag.CRYPTO) + /** * Allows to cache and batch store operations on inbound group session store. * Because it is used in the decrypt flow, that can be called quite rapidly @@ -42,12 +51,13 @@ internal class InboundGroupSessionStore @Inject constructor( val senderKey: String ) - private val sessionCache = object : LruCache(30) { - override fun entryRemoved(evicted: Boolean, key: CacheKey?, oldValue: OlmInboundGroupSessionWrapper2?, newValue: OlmInboundGroupSessionWrapper2?) { - if (evicted && oldValue != null) { + private val sessionCache = object : LruCache(100) { + override fun entryRemoved(evicted: Boolean, key: CacheKey?, oldValue: InboundGroupSessionHolder?, newValue: InboundGroupSessionHolder?) { + if (oldValue != null) { cryptoCoroutineScope.launch(coroutineDispatchers.crypto) { - Timber.v("## Inbound: entryRemoved ${oldValue.roomId}-${oldValue.senderKey}") - store.storeInboundGroupSessions(listOf(oldValue)) + Timber.tag(loggerTag.value).v("## Inbound: entryRemoved ${oldValue.wrapper.roomId}-${oldValue.wrapper.senderKey}") + store.storeInboundGroupSessions(listOf(oldValue).map { it.wrapper }) + oldValue.wrapper.olmInboundGroupSession?.releaseSession() } } } @@ -59,27 +69,50 @@ internal class InboundGroupSessionStore @Inject constructor( private val dirtySession = mutableListOf() @Synchronized - fun getInboundGroupSession(sessionId: String, senderKey: String): OlmInboundGroupSessionWrapper2? { - synchronized(sessionCache) { - val known = sessionCache[CacheKey(sessionId, senderKey)] - Timber.v("## Inbound: getInboundGroupSession in cache ${known != null}") - return known ?: store.getInboundGroupSession(sessionId, senderKey)?.also { - Timber.v("## Inbound: getInboundGroupSession cache populate ${it.roomId}") - sessionCache.put(CacheKey(sessionId, senderKey), it) - } - } + fun clear() { + sessionCache.evictAll() } @Synchronized - fun storeInBoundGroupSession(wrapper: OlmInboundGroupSessionWrapper2, sessionId: String, senderKey: String) { - Timber.v("## Inbound: getInboundGroupSession mark as dirty ${wrapper.roomId}-${wrapper.senderKey}") + fun getInboundGroupSession(sessionId: String, senderKey: String): InboundGroupSessionHolder? { + val known = sessionCache[CacheKey(sessionId, senderKey)] + Timber.tag(loggerTag.value).v("## Inbound: getInboundGroupSession $sessionId in cache ${known != null}") + return known + ?: store.getInboundGroupSession(sessionId, senderKey)?.also { + Timber.tag(loggerTag.value).v("## Inbound: getInboundGroupSession cache populate ${it.roomId}") + sessionCache.put(CacheKey(sessionId, senderKey), InboundGroupSessionHolder(it)) + }?.let { + InboundGroupSessionHolder(it) + } + } + + @Synchronized + fun replaceGroupSession(old: InboundGroupSessionHolder, new: InboundGroupSessionHolder, sessionId: String, senderKey: String) { + Timber.tag(loggerTag.value).v("## Replacing outdated session ${old.wrapper.roomId}-${old.wrapper.senderKey}") + dirtySession.remove(old.wrapper) + store.removeInboundGroupSession(sessionId, senderKey) + sessionCache.remove(CacheKey(sessionId, senderKey)) + + // release removed session + old.wrapper.olmInboundGroupSession?.releaseSession() + + internalStoreGroupSession(new, sessionId, senderKey) + } + + @Synchronized + fun storeInBoundGroupSession(holder: InboundGroupSessionHolder, sessionId: String, senderKey: String) { + internalStoreGroupSession(holder, sessionId, senderKey) + } + + private fun internalStoreGroupSession(holder: InboundGroupSessionHolder, sessionId: String, senderKey: String) { + Timber.tag(loggerTag.value).v("## Inbound: getInboundGroupSession mark as dirty ${holder.wrapper.roomId}-${holder.wrapper.senderKey}") // We want to batch this a bit for performances - dirtySession.add(wrapper) + dirtySession.add(holder.wrapper) if (sessionCache[CacheKey(sessionId, senderKey)] == null) { // first time seen, put it in memory cache while waiting for batch insert // If it's already known, no need to update cache it's already there - sessionCache.put(CacheKey(sessionId, senderKey), wrapper) + sessionCache.put(CacheKey(sessionId, senderKey), holder) } timerTask?.cancel() @@ -96,7 +129,7 @@ internal class InboundGroupSessionStore @Inject constructor( val toSave = mutableListOf().apply { addAll(dirtySession) } dirtySession.clear() cryptoCoroutineScope.launch(coroutineDispatchers.crypto) { - Timber.v("## Inbound: getInboundGroupSession batching save of ${dirtySession.size}") + Timber.tag(loggerTag.value).v("## Inbound: getInboundGroupSession batching save of ${toSave.size}") tryOrNull { store.storeInboundGroupSessions(toSave) } diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/MXOlmDevice.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/MXOlmDevice.kt index e1a706df79..501fb42db2 100755 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/MXOlmDevice.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/MXOlmDevice.kt @@ -16,6 +16,11 @@ package org.matrix.android.sdk.internal.crypto +import androidx.annotation.VisibleForTesting +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import org.matrix.android.sdk.api.extensions.tryOrNull +import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.api.session.crypto.MXCryptoError import org.matrix.android.sdk.api.util.JSON_DICT_PARAMETERIZED_TYPE import org.matrix.android.sdk.api.util.JsonDict @@ -40,6 +45,8 @@ import timber.log.Timber import java.net.URLEncoder import javax.inject.Inject +private val loggerTag = LoggerTag("MXOlmDevice", LoggerTag.CRYPTO) + // The libolm wrapper. @SessionScope internal class MXOlmDevice @Inject constructor( @@ -47,9 +54,12 @@ internal class MXOlmDevice @Inject constructor( * The store where crypto data is saved. */ private val store: IMXCryptoStore, + private val olmSessionStore: OlmSessionStore, private val inboundGroupSessionStore: InboundGroupSessionStore ) { + val mutex = Mutex() + /** * @return the Curve25519 key for the account. */ @@ -93,26 +103,26 @@ internal class MXOlmDevice @Inject constructor( try { store.getOrCreateOlmAccount() } catch (e: Exception) { - Timber.e(e, "MXOlmDevice : cannot initialize olmAccount") + Timber.tag(loggerTag.value).e(e, "MXOlmDevice : cannot initialize olmAccount") } try { olmUtility = OlmUtility() } catch (e: Exception) { - Timber.e(e, "## MXOlmDevice : OlmUtility failed with error") + Timber.tag(loggerTag.value).e(e, "## MXOlmDevice : OlmUtility failed with error") olmUtility = null } try { - deviceCurve25519Key = store.getOlmAccount().identityKeys()[OlmAccount.JSON_KEY_IDENTITY_KEY] + deviceCurve25519Key = store.doWithOlmAccount { it.identityKeys()[OlmAccount.JSON_KEY_IDENTITY_KEY] } } catch (e: Exception) { - Timber.e(e, "## MXOlmDevice : cannot find ${OlmAccount.JSON_KEY_IDENTITY_KEY} with error") + Timber.tag(loggerTag.value).e(e, "## MXOlmDevice : cannot find ${OlmAccount.JSON_KEY_IDENTITY_KEY} with error") } try { - deviceEd25519Key = store.getOlmAccount().identityKeys()[OlmAccount.JSON_KEY_FINGER_PRINT_KEY] + deviceEd25519Key = store.doWithOlmAccount { it.identityKeys()[OlmAccount.JSON_KEY_FINGER_PRINT_KEY] } } catch (e: Exception) { - Timber.e(e, "## MXOlmDevice : cannot find ${OlmAccount.JSON_KEY_FINGER_PRINT_KEY} with error") + Timber.tag(loggerTag.value).e(e, "## MXOlmDevice : cannot find ${OlmAccount.JSON_KEY_FINGER_PRINT_KEY} with error") } } @@ -121,9 +131,9 @@ internal class MXOlmDevice @Inject constructor( */ fun getOneTimeKeys(): Map>? { try { - return store.getOlmAccount().oneTimeKeys() + return store.doWithOlmAccount { it.oneTimeKeys() } } catch (e: Exception) { - Timber.e(e, "## getOneTimeKeys() : failed") + Timber.tag(loggerTag.value).e(e, "## getOneTimeKeys() : failed") } return null @@ -133,7 +143,7 @@ internal class MXOlmDevice @Inject constructor( * @return The maximum number of one-time keys the olm account can store. */ fun getMaxNumberOfOneTimeKeys(): Long { - return store.getOlmAccount().maxOneTimeKeys() + return store.doWithOlmAccount { it.maxOneTimeKeys() } } /** @@ -143,9 +153,9 @@ internal class MXOlmDevice @Inject constructor( */ fun getFallbackKey(): MutableMap>? { try { - return store.getOlmAccount().fallbackKey() + return store.doWithOlmAccount { it.fallbackKey() } } catch (e: Exception) { - Timber.e("## getFallbackKey() : failed") + Timber.tag(loggerTag.value).e("## getFallbackKey() : failed") } return null } @@ -158,12 +168,14 @@ internal class MXOlmDevice @Inject constructor( fun generateFallbackKeyIfNeeded(): Boolean { try { if (!hasUnpublishedFallbackKey()) { - store.getOlmAccount().generateFallbackKey() - store.saveOlmAccount() + store.doWithOlmAccount { + it.generateFallbackKey() + store.saveOlmAccount() + } return true } } catch (e: Exception) { - Timber.e("## generateFallbackKey() : failed") + Timber.tag(loggerTag.value).e("## generateFallbackKey() : failed") } return false } @@ -174,10 +186,12 @@ internal class MXOlmDevice @Inject constructor( fun forgetFallbackKey() { try { - store.getOlmAccount().forgetFallbackKey() - store.saveOlmAccount() + store.doWithOlmAccount { + it.forgetFallbackKey() + store.saveOlmAccount() + } } catch (e: Exception) { - Timber.e("## forgetFallbackKey() : failed") + Timber.tag(loggerTag.value).e("## forgetFallbackKey() : failed") } } @@ -190,6 +204,8 @@ internal class MXOlmDevice @Inject constructor( it.groupSession.releaseSession() } outboundGroupSessionCache.clear() + inboundGroupSessionStore.clear() + olmSessionStore.clear() } /** @@ -200,9 +216,9 @@ internal class MXOlmDevice @Inject constructor( */ fun signMessage(message: String): String? { try { - return store.getOlmAccount().signMessage(message) + return store.doWithOlmAccount { it.signMessage(message) } } catch (e: Exception) { - Timber.e(e, "## signMessage() : failed") + Timber.tag(loggerTag.value).e(e, "## signMessage() : failed") } return null @@ -213,10 +229,12 @@ internal class MXOlmDevice @Inject constructor( */ fun markKeysAsPublished() { try { - store.getOlmAccount().markOneTimeKeysAsPublished() - store.saveOlmAccount() + store.doWithOlmAccount { + it.markOneTimeKeysAsPublished() + store.saveOlmAccount() + } } catch (e: Exception) { - Timber.e(e, "## markKeysAsPublished() : failed") + Timber.tag(loggerTag.value).e(e, "## markKeysAsPublished() : failed") } } @@ -227,10 +245,12 @@ internal class MXOlmDevice @Inject constructor( */ fun generateOneTimeKeys(numKeys: Int) { try { - store.getOlmAccount().generateOneTimeKeys(numKeys) - store.saveOlmAccount() + store.doWithOlmAccount { + it.generateOneTimeKeys(numKeys) + store.saveOlmAccount() + } } catch (e: Exception) { - Timber.e(e, "## generateOneTimeKeys() : failed") + Timber.tag(loggerTag.value).e(e, "## generateOneTimeKeys() : failed") } } @@ -243,12 +263,14 @@ internal class MXOlmDevice @Inject constructor( * @return the session id for the outbound session. */ fun createOutboundSession(theirIdentityKey: String, theirOneTimeKey: String): String? { - Timber.v("## createOutboundSession() ; theirIdentityKey $theirIdentityKey theirOneTimeKey $theirOneTimeKey") + Timber.tag(loggerTag.value).d("## createOutboundSession() ; theirIdentityKey $theirIdentityKey theirOneTimeKey $theirOneTimeKey") var olmSession: OlmSession? = null try { olmSession = OlmSession() - olmSession.initOutboundSession(store.getOlmAccount(), theirIdentityKey, theirOneTimeKey) + store.doWithOlmAccount { olmAccount -> + olmSession.initOutboundSession(olmAccount, theirIdentityKey, theirOneTimeKey) + } val olmSessionWrapper = OlmSessionWrapper(olmSession, 0) @@ -257,14 +279,14 @@ internal class MXOlmDevice @Inject constructor( // this session olmSessionWrapper.onMessageReceived() - store.storeSession(olmSessionWrapper, theirIdentityKey) + olmSessionStore.storeSession(olmSessionWrapper, theirIdentityKey) val sessionIdentifier = olmSession.sessionIdentifier() - Timber.v("## createOutboundSession() ; olmSession.sessionIdentifier: $sessionIdentifier") + Timber.tag(loggerTag.value).v("## createOutboundSession() ; olmSession.sessionIdentifier: $sessionIdentifier") return sessionIdentifier } catch (e: Exception) { - Timber.e(e, "## createOutboundSession() failed") + Timber.tag(loggerTag.value).e(e, "## createOutboundSession() failed") olmSession?.releaseSession() } @@ -281,34 +303,38 @@ internal class MXOlmDevice @Inject constructor( * @return {{payload: string, session_id: string}} decrypted payload, and session id of new session. */ fun createInboundSession(theirDeviceIdentityKey: String, messageType: Int, ciphertext: String): Map? { - Timber.v("## createInboundSession() : theirIdentityKey: $theirDeviceIdentityKey") + Timber.tag(loggerTag.value).d("## createInboundSession() : theirIdentityKey: $theirDeviceIdentityKey") var olmSession: OlmSession? = null try { try { olmSession = OlmSession() - olmSession.initInboundSessionFrom(store.getOlmAccount(), theirDeviceIdentityKey, ciphertext) + store.doWithOlmAccount { olmAccount -> + olmSession.initInboundSessionFrom(olmAccount, theirDeviceIdentityKey, ciphertext) + } } catch (e: Exception) { - Timber.e(e, "## createInboundSession() : the session creation failed") + Timber.tag(loggerTag.value).e(e, "## createInboundSession() : the session creation failed") return null } - Timber.v("## createInboundSession() : sessionId: ${olmSession.sessionIdentifier()}") + Timber.tag(loggerTag.value).v("## createInboundSession() : sessionId: ${olmSession.sessionIdentifier()}") try { - store.getOlmAccount().removeOneTimeKeys(olmSession) - store.saveOlmAccount() + store.doWithOlmAccount { olmAccount -> + olmAccount.removeOneTimeKeys(olmSession) + store.saveOlmAccount() + } } catch (e: Exception) { - Timber.e(e, "## createInboundSession() : removeOneTimeKeys failed") + Timber.tag(loggerTag.value).e(e, "## createInboundSession() : removeOneTimeKeys failed") } - Timber.v("## createInboundSession() : ciphertext: $ciphertext") + Timber.tag(loggerTag.value).v("## createInboundSession() : ciphertext: $ciphertext") try { val sha256 = olmUtility!!.sha256(URLEncoder.encode(ciphertext, "utf-8")) - Timber.v("## createInboundSession() :ciphertext: SHA256: $sha256") + Timber.tag(loggerTag.value).v("## createInboundSession() :ciphertext: SHA256: $sha256") } catch (e: Exception) { - Timber.e(e, "## createInboundSession() :ciphertext: cannot encode ciphertext") + Timber.tag(loggerTag.value).e(e, "## createInboundSession() :ciphertext: cannot encode ciphertext") } val olmMessage = OlmMessage() @@ -324,9 +350,9 @@ internal class MXOlmDevice @Inject constructor( // This counts as a received message: set last received message time to now olmSessionWrapper.onMessageReceived() - store.storeSession(olmSessionWrapper, theirDeviceIdentityKey) + olmSessionStore.storeSession(olmSessionWrapper, theirDeviceIdentityKey) } catch (e: Exception) { - Timber.e(e, "## createInboundSession() : decryptMessage failed") + Timber.tag(loggerTag.value).e(e, "## createInboundSession() : decryptMessage failed") } val res = HashMap() @@ -343,7 +369,7 @@ internal class MXOlmDevice @Inject constructor( return res } catch (e: Exception) { - Timber.e(e, "## createInboundSession() : OlmSession creation failed") + Timber.tag(loggerTag.value).e(e, "## createInboundSession() : OlmSession creation failed") olmSession?.releaseSession() } @@ -357,8 +383,8 @@ internal class MXOlmDevice @Inject constructor( * @param theirDeviceIdentityKey the Curve25519 identity key for the remote device. * @return a list of known session ids for the device. */ - fun getSessionIds(theirDeviceIdentityKey: String): List? { - return store.getDeviceSessionIds(theirDeviceIdentityKey) + fun getSessionIds(theirDeviceIdentityKey: String): List { + return olmSessionStore.getDeviceSessionIds(theirDeviceIdentityKey) } /** @@ -368,7 +394,7 @@ internal class MXOlmDevice @Inject constructor( * @return the session id, or null if no established session. */ fun getSessionId(theirDeviceIdentityKey: String): String? { - return store.getLastUsedSessionId(theirDeviceIdentityKey) + return olmSessionStore.getLastUsedSessionId(theirDeviceIdentityKey) } /** @@ -379,30 +405,30 @@ internal class MXOlmDevice @Inject constructor( * @param payloadString the payload to be encrypted and sent * @return the cipher text */ - fun encryptMessage(theirDeviceIdentityKey: String, sessionId: String, payloadString: String): Map? { - var res: MutableMap? = null - val olmMessage: OlmMessage + suspend fun encryptMessage(theirDeviceIdentityKey: String, sessionId: String, payloadString: String): Map? { val olmSessionWrapper = getSessionForDevice(theirDeviceIdentityKey, sessionId) if (olmSessionWrapper != null) { try { - Timber.v("## encryptMessage() : olmSession.sessionIdentifier: $sessionId") - // Timber.v("## encryptMessage() : payloadString: " + payloadString); + Timber.tag(loggerTag.value).v("## encryptMessage() : olmSession.sessionIdentifier: $sessionId") - olmMessage = olmSessionWrapper.olmSession.encryptMessage(payloadString) - store.storeSession(olmSessionWrapper, theirDeviceIdentityKey) - res = HashMap() - - res["body"] = olmMessage.mCipherText - res["type"] = olmMessage.mType - } catch (e: Exception) { - Timber.e(e, "## encryptMessage() : failed") + val olmMessage = olmSessionWrapper.mutex.withLock { + olmSessionWrapper.olmSession.encryptMessage(payloadString) + } + return mapOf( + "body" to olmMessage.mCipherText, + "type" to olmMessage.mType, + ).also { + olmSessionStore.storeSession(olmSessionWrapper, theirDeviceIdentityKey) + } + } catch (e: Throwable) { + Timber.tag(loggerTag.value).e(e, "## encryptMessage() : failed to encrypt olm with device|session:$theirDeviceIdentityKey|$sessionId") + return null } } else { - Timber.e("## encryptMessage() : Failed to encrypt unknown session $sessionId") + Timber.tag(loggerTag.value).e("## encryptMessage() : Failed to encrypt unknown session $sessionId") + return null } - - return res } /** @@ -414,7 +440,8 @@ internal class MXOlmDevice @Inject constructor( * @param sessionId the id of the active session. * @return the decrypted payload. */ - fun decryptMessage(ciphertext: String, messageType: Int, sessionId: String, theirDeviceIdentityKey: String): String? { + @kotlin.jvm.Throws + suspend fun decryptMessage(ciphertext: String, messageType: Int, sessionId: String, theirDeviceIdentityKey: String): String? { var payloadString: String? = null val olmSessionWrapper = getSessionForDevice(theirDeviceIdentityKey, sessionId) @@ -424,13 +451,13 @@ internal class MXOlmDevice @Inject constructor( olmMessage.mCipherText = ciphertext olmMessage.mType = messageType.toLong() - try { - payloadString = olmSessionWrapper.olmSession.decryptMessage(olmMessage) - olmSessionWrapper.onMessageReceived() - store.storeSession(olmSessionWrapper, theirDeviceIdentityKey) - } catch (e: Exception) { - Timber.e(e, "## decryptMessage() : decryptMessage failed") - } + payloadString = + olmSessionWrapper.mutex.withLock { + olmSessionWrapper.olmSession.decryptMessage(olmMessage).also { + olmSessionWrapper.onMessageReceived() + } + } + olmSessionStore.storeSession(olmSessionWrapper, theirDeviceIdentityKey) } return payloadString @@ -469,7 +496,7 @@ internal class MXOlmDevice @Inject constructor( store.storeCurrentOutboundGroupSessionForRoom(roomId, session) return session.sessionIdentifier() } catch (e: Exception) { - Timber.e(e, "createOutboundGroupSession") + Timber.tag(loggerTag.value).e(e, "createOutboundGroupSession") session?.releaseSession() } @@ -521,7 +548,7 @@ internal class MXOlmDevice @Inject constructor( try { return outboundGroupSessionCache[sessionId]!!.groupSession.sessionKey() } catch (e: Exception) { - Timber.e(e, "## getSessionKey() : failed") + Timber.tag(loggerTag.value).e(e, "## getSessionKey() : failed") } } return null @@ -550,8 +577,8 @@ internal class MXOlmDevice @Inject constructor( if (sessionId.isNotEmpty() && payloadString.isNotEmpty()) { try { return outboundGroupSessionCache[sessionId]!!.groupSession.encryptMessage(payloadString) - } catch (e: Exception) { - Timber.e(e, "## encryptGroupMessage() : failed") + } catch (e: Throwable) { + Timber.tag(loggerTag.value).e(e, "## encryptGroupMessage() : failed") } } return null @@ -578,52 +605,64 @@ internal class MXOlmDevice @Inject constructor( forwardingCurve25519KeyChain: List, keysClaimed: Map, exportFormat: Boolean): Boolean { - val session = OlmInboundGroupSessionWrapper2(sessionKey, exportFormat) - runCatching { getInboundGroupSession(sessionId, senderKey, roomId) } - .fold( - { - // If we already have this session, consider updating it - Timber.e("## addInboundGroupSession() : Update for megolm session $senderKey/$sessionId") + val candidateSession = OlmInboundGroupSessionWrapper2(sessionKey, exportFormat) + val existingSessionHolder = tryOrNull { getInboundGroupSession(sessionId, senderKey, roomId) } + val existingSession = existingSessionHolder?.wrapper + // If we have an existing one we should check if the new one is not better + if (existingSession != null) { + Timber.tag(loggerTag.value).d("## addInboundGroupSession() check if known session is better than candidate session") + try { + val existingFirstKnown = existingSession.firstKnownIndex ?: return false.also { + // This is quite unexpected, could throw if native was released? + Timber.tag(loggerTag.value).e("## addInboundGroupSession() null firstKnownIndex on existing session") + candidateSession.olmInboundGroupSession?.releaseSession() + // Probably should discard it? + } + val newKnownFirstIndex = candidateSession.firstKnownIndex + // If our existing session is better we keep it + if (newKnownFirstIndex != null && existingFirstKnown <= newKnownFirstIndex) { + Timber.tag(loggerTag.value).d("## addInboundGroupSession() : ignore session our is better $senderKey/$sessionId") + candidateSession.olmInboundGroupSession?.releaseSession() + return false + } + } catch (failure: Throwable) { + Timber.tag(loggerTag.value).e("## addInboundGroupSession() Failed to add inbound: ${failure.localizedMessage}") + candidateSession.olmInboundGroupSession?.releaseSession() + return false + } + } - val existingFirstKnown = it.firstKnownIndex!! - val newKnownFirstIndex = session.firstKnownIndex + Timber.tag(loggerTag.value).d("## addInboundGroupSession() : Candidate session should be added $senderKey/$sessionId") - // If our existing session is better we keep it - if (newKnownFirstIndex != null && existingFirstKnown <= newKnownFirstIndex) { - session.olmInboundGroupSession?.releaseSession() - return false - } - }, - { - // Nothing to do in case of error - } - ) - - // sanity check - if (null == session.olmInboundGroupSession) { - Timber.e("## addInboundGroupSession : invalid session") + // sanity check on the new session + val candidateOlmInboundSession = candidateSession.olmInboundGroupSession + if (null == candidateOlmInboundSession) { + Timber.tag(loggerTag.value).e("## addInboundGroupSession : invalid session ") return false } try { - if (session.olmInboundGroupSession!!.sessionIdentifier() != sessionId) { - Timber.e("## addInboundGroupSession : ERROR: Mismatched group session ID from senderKey: $senderKey") - session.olmInboundGroupSession!!.releaseSession() + if (candidateOlmInboundSession.sessionIdentifier() != sessionId) { + Timber.tag(loggerTag.value).e("## addInboundGroupSession : ERROR: Mismatched group session ID from senderKey: $senderKey") + candidateOlmInboundSession.releaseSession() return false } - } catch (e: Exception) { - session.olmInboundGroupSession?.releaseSession() - Timber.e(e, "## addInboundGroupSession : sessionIdentifier() failed") + } catch (e: Throwable) { + candidateOlmInboundSession.releaseSession() + Timber.tag(loggerTag.value).e(e, "## addInboundGroupSession : sessionIdentifier() failed") return false } - session.senderKey = senderKey - session.roomId = roomId - session.keysClaimed = keysClaimed - session.forwardingCurve25519KeyChain = forwardingCurve25519KeyChain + candidateSession.senderKey = senderKey + candidateSession.roomId = roomId + candidateSession.keysClaimed = keysClaimed + candidateSession.forwardingCurve25519KeyChain = forwardingCurve25519KeyChain - inboundGroupSessionStore.storeInBoundGroupSession(session, sessionId, senderKey) -// store.storeInboundGroupSessions(listOf(session)) + if (existingSession != null) { + inboundGroupSessionStore.replaceGroupSession(existingSessionHolder, InboundGroupSessionHolder(candidateSession), sessionId, senderKey) + } else { + inboundGroupSessionStore.storeInBoundGroupSession(InboundGroupSessionHolder(candidateSession), sessionId, senderKey) + } return true } @@ -638,57 +677,70 @@ internal class MXOlmDevice @Inject constructor( val sessions = ArrayList(megolmSessionsData.size) for (megolmSessionData in megolmSessionsData) { - val sessionId = megolmSessionData.sessionId - val senderKey = megolmSessionData.senderKey + val sessionId = megolmSessionData.sessionId ?: continue + val senderKey = megolmSessionData.senderKey ?: continue val roomId = megolmSessionData.roomId - var session: OlmInboundGroupSessionWrapper2? = null + var candidateSessionToImport: OlmInboundGroupSessionWrapper2? = null try { - session = OlmInboundGroupSessionWrapper2(megolmSessionData) + candidateSessionToImport = OlmInboundGroupSessionWrapper2(megolmSessionData) } catch (e: Exception) { - Timber.e(e, "## importInboundGroupSession() : Update for megolm session $senderKey/$sessionId") + Timber.tag(loggerTag.value).e(e, "## importInboundGroupSession() : Update for megolm session $senderKey/$sessionId") } // sanity check - if (session?.olmInboundGroupSession == null) { - Timber.e("## importInboundGroupSession : invalid session") + if (candidateSessionToImport?.olmInboundGroupSession == null) { + Timber.tag(loggerTag.value).e("## importInboundGroupSession : invalid session") continue } + val candidateOlmInboundGroupSession = candidateSessionToImport.olmInboundGroupSession try { - if (session.olmInboundGroupSession?.sessionIdentifier() != sessionId) { - Timber.e("## importInboundGroupSession : ERROR: Mismatched group session ID from senderKey: $senderKey") - if (session.olmInboundGroupSession != null) session.olmInboundGroupSession!!.releaseSession() + if (candidateOlmInboundGroupSession?.sessionIdentifier() != sessionId) { + Timber.tag(loggerTag.value).e("## importInboundGroupSession : ERROR: Mismatched group session ID from senderKey: $senderKey") + candidateOlmInboundGroupSession?.releaseSession() continue } } catch (e: Exception) { - Timber.e(e, "## importInboundGroupSession : sessionIdentifier() failed") - session.olmInboundGroupSession!!.releaseSession() + Timber.tag(loggerTag.value).e(e, "## importInboundGroupSession : sessionIdentifier() failed") + candidateOlmInboundGroupSession?.releaseSession() continue } - runCatching { getInboundGroupSession(sessionId, senderKey, roomId) } - .fold( - { - // If we already have this session, consider updating it - Timber.e("## importInboundGroupSession() : Update for megolm session $senderKey/$sessionId") + val existingSessionHolder = tryOrNull { getInboundGroupSession(sessionId, senderKey, roomId) } + val existingSession = existingSessionHolder?.wrapper - // For now we just ignore updates. TODO: implement something here - if (it.firstKnownIndex!! <= session.firstKnownIndex!!) { - // Ignore this, keep existing - session.olmInboundGroupSession!!.releaseSession() - } else { - sessions.add(session) - } - Unit - }, - { - // Session does not already exist, add it - sessions.add(session) - } + if (existingSession == null) { + // Session does not already exist, add it + Timber.tag(loggerTag.value).d("## importInboundGroupSession() : importing new megolm session $senderKey/$sessionId") + sessions.add(candidateSessionToImport) + } else { + Timber.tag(loggerTag.value).e("## importInboundGroupSession() : Update for megolm session $senderKey/$sessionId") + val existingFirstKnown = tryOrNull { existingSession.firstKnownIndex } + val candidateFirstKnownIndex = tryOrNull { candidateSessionToImport.firstKnownIndex } - ) + if (existingFirstKnown == null || candidateFirstKnownIndex == null) { + // should not happen? + candidateSessionToImport.olmInboundGroupSession?.releaseSession() + Timber.tag(loggerTag.value) + .w("## importInboundGroupSession() : Can't check session null index $existingFirstKnown/$candidateFirstKnownIndex") + } else { + if (existingFirstKnown <= candidateSessionToImport.firstKnownIndex!!) { + // Ignore this, keep existing + candidateOlmInboundGroupSession.releaseSession() + } else { + // update cache with better session + inboundGroupSessionStore.replaceGroupSession( + existingSessionHolder, + InboundGroupSessionHolder(candidateSessionToImport), + sessionId, + senderKey + ) + sessions.add(candidateSessionToImport) + } + } + } } store.storeInboundGroupSessions(sessions) @@ -696,18 +748,6 @@ internal class MXOlmDevice @Inject constructor( return sessions } - /** - * Remove an inbound group session - * - * @param sessionId the session identifier. - * @param sessionKey base64-encoded secret key. - */ - fun removeInboundGroupSession(sessionId: String?, sessionKey: String?) { - if (null != sessionId && null != sessionKey) { - store.removeInboundGroupSession(sessionId, sessionKey) - } - } - /** * Decrypt a received message with an inbound group session. * @@ -719,19 +759,24 @@ internal class MXOlmDevice @Inject constructor( * @return the decrypting result. Nil if the sessionId is unknown. */ @Throws(MXCryptoError::class) - fun decryptGroupMessage(body: String, - roomId: String, - timeline: String?, - sessionId: String, - senderKey: String): OlmDecryptionResult { - val session = getInboundGroupSession(sessionId, senderKey, roomId) + suspend fun decryptGroupMessage(body: String, + roomId: String, + timeline: String?, + sessionId: String, + senderKey: String): OlmDecryptionResult { + val sessionHolder = getInboundGroupSession(sessionId, senderKey, roomId) + val wrapper = sessionHolder.wrapper + val inboundGroupSession = wrapper.olmInboundGroupSession + ?: throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, "Session is null") // Check that the room id matches the original one for the session. This stops // the HS pretending a message was targeting a different room. - if (roomId == session.roomId) { + if (roomId == wrapper.roomId) { val decryptResult = try { - session.olmInboundGroupSession!!.decryptMessage(body) + sessionHolder.mutex.withLock { + inboundGroupSession.decryptMessage(body) + } } catch (e: OlmException) { - Timber.e(e, "## decryptGroupMessage () : decryptMessage failed") + Timber.tag(loggerTag.value).e(e, "## decryptGroupMessage () : decryptMessage failed") throw MXCryptoError.OlmError(e) } @@ -742,32 +787,32 @@ internal class MXOlmDevice @Inject constructor( if (timelineSet.contains(messageIndexKey)) { val reason = String.format(MXCryptoError.DUPLICATE_MESSAGE_INDEX_REASON, decryptResult.mIndex) - Timber.e("## decryptGroupMessage() : $reason") + Timber.tag(loggerTag.value).e("## decryptGroupMessage() : $reason") throw MXCryptoError.Base(MXCryptoError.ErrorType.DUPLICATED_MESSAGE_INDEX, reason) } timelineSet.add(messageIndexKey) } - inboundGroupSessionStore.storeInBoundGroupSession(session, sessionId, senderKey) + inboundGroupSessionStore.storeInBoundGroupSession(sessionHolder, sessionId, senderKey) val payload = try { val adapter = MoshiProvider.providesMoshi().adapter(JSON_DICT_PARAMETERIZED_TYPE) val payloadString = convertFromUTF8(decryptResult.mDecryptedMessage) adapter.fromJson(payloadString) } catch (e: Exception) { - Timber.e("## decryptGroupMessage() : fails to parse the payload") + Timber.tag(loggerTag.value).e("## decryptGroupMessage() : fails to parse the payload") throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_DECRYPTED_FORMAT, MXCryptoError.BAD_DECRYPTED_FORMAT_TEXT_REASON) } return OlmDecryptionResult( payload, - session.keysClaimed, + wrapper.keysClaimed, senderKey, - session.forwardingCurve25519KeyChain + wrapper.forwardingCurve25519KeyChain ) } else { - val reason = String.format(MXCryptoError.INBOUND_SESSION_MISMATCH_ROOM_ID_REASON, roomId, session.roomId) - Timber.e("## decryptGroupMessage() : $reason") + val reason = String.format(MXCryptoError.INBOUND_SESSION_MISMATCH_ROOM_ID_REASON, roomId, wrapper.roomId) + Timber.tag(loggerTag.value).e("## decryptGroupMessage() : $reason") throw MXCryptoError.Base(MXCryptoError.ErrorType.INBOUND_SESSION_MISMATCH_ROOM_ID, reason) } } @@ -819,7 +864,7 @@ internal class MXOlmDevice @Inject constructor( private fun getSessionForDevice(theirDeviceIdentityKey: String, sessionId: String): OlmSessionWrapper? { // sanity check return if (theirDeviceIdentityKey.isEmpty() || sessionId.isEmpty()) null else { - store.getDeviceSession(sessionId, theirDeviceIdentityKey) + olmSessionStore.getDeviceSession(sessionId, theirDeviceIdentityKey) } } @@ -832,25 +877,26 @@ internal class MXOlmDevice @Inject constructor( * @param senderKey the base64-encoded curve25519 key of the sender. * @return the inbound group session. */ - fun getInboundGroupSession(sessionId: String?, senderKey: String?, roomId: String?): OlmInboundGroupSessionWrapper2 { + fun getInboundGroupSession(sessionId: String?, senderKey: String?, roomId: String?): InboundGroupSessionHolder { if (sessionId.isNullOrBlank() || senderKey.isNullOrBlank()) { throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_SENDER_KEY, MXCryptoError.ERROR_MISSING_PROPERTY_REASON) } - val session = inboundGroupSessionStore.getInboundGroupSession(sessionId, senderKey) + val holder = inboundGroupSessionStore.getInboundGroupSession(sessionId, senderKey) + val session = holder?.wrapper if (session != null) { // Check that the room id matches the original one for the session. This stops // the HS pretending a message was targeting a different room. if (roomId != session.roomId) { val errorDescription = String.format(MXCryptoError.INBOUND_SESSION_MISMATCH_ROOM_ID_REASON, roomId, session.roomId) - Timber.e("## getInboundGroupSession() : $errorDescription") + Timber.tag(loggerTag.value).e("## getInboundGroupSession() : $errorDescription") throw MXCryptoError.Base(MXCryptoError.ErrorType.INBOUND_SESSION_MISMATCH_ROOM_ID, errorDescription) } else { - return session + return holder } } else { - Timber.w("## getInboundGroupSession() : Cannot retrieve inbound group session $sessionId") + Timber.tag(loggerTag.value).w("## getInboundGroupSession() : UISI $sessionId") throw MXCryptoError.Base(MXCryptoError.ErrorType.UNKNOWN_INBOUND_SESSION_ID, MXCryptoError.UNKNOWN_INBOUND_SESSION_ID_REASON) } } @@ -866,4 +912,9 @@ internal class MXOlmDevice @Inject constructor( fun hasInboundSessionKeys(roomId: String, senderKey: String, sessionId: String): Boolean { return runCatching { getInboundGroupSession(sessionId, senderKey, roomId) }.isSuccess } + + @VisibleForTesting + fun clearOlmSessionCache() { + olmSessionStore.clear() + } } diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmSessionStore.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmSessionStore.kt new file mode 100644 index 0000000000..f4fbca6a0f --- /dev/null +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/OlmSessionStore.kt @@ -0,0 +1,159 @@ +/* + * Copyright 2022 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.matrix.android.sdk.internal.crypto + +import org.matrix.android.sdk.api.logger.LoggerTag +import org.matrix.android.sdk.internal.crypto.model.OlmSessionWrapper +import org.matrix.android.sdk.internal.crypto.store.IMXCryptoStore +import org.matrix.olm.OlmSession +import timber.log.Timber +import javax.inject.Inject + +private val loggerTag = LoggerTag("OlmSessionStore", LoggerTag.CRYPTO) + +/** + * Keep the used olm session in memory and load them from the data layer when needed + * Access is synchronized for thread safety + */ +internal class OlmSessionStore @Inject constructor(private val store: IMXCryptoStore) { + /** + * map of device key to list of olm sessions (it is possible to have several active sessions with a device) + */ + private val olmSessions = HashMap>() + + /** + * Store a session between our own device and another device. + * This will be called after the session has been created but also every time it has been used + * in order to persist the correct state for next run + * @param olmSessionWrapper the end-to-end session. + * @param deviceKey the public key of the other device. + */ + @Synchronized + fun storeSession(olmSessionWrapper: OlmSessionWrapper, deviceKey: String) { + // This could be a newly created session or one that was just created + // Anyhow we should persist ratchet state for future app lifecycle + addNewSessionInCache(olmSessionWrapper, deviceKey) + store.storeSession(olmSessionWrapper, deviceKey) + } + + /** + * Get all the Olm Sessions we are sharing with the given device. + * + * @param deviceKey the public key of the other device. + * @return A set of sessionId, or empty if device is not known + */ + @Synchronized + fun getDeviceSessionIds(deviceKey: String): List { + // we need to get the persisted ids first + val persistedKnownSessions = store.getDeviceSessionIds(deviceKey) + .orEmpty() + .toMutableList() + // Do we have some in cache not yet persisted? + olmSessions.getOrPut(deviceKey) { mutableListOf() }.forEach { cached -> + getSafeSessionIdentifier(cached.olmSession)?.let { cachedSessionId -> + if (!persistedKnownSessions.contains(cachedSessionId)) { + persistedKnownSessions.add(cachedSessionId) + } + } + } + return persistedKnownSessions + } + + /** + * Retrieve an end-to-end session between our own device and another + * device. + * + * @param sessionId the session Id. + * @param deviceKey the public key of the other device. + * @return the session wrapper if found + */ + @Synchronized + fun getDeviceSession(sessionId: String, deviceKey: String): OlmSessionWrapper? { + // get from cache or load and add to cache + return internalGetSession(sessionId, deviceKey) + } + + /** + * Retrieve the last used sessionId, regarding `lastReceivedMessageTs`, or null if no session exist + * + * @param deviceKey the public key of the other device. + * @return last used sessionId, or null if not found + */ + @Synchronized + fun getLastUsedSessionId(deviceKey: String): String? { + // We want to avoid to load in memory old session if possible + val lastPersistedUsedSession = store.getLastUsedSessionId(deviceKey) + var candidate = lastPersistedUsedSession?.let { internalGetSession(it, deviceKey) } + // we should check if we have one in cache with a higher last message received? + olmSessions[deviceKey].orEmpty().forEach { inCache -> + if (inCache.lastReceivedMessageTs > (candidate?.lastReceivedMessageTs ?: 0L)) { + candidate = inCache + } + } + + return candidate?.olmSession?.sessionIdentifier() + } + + /** + * Release all sessions and clear cache + */ + @Synchronized + fun clear() { + olmSessions.entries.onEach { entry -> + entry.value.onEach { it.olmSession.releaseSession() } + } + olmSessions.clear() + } + + private fun internalGetSession(sessionId: String, deviceKey: String): OlmSessionWrapper? { + return getSessionInCache(sessionId, deviceKey) + ?: // deserialize from store + return store.getDeviceSession(sessionId, deviceKey)?.also { + addNewSessionInCache(it, deviceKey) + } + } + + private fun getSessionInCache(sessionId: String, deviceKey: String): OlmSessionWrapper? { + return olmSessions[deviceKey]?.firstOrNull { + getSafeSessionIdentifier(it.olmSession) == sessionId + } + } + + private fun getSafeSessionIdentifier(session: OlmSession): String? { + return try { + session.sessionIdentifier() + } catch (throwable: Throwable) { + Timber.tag(loggerTag.value).w("Failed to load sessionId from loaded olm session") + null + } + } + + private fun addNewSessionInCache(session: OlmSessionWrapper, deviceKey: String) { + val sessionId = getSafeSessionIdentifier(session.olmSession) ?: return + olmSessions.getOrPut(deviceKey) { mutableListOf() }.let { + val existing = it.firstOrNull { getSafeSessionIdentifier(it.olmSession) == sessionId } + it.add(session) + // remove and release if was there but with different instance + if (existing != null && existing.olmSession != session.olmSession) { + // mm not sure when this could happen + // anyhow we should remove and release the one known + it.remove(existing) + existing.olmSession.releaseSession() + } + } + } +} diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/actions/EnsureOlmSessionsForDevicesAction.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/actions/EnsureOlmSessionsForDevicesAction.kt index ab2ed04dfb..87c176612d 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/actions/EnsureOlmSessionsForDevicesAction.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/actions/EnsureOlmSessionsForDevicesAction.kt @@ -16,14 +16,18 @@ package org.matrix.android.sdk.internal.crypto.actions +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.withContext +import org.matrix.android.sdk.api.MatrixCoroutineDispatchers import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.internal.crypto.MXOlmDevice import org.matrix.android.sdk.internal.crypto.model.CryptoDeviceInfo import org.matrix.android.sdk.internal.crypto.model.MXKey import org.matrix.android.sdk.internal.crypto.model.MXOlmSessionResult import org.matrix.android.sdk.internal.crypto.model.MXUsersDevicesMap -import org.matrix.android.sdk.internal.crypto.model.toDebugString import org.matrix.android.sdk.internal.crypto.tasks.ClaimOneTimeKeysForUsersDeviceTask +import org.matrix.android.sdk.internal.session.SessionScope import timber.log.Timber import javax.inject.Inject @@ -31,90 +35,90 @@ private const val ONE_TIME_KEYS_RETRY_COUNT = 3 private val loggerTag = LoggerTag("EnsureOlmSessionsForDevicesAction", LoggerTag.CRYPTO) +@SessionScope internal class EnsureOlmSessionsForDevicesAction @Inject constructor( private val olmDevice: MXOlmDevice, + private val coroutineDispatchers: MatrixCoroutineDispatchers, private val oneTimeKeysForUsersDeviceTask: ClaimOneTimeKeysForUsersDeviceTask) { + private val ensureMutex = Mutex() + + /** + * We want to synchronize a bit here, because we are iterating to check existing olm session and + * also adding some + */ suspend fun handle(devicesByUser: Map>, force: Boolean = false): MXUsersDevicesMap { - val devicesWithoutSession = ArrayList() + ensureMutex.withLock { + val results = MXUsersDevicesMap() + val deviceList = devicesByUser.flatMap { it.value } + Timber.tag(loggerTag.value) + .d("ensure olm forced:$force for ${deviceList.joinToString { it.shortDebugString() }}") + val devicesToCreateSessionWith = mutableListOf() + if (force) { + // we take all devices and will query otk for them + devicesToCreateSessionWith.addAll(deviceList) + } else { + // only peek devices without active session + deviceList.forEach { deviceInfo -> + val deviceId = deviceInfo.deviceId + val userId = deviceInfo.userId + val key = deviceInfo.identityKey() ?: return@forEach Unit.also { + Timber.tag(loggerTag.value).w("Ignoring device ${deviceInfo.shortDebugString()} without identity key") + } - val results = MXUsersDevicesMap() - - for ((userId, deviceList) in devicesByUser) { - for (deviceInfo in deviceList) { - val deviceId = deviceInfo.deviceId - val key = deviceInfo.identityKey() - if (key == null) { - Timber.w("## CRYPTO | Ignoring device (${deviceInfo.userId}|$deviceId) without identity key") - continue - } - - val sessionId = olmDevice.getSessionId(key) - - if (sessionId.isNullOrEmpty() || force) { - Timber.tag(loggerTag.value).d("Found no existing olm session (${deviceInfo.userId}|$deviceId) (force=$force)") - devicesWithoutSession.add(deviceInfo) - } else { - Timber.tag(loggerTag.value).d("using olm session $sessionId for (${deviceInfo.userId}|$deviceId)") - } - - val olmSessionResult = MXOlmSessionResult(deviceInfo, sessionId) - results.setObject(userId, deviceId, olmSessionResult) - } - } - - Timber.tag(loggerTag.value).d("Devices without olm session (count:${devicesWithoutSession.size}) :" + - " ${devicesWithoutSession.joinToString { "${it.userId}|${it.deviceId}" }}") - if (devicesWithoutSession.size == 0) { - return results - } - - // Prepare the request for claiming one-time keys - val usersDevicesToClaim = MXUsersDevicesMap() - - val oneTimeKeyAlgorithm = MXKey.KEY_SIGNED_CURVE_25519_TYPE - - for (device in devicesWithoutSession) { - usersDevicesToClaim.setObject(device.userId, device.deviceId, oneTimeKeyAlgorithm) - } - - // TODO: this has a race condition - if we try to send another message - // while we are claiming a key, we will end up claiming two and setting up - // two sessions. - // - // That should eventually resolve itself, but it's poor form. - - Timber.tag(loggerTag.value).i("claimOneTimeKeysForUsersDevices() : ${usersDevicesToClaim.toDebugString()}") - - val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(usersDevicesToClaim) - val oneTimeKeys = oneTimeKeysForUsersDeviceTask.executeRetry(claimParams, remainingRetry = ONE_TIME_KEYS_RETRY_COUNT) - Timber.tag(loggerTag.value).v("claimOneTimeKeysForUsersDevices() : keysClaimResponse.oneTimeKeys: $oneTimeKeys") - for ((userId, deviceInfos) in devicesByUser) { - for (deviceInfo in deviceInfos) { - var oneTimeKey: MXKey? = null - val deviceIds = oneTimeKeys.getUserDeviceIds(userId) - if (null != deviceIds) { - for (deviceId in deviceIds) { - val olmSessionResult = results.getObject(userId, deviceId) - if (olmSessionResult?.sessionId != null && !force) { - // We already have a result for this device - continue - } - val key = oneTimeKeys.getObject(userId, deviceId) - if (key?.type == oneTimeKeyAlgorithm) { - oneTimeKey = key - } - if (oneTimeKey == null) { - Timber.tag(loggerTag.value).d("No one time key for $userId|$deviceId") - continue - } - // Update the result for this device in results - olmSessionResult?.sessionId = verifyKeyAndStartSession(oneTimeKey, userId, deviceInfo) + // is there a session that as been already used? + val sessionId = olmDevice.getSessionId(key) + if (sessionId.isNullOrEmpty()) { + Timber.tag(loggerTag.value).d("Found no existing olm session ${deviceInfo.shortDebugString()} add to claim list") + devicesToCreateSessionWith.add(deviceInfo) + } else { + Timber.tag(loggerTag.value).d("using olm session $sessionId for (${deviceInfo.userId}|$deviceId)") + val olmSessionResult = MXOlmSessionResult(deviceInfo, sessionId) + results.setObject(userId, deviceId, olmSessionResult) } } } + + if (devicesToCreateSessionWith.isEmpty()) { + // no session to create + return results + } + val usersDevicesToClaim = MXUsersDevicesMap().apply { + devicesToCreateSessionWith.forEach { + setObject(it.userId, it.deviceId, MXKey.KEY_SIGNED_CURVE_25519_TYPE) + } + } + + // Let's now claim one time keys + val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(usersDevicesToClaim) + val oneTimeKeys = withContext(coroutineDispatchers.io) { + oneTimeKeysForUsersDeviceTask.executeRetry(claimParams, ONE_TIME_KEYS_RETRY_COUNT) + } + + // let now start olm session using the new otks + devicesToCreateSessionWith.forEach { deviceInfo -> + val userId = deviceInfo.userId + val deviceId = deviceInfo.deviceId + // Did we get an OTK + val oneTimeKey = oneTimeKeys.getObject(userId, deviceId) + if (oneTimeKey == null) { + Timber.tag(loggerTag.value).d("No otk for ${deviceInfo.shortDebugString()}") + } else if (oneTimeKey.type != MXKey.KEY_SIGNED_CURVE_25519_TYPE) { + Timber.tag(loggerTag.value).d("Bad otk type (${oneTimeKey.type}) for ${deviceInfo.shortDebugString()}") + } else { + val olmSessionId = verifyKeyAndStartSession(oneTimeKey, userId, deviceInfo) + if (olmSessionId != null) { + val olmSessionResult = MXOlmSessionResult(deviceInfo, olmSessionId) + results.setObject(userId, deviceId, olmSessionResult) + } else { + Timber + .tag(loggerTag.value) + .d("## CRYPTO | cant unwedge failed to create outbound ${deviceInfo.shortDebugString()}") + } + } + } + return results } - return results } private fun verifyKeyAndStartSession(oneTimeKey: MXKey, userId: String, deviceInfo: CryptoDeviceInfo): String? { diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/actions/MessageEncrypter.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/actions/MessageEncrypter.kt index 165f200bac..4e158602c8 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/actions/MessageEncrypter.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/actions/MessageEncrypter.kt @@ -16,6 +16,7 @@ package org.matrix.android.sdk.internal.crypto.actions +import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.api.session.events.model.Content import org.matrix.android.sdk.internal.crypto.MXCRYPTO_ALGORITHM_OLM import org.matrix.android.sdk.internal.crypto.MXOlmDevice @@ -28,6 +29,8 @@ import org.matrix.android.sdk.internal.util.convertToUTF8 import timber.log.Timber import javax.inject.Inject +private val loggerTag = LoggerTag("MessageEncrypter", LoggerTag.CRYPTO) + internal class MessageEncrypter @Inject constructor( @UserId private val userId: String, @@ -42,7 +45,7 @@ internal class MessageEncrypter @Inject constructor( * @param deviceInfos list of device infos to encrypt for. * @return the content for an m.room.encrypted event. */ - fun encryptMessage(payloadFields: Content, deviceInfos: List): EncryptedMessage { + suspend fun encryptMessage(payloadFields: Content, deviceInfos: List): EncryptedMessage { val deviceInfoParticipantKey = deviceInfos.associateBy { it.identityKey()!! } val payloadJson = payloadFields.toMutableMap() @@ -66,7 +69,7 @@ internal class MessageEncrypter @Inject constructor( val sessionId = olmDevice.getSessionId(deviceKey) if (!sessionId.isNullOrEmpty()) { - Timber.v("Using sessionid $sessionId for device $deviceKey") + Timber.tag(loggerTag.value).d("Using sessionid $sessionId for device $deviceKey") payloadJson["recipient"] = deviceInfo.userId payloadJson["recipient_keys"] = mapOf("ed25519" to deviceInfo.fingerprint()!!) diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/IMXDecrypting.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/IMXDecrypting.kt index 79c7608cbf..b6c1d99aa5 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/IMXDecrypting.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/IMXDecrypting.kt @@ -36,7 +36,7 @@ internal interface IMXDecrypting { * @return the decryption information, or an error */ @Throws(MXCryptoError::class) - fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult + suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult /** * Handle a key event. diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/IMXGroupEncryption.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/IMXGroupEncryption.kt index 1fd5061a65..6f488def0a 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/IMXGroupEncryption.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/IMXGroupEncryption.kt @@ -45,7 +45,7 @@ internal interface IMXGroupEncryption { * * @return true in case of success */ - suspend fun reshareKey(sessionId: String, + suspend fun reshareKey(groupSessionId: String, userId: String, deviceId: String, senderKey: String): Boolean diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/megolm/MXMegolmDecryption.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/megolm/MXMegolmDecryption.kt index 2ee24dfbb0..e94daa0e76 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/megolm/MXMegolmDecryption.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/megolm/MXMegolmDecryption.kt @@ -19,6 +19,7 @@ package org.matrix.android.sdk.internal.crypto.algorithms.megolm import dagger.Lazy import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.launch +import kotlinx.coroutines.sync.withLock import org.matrix.android.sdk.api.MatrixCoroutineDispatchers import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.api.session.crypto.MXCryptoError @@ -71,7 +72,7 @@ internal class MXMegolmDecryption(private val userId: String, // private var pendingEvents: MutableMap>> = HashMap() @Throws(MXCryptoError::class) - override fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult { + override suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult { // If cross signing is enabled, we don't send request until the keys are trusted // There could be a race effect here when xsigning is enabled, we should ensure that keys was downloaded once val requestOnFail = cryptoStore.getMyCrossSigningInfo()?.isTrusted() == true @@ -79,7 +80,7 @@ internal class MXMegolmDecryption(private val userId: String, } @Throws(MXCryptoError::class) - private fun decryptEvent(event: Event, timeline: String, requestKeysOnFail: Boolean): MXEventDecryptionResult { + private suspend fun decryptEvent(event: Event, timeline: String, requestKeysOnFail: Boolean): MXEventDecryptionResult { Timber.tag(loggerTag.value).v("decryptEvent ${event.eventId}, requestKeysOnFail:$requestKeysOnFail") if (event.roomId.isNullOrBlank()) { throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_FIELDS, MXCryptoError.MISSING_FIELDS_REASON) @@ -345,7 +346,22 @@ internal class MXMegolmDecryption(private val userId: String, return } val userId = request.userId ?: return + cryptoCoroutineScope.launch(coroutineDispatchers.crypto) { + val body = request.requestBody + val sessionHolder = try { + olmDevice.getInboundGroupSession(body.sessionId, body.senderKey, body.roomId) + } catch (failure: Throwable) { + Timber.tag(loggerTag.value).e(failure, "shareKeysWithDevice: failed to get session for request $body") + return@launch + } + + val export = sessionHolder.mutex.withLock { + sessionHolder.wrapper.exportKeys() + } ?: return@launch Unit.also { + Timber.tag(loggerTag.value).e("shareKeysWithDevice: failed to export group session ${body.sessionId}") + } + runCatching { deviceListManager.downloadKeys(listOf(userId), false) } .mapCatching { val deviceId = request.deviceId @@ -355,7 +371,6 @@ internal class MXMegolmDecryption(private val userId: String, } else { val devicesByUser = mapOf(userId to listOf(deviceInfo)) val usersDeviceMap = ensureOlmSessionsForDevicesAction.handle(devicesByUser) - val body = request.requestBody val olmSessionResult = usersDeviceMap.getObject(userId, deviceId) if (olmSessionResult?.sessionId == null) { // no session with this device, probably because there @@ -365,19 +380,10 @@ internal class MXMegolmDecryption(private val userId: String, } Timber.tag(loggerTag.value).i("shareKeysWithDevice() : sharing session ${body.sessionId} with device $userId:$deviceId") - val payloadJson = mutableMapOf("type" to EventType.FORWARDED_ROOM_KEY) - runCatching { olmDevice.getInboundGroupSession(body.sessionId, body.senderKey, body.roomId) } - .fold( - { - // TODO - payloadJson["content"] = it.exportKeys() ?: "" - }, - { - // TODO - Timber.tag(loggerTag.value).e(it, "shareKeysWithDevice: failed to get session for request $body") - } - - ) + val payloadJson = mapOf( + "type" to EventType.FORWARDED_ROOM_KEY, + "content" to export + ) val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo)) val sendToDeviceMap = MXUsersDevicesMap() diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/megolm/MXMegolmEncryption.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/megolm/MXMegolmEncryption.kt index 389036a1f8..cf9733dc2d 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/megolm/MXMegolmEncryption.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/megolm/MXMegolmEncryption.kt @@ -18,6 +18,8 @@ package org.matrix.android.sdk.internal.crypto.algorithms.megolm import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.launch +import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.withContext import org.matrix.android.sdk.api.MatrixCoroutineDispatchers import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.api.session.crypto.MXCryptoError @@ -88,7 +90,7 @@ internal class MXMegolmEncryption( Timber.tag(loggerTag.value).v("encryptEventContent : getDevicesInRoom") val devices = getDevicesInRoom(userIds) Timber.tag(loggerTag.value).d("encrypt event in room=$roomId - devices count in room ${devices.allowedDevices.toDebugCount()}") - Timber.tag(loggerTag.value).v("encryptEventContent ${System.currentTimeMillis() - ts}: getDevicesInRoom ${devices.allowedDevices.map}") + Timber.tag(loggerTag.value).v("encryptEventContent ${System.currentTimeMillis() - ts}: getDevicesInRoom ${devices.allowedDevices.toDebugString()}") val outboundSession = ensureOutboundSession(devices.allowedDevices) return encryptContent(outboundSession, eventType, eventContent) @@ -142,8 +144,9 @@ internal class MXMegolmEncryption( Timber.tag(loggerTag.value).v("prepareNewSessionInRoom() ") val sessionId = olmDevice.createOutboundGroupSessionForRoom(roomId) - val keysClaimedMap = HashMap() - keysClaimedMap["ed25519"] = olmDevice.deviceEd25519Key!! + val keysClaimedMap = mapOf( + "ed25519" to olmDevice.deviceEd25519Key!! + ) olmDevice.addInboundGroupSession(sessionId!!, olmDevice.getSessionKey(sessionId)!!, roomId, olmDevice.deviceCurve25519Key!!, emptyList(), keysClaimedMap, false) @@ -303,11 +306,13 @@ internal class MXMegolmEncryption( Timber.tag(loggerTag.value).d("sending to device room key for ${session.sessionId} to ${contentMap.toDebugString()}") val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, contentMap) try { - sendToDeviceTask.execute(sendToDeviceParams) + withContext(coroutineDispatchers.io) { + sendToDeviceTask.execute(sendToDeviceParams) + } Timber.tag(loggerTag.value).i("shareUserDevicesKey() : sendToDevice succeeds after ${System.currentTimeMillis() - t0} ms") } catch (failure: Throwable) { // What to do here... - Timber.tag(loggerTag.value).e("shareUserDevicesKey() : Failed to share session <${session.sessionId}> with $devicesByUser ") + Timber.tag(loggerTag.value).e("shareUserDevicesKey() : Failed to share <${session.sessionId}>") } } else { Timber.tag(loggerTag.value).i("shareUserDevicesKey() : no need to share key") @@ -346,9 +351,12 @@ internal class MXMegolmEncryption( } ) try { - sendToDeviceTask.execute(params) + withContext(coroutineDispatchers.io) { + sendToDeviceTask.execute(params) + } } catch (failure: Throwable) { - Timber.tag(loggerTag.value).e("notifyKeyWithHeld() : Failed to notify withheld key for $targets session: $sessionId ") + Timber.tag(loggerTag.value) + .e("notifyKeyWithHeld() :$sessionId Failed to send withheld ${targets.map { "${it.userId}|${it.deviceId}" }}") } } @@ -432,20 +440,20 @@ internal class MXMegolmEncryption( } } - override suspend fun reshareKey(sessionId: String, + override suspend fun reshareKey(groupSessionId: String, userId: String, deviceId: String, senderKey: String): Boolean { - Timber.tag(loggerTag.value).i("process reshareKey for $sessionId to $userId:$deviceId") + Timber.tag(loggerTag.value).i("process reshareKey for $groupSessionId to $userId:$deviceId") val deviceInfo = cryptoStore.getUserDevice(userId, deviceId) ?: return false .also { Timber.tag(loggerTag.value).w("reshareKey: Device not found") } // Get the chain index of the key we previously sent this device - val wasSessionSharedWithUser = cryptoStore.getSharedSessionInfo(roomId, sessionId, deviceInfo) + val wasSessionSharedWithUser = cryptoStore.getSharedSessionInfo(roomId, groupSessionId, deviceInfo) if (!wasSessionSharedWithUser.found) { // This session was never shared with this user // Send a room key with held - notifyKeyWithHeld(listOf(UserDevice(userId, deviceId)), sessionId, senderKey, WithHeldCode.UNAUTHORISED) + notifyKeyWithHeld(listOf(UserDevice(userId, deviceId)), groupSessionId, senderKey, WithHeldCode.UNAUTHORISED) Timber.tag(loggerTag.value).w("reshareKey: ERROR : Never shared megolm with this device") return false } @@ -456,42 +464,47 @@ internal class MXMegolmEncryption( } val devicesByUser = mapOf(userId to listOf(deviceInfo)) - val usersDeviceMap = ensureOlmSessionsForDevicesAction.handle(devicesByUser) - val olmSessionResult = usersDeviceMap.getObject(userId, deviceId) - olmSessionResult?.sessionId // no session with this device, probably because there were no one-time keys. - // ensureOlmSessionsForDevicesAction has already done the logging, so just skip it. - ?: return false.also { - Timber.tag(loggerTag.value).w("reshareKey: no session with this device, probably because there were no one-time keys") - } + val usersDeviceMap = try { + ensureOlmSessionsForDevicesAction.handle(devicesByUser) + } catch (failure: Throwable) { + null + } + val olmSessionResult = usersDeviceMap?.getObject(userId, deviceId) + if (olmSessionResult?.sessionId == null) { + Timber.tag(loggerTag.value).w("reshareKey: no session with this device, probably because there were no one-time keys") + return false + } + Timber.tag(loggerTag.value).i(" reshareKey: $groupSessionId:$chainIndex with device $userId:$deviceId using session ${olmSessionResult.sessionId}") - Timber.tag(loggerTag.value).i(" reshareKey: sharing keys for session $senderKey|$sessionId:$chainIndex with device $userId:$deviceId") + val sessionHolder = try { + olmDevice.getInboundGroupSession(groupSessionId, senderKey, roomId) + } catch (failure: Throwable) { + Timber.tag(loggerTag.value).e(failure, "shareKeysWithDevice: failed to get session $groupSessionId") + return false + } - val payloadJson = mutableMapOf("type" to EventType.FORWARDED_ROOM_KEY) + val export = sessionHolder.mutex.withLock { + sessionHolder.wrapper.exportKeys() + } ?: return false.also { + Timber.tag(loggerTag.value).e("shareKeysWithDevice: failed to export group session $groupSessionId") + } - runCatching { olmDevice.getInboundGroupSession(sessionId, senderKey, roomId) } - .fold( - { - // TODO - payloadJson["content"] = it.exportKeys(chainIndex.toLong()) ?: "" - }, - { - // TODO - Timber.tag(loggerTag.value).e(it, "reshareKey: failed to get session $sessionId|$senderKey|$roomId") - } - - ) + val payloadJson = mapOf( + "type" to EventType.FORWARDED_ROOM_KEY, + "content" to export + ) val encodedPayload = messageEncrypter.encryptMessage(payloadJson, listOf(deviceInfo)) val sendToDeviceMap = MXUsersDevicesMap() sendToDeviceMap.setObject(userId, deviceId, encodedPayload) - Timber.tag(loggerTag.value).i("reshareKey() : sending session $sessionId to $userId:$deviceId") + Timber.tag(loggerTag.value).i("reshareKey() : sending session $groupSessionId to $userId:$deviceId") val sendToDeviceParams = SendToDeviceTask.Params(EventType.ENCRYPTED, sendToDeviceMap) return try { sendToDeviceTask.execute(sendToDeviceParams) - Timber.tag(loggerTag.value).i("reshareKey() : successfully send <$sessionId> to $userId:$deviceId") + Timber.tag(loggerTag.value).i("reshareKey() : successfully send <$groupSessionId> to $userId:$deviceId") true } catch (failure: Throwable) { - Timber.tag(loggerTag.value).e(failure, "reshareKey() : fail to send <$sessionId> to $userId:$deviceId") + Timber.tag(loggerTag.value).e(failure, "reshareKey() : fail to send <$groupSessionId> to $userId:$deviceId") false } } diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/olm/MXOlmDecryption.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/olm/MXOlmDecryption.kt index f1bca4fbc6..afa249801d 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/olm/MXOlmDecryption.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/algorithms/olm/MXOlmDecryption.kt @@ -16,6 +16,8 @@ package org.matrix.android.sdk.internal.crypto.algorithms.olm +import kotlinx.coroutines.sync.withLock +import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.api.session.crypto.MXCryptoError import org.matrix.android.sdk.api.session.events.model.Event import org.matrix.android.sdk.api.session.events.model.toModel @@ -30,6 +32,7 @@ import org.matrix.android.sdk.internal.di.MoshiProvider import org.matrix.android.sdk.internal.util.convertFromUTF8 import timber.log.Timber +private val loggerTag = LoggerTag("MXOlmDecryption", LoggerTag.CRYPTO) internal class MXOlmDecryption( // The olm device interface private val olmDevice: MXOlmDevice, @@ -38,27 +41,27 @@ internal class MXOlmDecryption( IMXDecrypting { @Throws(MXCryptoError::class) - override fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult { + override suspend fun decryptEvent(event: Event, timeline: String): MXEventDecryptionResult { val olmEventContent = event.content.toModel() ?: run { - Timber.e("## decryptEvent() : bad event format") + Timber.tag(loggerTag.value).e("## decryptEvent() : bad event format") throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_EVENT_FORMAT, MXCryptoError.BAD_EVENT_FORMAT_TEXT_REASON) } val cipherText = olmEventContent.ciphertext ?: run { - Timber.e("## decryptEvent() : missing cipher text") + Timber.tag(loggerTag.value).e("## decryptEvent() : missing cipher text") throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_CIPHER_TEXT, MXCryptoError.MISSING_CIPHER_TEXT_REASON) } val senderKey = olmEventContent.senderKey ?: run { - Timber.e("## decryptEvent() : missing sender key") + Timber.tag(loggerTag.value).e("## decryptEvent() : missing sender key") throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_SENDER_KEY, MXCryptoError.MISSING_SENDER_KEY_TEXT_REASON) } val messageAny = cipherText[olmDevice.deviceCurve25519Key] ?: run { - Timber.e("## decryptEvent() : our device ${olmDevice.deviceCurve25519Key} is not included in recipients") + Timber.tag(loggerTag.value).e("## decryptEvent() : our device ${olmDevice.deviceCurve25519Key} is not included in recipients") throw MXCryptoError.Base(MXCryptoError.ErrorType.NOT_INCLUDE_IN_RECIPIENTS, MXCryptoError.NOT_INCLUDED_IN_RECIPIENT_REASON) } @@ -69,7 +72,7 @@ internal class MXOlmDecryption( val decryptedPayload = decryptMessage(message, senderKey) if (decryptedPayload == null) { - Timber.e("## decryptEvent() Failed to decrypt Olm event (id= ${event.eventId} from $senderKey") + Timber.tag(loggerTag.value).e("## decryptEvent() Failed to decrypt Olm event (id= ${event.eventId} from $senderKey") throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_ENCRYPTED_MESSAGE, MXCryptoError.BAD_ENCRYPTED_MESSAGE_REASON) } val payloadString = convertFromUTF8(decryptedPayload) @@ -78,30 +81,30 @@ internal class MXOlmDecryption( val payload = adapter.fromJson(payloadString) if (payload == null) { - Timber.e("## decryptEvent failed : null payload") + Timber.tag(loggerTag.value).e("## decryptEvent failed : null payload") throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, MXCryptoError.MISSING_CIPHER_TEXT_REASON) } val olmPayloadContent = OlmPayloadContent.fromJsonString(payloadString) ?: run { - Timber.e("## decryptEvent() : bad olmPayloadContent format") + Timber.tag(loggerTag.value).e("## decryptEvent() : bad olmPayloadContent format") throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_DECRYPTED_FORMAT, MXCryptoError.BAD_DECRYPTED_FORMAT_TEXT_REASON) } if (olmPayloadContent.recipient.isNullOrBlank()) { val reason = String.format(MXCryptoError.ERROR_MISSING_PROPERTY_REASON, "recipient") - Timber.e("## decryptEvent() : $reason") + Timber.tag(loggerTag.value).e("## decryptEvent() : $reason") throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_PROPERTY, reason) } if (olmPayloadContent.recipient != userId) { - Timber.e("## decryptEvent() : Event ${event.eventId}:" + + Timber.tag(loggerTag.value).e("## decryptEvent() : Event ${event.eventId}:" + " Intended recipient ${olmPayloadContent.recipient} does not match our id $userId") throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_RECIPIENT, String.format(MXCryptoError.BAD_RECIPIENT_REASON, olmPayloadContent.recipient)) } val recipientKeys = olmPayloadContent.recipientKeys ?: run { - Timber.e("## decryptEvent() : Olm event (id=${event.eventId}) contains no 'recipient_keys'" + + Timber.tag(loggerTag.value).e("## decryptEvent() : Olm event (id=${event.eventId}) contains no 'recipient_keys'" + " property; cannot prevent unknown-key attack") throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_PROPERTY, String.format(MXCryptoError.ERROR_MISSING_PROPERTY_REASON, "recipient_keys")) @@ -110,31 +113,34 @@ internal class MXOlmDecryption( val ed25519 = recipientKeys["ed25519"] if (ed25519 != olmDevice.deviceEd25519Key) { - Timber.e("## decryptEvent() : Event ${event.eventId}: Intended recipient ed25519 key $ed25519 did not match ours") + Timber.tag(loggerTag.value).e("## decryptEvent() : Event ${event.eventId}: Intended recipient ed25519 key $ed25519 did not match ours") throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_RECIPIENT_KEY, MXCryptoError.BAD_RECIPIENT_KEY_REASON) } if (olmPayloadContent.sender.isNullOrBlank()) { - Timber.e("## decryptEvent() : Olm event (id=${event.eventId}) contains no 'sender' property; cannot prevent unknown-key attack") + Timber.tag(loggerTag.value) + .e("## decryptEvent() : Olm event (id=${event.eventId}) contains no 'sender' property; cannot prevent unknown-key attack") throw MXCryptoError.Base(MXCryptoError.ErrorType.MISSING_PROPERTY, String.format(MXCryptoError.ERROR_MISSING_PROPERTY_REASON, "sender")) } if (olmPayloadContent.sender != event.senderId) { - Timber.e("Event ${event.eventId}: original sender ${olmPayloadContent.sender} does not match reported sender ${event.senderId}") + Timber.tag(loggerTag.value) + .e("Event ${event.eventId}: sender ${olmPayloadContent.sender} does not match reported sender ${event.senderId}") throw MXCryptoError.Base(MXCryptoError.ErrorType.FORWARDED_MESSAGE, String.format(MXCryptoError.FORWARDED_MESSAGE_REASON, olmPayloadContent.sender)) } if (olmPayloadContent.roomId != event.roomId) { - Timber.e("## decryptEvent() : Event ${event.eventId}: original room ${olmPayloadContent.roomId} does not match reported room ${event.roomId}") + Timber.tag(loggerTag.value) + .e("## decryptEvent() : Event ${event.eventId}: room ${olmPayloadContent.roomId} does not match reported room ${event.roomId}") throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_ROOM, String.format(MXCryptoError.BAD_ROOM_REASON, olmPayloadContent.roomId)) } val keys = olmPayloadContent.keys ?: run { - Timber.e("## decryptEvent failed : null keys") + Timber.tag(loggerTag.value).e("## decryptEvent failed : null keys") throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, MXCryptoError.MISSING_CIPHER_TEXT_REASON) } @@ -153,8 +159,8 @@ internal class MXOlmDecryption( * @param message message object, with 'type' and 'body' fields. * @return payload, if decrypted successfully. */ - private fun decryptMessage(message: JsonDict, theirDeviceIdentityKey: String): String? { - val sessionIds = olmDevice.getSessionIds(theirDeviceIdentityKey).orEmpty() + private suspend fun decryptMessage(message: JsonDict, theirDeviceIdentityKey: String): String? { + val sessionIds = olmDevice.getSessionIds(theirDeviceIdentityKey) val messageBody = message["body"] as? String ?: return null val messageType = when (val typeAsVoid = message["type"]) { @@ -166,11 +172,32 @@ internal class MXOlmDecryption( // Try each session in turn // decryptionErrors = {}; + + val isPreKey = messageType == 0 + // we want to synchronize on prekey if not we could end up create two olm sessions + // Not very clear but it looks like the js-sdk for consistency + return if (isPreKey) { + olmDevice.mutex.withLock { + reallyDecryptMessage(sessionIds, messageBody, messageType, theirDeviceIdentityKey) + } + } else { + reallyDecryptMessage(sessionIds, messageBody, messageType, theirDeviceIdentityKey) + } + } + + private suspend fun reallyDecryptMessage(sessionIds: List, messageBody: String, messageType: Int, theirDeviceIdentityKey: String): String? { + Timber.tag(loggerTag.value).d("decryptMessage() try to decrypt olm message type:$messageType from ${sessionIds.size} known sessions") for (sessionId in sessionIds) { - val payload = olmDevice.decryptMessage(messageBody, messageType, sessionId, theirDeviceIdentityKey) + val payload = try { + olmDevice.decryptMessage(messageBody, messageType, sessionId, theirDeviceIdentityKey) + } catch (throwable: Exception) { + // As we are trying one by one, we don't really care of the error here + Timber.tag(loggerTag.value).d("decryptMessage() failed with session $sessionId") + null + } if (null != payload) { - Timber.v("## decryptMessage() : Decrypted Olm message from $theirDeviceIdentityKey with session $sessionId") + Timber.tag(loggerTag.value).v("## decryptMessage() : Decrypted Olm message from $theirDeviceIdentityKey with session $sessionId") return payload } else { val foundSession = olmDevice.matchesSession(theirDeviceIdentityKey, sessionId, messageType, messageBody) @@ -178,7 +205,7 @@ internal class MXOlmDecryption( if (foundSession) { // Decryption failed, but it was a prekey message matching this // session, so it should have worked. - Timber.e("## decryptMessage() : Error decrypting prekey message with existing session id $sessionId:TODO") + Timber.tag(loggerTag.value).e("## decryptMessage() : Error decrypting prekey message with existing session id $sessionId:TODO") return null } } @@ -189,9 +216,9 @@ internal class MXOlmDecryption( // didn't work. if (sessionIds.isEmpty()) { - Timber.e("## decryptMessage() : No existing sessions") + Timber.tag(loggerTag.value).e("## decryptMessage() : No existing sessions") } else { - Timber.e("## decryptMessage() : Error decrypting non-prekey message with existing sessions") + Timber.tag(loggerTag.value).e("## decryptMessage() : Error decrypting non-prekey message with existing sessions") } return null @@ -199,14 +226,17 @@ internal class MXOlmDecryption( // prekey message which doesn't match any existing sessions: make a new // session. + // XXXX Possible races here? if concurrent access for same prekey message, we might create 2 sessions? + Timber.tag(loggerTag.value).d("## decryptMessage() : Create inbound group session from prekey sender:$theirDeviceIdentityKey") + val res = olmDevice.createInboundSession(theirDeviceIdentityKey, messageType, messageBody) if (null == res) { - Timber.e("## decryptMessage() : Error decrypting non-prekey message with existing sessions") + Timber.tag(loggerTag.value).e("## decryptMessage() : Error decrypting non-prekey message with existing sessions") return null } - Timber.v("## decryptMessage() : Created new inbound Olm session get id ${res["session_id"]} with $theirDeviceIdentityKey") + Timber.tag(loggerTag.value).v("## decryptMessage() : Created new inbound Olm session get id ${res["session_id"]} with $theirDeviceIdentityKey") return res["payload"] } diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/crosssigning/UpdateTrustWorker.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/crosssigning/UpdateTrustWorker.kt index 5cd647ff6f..9325355d28 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/crosssigning/UpdateTrustWorker.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/crosssigning/UpdateTrustWorker.kt @@ -96,7 +96,7 @@ internal class UpdateTrustWorker(context: Context, params: WorkerParameters, ses if (userList.isNotEmpty()) { // Unfortunately we don't have much info on what did exactly changed (is it the cross signing keys of that user, // or a new device?) So we check all again :/ - Timber.d("## CrossSigning - Updating trust for users: ${userList.logLimit()}") + Timber.v("## CrossSigning - Updating trust for users: ${userList.logLimit()}") updateTrust(userList) } @@ -148,7 +148,7 @@ internal class UpdateTrustWorker(context: Context, params: WorkerParameters, ses myUserId -> myTrustResult else -> { crossSigningService.checkOtherMSKTrusted(myCrossSigningInfo, entry.value).also { - Timber.d("## CrossSigning - user:${entry.key} result:$it") + Timber.v("## CrossSigning - user:${entry.key} result:$it") } } } @@ -178,7 +178,7 @@ internal class UpdateTrustWorker(context: Context, params: WorkerParameters, ses // Update trust if needed devicesEntities?.forEach { device -> val crossSignedVerified = trustMap?.get(device)?.isCrossSignedVerified() - Timber.d("## CrossSigning - Trust for ${device.userId}|${device.deviceId} : cross verified: ${trustMap?.get(device)}") + Timber.v("## CrossSigning - Trust for ${device.userId}|${device.deviceId} : cross verified: ${trustMap?.get(device)}") if (device.trustLevelEntity?.crossSignedVerified != crossSignedVerified) { Timber.d("## CrossSigning - Trust change detected for ${device.userId}|${device.deviceId} : cross verified: $crossSignedVerified") // need to save @@ -216,7 +216,7 @@ internal class UpdateTrustWorker(context: Context, params: WorkerParameters, ses .equalTo(RoomSummaryEntityFields.IS_ENCRYPTED, true) .findFirst() ?.let { roomSummary -> - Timber.d("## CrossSigning - Check shield state for room $roomId") + Timber.v("## CrossSigning - Check shield state for room $roomId") val allActiveRoomMembers = RoomMemberHelper(sessionRealm, roomId).getActiveRoomMemberIds() try { val updatedTrust = computeRoomShield( @@ -277,7 +277,7 @@ internal class UpdateTrustWorker(context: Context, params: WorkerParameters, ses cryptoRealm: Realm, activeMemberUserIds: List, roomSummaryEntity: RoomSummaryEntity): RoomEncryptionTrustLevel { - Timber.d("## CrossSigning - computeRoomShield ${roomSummaryEntity.roomId} -> ${activeMemberUserIds.logLimit()}") + Timber.v("## CrossSigning - computeRoomShield ${roomSummaryEntity.roomId} -> ${activeMemberUserIds.logLimit()}") // The set of “all users” depends on the type of room: // For regular / topic rooms which have more than 2 members (including yourself) are considered when decorating a room // For 1:1 and group DM rooms, all other users (i.e. excluding yourself) are considered when decorating a room diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/keysbackup/DefaultKeysBackupService.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/keysbackup/DefaultKeysBackupService.kt index b20168eaa3..954c2dbe43 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/keysbackup/DefaultKeysBackupService.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/keysbackup/DefaultKeysBackupService.kt @@ -671,7 +671,6 @@ internal class DefaultKeysBackupService @Inject constructor( Timber.e("restoreKeysWithRecoveryKey: Invalid recovery key for this keys version") throw InvalidParameterException("Invalid recovery key") } - // Get a PK decryption instance pkDecryptionFromRecoveryKey(recoveryKey) } @@ -681,6 +680,10 @@ internal class DefaultKeysBackupService @Inject constructor( throw InvalidParameterException("Invalid recovery key") } + // Save for next time and for gossiping + // Save now as it's valid, don't wait for the import as it could take long. + saveBackupRecoveryKey(recoveryKey, keysVersionResult.version) + stepProgressListener?.onStepProgress(StepProgressListener.Step.DownloadingKey) // Get backed up keys from the homeserver @@ -729,8 +732,6 @@ internal class DefaultKeysBackupService @Inject constructor( if (backUp) { maybeBackupKeys() } - // Save for next time and for gossiping - saveBackupRecoveryKey(recoveryKey, keysVersionResult.version) result } }.foldToCallback(callback) diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/model/CryptoDeviceInfo.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/model/CryptoDeviceInfo.kt index 5e7744853a..b3638dc414 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/model/CryptoDeviceInfo.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/model/CryptoDeviceInfo.kt @@ -70,6 +70,8 @@ data class CryptoDeviceInfo( keys?.let { map["keys"] = it } return map } + + fun shortDebugString() = "$userId|$deviceId" } internal fun CryptoDeviceInfo.toRest(): DeviceKeys { diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/model/OlmSessionWrapper.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/model/OlmSessionWrapper.kt index 15b92f105a..263cb3b036 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/model/OlmSessionWrapper.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/model/OlmSessionWrapper.kt @@ -16,6 +16,7 @@ package org.matrix.android.sdk.internal.crypto.model +import kotlinx.coroutines.sync.Mutex import org.matrix.olm.OlmSession /** @@ -25,7 +26,10 @@ data class OlmSessionWrapper( // The associated olm session. val olmSession: OlmSession, // Timestamp at which the session last received a message. - var lastReceivedMessageTs: Long = 0) { + var lastReceivedMessageTs: Long = 0, + + val mutex: Mutex = Mutex() +) { /** * Notify that a message has been received on this olm session so that it updates `lastReceivedMessageTs` diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/store/IMXCryptoStore.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/store/IMXCryptoStore.kt index 96ea5c03fa..e662ff74e7 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/store/IMXCryptoStore.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/store/IMXCryptoStore.kt @@ -54,7 +54,7 @@ internal interface IMXCryptoStore { /** * @return the olm account */ - fun getOlmAccount(): OlmAccount + fun doWithOlmAccount(block: (OlmAccount) -> T): T fun getOrCreateOlmAccount(): OlmAccount @@ -261,7 +261,7 @@ internal interface IMXCryptoStore { fun storeSession(olmSessionWrapper: OlmSessionWrapper, deviceKey: String) /** - * Retrieve the end-to-end session ids between the logged-in user and another + * Retrieve all end-to-end session ids between our own device and another * device. * * @param deviceKey the public key of the other device. @@ -270,7 +270,7 @@ internal interface IMXCryptoStore { fun getDeviceSessionIds(deviceKey: String): List? /** - * Retrieve an end-to-end session between the logged-in user and another + * Retrieve an end-to-end session between our own device and another * device. * * @param sessionId the session Id. diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/store/db/RealmCryptoStore.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/store/db/RealmCryptoStore.kt index a07827c033..585b3d2d25 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/store/db/RealmCryptoStore.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/crypto/store/db/RealmCryptoStore.kt @@ -104,7 +104,6 @@ import timber.log.Timber import java.util.concurrent.Executors import java.util.concurrent.TimeUnit import javax.inject.Inject -import kotlin.collections.set @SessionScope internal class RealmCryptoStore @Inject constructor( @@ -124,12 +123,6 @@ internal class RealmCryptoStore @Inject constructor( // The olm account private var olmAccount: OlmAccount? = null - // Cache for OlmSession, to release them properly - private val olmSessionsToRelease = HashMap() - - // Cache for InboundGroupSession, to release them properly - private val inboundGroupSessionToRelease = HashMap() - private val newSessionListeners = ArrayList() override fun addNewSessionListener(listener: NewSessionListener) { @@ -213,16 +206,6 @@ internal class RealmCryptoStore @Inject constructor( monarchyWriteAsyncExecutor.awaitTermination(1, TimeUnit.MINUTES) } - olmSessionsToRelease.forEach { - it.value.olmSession.releaseSession() - } - olmSessionsToRelease.clear() - - inboundGroupSessionToRelease.forEach { - it.value.olmInboundGroupSession?.releaseSession() - } - inboundGroupSessionToRelease.clear() - olmAccount?.releaseAccount() realmLocker?.close() @@ -247,10 +230,18 @@ internal class RealmCryptoStore @Inject constructor( } } - override fun getOlmAccount(): OlmAccount { - return olmAccount!! + /** + * Olm account access should be synchronized + */ + override fun doWithOlmAccount(block: (OlmAccount) -> T): T { + return olmAccount!!.let { olmAccount -> + synchronized(olmAccount) { + block.invoke(olmAccount) + } + } } + @Synchronized override fun getOrCreateOlmAccount(): OlmAccount { doRealmTransaction(realmConfiguration) { val metaData = it.where().findFirst() @@ -680,13 +671,6 @@ internal class RealmCryptoStore @Inject constructor( if (sessionIdentifier != null) { val key = OlmSessionEntity.createPrimaryKey(sessionIdentifier, deviceKey) - // Release memory of previously known session, if it is not the same one - if (olmSessionsToRelease[key]?.olmSession != olmSessionWrapper.olmSession) { - olmSessionsToRelease[key]?.olmSession?.releaseSession() - } - - olmSessionsToRelease[key] = olmSessionWrapper - doRealmTransaction(realmConfiguration) { val realmOlmSession = OlmSessionEntity().apply { primaryKey = key @@ -703,23 +687,18 @@ internal class RealmCryptoStore @Inject constructor( override fun getDeviceSession(sessionId: String, deviceKey: String): OlmSessionWrapper? { val key = OlmSessionEntity.createPrimaryKey(sessionId, deviceKey) - - // If not in cache (or not found), try to read it from realm - if (olmSessionsToRelease[key] == null) { - doRealmQueryAndCopy(realmConfiguration) { - it.where() - .equalTo(OlmSessionEntityFields.PRIMARY_KEY, key) - .findFirst() - } - ?.let { - val olmSession = it.getOlmSession() - if (olmSession != null && it.sessionId != null) { - olmSessionsToRelease[key] = OlmSessionWrapper(olmSession, it.lastReceivedMessageTs) - } - } + return doRealmQueryAndCopy(realmConfiguration) { + it.where() + .equalTo(OlmSessionEntityFields.PRIMARY_KEY, key) + .findFirst() } - - return olmSessionsToRelease[key] + ?.let { + val olmSession = it.getOlmSession() + if (olmSession != null && it.sessionId != null) { + return@let OlmSessionWrapper(olmSession, it.lastReceivedMessageTs) + } + null + } } override fun getLastUsedSessionId(deviceKey: String): String? { @@ -761,13 +740,6 @@ internal class RealmCryptoStore @Inject constructor( if (sessionIdentifier != null) { val key = OlmInboundGroupSessionEntity.createPrimaryKey(sessionIdentifier, session.senderKey) - // Release memory of previously known session, if it is not the same one - if (inboundGroupSessionToRelease[key] != session) { - inboundGroupSessionToRelease[key]?.olmInboundGroupSession?.releaseSession() - } - - inboundGroupSessionToRelease[key] = session - val realmOlmInboundGroupSession = OlmInboundGroupSessionEntity().apply { primaryKey = key sessionId = sessionIdentifier @@ -784,20 +756,12 @@ internal class RealmCryptoStore @Inject constructor( override fun getInboundGroupSession(sessionId: String, senderKey: String): OlmInboundGroupSessionWrapper2? { val key = OlmInboundGroupSessionEntity.createPrimaryKey(sessionId, senderKey) - // If not in cache (or not found), try to read it from realm - if (inboundGroupSessionToRelease[key] == null) { - doWithRealm(realmConfiguration) { - it.where() - .equalTo(OlmInboundGroupSessionEntityFields.PRIMARY_KEY, key) - .findFirst() - ?.getInboundGroupSession() - } - ?.let { - inboundGroupSessionToRelease[key] = it - } + return doWithRealm(realmConfiguration) { + it.where() + .equalTo(OlmInboundGroupSessionEntityFields.PRIMARY_KEY, key) + .findFirst() + ?.getInboundGroupSession() } - - return inboundGroupSessionToRelease[key] } override fun getCurrentOutboundGroupSessionForRoom(roomId: String): OutboundGroupSessionWrapper? { @@ -853,10 +817,6 @@ internal class RealmCryptoStore @Inject constructor( override fun removeInboundGroupSession(sessionId: String, senderKey: String) { val key = OlmInboundGroupSessionEntity.createPrimaryKey(sessionId, senderKey) - // Release memory of previously known session - inboundGroupSessionToRelease[key]?.olmInboundGroupSession?.releaseSession() - inboundGroupSessionToRelease.remove(key) - doRealmTransaction(realmConfiguration) { it.where() .equalTo(OlmInboundGroupSessionEntityFields.PRIMARY_KEY, key) diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/room/relation/threads/FetchThreadTimelineTask.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/room/relation/threads/FetchThreadTimelineTask.kt index e0d501c515..cd06d47f05 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/room/relation/threads/FetchThreadTimelineTask.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/room/relation/threads/FetchThreadTimelineTask.kt @@ -156,7 +156,7 @@ internal class DefaultFetchThreadTimelineTask @Inject constructor( * Invoke the event decryption mechanism for a specific event */ - private fun decryptIfNeeded(event: Event, roomId: String) { + private suspend fun decryptIfNeeded(event: Event, roomId: String) { try { // Event from sync does not have roomId, so add it to the event first val result = cryptoService.decryptEvent(event.copy(roomId = roomId), "") diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/room/summary/RoomSummaryUpdater.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/room/summary/RoomSummaryUpdater.kt index 449fe60dca..c9712c5721 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/room/summary/RoomSummaryUpdater.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/room/summary/RoomSummaryUpdater.kt @@ -18,6 +18,7 @@ package org.matrix.android.sdk.internal.session.room.summary import io.realm.Realm import io.realm.kotlin.createObject +import kotlinx.coroutines.runBlocking import org.matrix.android.sdk.api.extensions.orFalse import org.matrix.android.sdk.api.extensions.tryOrNull import org.matrix.android.sdk.api.session.events.model.EventType @@ -165,7 +166,9 @@ internal class RoomSummaryUpdater @Inject constructor( Timber.v("Should decrypt ${latestPreviewableEvent.eventId}") // mmm i want to decrypt now or is it ok to do it async? tryOrNull { - eventDecryptor.decryptEvent(root.asDomain(), "") + runBlocking { + eventDecryptor.decryptEvent(root.asDomain(), "") + } } ?.let { root.setDecryptionResult(it) } } diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/room/timeline/TimelineEventDecryptor.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/room/timeline/TimelineEventDecryptor.kt index 49a8a8b55a..bacac58d84 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/room/timeline/TimelineEventDecryptor.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/room/timeline/TimelineEventDecryptor.kt @@ -17,6 +17,7 @@ package org.matrix.android.sdk.internal.session.room.timeline import io.realm.Realm import io.realm.RealmConfiguration +import kotlinx.coroutines.runBlocking import org.matrix.android.sdk.api.session.crypto.CryptoService import org.matrix.android.sdk.api.session.crypto.MXCryptoError import org.matrix.android.sdk.api.session.events.model.Event @@ -99,7 +100,9 @@ internal class TimelineEventDecryptor @Inject constructor( } executor?.execute { Realm.getInstance(realmConfiguration).use { realm -> - processDecryptRequest(request, realm) + runBlocking { + processDecryptRequest(request, realm) + } } } } @@ -115,7 +118,7 @@ internal class TimelineEventDecryptor @Inject constructor( threadsAwarenessHandler.makeEventThreadAware(realm, event.roomId, decryptedEvent, eventEntity) } } - private fun processDecryptRequest(request: DecryptionRequest, realm: Realm) { + private suspend fun processDecryptRequest(request: DecryptionRequest, realm: Realm) { val event = request.event val timelineId = request.timelineId diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/sync/SyncResponseHandler.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/sync/SyncResponseHandler.kt index f93da9705d..1bbf54a788 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/sync/SyncResponseHandler.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/sync/SyncResponseHandler.kt @@ -110,6 +110,7 @@ internal class SyncResponseHandler @Inject constructor( // Start one big transaction monarchy.awaitTransaction { realm -> + // IMPORTANT nothing should be suspend here as we are accessing the realm instance (thread local) measureTimeMillis { Timber.v("Handle rooms") reportSubtask(reporter, InitSyncStep.ImportingAccountRoom, 1, 0.7f) { diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/sync/handler/CryptoSyncHandler.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/sync/handler/CryptoSyncHandler.kt index f299d3effa..9ae7b82777 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/sync/handler/CryptoSyncHandler.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/sync/handler/CryptoSyncHandler.kt @@ -38,7 +38,7 @@ private val loggerTag = LoggerTag("CryptoSyncHandler", LoggerTag.CRYPTO) internal class CryptoSyncHandler @Inject constructor(private val cryptoService: DefaultCryptoService, private val verificationService: DefaultVerificationService) { - fun handleToDevice(toDevice: ToDeviceSyncResponse, progressReporter: ProgressReporter? = null) { + suspend fun handleToDevice(toDevice: ToDeviceSyncResponse, progressReporter: ProgressReporter? = null) { val total = toDevice.events?.size ?: 0 toDevice.events?.forEachIndexed { index, event -> progressReporter?.reportProgress(index * 100F / total) @@ -66,7 +66,7 @@ internal class CryptoSyncHandler @Inject constructor(private val cryptoService: * @param timelineId the timeline identifier * @return true if the event has been decrypted */ - private fun decryptToDeviceEvent(event: Event, timelineId: String?): Boolean { + private suspend fun decryptToDeviceEvent(event: Event, timelineId: String?): Boolean { Timber.v("## CRYPTO | decryptToDeviceEvent") if (event.getClearType() == EventType.ENCRYPTED) { var result: MXEventDecryptionResult? = null @@ -80,6 +80,8 @@ internal class CryptoSyncHandler @Inject constructor(private val cryptoService: it.identityKey() == senderKey }?.deviceId ?: senderKey Timber.e("## CRYPTO | Failed to decrypt to device event from ${event.senderId}|$deviceId reason:<${event.mCryptoError ?: exception}>") + } catch (failure: Throwable) { + Timber.e(failure, "## CRYPTO | Failed to decrypt to device event from ${event.senderId}") } if (null != result) { @@ -91,7 +93,9 @@ internal class CryptoSyncHandler @Inject constructor(private val cryptoService: ) return true } else { - // should not happen + // Could happen for to device events + // None of the known session could decrypt the message + // In this case unwedging process might have been started (rate limited) Timber.e("## CRYPTO | ERROR NULL DECRYPTION RESULT from ${event.senderId}") } } diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/sync/handler/room/RoomSyncHandler.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/sync/handler/room/RoomSyncHandler.kt index 99e6521eb7..a5ad19bbf8 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/sync/handler/room/RoomSyncHandler.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/session/sync/handler/room/RoomSyncHandler.kt @@ -19,6 +19,7 @@ package org.matrix.android.sdk.internal.session.sync.handler.room import dagger.Lazy import io.realm.Realm import io.realm.kotlin.createObject +import kotlinx.coroutines.runBlocking import org.matrix.android.sdk.api.session.crypto.MXCryptoError import org.matrix.android.sdk.api.session.events.model.Event import org.matrix.android.sdk.api.session.events.model.EventType @@ -379,7 +380,9 @@ internal class RoomSyncHandler @Inject constructor(private val readReceiptHandle val isInitialSync = insertType == EventInsertType.INITIAL_SYNC if (event.isEncrypted() && !isInitialSync) { - decryptIfNeeded(event, roomId) + runBlocking { + decryptIfNeeded(event, roomId) + } } var contentToInject: String? = null if (!isInitialSync) { @@ -455,7 +458,7 @@ internal class RoomSyncHandler @Inject constructor(private val readReceiptHandle return chunkEntity } - private fun decryptIfNeeded(event: Event, roomId: String) { + private suspend fun decryptIfNeeded(event: Event, roomId: String) { try { // Event from sync does not have roomId, so add it to the event first val result = cryptoService.decryptEvent(event.copy(roomId = roomId), "") diff --git a/vector/src/main/java/im/vector/app/features/notifications/NotifiableEventResolver.kt b/vector/src/main/java/im/vector/app/features/notifications/NotifiableEventResolver.kt index ec034173fc..3cdc9e8c76 100644 --- a/vector/src/main/java/im/vector/app/features/notifications/NotifiableEventResolver.kt +++ b/vector/src/main/java/im/vector/app/features/notifications/NotifiableEventResolver.kt @@ -190,7 +190,7 @@ class NotifiableEventResolver @Inject constructor( } } - private fun TimelineEvent.attemptToDecryptIfNeeded(session: Session) { + private suspend fun TimelineEvent.attemptToDecryptIfNeeded(session: Session) { if (root.isEncrypted() && root.mxDecryptionResult == null) { // TODO use a global event decryptor? attache to session and that listen to new sessionId? // for now decrypt sync