diff --git a/apps/desktop/desktop_native/core/Cargo.toml b/apps/desktop/desktop_native/core/Cargo.toml index e53d12c702..039cd19c6a 100644 --- a/apps/desktop/desktop_native/core/Cargo.toml +++ b/apps/desktop/desktop_native/core/Cargo.toml @@ -53,7 +53,7 @@ ssh-key = { version = "=0.6.6", default-features = false, features = [ bitwarden-russh = { git = "https://github.com/bitwarden/bitwarden-russh.git", branch = "km/pm-10098/clean-russh-implementation" } tokio = { version = "=1.40.0", features = ["io-util", "sync", "macros", "net"] } tokio-stream = { version = "=0.1.15", features = ["net"] } -tokio-util = "=0.7.12" +tokio-util = { version = "0.7.11", features = ["codec"] } thiserror = "=1.0.69" typenum = "=1.17.0" rand_chacha = "=0.3.1" diff --git a/apps/desktop/desktop_native/core/src/ipc/client.rs b/apps/desktop/desktop_native/core/src/ipc/client.rs index 7eff8a1097..6c4ca0a605 100644 --- a/apps/desktop/desktop_native/core/src/ipc/client.rs +++ b/apps/desktop/desktop_native/core/src/ipc/client.rs @@ -1,13 +1,11 @@ use std::path::PathBuf; +use futures::{SinkExt, StreamExt}; use interprocess::local_socket::{ tokio::{prelude::*, Stream}, GenericFilePath, ToFsName, }; use log::{error, info}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; - -use crate::ipc::NATIVE_MESSAGING_BUFFER_SIZE; pub async fn connect( path: PathBuf, @@ -17,7 +15,9 @@ pub async fn connect( info!("Attempting to connect to {}", path.display()); let name = path.as_os_str().to_fs_name::()?; - let mut conn = Stream::connect(name).await?; + let conn = Stream::connect(name).await?; + + let mut conn = crate::ipc::internal_ipc_codec(conn); info!("Connected to {}", path.display()); @@ -26,8 +26,6 @@ pub async fn connect( // As it's only two, we hardcode the JSON values to avoid pulling in a JSON library. send.send("{\"command\":\"connected\"}".to_owned()).await?; - let mut buffer = vec![0; NATIVE_MESSAGING_BUFFER_SIZE]; - // Listen to IPC messages loop { tokio::select! { @@ -35,7 +33,7 @@ pub async fn connect( msg = recv.recv() => { match msg { Some(msg) => { - conn.write_all(msg.as_bytes()).await?; + conn.send(msg.into()).await?; } None => { info!("Client channel closed"); @@ -45,18 +43,18 @@ pub async fn connect( }, // Forward messages from the IPC server - res = conn.read(&mut buffer[..]) => { + res = conn.next() => { match res { - Err(e) => { + Some(Err(e)) => { error!("Error reading from IPC server: {e}"); break; } - Ok(0) => { + None => { info!("Connection closed"); break; } - Ok(n) => { - let message = String::from_utf8_lossy(&buffer[..n]).to_string(); + Some(Ok(bytes)) => { + let message = String::from_utf8_lossy(&bytes).to_string(); send.send(message).await?; } } diff --git a/apps/desktop/desktop_native/core/src/ipc/mod.rs b/apps/desktop/desktop_native/core/src/ipc/mod.rs index d406b6aa13..6873f0cfb8 100644 --- a/apps/desktop/desktop_native/core/src/ipc/mod.rs +++ b/apps/desktop/desktop_native/core/src/ipc/mod.rs @@ -1,3 +1,6 @@ +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::{Framed, LengthDelimitedCodec}; + pub mod client; pub mod server; @@ -16,6 +19,16 @@ pub const NATIVE_MESSAGING_BUFFER_SIZE: usize = 1024 * 1024; /// but ideally the messages should be processed as quickly as possible. pub const MESSAGE_CHANNEL_BUFFER: usize = 32; +/// This is the codec used for communication through the UNIX socket / Windows named pipe. +/// It's an internal implementation detail, but we want to make sure that both the client +/// and the server use the same one. +fn internal_ipc_codec(inner: T) -> Framed { + LengthDelimitedCodec::builder() + .max_frame_length(NATIVE_MESSAGING_BUFFER_SIZE) + .native_endian() + .new_framed(inner) +} + /// Resolve the path to the IPC socket. pub fn path(name: &str) -> std::path::PathBuf { #[cfg(target_os = "windows")] diff --git a/apps/desktop/desktop_native/core/src/ipc/server.rs b/apps/desktop/desktop_native/core/src/ipc/server.rs index 0aa1cf3017..a1c77e7ab1 100644 --- a/apps/desktop/desktop_native/core/src/ipc/server.rs +++ b/apps/desktop/desktop_native/core/src/ipc/server.rs @@ -1,21 +1,20 @@ use std::{ error::Error, path::{Path, PathBuf}, - vec, }; -use futures::TryFutureExt; +use futures::{SinkExt, StreamExt, TryFutureExt}; use anyhow::Result; use interprocess::local_socket::{tokio::prelude::*, GenericFilePath, ListenerOptions}; use log::{error, info}; use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + io::{AsyncRead, AsyncWrite}, sync::{broadcast, mpsc}, }; use tokio_util::sync::CancellationToken; -use super::{MESSAGE_CHANNEL_BUFFER, NATIVE_MESSAGING_BUFFER_SIZE}; +use super::MESSAGE_CHANNEL_BUFFER; #[derive(Debug)] pub struct Message { @@ -158,7 +157,7 @@ async fn listen_incoming( } async fn handle_connection( - mut client_stream: impl AsyncRead + AsyncWrite + Unpin, + client_stream: impl AsyncRead + AsyncWrite + Unpin, client_to_server_send: mpsc::Sender, mut server_to_clients_recv: broadcast::Receiver, cancel_token: CancellationToken, @@ -172,7 +171,7 @@ async fn handle_connection( }) .await?; - let mut buf = vec![0u8; NATIVE_MESSAGING_BUFFER_SIZE]; + let mut client_stream = crate::ipc::internal_ipc_codec(client_stream); loop { tokio::select! { @@ -185,7 +184,7 @@ async fn handle_connection( msg = server_to_clients_recv.recv() => { match msg { Ok(msg) => { - client_stream.write_all(msg.as_bytes()).await?; + client_stream.send(msg.into()).await?; }, Err(e) => { info!("Error reading message: {}", e); @@ -197,9 +196,9 @@ async fn handle_connection( // Forwards messages from the IPC clients to the server // Note that we also send connect and disconnect events so that // the server can keep track of multiple clients - result = client_stream.read(&mut buf) => { + result = client_stream.next() => { match result { - Err(e) => { + Some(Err(e)) => { info!("Error reading from client {client_id}: {e}"); client_to_server_send.send(Message { @@ -209,7 +208,7 @@ async fn handle_connection( }).await?; break; }, - Ok(0) => { + None => { info!("Client {client_id} disconnected."); client_to_server_send.send(Message { @@ -219,8 +218,8 @@ async fn handle_connection( }).await?; break; }, - Ok(size) => { - let msg = std::str::from_utf8(&buf[..size])?; + Some(Ok(bytes)) => { + let msg = std::str::from_utf8(&bytes)?; client_to_server_send.send(Message { client_id,