fixing keys being rotated out when the server returns a 0 key count but we have local keys

This commit is contained in:
Adam Brown 2022-09-15 20:27:25 +01:00
parent 1478f4a7ff
commit f4007bff76
8 changed files with 54 additions and 27 deletions

View File

@ -8,3 +8,7 @@ fun OlmAccount.readIdentityKeys(): Pair<Ed25519, Curve25519> {
val identityKeys = this.identityKeys()
return Ed25519(identityKeys["ed25519"]!!) to Curve25519(identityKeys["curve25519"]!!)
}
fun OlmAccount.oneTimeCurveKeys(): List<Pair<String, Curve25519>> {
return this.oneTimeKeys()["curve25519"]?.map { it.key to Curve25519(it.value) } ?: emptyList()
}

View File

@ -67,7 +67,7 @@ class OlmWrapper(
private suspend fun accountCrypto(deviceCredentials: DeviceCredentials): AccountCryptoSession? {
return olmStore.read()?.let { olmAccount ->
createAccountCryptoSession(deviceCredentials, olmAccount)
createAccountCryptoSession(deviceCredentials, olmAccount, isNew = false)
}
}
@ -80,12 +80,12 @@ class OlmWrapper(
val olmAccount = this.olmAccount as OlmAccount
olmAccount.generateOneTimeKeys(count)
val oneTimeKeys = DeviceService.OneTimeKeys(olmAccount.oneTimeKeys()["curve25519"]!!.map {
val oneTimeKeys = DeviceService.OneTimeKeys(olmAccount.oneTimeCurveKeys().map { (key, value) ->
DeviceService.OneTimeKeys.Key.SignedCurve(
keyId = it.key,
value = it.value,
keyId = key,
value = value.value,
signature = DeviceService.OneTimeKeys.Key.SignedCurve.Ed25519Signature(
value = it.value.toSignedJson(olmAccount),
value = value.value.toSignedJson(olmAccount),
deviceId = credentials.deviceId,
userId = credentials.userId,
)
@ -98,20 +98,21 @@ class OlmWrapper(
private suspend fun createAccountCrypto(deviceCredentials: DeviceCredentials, action: suspend (AccountCryptoSession) -> Unit): AccountCryptoSession {
val olmAccount = OlmAccount()
return createAccountCryptoSession(deviceCredentials, olmAccount).also {
return createAccountCryptoSession(deviceCredentials, olmAccount, isNew = true).also {
action(it)
olmStore.persist(olmAccount)
}
}
private fun createAccountCryptoSession(credentials: DeviceCredentials, olmAccount: OlmAccount): AccountCryptoSession {
private fun createAccountCryptoSession(credentials: DeviceCredentials, olmAccount: OlmAccount, isNew: Boolean): AccountCryptoSession {
val (identityKey, senderKey) = olmAccount.readIdentityKeys()
return AccountCryptoSession(
fingerprint = identityKey,
senderKey = senderKey,
deviceKeys = deviceKeyFactory.create(credentials.userId, credentials.deviceId, identityKey, senderKey, olmAccount),
olmAccount = olmAccount,
maxKeys = olmAccount.maxOneTimeKeys().toInt()
maxKeys = olmAccount.maxOneTimeKeys().toInt(),
hasKeys = !isNew,
)
}
@ -136,6 +137,7 @@ class OlmWrapper(
singletonFlows.update("room-${roomId.value}", rotatedSession)
}
}
else -> this
}
}
@ -277,10 +279,12 @@ class OlmWrapper(
}
}
}
OlmMessage.MESSAGE_TYPE_MESSAGE -> {
logger.crypto("decrypting olm message type")
session.decryptMessage(olmMessage)?.let { JsonString(it) }
}
else -> throw IllegalArgumentException("Unknown message type: $type")
}
}.onFailure {
@ -297,7 +301,7 @@ class OlmWrapper(
}
private suspend fun AccountCryptoSession.updateAccountInstance(olmAccount: OlmAccount) {
singletonFlows.update("account-crypto", this.copy(olmAccount = olmAccount))
singletonFlows.update("account-crypto", this.copy(olmAccount = olmAccount, hasKeys = true))
olmStore.persist(olmAccount)
}

View File

@ -72,6 +72,7 @@ interface Olm {
val fingerprint: Ed25519,
val senderKey: Curve25519,
val deviceKeys: DeviceKeys,
val hasKeys: Boolean,
val maxKeys: Int,
val olmAccount: Any,
)

View File

@ -15,16 +15,28 @@ internal class MaybeCreateAndUploadOneTimeKeysUseCaseImpl(
private val credentialsStore: CredentialsStore,
private val deviceService: DeviceService,
private val logger: MatrixLogger,
): MaybeCreateAndUploadOneTimeKeysUseCase {
) : MaybeCreateAndUploadOneTimeKeysUseCase {
override suspend fun invoke(currentServerKeyCount: ServerKeyCount) {
val ensureCryptoAccount = fetchAccountCryptoUseCase.invoke()
val keysDiff = (ensureCryptoAccount.maxKeys / 2) - currentServerKeyCount.value
if (keysDiff > 0) {
logger.crypto("current otk: $currentServerKeyCount, creating: $keysDiff")
ensureCryptoAccount.createAndUploadOneTimeKeys(countToCreate = keysDiff + (ensureCryptoAccount.maxKeys / 4))
} else {
logger.crypto("current otk: $currentServerKeyCount, not creating new keys")
val cryptoAccount = fetchAccountCryptoUseCase.invoke()
when {
currentServerKeyCount.value == 0 && cryptoAccount.hasKeys -> {
logger.crypto("Server has no keys but a crypto instance exists, waiting for next update")
}
else -> {
val keysDiff = (cryptoAccount.maxKeys / 2) - currentServerKeyCount.value
when {
keysDiff > 0 -> {
logger.crypto("current otk: $currentServerKeyCount, creating: $keysDiff")
cryptoAccount.createAndUploadOneTimeKeys(countToCreate = keysDiff + (cryptoAccount.maxKeys / 4))
}
else -> {
logger.crypto("current otk: $currentServerKeyCount, not creating new keys")
}
}
}
}
}

View File

@ -22,12 +22,11 @@ class MaybeCreateAndUploadOneTimeKeysUseCaseTest {
private val fakeDeviceService = FakeDeviceService()
private val fakeOlm = FakeOlm()
private val fakeCredentialsStore = FakeCredentialsStore().also {
it.givenCredentials().returns(A_USER_CREDENTIALS)
}
private val fakeCredentialsStore = FakeCredentialsStore().also { it.givenCredentials().returns(A_USER_CREDENTIALS) }
private val fakeFetchAccountCryptoUseCase = FakeFetchAccountCryptoUseCase()
private val maybeCreateAndUploadOneTimeKeysUseCase = MaybeCreateAndUploadOneTimeKeysUseCaseImpl(
FakeFetchAccountCryptoUseCase().also { it.givenFetch().returns(AN_ACCOUNT_CRYPTO_SESSION) },
fakeFetchAccountCryptoUseCase.also { it.givenFetch().returns(AN_ACCOUNT_CRYPTO_SESSION) },
fakeOlm,
fakeCredentialsStore,
fakeDeviceService,
@ -43,6 +42,16 @@ class MaybeCreateAndUploadOneTimeKeysUseCaseTest {
fakeDeviceService.verifyDidntUploadOneTimeKeys()
}
@Test
fun `given account has keys and server count is 0 then does nothing`() = runTest {
fakeFetchAccountCryptoUseCase.givenFetch().returns(AN_ACCOUNT_CRYPTO_SESSION.copy(hasKeys = true))
val zeroServiceKeys = ServerKeyCount(0)
maybeCreateAndUploadOneTimeKeysUseCase.invoke(zeroServiceKeys)
fakeDeviceService.verifyDidntUploadOneTimeKeys()
}
@Test
fun `given 0 current keys than generates and uploads 75 percent of the max key capacity`() = runTest {
fakeDeviceService.expect { it.uploadOneTimeKeys(GENERATED_ONE_TIME_KEYS) }

View File

@ -10,8 +10,9 @@ fun anAccountCryptoSession(
senderKey: Curve25519 = aCurve25519(),
deviceKeys: DeviceKeys = aDeviceKeys(),
maxKeys: Int = 5,
hasKeys: Boolean = false,
olmAccount: Any = mockk(),
) = Olm.AccountCryptoSession(fingerprint, senderKey, deviceKeys, maxKeys, olmAccount)
) = Olm.AccountCryptoSession(fingerprint, senderKey, deviceKeys, hasKeys, maxKeys, olmAccount)
fun aRoomCryptoSession(
creationTimestampUtc: Long = 0L,

View File

@ -27,6 +27,7 @@ internal class SyncSideEffects(
response.deviceLists?.changed?.ifEmpty { null }?.let {
notifyDevicesUpdated.notifyChanges(it, requestToken)
}
oneTimeKeyProducer.onServerKeyCount(response.oneTimeKeysCount["signed_curve25519"] ?: ServerKeyCount(0))
val decryptedToDeviceEvents = decryptedToDeviceEvents(response)

View File

@ -5,17 +5,13 @@ package test
import TestMessage
import TestUser
import app.dapk.st.core.extensions.ifNull
import app.dapk.st.matrix.common.MxUrl
import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.common.RoomMember
import app.dapk.st.matrix.common.convertMxUrToUrl
import app.dapk.st.matrix.http.MatrixHttpClient
import app.dapk.st.matrix.message.MessageService
import app.dapk.st.matrix.message.messageService
import app.dapk.st.matrix.sync.RoomEvent
import app.dapk.st.matrix.sync.syncService
import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.util.cio.*
@ -28,7 +24,6 @@ import org.amshove.kluent.fail
import org.amshove.kluent.shouldBeEqualTo
import java.io.File
import java.math.BigInteger
import java.net.URL
import java.security.MessageDigest
import java.util.*