Merge pull request #145 from ouchadam/feature/synapse-cache-error

Fixing synapse caching breaking encrypted smoke test
This commit is contained in:
Adam Brown 2022-09-15 20:40:17 +01:00 committed by GitHub
commit f8976cc1bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 81 additions and 41 deletions

View File

@ -25,7 +25,7 @@ jobs:
- name: Create pip requirements - name: Create pip requirements
run: | run: |
echo "matrix-synapse==v1.60.0" > requirements.txt echo "matrix-synapse" > requirements.txt
- name: Set up Python 3.8 - name: Set up Python 3.8
uses: actions/setup-python@v2 uses: actions/setup-python@v2

View File

@ -8,3 +8,7 @@ fun OlmAccount.readIdentityKeys(): Pair<Ed25519, Curve25519> {
val identityKeys = this.identityKeys() val identityKeys = this.identityKeys()
return Ed25519(identityKeys["ed25519"]!!) to Curve25519(identityKeys["curve25519"]!!) return Ed25519(identityKeys["ed25519"]!!) to Curve25519(identityKeys["curve25519"]!!)
} }
fun OlmAccount.oneTimeCurveKeys(): List<Pair<String, Curve25519>> {
return this.oneTimeKeys()["curve25519"]?.map { it.key to Curve25519(it.value) } ?: emptyList()
}

View File

@ -67,7 +67,7 @@ class OlmWrapper(
private suspend fun accountCrypto(deviceCredentials: DeviceCredentials): AccountCryptoSession? { private suspend fun accountCrypto(deviceCredentials: DeviceCredentials): AccountCryptoSession? {
return olmStore.read()?.let { olmAccount -> return olmStore.read()?.let { olmAccount ->
createAccountCryptoSession(deviceCredentials, olmAccount) createAccountCryptoSession(deviceCredentials, olmAccount, isNew = false)
} }
} }
@ -80,12 +80,12 @@ class OlmWrapper(
val olmAccount = this.olmAccount as OlmAccount val olmAccount = this.olmAccount as OlmAccount
olmAccount.generateOneTimeKeys(count) olmAccount.generateOneTimeKeys(count)
val oneTimeKeys = DeviceService.OneTimeKeys(olmAccount.oneTimeKeys()["curve25519"]!!.map { val oneTimeKeys = DeviceService.OneTimeKeys(olmAccount.oneTimeCurveKeys().map { (key, value) ->
DeviceService.OneTimeKeys.Key.SignedCurve( DeviceService.OneTimeKeys.Key.SignedCurve(
keyId = it.key, keyId = key,
value = it.value, value = value.value,
signature = DeviceService.OneTimeKeys.Key.SignedCurve.Ed25519Signature( signature = DeviceService.OneTimeKeys.Key.SignedCurve.Ed25519Signature(
value = it.value.toSignedJson(olmAccount), value = value.value.toSignedJson(olmAccount),
deviceId = credentials.deviceId, deviceId = credentials.deviceId,
userId = credentials.userId, userId = credentials.userId,
) )
@ -98,20 +98,21 @@ class OlmWrapper(
private suspend fun createAccountCrypto(deviceCredentials: DeviceCredentials, action: suspend (AccountCryptoSession) -> Unit): AccountCryptoSession { private suspend fun createAccountCrypto(deviceCredentials: DeviceCredentials, action: suspend (AccountCryptoSession) -> Unit): AccountCryptoSession {
val olmAccount = OlmAccount() val olmAccount = OlmAccount()
return createAccountCryptoSession(deviceCredentials, olmAccount).also { return createAccountCryptoSession(deviceCredentials, olmAccount, isNew = true).also {
action(it) action(it)
olmStore.persist(olmAccount) olmStore.persist(olmAccount)
} }
} }
private fun createAccountCryptoSession(credentials: DeviceCredentials, olmAccount: OlmAccount): AccountCryptoSession { private fun createAccountCryptoSession(credentials: DeviceCredentials, olmAccount: OlmAccount, isNew: Boolean): AccountCryptoSession {
val (identityKey, senderKey) = olmAccount.readIdentityKeys() val (identityKey, senderKey) = olmAccount.readIdentityKeys()
return AccountCryptoSession( return AccountCryptoSession(
fingerprint = identityKey, fingerprint = identityKey,
senderKey = senderKey, senderKey = senderKey,
deviceKeys = deviceKeyFactory.create(credentials.userId, credentials.deviceId, identityKey, senderKey, olmAccount), deviceKeys = deviceKeyFactory.create(credentials.userId, credentials.deviceId, identityKey, senderKey, olmAccount),
olmAccount = olmAccount, olmAccount = olmAccount,
maxKeys = olmAccount.maxOneTimeKeys().toInt() maxKeys = olmAccount.maxOneTimeKeys().toInt(),
hasKeys = !isNew,
) )
} }
@ -136,6 +137,7 @@ class OlmWrapper(
singletonFlows.update("room-${roomId.value}", rotatedSession) singletonFlows.update("room-${roomId.value}", rotatedSession)
} }
} }
else -> this else -> this
} }
} }
@ -277,10 +279,12 @@ class OlmWrapper(
} }
} }
} }
OlmMessage.MESSAGE_TYPE_MESSAGE -> { OlmMessage.MESSAGE_TYPE_MESSAGE -> {
logger.crypto("decrypting olm message type") logger.crypto("decrypting olm message type")
session.decryptMessage(olmMessage)?.let { JsonString(it) } session.decryptMessage(olmMessage)?.let { JsonString(it) }
} }
else -> throw IllegalArgumentException("Unknown message type: $type") else -> throw IllegalArgumentException("Unknown message type: $type")
} }
}.onFailure { }.onFailure {
@ -297,7 +301,7 @@ class OlmWrapper(
} }
private suspend fun AccountCryptoSession.updateAccountInstance(olmAccount: OlmAccount) { private suspend fun AccountCryptoSession.updateAccountInstance(olmAccount: OlmAccount) {
singletonFlows.update("account-crypto", this.copy(olmAccount = olmAccount)) singletonFlows.update("account-crypto", this.copy(olmAccount = olmAccount, hasKeys = true))
olmStore.persist(olmAccount) olmStore.persist(olmAccount)
} }

