Switch from event stream to WebSocket for events

This commit is contained in:
oxalica 2024-09-02 21:33:05 -04:00
parent 5fadffef4d
commit 77216aa0f8
9 changed files with 386 additions and 123 deletions

View file

@ -1,7 +1,8 @@
use std::path::PathBuf;
use std::time::Duration;
use anyhow::{ensure, Result};
use serde::Deserialize;
use serde::{Deserialize, Deserializer};
use serde_inline_default::serde_inline_default;
#[derive(Debug, Clone, Deserialize)]
@ -30,11 +31,22 @@ pub struct ServerConfig {
pub max_page_len: usize,
#[serde_inline_default(4096)] // 4KiB
pub max_request_len: usize,
#[serde_inline_default(1024)]
pub event_queue_len: usize,
#[serde_inline_default(90)]
pub timestamp_tolerance_secs: u64,
#[serde_inline_default(Duration::from_secs(15))]
#[serde(deserialize_with = "de_duration_sec")]
pub ws_auth_timeout_sec: Duration,
#[serde_inline_default(Duration::from_secs(15))]
#[serde(deserialize_with = "de_duration_sec")]
pub ws_send_timeout_sec: Duration,
#[serde_inline_default(1024)]
pub ws_event_queue_len: usize,
}
fn de_duration_sec<'de, D: Deserializer<'de>>(de: D) -> Result<Duration, D::Error> {
<u64>::deserialize(de).map(Duration::from_secs)
}
impl Config {

166
blahd/src/event.rs Normal file
View file

@ -0,0 +1,166 @@
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::convert::Infallible;
use std::fmt;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use anyhow::{bail, Context as _, Result};
use axum::extract::ws::{Message, WebSocket};
use blah::types::{AuthPayload, ChatItem, WithSig};
use futures_util::future::Either;
use futures_util::stream::SplitSink;
use futures_util::{stream_select, SinkExt as _, Stream, StreamExt};
use rusqlite::{params, OptionalExtension};
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast;
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tokio_stream::wrappers::BroadcastStream;
use crate::AppState;
#[derive(Debug, Deserialize)]
pub enum Incoming {}
#[derive(Debug, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum Outgoing<'a> {
/// A chat message from a joined room.
Chat(&'a ChatItem),
/// The receiver is too slow to receive and some events and are dropped.
// FIXME: Should we indefinitely buffer them or just disconnect the client instead?
Lagged,
}
#[derive(Debug, Default)]
pub struct State {
pub user_listeners: Mutex<HashMap<u64, UserEventSender>>,
}
#[derive(Debug)]
pub struct StreamEnded;
impl fmt::Display for StreamEnded {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("stream unexpectedly ended")
}
}
impl std::error::Error for StreamEnded {}
struct WsSenderWrapper<'ws, 'c> {
inner: SplitSink<&'ws mut WebSocket, Message>,
config: &'c crate::config::Config,
}
impl WsSenderWrapper<'_, '_> {
async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> {
let data = serde_json::to_string(&msg).expect("serialization cannot fail");
let fut = tokio::time::timeout(
self.config.server.ws_send_timeout_sec,
self.inner.send(Message::Text(data)),
);
match fut.await {
Ok(Ok(())) => Ok(()),
Ok(Err(_send_err)) => Err(StreamEnded.into()),
Err(_elapsed) => bail!("send timeout"),
}
}
}
type UserEventSender = broadcast::Sender<Arc<ChatItem>>;
#[derive(Debug)]
struct UserEventReceiver {
rx: BroadcastStream<Arc<ChatItem>>,
st: Arc<AppState>,
uid: u64,
}
impl Drop for UserEventReceiver {
fn drop(&mut self) {
tracing::debug!(%self.uid, "user disconnected");
if let Ok(mut map) = self.st.event.user_listeners.lock() {
if let Some(tx) = map.get_mut(&self.uid) {
if tx.receiver_count() == 1 {
map.remove(&self.uid);
}
}
}
}
}
impl Stream for UserEventReceiver {
type Item = Result<Arc<ChatItem>, BroadcastStreamRecvError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx.poll_next_unpin(cx)
}
}
pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallible> {
let (ws_tx, ws_rx) = ws.split();
let mut ws_rx = ws_rx.map(|ret| ret.and_then(|msg| msg.into_text()).map_err(|_| StreamEnded));
let mut ws_tx = WsSenderWrapper {
inner: ws_tx,
config: &st.config,
};
let uid = {
let payload = tokio::time::timeout(st.config.server.ws_auth_timeout_sec, ws_rx.next())
.await
.context("authentication timeout")?
.ok_or(StreamEnded)??;
let auth = serde_json::from_str::<WithSig<AuthPayload>>(&payload)?;
st.verify_signed_data(&auth)?;
st.conn
.lock()
.unwrap()
.query_row(
r"
SELECT `uid`
FROM `user`
WHERE `userkey` = ?
",
params![auth.signee.user],
|row| row.get::<_, u64>(0),
)
.optional()?
.context("invalid user")?
};
tracing::debug!(%uid, "user connected");
let event_rx = {
let rx = match st.event.user_listeners.lock().unwrap().entry(uid) {
Entry::Occupied(ent) => ent.get().subscribe(),
Entry::Vacant(ent) => {
let (tx, rx) = broadcast::channel(st.config.server.ws_event_queue_len);
ent.insert(tx);
rx
}
};
UserEventReceiver {
rx: rx.into(),
st: st.clone(),
uid,
}
};
let mut stream = stream_select!(ws_rx.map(Either::Left), event_rx.map(Either::Right));
loop {
match stream.next().await.ok_or(StreamEnded)? {
Either::Left(msg) => match serde_json::from_str::<Incoming>(&msg?)? {},
Either::Right(ret) => {
let msg = match &ret {
Ok(chat) => Outgoing::Chat(chat),
Err(BroadcastStreamRecvError::Lagged(_)) => Outgoing::Lagged,
};
// TODO: Concurrent send.
ws_tx.send(&msg).await?;
}
}
}
}

