diff --git a/.env.template b/.env.template index 46aa6271..61cd046b 100644 --- a/.env.template +++ b/.env.template @@ -84,12 +84,8 @@ ### WebSocket ### ################# -## Enables websocket notifications -# WEBSOCKET_ENABLED=false - -## Controls the WebSocket server address and port -# WEBSOCKET_ADDRESS=0.0.0.0 -# WEBSOCKET_PORT=3012 +## Enable websocket notifications +# ENABLE_WEBSOCKET=true ########################## ### Push notifications ### diff --git a/Cargo.lock b/Cargo.lock index 86f0f234..b83eb071 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3784,7 +3784,6 @@ dependencies = [ "syslog", "time", "tokio", - "tokio-tungstenite", "totp-lite", "tracing", "url", diff --git a/Cargo.toml b/Cargo.toml index e5a3edd9..26916626 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,7 +60,6 @@ rocket = { version = "0.5.0", features = ["tls", "json"], default-features = fal rocket_ws = { version ="0.1.0" } # WebSockets libraries -tokio-tungstenite = "0.20.1" rmpv = "1.0.1" # MessagePack library # Concurrent HashMap used for WebSocket messaging and favicons diff --git a/src/api/mod.rs b/src/api/mod.rs index 99915bdf..c6838aaa 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -23,7 +23,7 @@ pub use crate::api::{ icons::routes as icons_routes, identity::routes as identity_routes, notifications::routes as notifications_routes, - notifications::{start_notification_server, AnonymousNotify, Notify, UpdateType, WS_ANONYMOUS_SUBSCRIPTIONS}, + notifications::{AnonymousNotify, Notify, UpdateType, WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS}, push::{ push_cipher_update, push_folder_update, push_logout, push_send_update, push_user_update, register_push_device, unregister_push_device, diff --git a/src/api/notifications.rs b/src/api/notifications.rs index da2664cf..1f64b86e 100644 --- a/src/api/notifications.rs +++ b/src/api/notifications.rs @@ -1,23 +1,11 @@ -use std::{ - net::{IpAddr, SocketAddr}, - sync::Arc, - time::Duration, -}; +use std::{net::IpAddr, sync::Arc, time::Duration}; use chrono::{NaiveDateTime, Utc}; use rmpv::Value; -use rocket::{ - futures::{SinkExt, StreamExt}, - Route, -}; -use tokio::{ - net::{TcpListener, TcpStream}, - sync::mpsc::Sender, -}; -use tokio_tungstenite::{ - accept_hdr_async, - tungstenite::{handshake, Message}, -}; +use rocket::{futures::StreamExt, Route}; +use tokio::sync::mpsc::Sender; + +use rocket_ws::{Message, WebSocket}; use crate::{ auth::{ClientIp, WsAccessTokenHeader}, @@ -30,7 +18,7 @@ use crate::{ use once_cell::sync::Lazy; -static WS_USERS: Lazy> = Lazy::new(|| { +pub static WS_USERS: Lazy> = Lazy::new(|| { Arc::new(WebSocketUsers { map: Arc::new(dashmap::DashMap::new()), }) @@ -47,8 +35,15 @@ use super::{ push_send_update, push_user_update, }; +static NOTIFICATIONS_DISABLED: Lazy = Lazy::new(|| !CONFIG.enable_websocket() && !CONFIG.push_enabled()); + pub fn routes() -> Vec { - routes![websockets_hub, anonymous_websockets_hub] + if CONFIG.enable_websocket() { + routes![websockets_hub, anonymous_websockets_hub] + } else { + info!("WebSocket are disabled, realtime sync functionality will not work!"); + routes![] + } } #[derive(FromForm, Debug)] @@ -108,7 +103,7 @@ impl Drop for WSAnonymousEntryMapGuard { #[get("/hub?")] fn websockets_hub<'r>( - ws: rocket_ws::WebSocket, + ws: WebSocket, data: WsAccessToken, ip: ClientIp, header_token: WsAccessTokenHeader, @@ -192,11 +187,7 @@ fn websockets_hub<'r>( } #[get("/anonymous-hub?")] -fn anonymous_websockets_hub<'r>( - ws: rocket_ws::WebSocket, - token: String, - ip: ClientIp, -) -> Result { +fn anonymous_websockets_hub<'r>(ws: WebSocket, token: String, ip: ClientIp) -> Result { let addr = ip.ip; info!("Accepting Anonymous Rocket WS connection from {addr}"); @@ -349,13 +340,19 @@ impl WebSocketUsers { // NOTE: The last modified date needs to be updated before calling these methods pub async fn send_user_update(&self, ut: UpdateType, user: &User) { + // Skip any processing if both WebSockets and Push are not active + if *NOTIFICATIONS_DISABLED { + return; + } let data = create_update( vec![("UserId".into(), user.uuid.clone().into()), ("Date".into(), serialize_date(user.updated_at))], ut, None, ); - self.send_update(&user.uuid, &data).await; + if CONFIG.enable_websocket() { + self.send_update(&user.uuid, &data).await; + } if CONFIG.push_enabled() { push_user_update(ut, user); @@ -363,13 +360,19 @@ impl WebSocketUsers { } pub async fn send_logout(&self, user: &User, acting_device_uuid: Option) { + // Skip any processing if both WebSockets and Push are not active + if *NOTIFICATIONS_DISABLED { + return; + } let data = create_update( vec![("UserId".into(), user.uuid.clone().into()), ("Date".into(), serialize_date(user.updated_at))], UpdateType::LogOut, acting_device_uuid.clone(), ); - self.send_update(&user.uuid, &data).await; + if CONFIG.enable_websocket() { + self.send_update(&user.uuid, &data).await; + } if CONFIG.push_enabled() { push_logout(user, acting_device_uuid); @@ -383,6 +386,10 @@ impl WebSocketUsers { acting_device_uuid: &String, conn: &mut DbConn, ) { + // Skip any processing if both WebSockets and Push are not active + if *NOTIFICATIONS_DISABLED { + return; + } let data = create_update( vec![ ("Id".into(), folder.uuid.clone().into()), @@ -393,7 +400,9 @@ impl WebSocketUsers { Some(acting_device_uuid.into()), ); - self.send_update(&folder.user_uuid, &data).await; + if CONFIG.enable_websocket() { + self.send_update(&folder.user_uuid, &data).await; + } if CONFIG.push_enabled() { push_folder_update(ut, folder, acting_device_uuid, conn).await; @@ -409,6 +418,10 @@ impl WebSocketUsers { collection_uuids: Option>, conn: &mut DbConn, ) { + // Skip any processing if both WebSockets and Push are not active + if *NOTIFICATIONS_DISABLED { + return; + } let org_uuid = convert_option(cipher.organization_uuid.clone()); // Depending if there are collections provided or not, we need to have different values for the following variables. // The user_uuid should be `null`, and the revision date should be set to now, else the clients won't sync the collection change. @@ -434,8 +447,10 @@ impl WebSocketUsers { Some(acting_device_uuid.into()), ); - for uuid in user_uuids { - self.send_update(uuid, &data).await; + if CONFIG.enable_websocket() { + for uuid in user_uuids { + self.send_update(uuid, &data).await; + } } if CONFIG.push_enabled() && user_uuids.len() == 1 { @@ -451,6 +466,10 @@ impl WebSocketUsers { acting_device_uuid: &String, conn: &mut DbConn, ) { + // Skip any processing if both WebSockets and Push are not active + if *NOTIFICATIONS_DISABLED { + return; + } let user_uuid = convert_option(send.user_uuid.clone()); let data = create_update( @@ -463,8 +482,10 @@ impl WebSocketUsers { None, ); - for uuid in user_uuids { - self.send_update(uuid, &data).await; + if CONFIG.enable_websocket() { + for uuid in user_uuids { + self.send_update(uuid, &data).await; + } } if CONFIG.push_enabled() && user_uuids.len() == 1 { push_send_update(ut, send, acting_device_uuid, conn).await; @@ -478,12 +499,18 @@ impl WebSocketUsers { acting_device_uuid: &String, conn: &mut DbConn, ) { + // Skip any processing if both WebSockets and Push are not active + if *NOTIFICATIONS_DISABLED { + return; + } let data = create_update( vec![("Id".into(), auth_request_uuid.clone().into()), ("UserId".into(), user_uuid.clone().into())], UpdateType::AuthRequest, Some(acting_device_uuid.to_string()), ); - self.send_update(user_uuid, &data).await; + if CONFIG.enable_websocket() { + self.send_update(user_uuid, &data).await; + } if CONFIG.push_enabled() { push_auth_request(user_uuid.to_string(), auth_request_uuid.to_string(), conn).await; @@ -497,12 +524,18 @@ impl WebSocketUsers { approving_device_uuid: String, conn: &mut DbConn, ) { + // Skip any processing if both WebSockets and Push are not active + if *NOTIFICATIONS_DISABLED { + return; + } let data = create_update( vec![("Id".into(), auth_response_uuid.to_owned().into()), ("UserId".into(), user_uuid.clone().into())], UpdateType::AuthRequestResponse, approving_device_uuid.clone().into(), ); - self.send_update(auth_response_uuid, &data).await; + if CONFIG.enable_websocket() { + self.send_update(auth_response_uuid, &data).await; + } if CONFIG.push_enabled() { push_auth_response(user_uuid.to_string(), auth_response_uuid.to_string(), approving_device_uuid, conn) @@ -526,6 +559,9 @@ impl AnonymousWebSocketSubscriptions { } pub async fn send_auth_response(&self, user_uuid: &String, auth_response_uuid: &str) { + if !CONFIG.enable_websocket() { + return; + } let data = create_anonymous_update( vec![("Id".into(), auth_response_uuid.to_owned().into()), ("UserId".into(), user_uuid.clone().into())], UpdateType::AuthRequestResponse, @@ -620,127 +656,3 @@ pub enum UpdateType { pub type Notify<'a> = &'a rocket::State>; pub type AnonymousNotify<'a> = &'a rocket::State>; - -pub fn start_notification_server() -> Arc { - let users = Arc::clone(&WS_USERS); - if CONFIG.websocket_enabled() { - let users2 = Arc::::clone(&users); - tokio::spawn(async move { - let addr = (CONFIG.websocket_address(), CONFIG.websocket_port()); - info!("Starting WebSockets server on {}:{}", addr.0, addr.1); - let listener = TcpListener::bind(addr).await.expect("Can't listen on websocket port"); - - let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>(); - CONFIG.set_ws_shutdown_handle(shutdown_tx); - - loop { - tokio::select! { - Ok((stream, addr)) = listener.accept() => { - tokio::spawn(handle_connection(stream, Arc::::clone(&users2), addr)); - } - - _ = &mut shutdown_rx => { - break; - } - } - } - - info!("Shutting down WebSockets server!") - }); - } - - users -} - -async fn handle_connection(stream: TcpStream, users: Arc, addr: SocketAddr) -> Result<(), Error> { - let mut user_uuid: Option = None; - - info!("Accepting WS connection from {addr}"); - - // Accept connection, do initial handshake, validate auth token and get the user ID - use handshake::server::{Request, Response}; - let mut stream = accept_hdr_async(stream, |req: &Request, res: Response| { - if let Some(token) = get_request_token(req) { - if let Ok(claims) = crate::auth::decode_login(&token) { - user_uuid = Some(claims.sub); - return Ok(res); - } - } - Err(Response::builder().status(401).body(None).unwrap()) - }) - .await?; - - let user_uuid = user_uuid.expect("User UUID should be set after the handshake"); - - let (mut rx, guard) = { - // Add a channel to send messages to this client to the map - let entry_uuid = uuid::Uuid::new_v4(); - let (tx, rx) = tokio::sync::mpsc::channel::(100); - users.map.entry(user_uuid.clone()).or_default().push((entry_uuid, tx)); - - // Once the guard goes out of scope, the connection will have been closed and the entry will be deleted from the map - (rx, WSEntryMapGuard::new(users, user_uuid, entry_uuid, addr.ip())) - }; - - let _guard = guard; - let mut interval = tokio::time::interval(Duration::from_secs(15)); - loop { - tokio::select! { - res = stream.next() => { - match res { - Some(Ok(message)) => { - match message { - // Respond to any pings - Message::Ping(ping) => stream.send(Message::Pong(ping)).await?, - Message::Pong(_) => {/* Ignored */}, - - // We should receive an initial message with the protocol and version, and we will reply to it - Message::Text(ref message) => { - let msg = message.strip_suffix(RECORD_SEPARATOR as char).unwrap_or(message); - - if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) { - stream.send(Message::binary(INITIAL_RESPONSE)).await?; - continue; - } - } - // Just echo anything else the client sends - _ => stream.send(message).await?, - } - } - _ => break, - } - } - - res = rx.recv() => { - match res { - Some(res) => stream.send(res).await?, - None => break, - } - } - - _ = interval.tick() => stream.send(Message::Ping(create_ping())).await? - } - } - - Ok(()) -} - -fn get_request_token(req: &handshake::server::Request) -> Option { - const ACCESS_TOKEN_KEY: &str = "access_token="; - - if let Some(Ok(auth)) = req.headers().get("Authorization").map(|a| a.to_str()) { - if let Some(token_part) = auth.strip_prefix("Bearer ") { - return Some(token_part.to_owned()); - } - } - - if let Some(params) = req.uri().query() { - let params_iter = params.split('&').take(1); - for val in params_iter { - if let Some(stripped) = val.strip_prefix(ACCESS_TOKEN_KEY) { - return Some(stripped.to_owned()); - } - } - } - None -} diff --git a/src/config.rs b/src/config.rs index e174c66b..01f387ec 100644 --- a/src/config.rs +++ b/src/config.rs @@ -39,7 +39,6 @@ macro_rules! make_config { struct Inner { rocket_shutdown_handle: Option, - ws_shutdown_handle: Option>, templates: Handlebars<'static>, config: ConfigItems, @@ -361,7 +360,7 @@ make_config! { /// Sends folder sends_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "sends"); /// Temp folder |> Used for storing temporary file uploads - tmp_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "tmp"); + tmp_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "tmp"); /// Templates folder templates_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "templates"); /// Session JWT key @@ -371,11 +370,7 @@ make_config! { }, ws { /// Enable websocket notifications - websocket_enabled: bool, false, def, false; - /// Websocket address - websocket_address: String, false, def, "0.0.0.0".to_string(); - /// Websocket port - websocket_port: u16, false, def, 3012; + enable_websocket: bool, false, def, true; }, push { /// Enable push notifications @@ -1071,7 +1066,6 @@ impl Config { Ok(Config { inner: RwLock::new(Inner { rocket_shutdown_handle: None, - ws_shutdown_handle: None, templates: load_templates(&config.templates_folder), config, _env, @@ -1237,16 +1231,8 @@ impl Config { self.inner.write().unwrap().rocket_shutdown_handle = Some(handle); } - pub fn set_ws_shutdown_handle(&self, handle: tokio::sync::oneshot::Sender<()>) { - self.inner.write().unwrap().ws_shutdown_handle = Some(handle); - } - pub fn shutdown(&self) { if let Ok(mut c) = self.inner.write() { - if let Some(handle) = c.ws_shutdown_handle.take() { - handle.send(()).ok(); - } - if let Some(handle) = c.rocket_shutdown_handle.take() { handle.notify(); } diff --git a/src/error.rs b/src/error.rs index f0969bff..784aad6a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -52,7 +52,6 @@ use rocket::error::Error as RocketErr; use serde_json::{Error as SerdeErr, Value}; use std::io::Error as IoErr; use std::time::SystemTimeError as TimeErr; -use tokio_tungstenite::tungstenite::Error as TungstError; use webauthn_rs::error::WebauthnError as WebauthnErr; use yubico::yubicoerror::YubicoError as YubiErr; @@ -91,7 +90,6 @@ make_error! { DieselCon(DieselConErr): _has_source, _api_error, Webauthn(WebauthnErr): _has_source, _api_error, - WebSocket(TungstError): _has_source, _api_error, } impl std::fmt::Debug for Error { diff --git a/src/main.rs b/src/main.rs index 53b72606..285dc33a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -52,7 +52,7 @@ mod ratelimit; mod util; use crate::api::purge_auth_requests; -use crate::api::WS_ANONYMOUS_SUBSCRIPTIONS; +use crate::api::{WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS}; pub use config::CONFIG; pub use error::{Error, MapResult}; use rocket::data::{Limits, ToByteUnit}; @@ -497,7 +497,7 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error> .register([basepath, "/api"].concat(), api::core_catchers()) .register([basepath, "/admin"].concat(), api::admin_catchers()) .manage(pool) - .manage(api::start_notification_server()) + .manage(Arc::clone(&WS_USERS)) .manage(Arc::clone(&WS_ANONYMOUS_SUBSCRIPTIONS)) .attach(util::AppHeaders()) .attach(util::Cors())