View File

@ -72,6 +72,7 @@ interface Olm {
val fingerprint: Ed25519, val fingerprint: Ed25519,
val senderKey: Curve25519, val senderKey: Curve25519,
val deviceKeys: DeviceKeys, val deviceKeys: DeviceKeys,
val hasKeys: Boolean,
val maxKeys: Int, val maxKeys: Int,
val olmAccount: Any, val olmAccount: Any,
) )

View File

@ -15,16 +15,28 @@ internal class MaybeCreateAndUploadOneTimeKeysUseCaseImpl(
private val credentialsStore: CredentialsStore, private val credentialsStore: CredentialsStore,
private val deviceService: DeviceService, private val deviceService: DeviceService,
private val logger: MatrixLogger, private val logger: MatrixLogger,
): MaybeCreateAndUploadOneTimeKeysUseCase { ) : MaybeCreateAndUploadOneTimeKeysUseCase {
override suspend fun invoke(currentServerKeyCount: ServerKeyCount) { override suspend fun invoke(currentServerKeyCount: ServerKeyCount) {
val ensureCryptoAccount = fetchAccountCryptoUseCase.invoke() val cryptoAccount = fetchAccountCryptoUseCase.invoke()
val keysDiff = (ensureCryptoAccount.maxKeys / 2) - currentServerKeyCount.value when {
if (keysDiff > 0) { currentServerKeyCount.value == 0 && cryptoAccount.hasKeys -> {
logger.crypto("current otk: $currentServerKeyCount, creating: $keysDiff") logger.crypto("Server has no keys but a crypto instance exists, waiting for next update")
ensureCryptoAccount.createAndUploadOneTimeKeys(countToCreate = keysDiff + (ensureCryptoAccount.maxKeys / 4)) }
} else {
logger.crypto("current otk: $currentServerKeyCount, not creating new keys") else -> {
val keysDiff = (cryptoAccount.maxKeys / 2) - currentServerKeyCount.value
when {
keysDiff > 0 -> {
logger.crypto("current otk: $currentServerKeyCount, creating: $keysDiff")
cryptoAccount.createAndUploadOneTimeKeys(countToCreate = keysDiff + (cryptoAccount.maxKeys / 4))
}
else -> {
logger.crypto("current otk: $currentServerKeyCount, not creating new keys")
}
}
}
} }
} }

View File

@ -22,12 +22,11 @@ class MaybeCreateAndUploadOneTimeKeysUseCaseTest {
private val fakeDeviceService = FakeDeviceService() private val fakeDeviceService = FakeDeviceService()
private val fakeOlm = FakeOlm() private val fakeOlm = FakeOlm()
private val fakeCredentialsStore = FakeCredentialsStore().also { private val fakeCredentialsStore = FakeCredentialsStore().also { it.givenCredentials().returns(A_USER_CREDENTIALS) }
it.givenCredentials().returns(A_USER_CREDENTIALS) private val fakeFetchAccountCryptoUseCase = FakeFetchAccountCryptoUseCase()
}
private val maybeCreateAndUploadOneTimeKeysUseCase = MaybeCreateAndUploadOneTimeKeysUseCaseImpl( private val maybeCreateAndUploadOneTimeKeysUseCase = MaybeCreateAndUploadOneTimeKeysUseCaseImpl(
FakeFetchAccountCryptoUseCase().also { it.givenFetch().returns(AN_ACCOUNT_CRYPTO_SESSION) }, fakeFetchAccountCryptoUseCase.also { it.givenFetch().returns(AN_ACCOUNT_CRYPTO_SESSION) },
fakeOlm, fakeOlm,
fakeCredentialsStore, fakeCredentialsStore,
fakeDeviceService, fakeDeviceService,
@ -43,6 +42,16 @@ class MaybeCreateAndUploadOneTimeKeysUseCaseTest {
fakeDeviceService.verifyDidntUploadOneTimeKeys() fakeDeviceService.verifyDidntUploadOneTimeKeys()
} }
@Test
fun `given account has keys and server count is 0 then does nothing`() = runTest {
fakeFetchAccountCryptoUseCase.givenFetch().returns(AN_ACCOUNT_CRYPTO_SESSION.copy(hasKeys = true))
val zeroServiceKeys = ServerKeyCount(0)
maybeCreateAndUploadOneTimeKeysUseCase.invoke(zeroServiceKeys)
fakeDeviceService.verifyDidntUploadOneTimeKeys()
}
@Test @Test
fun `given 0 current keys than generates and uploads 75 percent of the max key capacity`() = runTest { fun `given 0 current keys than generates and uploads 75 percent of the max key capacity`() = runTest {
fakeDeviceService.expect { it.uploadOneTimeKeys(GENERATED_ONE_TIME_KEYS) } fakeDeviceService.expect { it.uploadOneTimeKeys(GENERATED_ONE_TIME_KEYS) }

View File

@ -10,8 +10,9 @@ fun anAccountCryptoSession(
senderKey: Curve25519 = aCurve25519(), senderKey: Curve25519 = aCurve25519(),
deviceKeys: DeviceKeys = aDeviceKeys(), deviceKeys: DeviceKeys = aDeviceKeys(),
maxKeys: Int = 5, maxKeys: Int = 5,
hasKeys: Boolean = false,
olmAccount: Any = mockk(), olmAccount: Any = mockk(),
) = Olm.AccountCryptoSession(fingerprint, senderKey, deviceKeys, maxKeys, olmAccount) ) = Olm.AccountCryptoSession(fingerprint, senderKey, deviceKeys, hasKeys, maxKeys, olmAccount)
fun aRoomCryptoSession( fun aRoomCryptoSession(
creationTimestampUtc: Long = 0L, creationTimestampUtc: Long = 0L,

View File

@ -27,6 +27,7 @@ internal class SyncSideEffects(
response.deviceLists?.changed?.ifEmpty { null }?.let { response.deviceLists?.changed?.ifEmpty { null }?.let {
notifyDevicesUpdated.notifyChanges(it, requestToken) notifyDevicesUpdated.notifyChanges(it, requestToken)
} }
oneTimeKeyProducer.onServerKeyCount(response.oneTimeKeysCount["signed_curve25519"] ?: ServerKeyCount(0)) oneTimeKeyProducer.onServerKeyCount(response.oneTimeKeysCount["signed_curve25519"] ?: ServerKeyCount(0))
val decryptedToDeviceEvents = decryptedToDeviceEvents(response) val decryptedToDeviceEvents = decryptedToDeviceEvents(response)

View File

@ -20,10 +20,7 @@ import org.junit.jupiter.api.MethodOrderer
import org.junit.jupiter.api.Order import org.junit.jupiter.api.Order
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestMethodOrder import org.junit.jupiter.api.TestMethodOrder
import test.MatrixTestScope import test.*
import test.TestMatrix
import test.flowTest
import test.restoreLoginAndInitialSync
import java.nio.file.Paths import java.nio.file.Paths
import java.util.* import java.util.*
@ -35,8 +32,8 @@ class SmokeTest {
@Test @Test
@Order(1) @Order(1)
fun `can register accounts`() = runTest { fun `can register accounts`() = runTest {
SharedState._alice = createAndRegisterAccount() SharedState._alice = createAndRegisterAccount("alice")
SharedState._bob = createAndRegisterAccount() SharedState._bob = createAndRegisterAccount("bob")
} }
@Test @Test
@ -94,7 +91,7 @@ class SmokeTest {
@Test @Test
fun `can import E2E room keys file`() = runTest { fun `can import E2E room keys file`() = runTest {
val ignoredUser = TestUser("ignored", RoomMember(UserId("ignored"), null, null), "ignored") val ignoredUser = TestUser("ignored", RoomMember(UserId("ignored"), null, null), "ignored", "ignored")
val cryptoService = TestMatrix(ignoredUser, includeLogging = true).client.cryptoService() val cryptoService = TestMatrix(ignoredUser, includeLogging = true).client.cryptoService()
val stream = loadResourceStream("element-keys.txt") val stream = loadResourceStream("element-keys.txt")
@ -133,10 +130,10 @@ class SmokeTest {
} }
} }
private suspend fun createAndRegisterAccount(): TestUser { private suspend fun createAndRegisterAccount(testUsername: String): TestUser {
val aUserName = "${UUID.randomUUID()}" val aUserName = "${UUID.randomUUID()}"
val userId = UserId("@$aUserName:localhost:8080") val userId = UserId("@$aUserName:localhost:8080")
val aUser = TestUser("aaaa11111zzzz", RoomMember(userId, aUserName, null), HTTPS_TEST_SERVER_URL) val aUser = TestUser("aaaa11111zzzz", RoomMember(userId, aUserName, null), HTTPS_TEST_SERVER_URL, testUsername)
val result = TestMatrix(aUser, includeLogging = true, includeHttpLogging = true) val result = TestMatrix(aUser, includeLogging = true, includeHttpLogging = true)
.client .client
@ -167,26 +164,35 @@ private suspend fun login(user: TestUser) {
} }
object SharedState { object SharedState {
val alice: TestUser val alice: TestUser
get() = _alice!! get() = _alice!!
var _alice: TestUser? = null var _alice: TestUser? = null
set(value) {
field = value!!
TestUsers.users.add(value)
}
val bob: TestUser val bob: TestUser
get() = _bob!! get() = _bob!!
var _bob: TestUser? = null var _bob: TestUser? = null
set(value) {
field = value!!
TestUsers.users.add(value)
}
val sharedRoom: RoomId val sharedRoom: RoomId
get() = _sharedRoom!! get() = _sharedRoom!!
var _sharedRoom: RoomId? = null var _sharedRoom: RoomId? = null
} }
data class TestUser(val password: String, val roomMember: RoomMember, val homeServer: String) data class TestUser(val password: String, val roomMember: RoomMember, val homeServer: String, val testName: String)
data class TestMessage(val content: String, val author: RoomMember) data class TestMessage(val content: String, val author: RoomMember)
fun String.from(roomMember: RoomMember) = TestMessage("$this - ${UUID.randomUUID()}", roomMember) fun String.from(roomMember: RoomMember) = TestMessage("$this - ${UUID.randomUUID()}", roomMember)
fun testAfterInitialSync(block: suspend MatrixTestScope.(TestMatrix, TestMatrix) -> Unit) { fun testAfterInitialSync(block: suspend MatrixTestScope.(TestMatrix, TestMatrix) -> Unit) {
restoreLoginAndInitialSync(TestMatrix(SharedState.alice, includeLogging = false), TestMatrix(SharedState.bob, includeLogging = false), block) restoreLoginAndInitialSync(TestMatrix(SharedState.alice, includeLogging = true), TestMatrix(SharedState.bob, includeLogging = false), block)
} }
private fun Flow<Verification.State>.automaticVerification(testMatrix: TestMatrix) = this.onEach { private fun Flow<Verification.State>.automaticVerification(testMatrix: TestMatrix) = this.onEach {

View File

@ -5,17 +5,13 @@ package test
import TestMessage import TestMessage
import TestUser import TestUser
import app.dapk.st.core.extensions.ifNull import app.dapk.st.core.extensions.ifNull
import app.dapk.st.matrix.common.MxUrl
import app.dapk.st.matrix.common.RoomId import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.common.RoomMember import app.dapk.st.matrix.common.RoomMember
import app.dapk.st.matrix.common.convertMxUrToUrl
import app.dapk.st.matrix.http.MatrixHttpClient
import app.dapk.st.matrix.message.MessageService import app.dapk.st.matrix.message.MessageService
import app.dapk.st.matrix.message.messageService import app.dapk.st.matrix.message.messageService
import app.dapk.st.matrix.sync.RoomEvent import app.dapk.st.matrix.sync.RoomEvent
import app.dapk.st.matrix.sync.syncService import app.dapk.st.matrix.sync.syncService
import io.ktor.client.* import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.request.* import io.ktor.client.request.*
import io.ktor.client.statement.* import io.ktor.client.statement.*
import io.ktor.util.cio.* import io.ktor.util.cio.*
@ -28,7 +24,6 @@ import org.amshove.kluent.fail
import org.amshove.kluent.shouldBeEqualTo import org.amshove.kluent.shouldBeEqualTo
import java.io.File import java.io.File
import java.math.BigInteger import java.math.BigInteger
import java.net.URL
import java.security.MessageDigest import java.security.MessageDigest
import java.util.* import java.util.*

View File

@ -46,6 +46,12 @@ import java.io.File
import java.time.Clock import java.time.Clock
import javax.imageio.ImageIO import javax.imageio.ImageIO
object TestUsers {
val users = mutableSetOf<TestUser>()
}
class TestMatrix( class TestMatrix(
private val user: TestUser, private val user: TestUser,
temporaryDatabase: Boolean = false, temporaryDatabase: Boolean = false,
@ -53,10 +59,11 @@ class TestMatrix(
includeLogging: Boolean = false, includeLogging: Boolean = false,
) { ) {
private val errorTracker = PrintingErrorTracking(prefix = user.roomMember.id.value.split(":")[0]) private val errorTracker = PrintingErrorTracking(prefix = user.testName)
private val logger: MatrixLogger = { tag, message -> private val logger: MatrixLogger = { tag, message ->
if (includeLogging) { if (includeLogging) {
println("${user.roomMember.id.value.split(":")[0]} $tag $message") val messageWithIdReplaceByName = TestUsers.users.fold(message) { acc, user -> acc.replace(user.roomMember.id.value, "*${user.testName}") }
println("${user.testName} $tag $messageWithIdReplaceByName")
} }
} }