use std::collections::hash_map::Entry; use std::collections::HashMap; use std::convert::Infallible; use std::fmt; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; use anyhow::{bail, Context as _, Result}; use axum::extract::ws::{Message, WebSocket}; use blah_types::msg::{AuthPayload, SignedChatMsg}; use blah_types::server::ClientEvent; use blah_types::Signed; use futures_util::future::Either; use futures_util::stream::SplitSink; use futures_util::{stream_select, SinkExt as _, Stream, StreamExt}; use parking_lot::Mutex; use serde::{de, Deserialize, Serialize}; use serde_inline_default::serde_inline_default; use tokio::sync::broadcast; use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use tokio_stream::wrappers::BroadcastStream; use crate::database::TransactionOps; use crate::AppState; // We a borrowed type rather than an owned type. // So redefine it. Not sure if there is a better way. #[derive(Debug, Serialize)] #[serde(rename_all = "snake_case")] enum ServerEvent<'a> { /// A message from a joined room. Msg(&'a SignedChatMsg), /// The receiver is too slow to receive and some events and are dropped. // FIXME: Should we indefinitely buffer them or just disconnect the client instead? Lagged, } #[serde_inline_default] #[derive(Debug, Clone, Deserialize)] #[serde(default, deny_unknown_fields)] pub struct Config { #[serde(deserialize_with = "de_duration_sec")] pub auth_timeout_sec: Duration, #[serde(deserialize_with = "de_duration_sec")] pub send_timeout_sec: Duration, pub event_queue_len: usize, } impl Default for Config { fn default() -> Self { Self { auth_timeout_sec: Duration::from_secs(15), send_timeout_sec: Duration::from_secs(15), event_queue_len: 1024, } } } fn de_duration_sec<'de, D: de::Deserializer<'de>>(de: D) -> Result { ::deserialize(de).map(Duration::from_secs) } #[derive(Debug, Default)] pub struct State { pub user_listeners: Mutex>, } #[derive(Debug)] pub struct StreamEnded; impl fmt::Display for StreamEnded { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("stream unexpectedly ended") } } impl std::error::Error for StreamEnded {} struct WsSenderWrapper<'ws, 'c> { inner: SplitSink<&'ws mut WebSocket, Message>, config: &'c Config, } impl WsSenderWrapper<'_, '_> { async fn send(&mut self, msg: &ServerEvent<'_>) -> Result<()> { let data = serde_json::to_string(&msg).expect("serialization cannot fail"); let fut = tokio::time::timeout( self.config.send_timeout_sec, self.inner.send(Message::Text(data)), ); match fut.await { Ok(Ok(())) => Ok(()), Ok(Err(_send_err)) => Err(StreamEnded.into()), Err(_elapsed) => bail!("send timeout"), } } } type UserEventSender = broadcast::Sender>; #[derive(Debug)] struct UserEventReceiver { rx: BroadcastStream>, st: Arc, uid: i64, } impl Drop for UserEventReceiver { fn drop(&mut self) { tracing::debug!(%self.uid, "user disconnected"); let mut map = self.st.event.user_listeners.lock(); if let Some(tx) = map.get_mut(&self.uid) { if tx.receiver_count() == 1 { map.remove(&self.uid); } } } } impl Stream for UserEventReceiver { type Item = Result, BroadcastStreamRecvError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.rx.poll_next_unpin(cx) } } pub async fn handle_ws(st: Arc, ws: &mut WebSocket) -> Result { let config = &st.config.ws; let (ws_tx, ws_rx) = ws.split(); let mut ws_rx = ws_rx.map(|ret| match ret { Ok(Message::Text(data)) => Ok(data), Ok(Message::Close(_)) | Err(_) => Err(StreamEnded.into()), _ => bail!("unexpected message type"), }); let mut ws_tx = WsSenderWrapper { inner: ws_tx, config, }; let uid = { let payload = tokio::time::timeout(config.auth_timeout_sec, ws_rx.next()) .await .context("authentication timeout")? .ok_or(StreamEnded)??; let auth = serde_json::from_str::>(&payload)?; st.verify_signed_data(&auth)?; let (uid, _) = st.db.with_read(|txn| txn.get_user(&auth.signee.user))?; uid }; tracing::debug!(%uid, "user connected"); let event_rx = { let rx = match st.event.user_listeners.lock().entry(uid) { Entry::Occupied(ent) => ent.get().subscribe(), Entry::Vacant(ent) => { let (tx, rx) = broadcast::channel(config.event_queue_len); ent.insert(tx); rx } }; UserEventReceiver { rx: rx.into(), st: st.clone(), uid, } }; let mut stream = stream_select!(ws_rx.map(Either::Left), event_rx.map(Either::Right)); loop { match stream.next().await.ok_or(StreamEnded)? { Either::Left(msg) => match serde_json::from_str::(&msg?)? {}, Either::Right(ret) => { let msg = match &ret { Ok(chat) => ServerEvent::Msg(chat), Err(BroadcastStreamRecvError::Lagged(_)) => ServerEvent::Lagged, }; // TODO: Concurrent send. ws_tx.send(&msg).await?; } } } }