Thread safe use of OlmSAS

This commit is contained in:
Hugh Nimmo-Smith 2022-10-17 12:01:12 +01:00
parent 506fa729ea
commit 4306c57236
1 changed files with 38 additions and 31 deletions

View File

@ -90,9 +90,7 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu
} }
override suspend fun connect(): String { override suspend fun connect(): String {
if (olmSAS == null) { olmSAS ?.let { olmSAS ->
throw RuntimeException("Channel closed")
}
val isInitiator = theirPublicKey == null val isInitiator = theirPublicKey == null
if (isInitiator) { if (isInitiator) {
@ -117,17 +115,21 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu
) )
} }
olmSAS!!.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP)) synchronized(olmSAS) {
olmSAS.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP))
olmSAS.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP))
val initiatorKey = Base64.encodeToString(if (isInitiator) ourPublicKey else theirPublicKey, Base64.NO_WRAP) val initiatorKey = Base64.encodeToString(if (isInitiator) ourPublicKey else theirPublicKey, Base64.NO_WRAP)
val recipientKey = Base64.encodeToString(if (isInitiator) theirPublicKey else ourPublicKey, Base64.NO_WRAP) val recipientKey = Base64.encodeToString(if (isInitiator) theirPublicKey else ourPublicKey, Base64.NO_WRAP)
val aesInfo = "${SecureRendezvousChannelAlgorithm.ECDH_V1.value}|$initiatorKey|$recipientKey" val aesInfo = "${SecureRendezvousChannelAlgorithm.ECDH_V1.value}|$initiatorKey|$recipientKey"
aesKey = olmSAS!!.generateShortCode(aesInfo, 32) aesKey = olmSAS.generateShortCode(aesInfo, 32)
val rawChecksum = olmSAS!!.generateShortCode(aesInfo, 5) val rawChecksum = olmSAS.generateShortCode(aesInfo, 5)
return getDecimalCodeRepresentation(rawChecksum) return getDecimalCodeRepresentation(rawChecksum)
} }
} ?: throw RuntimeException("Channel closed")
}
private suspend fun send(payload: ECDHPayload) { private suspend fun send(payload: ECDHPayload) {
transport.send("application/json".toMediaType(), ecdhAdapter.toJson(payload).toByteArray(Charsets.UTF_8)) transport.send("application/json".toMediaType(), ecdhAdapter.toJson(payload).toByteArray(Charsets.UTF_8))
@ -174,9 +176,14 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu
} }
override suspend fun close() { override suspend fun close() {
olmSAS?.releaseSas() olmSAS ?.let {
synchronized(it) {
// this does a double release check already so we don't re-check ourselves
it.releaseSas()
olmSAS = null olmSAS = null
} }
}
}
private fun encrypt(plainText: ByteArray): ECDHPayload { private fun encrypt(plainText: ByteArray): ECDHPayload {
val iv = ByteArray(16) val iv = ByteArray(16)