refactor(event): decouple states from handlers and remove aux-types

This commit is contained in:
oxalica 2024-10-08 21:17:34 -04:00
parent c611396331
commit 814fac1974
2 changed files with 31 additions and 37 deletions

View file

@ -12,13 +12,13 @@ use axum::extract::ws::{close_code, CloseFrame, Message, WebSocket};
use axum::extract::WebSocketUpgrade;
use axum::response::Response;
use blah_types::msg::{AuthPayload, SignedChatMsg};
use blah_types::server::ClientEvent;
use blah_types::server::{ClientEvent, ServerEvent};
use blah_types::Signed;
use futures_util::future::Either;
use futures_util::stream::SplitSink;
use futures_util::{stream_select, SinkExt as _, Stream, StreamExt};
use parking_lot::Mutex;
use serde::{de, Deserialize, Serialize};
use serde::{de, Deserialize};
use serde_inline_default::serde_inline_default;
use tokio::sync::broadcast;
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
@ -27,18 +27,6 @@ use tokio_stream::wrappers::BroadcastStream;
use crate::database::TransactionOps;
use crate::{AppState, ArcState};
// We a borrowed type rather than an owned type.
// So redefine it. Not sure if there is a better way.
#[derive(Debug, Serialize)]
#[serde(rename_all = "snake_case")]
enum ServerEvent<'a> {
/// A message from a joined room.
Msg(&'a SignedChatMsg),
/// The receiver is too slow to receive and some events and are dropped.
// FIXME: Should we indefinitely buffer them or just disconnect the client instead?
Lagged,
}
#[serde_inline_default]
#[derive(Debug, Clone, Deserialize)]
#[serde(default, deny_unknown_fields)]
@ -66,7 +54,25 @@ fn de_duration_sec<'de, D: de::Deserializer<'de>>(de: D) -> Result<Duration, D::
#[derive(Debug, Default)]
pub struct State {
pub user_listeners: Mutex<HashMap<i64, UserEventSender>>,
user_listeners: Mutex<HashMap<i64, UserEventSender>>,
}
impl State {
pub fn on_room_msg(&self, msg: SignedChatMsg, room_members: Vec<i64>) {
let listeners = self.user_listeners.lock();
let mut cnt = 0usize;
let msg = Arc::new(ServerEvent::Msg(msg));
for uid in &room_members {
if let Some(tx) = listeners.get(uid) {
if tx.send(msg.clone()).is_ok() {
cnt += 1;
}
}
}
if cnt != 0 {
tracing::debug!("broadcasted event to {cnt} clients");
}
}
}
#[derive(Debug)]
@ -86,7 +92,7 @@ struct WsSenderWrapper<'ws, 'c> {
}
impl WsSenderWrapper<'_, '_> {
async fn send(&mut self, msg: &ServerEvent<'_>) -> Result<()> {
async fn send(&mut self, msg: &ServerEvent) -> Result<()> {
let data = serde_json::to_string(&msg).expect("serialization cannot fail");
let fut = tokio::time::timeout(
self.config.send_timeout_sec,
@ -100,11 +106,11 @@ impl WsSenderWrapper<'_, '_> {
}
}
type UserEventSender = broadcast::Sender<Arc<SignedChatMsg>>;
type UserEventSender = broadcast::Sender<Arc<ServerEvent>>;
#[derive(Debug)]
struct UserEventReceiver {
rx: BroadcastStream<Arc<SignedChatMsg>>,
rx: BroadcastStream<Arc<ServerEvent>>,
st: Arc<AppState>,
uid: i64,
}
@ -122,7 +128,7 @@ impl Drop for UserEventReceiver {
}
impl Stream for UserEventReceiver {
type Item = Result<Arc<SignedChatMsg>, BroadcastStreamRecvError>;
type Item = Result<Arc<ServerEvent>, BroadcastStreamRecvError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx.poll_next_unpin(cx)
@ -183,7 +189,7 @@ async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallible>
let rx = match st.event.user_listeners.lock().entry(uid) {
Entry::Occupied(ent) => ent.get().subscribe(),
Entry::Vacant(ent) => {
let (tx, rx) = broadcast::channel(config.event_queue_len);
let (tx, rx) = broadcast::channel::<Arc<ServerEvent>>(config.event_queue_len);
ent.insert(tx);
rx
}
@ -200,12 +206,12 @@ async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallible>
match stream.next().await.ok_or(StreamEnded)? {
Either::Left(msg) => match serde_json::from_str::<ClientEvent>(&msg?)? {},
Either::Right(ret) => {
let msg = match &ret {
Ok(chat) => ServerEvent::Msg(chat),
Err(BroadcastStreamRecvError::Lagged(_)) => ServerEvent::Lagged,
let event = match &ret {
Ok(event) => &**event,
Err(BroadcastStreamRecvError::Lagged(_)) => &ServerEvent::Lagged,
};
// TODO: Concurrent send.
ws_tx.send(&msg).await?;
ws_tx.send(event).await?;
}
}
}

View file

@ -459,20 +459,8 @@ async fn post_room_msg(
Ok((cid, members))
})?;
let chat = Arc::new(chat);
// FIXME: Optimize this to not traverses over all members.
let listeners = st.event.user_listeners.lock();
let mut cnt = 0usize;
for uid in members {
if let Some(tx) = listeners.get(&uid) {
if tx.send(chat.clone()).is_ok() {
cnt += 1;
}
}
}
if cnt != 0 {
tracing::debug!("broadcasted event to {cnt} clients");
}
st.event.on_room_msg(chat, members);
Ok(Json(cid))
}