mirror of
https://github.com/ouchadam/small-talk.git
synced 2025-02-01 20:16:44 +01:00
fixing keys being rotated out when the server returns a 0 key count but we have local keys
This commit is contained in:
parent
1478f4a7ff
commit
f4007bff76
@ -8,3 +8,7 @@ fun OlmAccount.readIdentityKeys(): Pair<Ed25519, Curve25519> {
|
||||
val identityKeys = this.identityKeys()
|
||||
return Ed25519(identityKeys["ed25519"]!!) to Curve25519(identityKeys["curve25519"]!!)
|
||||
}
|
||||
|
||||
fun OlmAccount.oneTimeCurveKeys(): List<Pair<String, Curve25519>> {
|
||||
return this.oneTimeKeys()["curve25519"]?.map { it.key to Curve25519(it.value) } ?: emptyList()
|
||||
}
|
@ -67,7 +67,7 @@ class OlmWrapper(
|
||||
|
||||
private suspend fun accountCrypto(deviceCredentials: DeviceCredentials): AccountCryptoSession? {
|
||||
return olmStore.read()?.let { olmAccount ->
|
||||
createAccountCryptoSession(deviceCredentials, olmAccount)
|
||||
createAccountCryptoSession(deviceCredentials, olmAccount, isNew = false)
|
||||
}
|
||||
}
|
||||
|
||||
@ -80,12 +80,12 @@ class OlmWrapper(
|
||||
val olmAccount = this.olmAccount as OlmAccount
|
||||
olmAccount.generateOneTimeKeys(count)
|
||||
|
||||
val oneTimeKeys = DeviceService.OneTimeKeys(olmAccount.oneTimeKeys()["curve25519"]!!.map {
|
||||
val oneTimeKeys = DeviceService.OneTimeKeys(olmAccount.oneTimeCurveKeys().map { (key, value) ->
|
||||
DeviceService.OneTimeKeys.Key.SignedCurve(
|
||||
keyId = it.key,
|
||||
value = it.value,
|
||||
keyId = key,
|
||||
value = value.value,
|
||||
signature = DeviceService.OneTimeKeys.Key.SignedCurve.Ed25519Signature(
|
||||
value = it.value.toSignedJson(olmAccount),
|
||||
value = value.value.toSignedJson(olmAccount),
|
||||
deviceId = credentials.deviceId,
|
||||
userId = credentials.userId,
|
||||
)
|
||||
@ -98,20 +98,21 @@ class OlmWrapper(
|
||||
|
||||
private suspend fun createAccountCrypto(deviceCredentials: DeviceCredentials, action: suspend (AccountCryptoSession) -> Unit): AccountCryptoSession {
|
||||
val olmAccount = OlmAccount()
|
||||
return createAccountCryptoSession(deviceCredentials, olmAccount).also {
|
||||
return createAccountCryptoSession(deviceCredentials, olmAccount, isNew = true).also {
|
||||
action(it)
|
||||
olmStore.persist(olmAccount)
|
||||
}
|
||||
}
|
||||
|
||||
private fun createAccountCryptoSession(credentials: DeviceCredentials, olmAccount: OlmAccount): AccountCryptoSession {
|
||||
private fun createAccountCryptoSession(credentials: DeviceCredentials, olmAccount: OlmAccount, isNew: Boolean): AccountCryptoSession {
|
||||
val (identityKey, senderKey) = olmAccount.readIdentityKeys()
|
||||
return AccountCryptoSession(
|
||||
fingerprint = identityKey,
|
||||
senderKey = senderKey,
|
||||
deviceKeys = deviceKeyFactory.create(credentials.userId, credentials.deviceId, identityKey, senderKey, olmAccount),
|
||||
olmAccount = olmAccount,
|
||||
maxKeys = olmAccount.maxOneTimeKeys().toInt()
|
||||
maxKeys = olmAccount.maxOneTimeKeys().toInt(),
|
||||
hasKeys = !isNew,
|
||||
)
|
||||
}
|
||||
|
||||
@ -136,6 +137,7 @@ class OlmWrapper(
|
||||
singletonFlows.update("room-${roomId.value}", rotatedSession)
|
||||
}
|
||||
}
|
||||
|
||||
else -> this
|
||||
}
|
||||
}
|
||||
@ -277,10 +279,12 @@ class OlmWrapper(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OlmMessage.MESSAGE_TYPE_MESSAGE -> {
|
||||
logger.crypto("decrypting olm message type")
|
||||
session.decryptMessage(olmMessage)?.let { JsonString(it) }
|
||||
}
|
||||
|
||||
else -> throw IllegalArgumentException("Unknown message type: $type")
|
||||
}
|
||||
}.onFailure {
|
||||
@ -297,7 +301,7 @@ class OlmWrapper(
|
||||
}
|
||||
|
||||
private suspend fun AccountCryptoSession.updateAccountInstance(olmAccount: OlmAccount) {
|
||||
singletonFlows.update("account-crypto", this.copy(olmAccount = olmAccount))
|
||||
singletonFlows.update("account-crypto", this.copy(olmAccount = olmAccount, hasKeys = true))
|
||||
olmStore.persist(olmAccount)
|
||||
}
|
||||
|
||||
|
@ -72,6 +72,7 @@ interface Olm {
|
||||
val fingerprint: Ed25519,
|
||||
val senderKey: Curve25519,
|
||||
val deviceKeys: DeviceKeys,
|
||||
val hasKeys: Boolean,
|
||||
val maxKeys: Int,
|
||||
val olmAccount: Any,
|
||||
)
|
||||
|
@ -15,16 +15,28 @@ internal class MaybeCreateAndUploadOneTimeKeysUseCaseImpl(
|
||||
private val credentialsStore: CredentialsStore,
|
||||
private val deviceService: DeviceService,
|
||||
private val logger: MatrixLogger,
|
||||
): MaybeCreateAndUploadOneTimeKeysUseCase {
|
||||
) : MaybeCreateAndUploadOneTimeKeysUseCase {
|
||||
|
||||
override suspend fun invoke(currentServerKeyCount: ServerKeyCount) {
|
||||
val ensureCryptoAccount = fetchAccountCryptoUseCase.invoke()
|
||||
val keysDiff = (ensureCryptoAccount.maxKeys / 2) - currentServerKeyCount.value
|
||||
if (keysDiff > 0) {
|
||||
logger.crypto("current otk: $currentServerKeyCount, creating: $keysDiff")
|
||||
ensureCryptoAccount.createAndUploadOneTimeKeys(countToCreate = keysDiff + (ensureCryptoAccount.maxKeys / 4))
|
||||
} else {
|
||||
logger.crypto("current otk: $currentServerKeyCount, not creating new keys")
|
||||
val cryptoAccount = fetchAccountCryptoUseCase.invoke()
|
||||
when {
|
||||
currentServerKeyCount.value == 0 && cryptoAccount.hasKeys -> {
|
||||
logger.crypto("Server has no keys but a crypto instance exists, waiting for next update")
|
||||
}
|
||||
|
||||
else -> {
|
||||
val keysDiff = (cryptoAccount.maxKeys / 2) - currentServerKeyCount.value
|
||||
when {
|
||||
keysDiff > 0 -> {
|
||||
logger.crypto("current otk: $currentServerKeyCount, creating: $keysDiff")
|
||||
cryptoAccount.createAndUploadOneTimeKeys(countToCreate = keysDiff + (cryptoAccount.maxKeys / 4))
|
||||
}
|
||||
|
||||
else -> {
|
||||
logger.crypto("current otk: $currentServerKeyCount, not creating new keys")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -22,12 +22,11 @@ class MaybeCreateAndUploadOneTimeKeysUseCaseTest {
|
||||
|
||||
private val fakeDeviceService = FakeDeviceService()
|
||||
private val fakeOlm = FakeOlm()
|
||||
private val fakeCredentialsStore = FakeCredentialsStore().also {
|
||||
it.givenCredentials().returns(A_USER_CREDENTIALS)
|
||||
}
|
||||
private val fakeCredentialsStore = FakeCredentialsStore().also { it.givenCredentials().returns(A_USER_CREDENTIALS) }
|
||||
private val fakeFetchAccountCryptoUseCase = FakeFetchAccountCryptoUseCase()
|
||||
|
||||
private val maybeCreateAndUploadOneTimeKeysUseCase = MaybeCreateAndUploadOneTimeKeysUseCaseImpl(
|
||||
FakeFetchAccountCryptoUseCase().also { it.givenFetch().returns(AN_ACCOUNT_CRYPTO_SESSION) },
|
||||
fakeFetchAccountCryptoUseCase.also { it.givenFetch().returns(AN_ACCOUNT_CRYPTO_SESSION) },
|
||||
fakeOlm,
|
||||
fakeCredentialsStore,
|
||||
fakeDeviceService,
|
||||
@ -43,6 +42,16 @@ class MaybeCreateAndUploadOneTimeKeysUseCaseTest {
|
||||
fakeDeviceService.verifyDidntUploadOneTimeKeys()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `given account has keys and server count is 0 then does nothing`() = runTest {
|
||||
fakeFetchAccountCryptoUseCase.givenFetch().returns(AN_ACCOUNT_CRYPTO_SESSION.copy(hasKeys = true))
|
||||
val zeroServiceKeys = ServerKeyCount(0)
|
||||
|
||||
maybeCreateAndUploadOneTimeKeysUseCase.invoke(zeroServiceKeys)
|
||||
|
||||
fakeDeviceService.verifyDidntUploadOneTimeKeys()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `given 0 current keys than generates and uploads 75 percent of the max key capacity`() = runTest {
|
||||
fakeDeviceService.expect { it.uploadOneTimeKeys(GENERATED_ONE_TIME_KEYS) }
|
||||
|
@ -10,8 +10,9 @@ fun anAccountCryptoSession(
|
||||
senderKey: Curve25519 = aCurve25519(),
|
||||
deviceKeys: DeviceKeys = aDeviceKeys(),
|
||||
maxKeys: Int = 5,
|
||||
hasKeys: Boolean = false,
|
||||
olmAccount: Any = mockk(),
|
||||
) = Olm.AccountCryptoSession(fingerprint, senderKey, deviceKeys, maxKeys, olmAccount)
|
||||
) = Olm.AccountCryptoSession(fingerprint, senderKey, deviceKeys, hasKeys, maxKeys, olmAccount)
|
||||
|
||||
fun aRoomCryptoSession(
|
||||
creationTimestampUtc: Long = 0L,
|
||||
|
@ -27,6 +27,7 @@ internal class SyncSideEffects(
|
||||
response.deviceLists?.changed?.ifEmpty { null }?.let {
|
||||
notifyDevicesUpdated.notifyChanges(it, requestToken)
|
||||
}
|
||||
|
||||
oneTimeKeyProducer.onServerKeyCount(response.oneTimeKeysCount["signed_curve25519"] ?: ServerKeyCount(0))
|
||||
|
||||
val decryptedToDeviceEvents = decryptedToDeviceEvents(response)
|
||||
|
@ -5,17 +5,13 @@ package test
|
||||
import TestMessage
|
||||
import TestUser
|
||||
import app.dapk.st.core.extensions.ifNull
|
||||
import app.dapk.st.matrix.common.MxUrl
|
||||
import app.dapk.st.matrix.common.RoomId
|
||||
import app.dapk.st.matrix.common.RoomMember
|
||||
import app.dapk.st.matrix.common.convertMxUrToUrl
|
||||
import app.dapk.st.matrix.http.MatrixHttpClient
|
||||
import app.dapk.st.matrix.message.MessageService
|
||||
import app.dapk.st.matrix.message.messageService
|
||||
import app.dapk.st.matrix.sync.RoomEvent
|
||||
import app.dapk.st.matrix.sync.syncService
|
||||
import io.ktor.client.*
|
||||
import io.ktor.client.call.*
|
||||
import io.ktor.client.request.*
|
||||
import io.ktor.client.statement.*
|
||||
import io.ktor.util.cio.*
|
||||
@ -28,7 +24,6 @@ import org.amshove.kluent.fail
|
||||
import org.amshove.kluent.shouldBeEqualTo
|
||||
import java.io.File
|
||||
import java.math.BigInteger
|
||||
import java.net.URL
|
||||
import java.security.MessageDigest
|
||||
import java.util.*
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user