diff --git a/Cargo.toml b/Cargo.toml index 7e46468..5e28bac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ serde-constant = "0.1.0" serde_json = "1.0.127" tokio = { version = "1.39.3", features = ["macros", "rt-multi-thread", "sync"] } tokio-stream = { version = "0.1.15", features = ["sync"] } -tower-http = { version = "0.5.2", features = ["cors"] } +tower-http = { version = "0.5.2", features = ["cors", "limit"] } tracing = "0.1.40" tracing-subscriber = "0.3.18" uuid = { version = "1.10.0", features = ["serde", "v4"] } diff --git a/src/main.rs b/src/main.rs index c200972..b3ca7c7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,8 +22,9 @@ use tokio::sync::broadcast; use tokio_stream::StreamExt; use uuid::Uuid; -const PAGE_LEN: usize = 1024; +const PAGE_LEN: usize = 64; const EVENT_QUEUE_LEN: usize = 1024; +const MAX_BODY_LEN: usize = 4 << 10; // 4KiB #[derive(Debug, clap::Parser)] struct Cli { @@ -94,6 +95,9 @@ async fn main_async(opt: Cli, st: AppState) -> Result<()> { .route("/room/:ruuid/event", get(room_event)) .route("/room/:ruuid/item", get(room_get_item).post(room_post_item)) .with_state(Arc::new(st)) + .layer(tower_http::limit::RequestBodyLimitLayer::new(MAX_BODY_LEN)) + // NB. This comes at last (outmost layer), so inner errors will still be wraped with + // correct CORS headers. .layer(tower_http::cors::CorsLayer::permissive()); let listener = tokio::net::TcpListener::bind(&opt.listen) @@ -466,7 +470,27 @@ async fn room_event( } }; - let stream = tokio_stream::wrappers::BroadcastStream::new(rx).filter_map(|ret| { + // Do clean up when this stream is closed. + struct CleanOnDrop { + st: Arc, + rid: u64, + } + impl Drop for CleanOnDrop { + fn drop(&mut self) { + if let Ok(mut listeners) = self.st.room_listeners.lock() { + if let Some(tx) = listeners.get(&self.rid) { + if tx.receiver_count() == 0 { + listeners.remove(&self.rid); + } + } + } + } + } + + let _guard = CleanOnDrop { st: st.0, rid }; + + let stream = tokio_stream::wrappers::BroadcastStream::new(rx).filter_map(move |ret| { + let _guard = &_guard; // On stream closure or lagging, close the current stream so client can retry. let item = ret.ok()?; let evt = sse::Event::default()