improving image decrypting pipeline to use one less copy and adding smoke test to sending encrypted images

This commit is contained in:
Adam Brown 2022-09-21 22:42:04 +01:00 committed by Adam Brown
parent 854a4c17ce
commit e70ed9f6e5
6 changed files with 56 additions and 24 deletions

View File

@ -4,6 +4,7 @@ apply plugin: 'kotlin-parcelize'
dependencies {
implementation project(":matrix:services:sync")
implementation project(":matrix:services:message")
implementation project(":matrix:services:crypto")
implementation project(":matrix:services:room")
implementation project(":domains:android:compose-core")
implementation project(":domains:android:viewmodel")

View File

@ -2,6 +2,7 @@ package app.dapk.st.messenger
import android.content.Context
import app.dapk.st.core.Base64
import app.dapk.st.matrix.crypto.MediaDecrypter
import app.dapk.st.matrix.sync.RoomEvent
import coil.ImageLoader
import coil.decode.DataSource
@ -42,7 +43,11 @@ class DecryptingFetcher(
}
private fun handleEncrypted(response: Response, keys: RoomEvent.Image.ImageMeta.Keys): Buffer {
return response.body?.byteStream()?.let { mediaDecrypter.decrypt(it, keys.k, keys.iv) } ?: Buffer()
return response.body?.byteStream()?.let { byteStream ->
Buffer().also { buffer ->
mediaDecrypter.decrypt(byteStream, keys.k, keys.iv).collect { buffer.write(it) }
}
} ?: Buffer()
}
}

View File

@ -1,7 +1,6 @@
package app.dapk.st.messenger
package app.dapk.st.matrix.crypto
import app.dapk.st.core.Base64
import okio.Buffer
import java.io.InputStream
import java.security.MessageDigest
import javax.crypto.Cipher
@ -15,7 +14,7 @@ private const val MESSAGE_DIGEST_ALGORITHM = "SHA-256"
class MediaDecrypter(private val base64: Base64) {
fun decrypt(input: InputStream, k: String, iv: String): Buffer {
fun decrypt(input: InputStream, k: String, iv: String): Collector {
val key = base64.decode(k.replace('-', '+').replace('_', '/'))
val initVectorBytes = base64.decode(iv)
@ -30,17 +29,22 @@ class MediaDecrypter(private val base64: Base64) {
val d = ByteArray(CRYPTO_BUFFER_SIZE)
var decodedBytes: ByteArray
val outputStream = Buffer()
input.use {
read = it.read(d)
while (read != -1) {
messageDigest.update(d, 0, read)
decodedBytes = decryptCipher.update(d, 0, read)
outputStream.write(decodedBytes)
return Collector { partial ->
input.use {
read = it.read(d)
while (read != -1) {
messageDigest.update(d, 0, read)
decodedBytes = decryptCipher.update(d, 0, read)
partial(decodedBytes)
read = it.read(d)
}
}
}
return outputStream
}
}
}
fun interface Collector {
fun collect(partial: (ByteArray) -> Unit)
}

View File

@ -71,16 +71,25 @@ class SmokeTest {
@Order(5)
fun `can send and receive encrypted text messages`() = testTextMessaging(isEncrypted = true)
@Test
@Order(6)
fun `can send and receive clear image messages`() = testAfterInitialSync { alice, bob ->
val testImage = loadResourceFile("test-image.png")
alice.sendImageMessage(SharedState.sharedRoom, testImage, isEncrypted = false)
bob.expectImageMessage(SharedState.sharedRoom, testImage, SharedState.alice.roomMember)
}
// @Test
// @Order(6)
// fun `can send and receive clear image messages`() = testAfterInitialSync { alice, bob ->
// val testImage = loadResourceFile("test-image.png")
// alice.sendImageMessage(SharedState.sharedRoom, testImage, isEncrypted = false)
// bob.expectImageMessage(SharedState.sharedRoom, testImage, SharedState.alice.roomMember, isEncrypted = false)
// }
@Test
@Order(7)
fun `can send and receive encrypted image messages`() = testAfterInitialSync { alice, bob ->
val testImage = loadResourceFile("test-image.png")
alice.sendImageMessage(SharedState.sharedRoom, testImage, isEncrypted = true)
bob.expectImageMessage(SharedState.sharedRoom, testImage, SharedState.alice.roomMember)
}
@Test
@Order(8)
fun `can request and verify devices`() = testAfterInitialSync { alice, bob ->
alice.client.cryptoService().verificationAction(Verification.Action.Request(bob.userId(), bob.deviceId()))
alice.client.cryptoService().verificationState().automaticVerification(alice).expectAsync { it == Verification.State.Done }

View File

@ -7,6 +7,7 @@ import TestUser
import app.dapk.st.core.extensions.ifNull
import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.common.RoomMember
import app.dapk.st.matrix.crypto.MediaDecrypter
import app.dapk.st.matrix.message.MessageService
import app.dapk.st.matrix.message.messageService
import app.dapk.st.matrix.sync.RoomEvent
@ -22,6 +23,7 @@ import kotlinx.coroutines.test.TestScope
import kotlinx.coroutines.test.runTest
import org.amshove.kluent.fail
import org.amshove.kluent.shouldBeEqualTo
import java.io.ByteArrayOutputStream
import java.io.File
import java.math.BigInteger
import java.security.MessageDigest
@ -145,10 +147,21 @@ class MatrixTestScope(private val testScope: TestScope) {
this.client.syncService().room(roomId)
.map {
it.events.filterIsInstance<RoomEvent.Image>().map {
println("found: ${it.imageMeta.url}")
println("found: ${it}")
val output = File(image.parentFile.absolutePath, "output.png")
HttpClient().request(it.imageMeta.url).bodyAsChannel().copyAndClose(output.writeChannel())
output.readBytes().md5Hash() to it.author
val md5Hash = when (val keys = it.imageMeta.keys) {
null -> output.readBytes().md5Hash()
else -> {
val byteStream = ByteArrayOutputStream()
MediaDecrypter(this.base64).decrypt(output.inputStream(), keys.k, keys.iv).collect {
byteStream.write(it)
}
byteStream.toByteArray().md5Hash()
}
}
md5Hash to it.author
}.firstOrNull()
}
.assert(image.readBytes().md5Hash() to author)

View File

@ -82,6 +82,7 @@ class TestMatrix(
},
coroutineDispatchers = coroutineDispatchers
)
val base64 = JavaBase64()
val client = MatrixClient(
KtorMatrixHttpClientFactory(
@ -94,7 +95,6 @@ class TestMatrix(
installAuthService(storeModule.credentialsStore())
installEncryptionService(storeModule.knownDevicesStore())
val base64 = JavaBase64()
val olmAccountStore = OlmPersistenceWrapper(storeModule.olmStore(), base64)
val olm = OlmWrapper(
olmStore = olmAccountStore,
@ -349,7 +349,7 @@ class JavaImageContentReader : ImageContentReader {
size = size,
mimeType = "image/${file.extension}",
fileName = file.name,
content = file.readBytes()
uri = file.toURI(),
)
}