diff --git a/blahd/src/event.rs b/blahd/src/event.rs index db38adb..2afd052 100644 --- a/blahd/src/event.rs +++ b/blahd/src/event.rs @@ -12,13 +12,13 @@ use axum::extract::ws::{close_code, CloseFrame, Message, WebSocket}; use axum::extract::WebSocketUpgrade; use axum::response::Response; use blah_types::msg::{AuthPayload, SignedChatMsg}; -use blah_types::server::ClientEvent; +use blah_types::server::{ClientEvent, ServerEvent}; use blah_types::Signed; use futures_util::future::Either; use futures_util::stream::SplitSink; use futures_util::{stream_select, SinkExt as _, Stream, StreamExt}; use parking_lot::Mutex; -use serde::{de, Deserialize, Serialize}; +use serde::{de, Deserialize}; use serde_inline_default::serde_inline_default; use tokio::sync::broadcast; use tokio_stream::wrappers::errors::BroadcastStreamRecvError; @@ -27,18 +27,6 @@ use tokio_stream::wrappers::BroadcastStream; use crate::database::TransactionOps; use crate::{AppState, ArcState}; -// We a borrowed type rather than an owned type. -// So redefine it. Not sure if there is a better way. -#[derive(Debug, Serialize)] -#[serde(rename_all = "snake_case")] -enum ServerEvent<'a> { - /// A message from a joined room. - Msg(&'a SignedChatMsg), - /// The receiver is too slow to receive and some events and are dropped. - // FIXME: Should we indefinitely buffer them or just disconnect the client instead? - Lagged, -} - #[serde_inline_default] #[derive(Debug, Clone, Deserialize)] #[serde(default, deny_unknown_fields)] @@ -66,7 +54,25 @@ fn de_duration_sec<'de, D: de::Deserializer<'de>>(de: D) -> Result>, + user_listeners: Mutex>, +} + +impl State { + pub fn on_room_msg(&self, msg: SignedChatMsg, room_members: Vec) { + let listeners = self.user_listeners.lock(); + let mut cnt = 0usize; + let msg = Arc::new(ServerEvent::Msg(msg)); + for uid in &room_members { + if let Some(tx) = listeners.get(uid) { + if tx.send(msg.clone()).is_ok() { + cnt += 1; + } + } + } + if cnt != 0 { + tracing::debug!("broadcasted event to {cnt} clients"); + } + } } #[derive(Debug)] @@ -86,7 +92,7 @@ struct WsSenderWrapper<'ws, 'c> { } impl WsSenderWrapper<'_, '_> { - async fn send(&mut self, msg: &ServerEvent<'_>) -> Result<()> { + async fn send(&mut self, msg: &ServerEvent) -> Result<()> { let data = serde_json::to_string(&msg).expect("serialization cannot fail"); let fut = tokio::time::timeout( self.config.send_timeout_sec, @@ -100,11 +106,11 @@ impl WsSenderWrapper<'_, '_> { } } -type UserEventSender = broadcast::Sender>; +type UserEventSender = broadcast::Sender>; #[derive(Debug)] struct UserEventReceiver { - rx: BroadcastStream>, + rx: BroadcastStream>, st: Arc, uid: i64, } @@ -122,7 +128,7 @@ impl Drop for UserEventReceiver { } impl Stream for UserEventReceiver { - type Item = Result, BroadcastStreamRecvError>; + type Item = Result, BroadcastStreamRecvError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.rx.poll_next_unpin(cx) @@ -183,7 +189,7 @@ async fn handle_ws(st: Arc, ws: &mut WebSocket) -> Result let rx = match st.event.user_listeners.lock().entry(uid) { Entry::Occupied(ent) => ent.get().subscribe(), Entry::Vacant(ent) => { - let (tx, rx) = broadcast::channel(config.event_queue_len); + let (tx, rx) = broadcast::channel::>(config.event_queue_len); ent.insert(tx); rx } @@ -200,12 +206,12 @@ async fn handle_ws(st: Arc, ws: &mut WebSocket) -> Result match stream.next().await.ok_or(StreamEnded)? { Either::Left(msg) => match serde_json::from_str::(&msg?)? {}, Either::Right(ret) => { - let msg = match &ret { - Ok(chat) => ServerEvent::Msg(chat), - Err(BroadcastStreamRecvError::Lagged(_)) => ServerEvent::Lagged, + let event = match &ret { + Ok(event) => &**event, + Err(BroadcastStreamRecvError::Lagged(_)) => &ServerEvent::Lagged, }; // TODO: Concurrent send. - ws_tx.send(&msg).await?; + ws_tx.send(event).await?; } } } diff --git a/blahd/src/lib.rs b/blahd/src/lib.rs index 01a6154..a0979d4 100644 --- a/blahd/src/lib.rs +++ b/blahd/src/lib.rs @@ -459,20 +459,8 @@ async fn post_room_msg( Ok((cid, members)) })?; - let chat = Arc::new(chat); // FIXME: Optimize this to not traverses over all members. - let listeners = st.event.user_listeners.lock(); - let mut cnt = 0usize; - for uid in members { - if let Some(tx) = listeners.get(&uid) { - if tx.send(chat.clone()).is_ok() { - cnt += 1; - } - } - } - if cnt != 0 { - tracing::debug!("broadcasted event to {cnt} clients"); - } + st.event.on_room_msg(chat, members); Ok(Json(cid)) }