View file

@ -1,14 +1,12 @@
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::convert::Infallible;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};
use anyhow::{Context, Result};
use axum::extract::{Path, Query, State};
use axum::extract::ws;
use axum::extract::{Path, Query, State, WebSocketUpgrade};
use axum::http::{header, StatusCode};
use axum::response::{sse, IntoResponse};
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{Json, Router};
use axum_extra::extract::WithRejection;
@ -21,14 +19,13 @@ use ed25519_dalek::SIGNATURE_LENGTH;
use middleware::{ApiError, OptionalAuth, SignedJson};
use rusqlite::{named_params, params, OptionalExtension, Row};
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast;
use tokio_stream::StreamExt;
use utils::ExpiringSet;
use uuid::Uuid;
#[macro_use]
mod middleware;
mod config;
mod event;
mod utils;
/// Blah Chat Server
@ -84,8 +81,8 @@ fn main() -> Result<()> {
#[derive(Debug)]
struct AppState {
conn: Mutex<rusqlite::Connection>,
room_listeners: Mutex<HashMap<u64, broadcast::Sender<Arc<ChatItem>>>>,
used_nonces: Mutex<ExpiringSet<u32>>,
event: event::State,
config: Config,
}
@ -101,10 +98,10 @@ impl AppState {
.context("failed to initialize database")?;
Ok(Self {
conn: Mutex::new(conn),
room_listeners: Mutex::new(HashMap::new()),
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(
config.server.timestamp_tolerance_secs,
))),
event: event::State::default(),
config,
})
@ -152,11 +149,11 @@ async fn main_async(st: AppState) -> Result<()> {
let st = Arc::new(st);
let app = Router::new()
.route("/ws", get(handle_ws))
.route("/room/create", post(room_create))
.route("/room/:ruuid", get(room_get_metadata))
// NB. Sync with `feed_url` and `next_url` generation.
.route("/room/:ruuid/feed.json", get(room_get_feed))
.route("/room/:ruuid/event", get(room_event))
.route("/room/:ruuid/item", get(room_get_item).post(room_post_item))
.route("/room/:ruuid/admin", post(room_admin))
.with_state(st.clone())
@ -179,6 +176,24 @@ async fn main_async(st: AppState) -> Result<()> {
Ok(())
}
async fn handle_ws(State(st): ArcState, ws: WebSocketUpgrade) -> Response {
ws.on_upgrade(move |mut socket| async move {
match event::handle_ws(st, &mut socket).await {
Ok(never) => match never {},
Err(err) if err.is::<event::StreamEnded>() => {}
Err(err) => {
tracing::debug!(%err, "ws error");
let _: Result<_, _> = socket
.send(ws::Message::Close(Some(ws::CloseFrame {
code: ws::close_code::ERROR,
reason: err.to_string().into(),
})))
.await;
}
}
})
}
async fn room_create(
st: ArcState,
SignedJson(params): SignedJson<CreateRoomPayload>,
@ -505,7 +520,7 @@ async fn room_post_item(
));
}
let (rid, cid) = {
let (cid, txs) = {
let conn = st.conn.lock().unwrap();
let Some((rid, uid)) = conn
.query_row(
@ -550,73 +565,38 @@ async fn room_post_item(
},
|row| row.get::<_, u64>(0),
)?;
(rid, cid)
// FIXME: Optimize this to not traverses over all members.
let mut stmt = conn.prepare(
r"
SELECT `uid`
FROM `room_member`
WHERE `rid` = :rid
",
)?;
let listeners = st.event.user_listeners.lock().unwrap();
let txs = stmt
.query_map(params![rid], |row| row.get::<_, u64>(0))?
.filter_map(|ret| match ret {
Ok(uid) => listeners.get(&uid).map(|tx| Ok(tx.clone())),
Err(err) => Some(Err(err)),
})
.collect::<Result<Vec<_>, _>>()?;
(cid, txs)
};
let mut listeners = st.room_listeners.lock().unwrap();
if let Some(tx) = listeners.get(&rid) {
if tx.send(Arc::new(chat)).is_err() {
// Clean up because all receivers died.
listeners.remove(&rid);
}
}
Ok(Json(cid))
}
async fn room_event(
st: ArcState,
Path(ruuid): Path<Uuid>,
// TODO: There is actually no way to add headers via `EventSource` in client side.
// But this API is kinda temporary and need a better replacement anyway.
// So just only support public room for now.
OptionalAuth(user): OptionalAuth,
) -> Result<impl IntoResponse, ApiError> {
let rid = get_room_if_readable(&st.conn.lock().unwrap(), ruuid, user.as_ref(), |row| {
row.get::<_, u64>(0)
})?;
let rx = match st.room_listeners.lock().unwrap().entry(rid) {
Entry::Occupied(ent) => ent.get().subscribe(),
Entry::Vacant(ent) => {
let (tx, rx) = broadcast::channel(st.config.server.event_queue_len);
ent.insert(tx);
rx
}
};
// Do clean up when this stream is closed.
struct CleanOnDrop {
st: Arc<AppState>,
rid: u64,
}
impl Drop for CleanOnDrop {
fn drop(&mut self) {
if let Ok(mut listeners) = self.st.room_listeners.lock() {
if let Some(tx) = listeners.get(&self.rid) {
if tx.receiver_count() == 0 {
listeners.remove(&self.rid);
}
}
if !txs.is_empty() {
tracing::debug!("broadcasting event to {} clients", txs.len());
let chat = Arc::new(chat);
for tx in txs {
if let Err(err) = tx.send(chat.clone()) {
tracing::debug!(%err, "failed to broadcast event");
}
}
}
let _guard = CleanOnDrop { st: st.0, rid };
let stream = tokio_stream::wrappers::BroadcastStream::new(rx).filter_map(move |ret| {
let _guard = &_guard;
// On stream closure or lagging, close the current stream so client can retry.
let item = ret.ok()?;
let evt = sse::Event::default()
.json_data(&*item)
.expect("serialization cannot fail");
Some(Ok::<_, Infallible>(evt))
});
// NB. Send an empty event immediately to trigger client ready event.
let first_event = sse::Event::default().comment("");
let stream = futures_util::stream::iter(Some(Ok(first_event))).chain(stream);
Ok(sse::Sse::new(stream).keep_alive(sse::KeepAlive::default()))
Ok(Json(cid))
}
async fn room_admin(

View file

@ -1,3 +1,4 @@
use std::fmt;
use std::sync::Arc;
use axum::extract::rejection::{JsonRejection, PathRejection, QueryRejection};
@ -22,6 +23,14 @@ pub struct ApiError {
pub message: String,
}
impl fmt::Display for ApiError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.message)
}
}
impl std::error::Error for ApiError {}
macro_rules! error_response {
($status:expr, $code:literal, $msg:literal $(, $msg_args:expr)* $(,)?) => {
$crate::middleware::ApiError {