Thread safe use of OlmSAS

This commit is contained in:
Hugh Nimmo-Smith 2022-10-17 12:01:12 +01:00
parent 506fa729ea
commit 4306c57236
1 changed files with 38 additions and 31 deletions

View File

@ -90,43 +90,45 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu
} }
override suspend fun connect(): String { override suspend fun connect(): String {
if (olmSAS == null) { olmSAS ?.let { olmSAS ->
throw RuntimeException("Channel closed") val isInitiator = theirPublicKey == null
}
val isInitiator = theirPublicKey == null
if (isInitiator) { if (isInitiator) {
Timber.tag(TAG).i("Waiting for other device to send their public key") Timber.tag(TAG).i("Waiting for other device to send their public key")
val res = this.receiveAsPayload() ?: throw RuntimeException("No reply from other device") val res = this.receiveAsPayload() ?: throw RuntimeException("No reply from other device")
if (res.key == null) { if (res.key == null) {
throw RendezvousError( throw RendezvousError(
"Unsupported algorithm: ${res.algorithm}", "Unsupported algorithm: ${res.algorithm}",
RendezvousFailureReason.UnsupportedAlgorithm, RendezvousFailureReason.UnsupportedAlgorithm,
)
}
theirPublicKey = Base64.decode(res.key, Base64.NO_WRAP)
} else {
// send our public key unencrypted
Timber.tag(TAG).i("Sending public key")
send(
ECDHPayload(
algorithm = SecureRendezvousChannelAlgorithm.ECDH_V1,
key = Base64.encodeToString(ourPublicKey, Base64.NO_WRAP)
)
) )
} }
theirPublicKey = Base64.decode(res.key, Base64.NO_WRAP)
} else {
// send our public key unencrypted
Timber.tag(TAG).i("Sending public key")
send(
ECDHPayload(
algorithm = SecureRendezvousChannelAlgorithm.ECDH_V1,
key = Base64.encodeToString(ourPublicKey, Base64.NO_WRAP)
)
)
}
olmSAS!!.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP)) synchronized(olmSAS) {
olmSAS.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP))
olmSAS.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP))
val initiatorKey = Base64.encodeToString(if (isInitiator) ourPublicKey else theirPublicKey, Base64.NO_WRAP) val initiatorKey = Base64.encodeToString(if (isInitiator) ourPublicKey else theirPublicKey, Base64.NO_WRAP)
val recipientKey = Base64.encodeToString(if (isInitiator) theirPublicKey else ourPublicKey, Base64.NO_WRAP) val recipientKey = Base64.encodeToString(if (isInitiator) theirPublicKey else ourPublicKey, Base64.NO_WRAP)
val aesInfo = "${SecureRendezvousChannelAlgorithm.ECDH_V1.value}|$initiatorKey|$recipientKey" val aesInfo = "${SecureRendezvousChannelAlgorithm.ECDH_V1.value}|$initiatorKey|$recipientKey"
aesKey = olmSAS!!.generateShortCode(aesInfo, 32) aesKey = olmSAS.generateShortCode(aesInfo, 32)
val rawChecksum = olmSAS!!.generateShortCode(aesInfo, 5) val rawChecksum = olmSAS.generateShortCode(aesInfo, 5)
return getDecimalCodeRepresentation(rawChecksum) return getDecimalCodeRepresentation(rawChecksum)
}
} ?: throw RuntimeException("Channel closed")
} }
private suspend fun send(payload: ECDHPayload) { private suspend fun send(payload: ECDHPayload) {
@ -174,8 +176,13 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu
} }
override suspend fun close() { override suspend fun close() {
olmSAS?.releaseSas() olmSAS ?.let {
olmSAS = null synchronized(it) {
// this does a double release check already so we don't re-check ourselves
it.releaseSas()
olmSAS = null
}
}
} }
private fun encrypt(plainText: ByteArray): ECDHPayload { private fun encrypt(plainText: ByteArray): ECDHPayload {