adding support for sending encrypted images

This commit is contained in:
Adam Brown 2022-09-21 20:35:39 +01:00
parent 06947301ac
commit f6ef073689
11 changed files with 348 additions and 95 deletions

View File

@ -272,7 +272,7 @@ internal class MatrixModules(
coroutineDispatchers = coroutineDispatchers,
)
val imageContentReader = AndroidImageContentReader(contentResolver)
installMessageService(store.localEchoStore, BackgroundWorkAdapter(workModule.workScheduler()), imageContentReader) { serviceProvider ->
installMessageService(store.localEchoStore, BackgroundWorkAdapter(workModule.workScheduler()), imageContentReader, base64) { serviceProvider ->
MessageEncrypter { message ->
val result = serviceProvider.cryptoService().encrypt(
roomId = message.roomId,

View File

@ -0,0 +1,26 @@
package app.dapk.st.matrix.device.internal
import app.dapk.st.matrix.common.MessageType
import app.dapk.st.matrix.common.RoomId
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
@Serializable
sealed class ApiMessage {
@Serializable
@SerialName("text_message")
data class TextMessage(
@SerialName("content") val content: TextContent,
@SerialName("room_id") val roomId: RoomId,
@SerialName("type") val type: String,
) : ApiMessage() {
@Serializable
data class TextContent(
@SerialName("body") val body: String,
@SerialName("msgtype") val type: String = MessageType.TEXT.value,
)
}
}

View File

@ -2,6 +2,8 @@ plugins { id 'java-test-fixtures' }
applyMatrixServiceModule(project)
dependencies {
implementation project(":core")
kotlinFixtures(it)
testFixturesImplementation(testFixtures(project(":core")))
testFixturesImplementation(testFixtures(project(":matrix:common")))

View File

@ -1,10 +1,14 @@
package app.dapk.st.matrix.message
import app.dapk.st.core.Base64
import app.dapk.st.matrix.MatrixService
import app.dapk.st.matrix.MatrixServiceInstaller
import app.dapk.st.matrix.MatrixServiceProvider
import app.dapk.st.matrix.ServiceDepFactory
import app.dapk.st.matrix.common.*
import app.dapk.st.matrix.common.AlgorithmName
import app.dapk.st.matrix.common.EventId
import app.dapk.st.matrix.common.MessageType
import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.message.internal.DefaultMessageService
import app.dapk.st.matrix.message.internal.ImageContentReader
import kotlinx.coroutines.flow.Flow
@ -67,21 +71,6 @@ interface MessageService : MatrixService {
@SerialName("uri") val uri: String,
) : Content()
@Serializable
data class ImageContent(
@SerialName("url") val url: MxUrl,
@SerialName("body") val filename: String,
@SerialName("info") val info: Info,
@SerialName("msgtype") val type: String = MessageType.IMAGE.value,
) : Content() {
@Serializable
data class Info(
@SerialName("h") val height: Int,
@SerialName("w") val width: Int,
@SerialName("size") val size: Long,
)
}
}
}
@ -141,10 +130,18 @@ fun MatrixServiceInstaller.installMessageService(
localEchoStore: LocalEchoStore,
backgroundScheduler: BackgroundScheduler,
imageContentReader: ImageContentReader,
base64: Base64,
messageEncrypter: ServiceDepFactory<MessageEncrypter> = ServiceDepFactory { MissingMessageEncrypter },
) {
this.install { (httpClient, _, installedServices) ->
SERVICE_KEY to DefaultMessageService(httpClient, localEchoStore, backgroundScheduler, messageEncrypter.create(installedServices), imageContentReader)
SERVICE_KEY to DefaultMessageService(
httpClient,
localEchoStore,
backgroundScheduler,
base64,
messageEncrypter.create(installedServices),
imageContentReader
)
}
}

View File

@ -1,6 +1,7 @@
package app.dapk.st.matrix.message.internal
import app.dapk.st.matrix.common.MessageType
import app.dapk.st.matrix.common.MxUrl
import app.dapk.st.matrix.common.RoomId
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
@ -20,7 +21,52 @@ sealed class ApiMessage {
data class TextContent(
@SerialName("body") val body: String,
@SerialName("msgtype") val type: String = MessageType.TEXT.value,
)
) : ApiMessageContent
}
@Serializable
@SerialName("image_message")
data class ImageMessage(
@SerialName("content") val content: ImageContent,
@SerialName("room_id") val roomId: RoomId,
@SerialName("type") val type: String,
) : ApiMessage() {
@Serializable
data class ImageContent(
@SerialName("url") val url: MxUrl?,
@SerialName("body") val filename: String,
@SerialName("info") val info: Info,
@SerialName("msgtype") val type: String = MessageType.IMAGE.value,
@SerialName("file") val file: File? = null,
) : ApiMessageContent {
@Serializable
data class Info(
@SerialName("h") val height: Int,
@SerialName("w") val width: Int,
@SerialName("size") val size: Long,
)
@Serializable
data class File(
@SerialName("url") val url: MxUrl,
@SerialName("key") val key: EncryptionMeta,
@SerialName("iv") val iv: String,
@SerialName("hashes") val hashes: Map<String, String>,
@SerialName("v") val v: String
) {
@Serializable
data class EncryptionMeta(
@SerialName("alg") val algorithm: String,
@SerialName("ext") val ext: Boolean,
@SerialName("key_ops") val keyOperations: List<String>,
@SerialName("kty") val kty: String,
@SerialName("k") val k: String
)
}
}
}
}
sealed interface ApiMessageContent

View File

@ -1,5 +1,6 @@
package app.dapk.st.matrix.message.internal
import app.dapk.st.core.Base64
import app.dapk.st.matrix.MatrixTaskRunner
import app.dapk.st.matrix.common.RoomId
import app.dapk.st.matrix.http.MatrixHttpClient
@ -19,17 +20,18 @@ internal class DefaultMessageService(
httpClient: MatrixHttpClient,
private val localEchoStore: LocalEchoStore,
private val backgroundScheduler: BackgroundScheduler,
base64: Base64,
messageEncrypter: MessageEncrypter,
imageContentReader: ImageContentReader,
) : MessageService, MatrixTaskRunner {
private val sendMessageUseCase = SendMessageUseCase(httpClient, messageEncrypter, imageContentReader)
private val sendMessageUseCase = SendMessageUseCase(httpClient, messageEncrypter, imageContentReader, base64)
private val sendEventMessageUseCase = SendEventMessageUseCase(httpClient)
override suspend fun canRun(task: MatrixTaskRunner.MatrixTask) = task.type == MATRIX_MESSAGE_TASK_TYPE || task.type == MATRIX_IMAGE_MESSAGE_TASK_TYPE
override suspend fun run(task: MatrixTaskRunner.MatrixTask): MatrixTaskRunner.TaskResult {
val message = when(task.type) {
val message = when (task.type) {
MATRIX_MESSAGE_TASK_TYPE -> Json.decodeFromString(MessageService.Message.TextMessage.serializer(), task.jsonPayload)
MATRIX_IMAGE_MESSAGE_TASK_TYPE -> Json.decodeFromString(MessageService.Message.ImageMessage.serializer(), task.jsonPayload)
else -> throw IllegalStateException("Unhandled task type: ${task.type}")

View File

@ -0,0 +1,100 @@
package app.dapk.st.matrix.message.internal
import app.dapk.st.core.Base64
import java.io.File
import java.io.InputStream
import java.security.MessageDigest
import java.security.SecureRandom
import javax.crypto.Cipher
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.SecretKeySpec
private const val CRYPTO_BUFFER_SIZE = 32 * 1024
private const val CIPHER_ALGORITHM = "AES/CTR/NoPadding"
private const val SECRET_KEY_SPEC_ALGORITHM = "AES"
private const val MESSAGE_DIGEST_ALGORITHM = "SHA-256"
class MediaEncrypter(private val base64: Base64) {
fun encrypt(input: InputStream, name: String): Result {
val secureRandom = SecureRandom()
val initVectorBytes = ByteArray(16) { 0.toByte() }
val ivRandomPart = ByteArray(8)
secureRandom.nextBytes(ivRandomPart)
System.arraycopy(ivRandomPart, 0, initVectorBytes, 0, ivRandomPart.size)
val key = ByteArray(32)
secureRandom.nextBytes(key)
val messageDigest = MessageDigest.getInstance(MESSAGE_DIGEST_ALGORITHM)
val outputFile = File.createTempFile("_encrypt-${name.hashCode()}", ".png")
val outputStream = outputFile.outputStream()
outputStream.use { s ->
val encryptCipher = Cipher.getInstance(CIPHER_ALGORITHM)
val secretKeySpec = SecretKeySpec(key, SECRET_KEY_SPEC_ALGORITHM)
val ivParameterSpec = IvParameterSpec(initVectorBytes)
encryptCipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, ivParameterSpec)
val data = ByteArray(CRYPTO_BUFFER_SIZE)
var read: Int
var encodedBytes: ByteArray
input.use { inputStream ->
read = inputStream.read(data)
var totalRead = read
while (read != -1) {
encodedBytes = encryptCipher.update(data, 0, read)
messageDigest.update(encodedBytes, 0, encodedBytes.size)
s.write(encodedBytes)
read = inputStream.read(data)
totalRead += read
}
}
encodedBytes = encryptCipher.doFinal()
messageDigest.update(encodedBytes, 0, encodedBytes.size)
s.write(encodedBytes)
}
return Result(
contents = outputFile.readBytes(),
algorithm = "A256CTR",
ext = true,
keyOperations = listOf("encrypt", "decrypt"),
kty = "oct",
k = base64ToBase64Url(base64.encode(key)),
iv = base64.encode(initVectorBytes).replace("\n", "").replace("=", ""),
hashes = mapOf("sha256" to base64ToUnpaddedBase64(base64.encode(messageDigest.digest()))),
v = "v2"
)
}
data class Result(
val contents: ByteArray,
val algorithm: String,
val ext: Boolean,
val keyOperations: List<String>,
val kty: String,
val k: String,
val iv: String,
val hashes: Map<String, String>,
val v: String,
)
}
private fun base64ToBase64Url(base64: String): String {
return base64.replace("\n".toRegex(), "")
.replace("\\+".toRegex(), "-")
.replace('/', '_')
.replace("=", "")
}
private fun base64ToUnpaddedBase64(base64: String): String {
return base64.replace("\n".toRegex(), "")
.replace("=", "")
}

View File

@ -1,86 +1,167 @@
package app.dapk.st.matrix.message.internal
import app.dapk.st.matrix.common.EventId
import app.dapk.st.matrix.common.EventType
import app.dapk.st.matrix.common.JsonString
import app.dapk.st.core.Base64
import app.dapk.st.matrix.common.*
import app.dapk.st.matrix.http.MatrixHttpClient
import app.dapk.st.matrix.http.MatrixHttpClient.HttpRequest
import app.dapk.st.matrix.message.ApiSendResponse
import app.dapk.st.matrix.message.MessageEncrypter
import app.dapk.st.matrix.message.MessageService
import app.dapk.st.matrix.message.MessageService.Message
import java.io.ByteArrayInputStream
internal class SendMessageUseCase(
private val httpClient: MatrixHttpClient,
private val messageEncrypter: MessageEncrypter,
private val imageContentReader: ImageContentReader,
private val base64: Base64,
) {
suspend fun sendMessage(message: MessageService.Message): EventId {
return when (message) {
is MessageService.Message.TextMessage -> {
val request = when (message.sendEncrypted) {
true -> {
val content = JsonString(
MatrixHttpClient.jsonWithDefaults.encodeToString(
ApiMessage.TextMessage.serializer(),
ApiMessage.TextMessage(
content = ApiMessage.TextMessage.TextContent(
message.content.body,
message.content.type,
),
roomId = message.roomId,
type = EventType.ROOM_MESSAGE.value
)
)
)
private val mapper = ApiMessageMapper()
sendRequest(
roomId = message.roomId,
eventType = EventType.ENCRYPTED,
txId = message.localId,
content = messageEncrypter.encrypt(MessageEncrypter.ClearMessagePayload(message.roomId, content)),
)
}
false -> {
sendRequest(
roomId = message.roomId,
eventType = EventType.ROOM_MESSAGE,
txId = message.localId,
content = message.content,
)
}
suspend fun sendMessage(message: Message): EventId {
return with(mapper) {
when (message) {
is Message.TextMessage -> {
val request = textMessageRequest(message)
httpClient.execute(request).eventId
}
httpClient.execute(request).eventId
is Message.ImageMessage -> {
val request = imageMessageRequest(message)
httpClient.execute(request).eventId
}
}
}
}
private suspend fun ApiMessageMapper.textMessageRequest(message: Message.TextMessage): HttpRequest<ApiSendResponse> {
val contents = message.toContents()
return when (message.sendEncrypted) {
true -> sendRequest(
roomId = message.roomId,
eventType = EventType.ENCRYPTED,
txId = message.localId,
content = messageEncrypter.encrypt(
MessageEncrypter.ClearMessagePayload(
message.roomId,
contents.toMessageJson(message.roomId)
)
),
)
false -> sendRequest(
roomId = message.roomId,
eventType = EventType.ROOM_MESSAGE,
txId = message.localId,
content = contents,
)
}
}
private suspend fun ApiMessageMapper.imageMessageRequest(message: Message.ImageMessage): HttpRequest<ApiSendResponse> {
val imageContent = imageContentReader.read(message.content.uri)
return when (message.sendEncrypted) {
true -> {
val result = MediaEncrypter(base64).encrypt(
ByteArrayInputStream(imageContent.content),
imageContent.fileName,
)
val uri = httpClient.execute(uploadRequest(result.contents, imageContent.fileName, "application/octet-stream")).contentUri
val content = ApiMessage.ImageMessage.ImageContent(
url = null,
filename = imageContent.fileName,
file = ApiMessage.ImageMessage.ImageContent.File(
url = uri,
key = ApiMessage.ImageMessage.ImageContent.File.EncryptionMeta(
algorithm = result.algorithm,
ext = result.ext,
keyOperations = result.keyOperations,
kty = result.kty,
k = result.k,
),
iv = result.iv,
hashes = result.hashes,
v = result.v,
),
info = ApiMessage.ImageMessage.ImageContent.Info(
height = imageContent.height,
width = imageContent.width,
size = imageContent.size
)
)
val json = JsonString(
MatrixHttpClient.jsonWithDefaults.encodeToString(
ApiMessage.ImageMessage.serializer(),
ApiMessage.ImageMessage(
content = content,
roomId = message.roomId,
type = EventType.ROOM_MESSAGE.value,
)
)
)
sendRequest(
roomId = message.roomId,
eventType = EventType.ENCRYPTED,
txId = message.localId,
content = messageEncrypter.encrypt(MessageEncrypter.ClearMessagePayload(message.roomId, json)),
)
}
is MessageService.Message.ImageMessage -> {
val request = when (message.sendEncrypted) {
true -> {
throw IllegalStateException()
}
false -> {
val imageContent = imageContentReader.read(message.content.uri)
val uri = httpClient.execute(uploadRequest(imageContent.content, imageContent.fileName, imageContent.mimeType)).contentUri
sendRequest(
roomId = message.roomId,
eventType = EventType.ROOM_MESSAGE,
txId = message.localId,
content = MessageService.Message.Content.ImageContent(
url = uri,
filename = imageContent.fileName,
MessageService.Message.Content.ImageContent.Info(
height = imageContent.height,
width = imageContent.width,
size = imageContent.size
)
),
false -> {
val uri = httpClient.execute(uploadRequest(imageContent.content, imageContent.fileName, imageContent.mimeType)).contentUri
sendRequest(
roomId = message.roomId,
eventType = EventType.ROOM_MESSAGE,
txId = message.localId,
content = ApiMessage.ImageMessage.ImageContent(
url = uri,
filename = imageContent.fileName,
ApiMessage.ImageMessage.ImageContent.Info(
height = imageContent.height,
width = imageContent.width,
size = imageContent.size
)
}
}
httpClient.execute(request).eventId
),
)
}
}
}
}
class ApiMessageMapper {
fun Message.TextMessage.toContents() = ApiMessage.TextMessage.TextContent(
this.content.body,
this.content.type,
)
fun ApiMessage.TextMessage.TextContent.toMessageJson(roomId: RoomId) = JsonString(
MatrixHttpClient.jsonWithDefaults.encodeToString(
ApiMessage.TextMessage.serializer(),
ApiMessage.TextMessage(
content = this,
roomId = roomId,
type = EventType.ROOM_MESSAGE.value
)
)
)
fun Message.ImageMessage.toContents(uri: MxUrl, image: ImageContentReader.ImageContent) = ApiMessage.ImageMessage.ImageContent(
url = uri,
filename = image.fileName,
ApiMessage.ImageMessage.ImageContent.Info(
height = image.height,
width = image.width,
size = image.size
)
)
}

View File

@ -9,18 +9,18 @@ import app.dapk.st.matrix.message.ApiSendResponse
import app.dapk.st.matrix.message.ApiUploadResponse
import app.dapk.st.matrix.message.MessageEncrypter
import app.dapk.st.matrix.message.MessageService.EventMessage
import app.dapk.st.matrix.message.MessageService.Message
import app.dapk.st.matrix.message.internal.ApiMessage.ImageMessage
import app.dapk.st.matrix.message.internal.ApiMessage.TextMessage
import io.ktor.content.*
import io.ktor.http.*
import java.util.*
internal fun sendRequest(roomId: RoomId, eventType: EventType, txId: String, content: Message.Content) = httpRequest<ApiSendResponse>(
internal fun sendRequest(roomId: RoomId, eventType: EventType, txId: String, content: ApiMessageContent) = httpRequest<ApiSendResponse>(
path = "_matrix/client/r0/rooms/${roomId.value}/send/${eventType.value}/${txId}",
method = MatrixHttpClient.Method.PUT,
body = when (content) {
is Message.Content.TextContent -> jsonBody(Message.Content.TextContent.serializer(), content, MatrixHttpClient.jsonWithDefaults)
is Message.Content.ImageContent -> jsonBody(Message.Content.ImageContent.serializer(), content, MatrixHttpClient.jsonWithDefaults)
is Message.Content.ApiImageContent -> throw IllegalArgumentException()
is TextMessage.TextContent -> jsonBody(TextMessage.TextContent.serializer(), content, MatrixHttpClient.jsonWithDefaults)
is ImageMessage.ImageContent -> jsonBody(ImageMessage.ImageContent.serializer(), content, MatrixHttpClient.jsonWithDefaults)
}
)
@ -45,5 +45,4 @@ internal fun uploadRequest(body: ByteArray, filename: String, contentType: Strin
body = ByteArrayContent(body, ContentType.parse(contentType)),
)
fun txId() = "local.${UUID.randomUUID()}"

View File

@ -75,7 +75,7 @@ class SmokeTest {
@Order(6)
fun `can send and receive clear image messages`() = testAfterInitialSync { alice, bob ->
val testImage = loadResourceFile("test-image.png")
alice.sendImageMessage(SharedState.sharedRoom, testImage, isEncrypted = false)
alice.sendImageMessage(SharedState.sharedRoom, testImage, isEncrypted = true)
bob.expectImageMessage(SharedState.sharedRoom, testImage, SharedState.alice.roomMember)
}

View File

@ -121,7 +121,7 @@ class TestMatrix(
coroutineDispatchers = coroutineDispatchers,
)
installMessageService(storeModule.localEchoStore, InstantScheduler(it), JavaImageContentReader()) { serviceProvider ->
installMessageService(storeModule.localEchoStore, InstantScheduler(it), JavaImageContentReader(), base64) { serviceProvider ->
MessageEncrypter { message ->
val result = serviceProvider.cryptoService().encrypt(
roomId = message.roomId,