diff --git a/Cargo.lock b/Cargo.lock index b3dacc4..47f759c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -130,6 +130,7 @@ checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" dependencies = [ "async-trait", "axum-core", + "base64 0.21.7", "bytes", "futures-util", "http", @@ -148,8 +149,10 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper 1.0.1", "tokio", + "tokio-tungstenite", "tower", "tower-layer", "tower-service", @@ -214,6 +217,12 @@ dependencies = [ "rustc-demangle", ] +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64" version = "0.22.1" @@ -479,6 +488,12 @@ dependencies = [ "syn", ] +[[package]] +name = "data-encoding" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" + [[package]] name = "der" version = "0.7.9" @@ -650,6 +665,7 @@ checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ "futures-core", "futures-macro", + "futures-sink", "futures-task", "pin-project-lite", "pin-utils", @@ -1246,7 +1262,7 @@ version = "0.12.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "encoding_rs", "futures-core", @@ -1360,7 +1376,7 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" dependencies = [ - "base64", + "base64 0.22.1", "rustls-pki-types", ] @@ -1530,6 +1546,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.8" @@ -1678,6 +1705,26 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "thiserror" +version = "1.0.63" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.63" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.8" @@ -1763,6 +1810,18 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "tokio-tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.11" @@ -1884,6 +1943,25 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "typenum" version = "1.17.0" @@ -1928,6 +2006,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8-width" version = "0.1.7" diff --git a/blahd/Cargo.toml b/blahd/Cargo.toml index 5ef3566..8f67a25 100644 --- a/blahd/Cargo.toml +++ b/blahd/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] anyhow = "1" -axum = { version = "0.7", features = ["tokio"] } +axum = { version = "0.7", features = ["ws"] } axum-extra = "0.9" clap = { version = "4", features = ["derive"] } ed25519-dalek = "2" @@ -17,7 +17,7 @@ sd-notify = "0.4" serde = { version = "1", features = ["derive"] } serde-aux = "4" serde_json = "1" -tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync"] } +tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync", "time"] } tokio-stream = { version = "0.1", features = ["sync"] } tower-http = { version = "0.5", features = ["cors", "limit"] } tracing = "0.1" diff --git a/blahd/config.example.toml b/blahd/config.example.toml index 4ec676b..fff3b72 100644 --- a/blahd/config.example.toml +++ b/blahd/config.example.toml @@ -26,8 +26,16 @@ max_page_len = 1024 # Maximum request body length in bytes. max_request_len = 4096 -# Maximum length of a single event queue. -event_queue_len = 1024 - # The maximum timestamp tolerance in seconds for request validation. timestamp_tolerance_secs = 90 + +# The max waiting time for the first authentication message for websocket. +ws_auth_timeout_sec = 15 + +# The max waiting time for outgoing message to be received for websocket. +ws_send_timeout_sec = 15 + +# Maximum number of pending events a single user can have. +# If events overflow the pending buffer, older events will be dropped and +# client will be notified. +ws_event_queue_len = 1024 diff --git a/blahd/docs/webapi.yaml b/blahd/docs/webapi.yaml index c666bf9..35ee808 100644 --- a/blahd/docs/webapi.yaml +++ b/blahd/docs/webapi.yaml @@ -4,6 +4,18 @@ info: version: 0.0.1 paths: + /ws: + get: + summary: WebSocket endpoint. + description: | + Once connection, client must send a JSON text message of type + `WithSig` for authentication. + If server does not close it immediately, it means success. + + Then server will send JSON text messages on events that user are + interested in (eg. chat from joined rooms). + The message has type `Outgoing` in `blahd/src/ws.rs`. + /room: get: summary: Get room metadata @@ -133,34 +145,6 @@ paths: application/json: $ref: '#/components/schemas/ApiError' - /room/{ruuid}/event: - get: - summary: Get an event stream for future new items. - description: | - This is a temporary interface, before a better notification system - (post notifications? websocket?) is implemented. - headers: - Authorization: - description: Proof of membership for private rooms. - required: false - schema: - $ret: WithSig - responses: - 200: - content: - text/event-stream: - x-description: An event stream, each event is a JSON with type WithSig - 400: - description: Body is invalid or fails the verification. - content: - application/json: - $ref: '#/components/schemas/ApiError' - 404: - description: Room not found. - content: - application/json: - $ref: '#/components/schemas/ApiError' - /room/{ruuid}/admin: post: summary: Room management diff --git a/blahd/src/config.rs b/blahd/src/config.rs index f6cbbeb..a4578a3 100644 --- a/blahd/src/config.rs +++ b/blahd/src/config.rs @@ -1,7 +1,8 @@ use std::path::PathBuf; +use std::time::Duration; use anyhow::{ensure, Result}; -use serde::Deserialize; +use serde::{Deserialize, Deserializer}; use serde_inline_default::serde_inline_default; #[derive(Debug, Clone, Deserialize)] @@ -30,11 +31,22 @@ pub struct ServerConfig { pub max_page_len: usize, #[serde_inline_default(4096)] // 4KiB pub max_request_len: usize, - #[serde_inline_default(1024)] - pub event_queue_len: usize, #[serde_inline_default(90)] pub timestamp_tolerance_secs: u64, + + #[serde_inline_default(Duration::from_secs(15))] + #[serde(deserialize_with = "de_duration_sec")] + pub ws_auth_timeout_sec: Duration, + #[serde_inline_default(Duration::from_secs(15))] + #[serde(deserialize_with = "de_duration_sec")] + pub ws_send_timeout_sec: Duration, + #[serde_inline_default(1024)] + pub ws_event_queue_len: usize, +} + +fn de_duration_sec<'de, D: Deserializer<'de>>(de: D) -> Result { + ::deserialize(de).map(Duration::from_secs) } impl Config { diff --git a/blahd/src/event.rs b/blahd/src/event.rs new file mode 100644 index 0000000..4911a5f --- /dev/null +++ b/blahd/src/event.rs @@ -0,0 +1,166 @@ +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::convert::Infallible; +use std::fmt; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; + +use anyhow::{bail, Context as _, Result}; +use axum::extract::ws::{Message, WebSocket}; +use blah::types::{AuthPayload, ChatItem, WithSig}; +use futures_util::future::Either; +use futures_util::stream::SplitSink; +use futures_util::{stream_select, SinkExt as _, Stream, StreamExt}; +use rusqlite::{params, OptionalExtension}; +use serde::{Deserialize, Serialize}; +use tokio::sync::broadcast; +use tokio_stream::wrappers::errors::BroadcastStreamRecvError; +use tokio_stream::wrappers::BroadcastStream; + +use crate::AppState; + +#[derive(Debug, Deserialize)] +pub enum Incoming {} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum Outgoing<'a> { + /// A chat message from a joined room. + Chat(&'a ChatItem), + /// The receiver is too slow to receive and some events and are dropped. + // FIXME: Should we indefinitely buffer them or just disconnect the client instead? + Lagged, +} + +#[derive(Debug, Default)] +pub struct State { + pub user_listeners: Mutex>, +} + +#[derive(Debug)] +pub struct StreamEnded; + +impl fmt::Display for StreamEnded { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("stream unexpectedly ended") + } +} + +impl std::error::Error for StreamEnded {} + +struct WsSenderWrapper<'ws, 'c> { + inner: SplitSink<&'ws mut WebSocket, Message>, + config: &'c crate::config::Config, +} + +impl WsSenderWrapper<'_, '_> { + async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> { + let data = serde_json::to_string(&msg).expect("serialization cannot fail"); + let fut = tokio::time::timeout( + self.config.server.ws_send_timeout_sec, + self.inner.send(Message::Text(data)), + ); + match fut.await { + Ok(Ok(())) => Ok(()), + Ok(Err(_send_err)) => Err(StreamEnded.into()), + Err(_elapsed) => bail!("send timeout"), + } + } +} + +type UserEventSender = broadcast::Sender>; + +#[derive(Debug)] +struct UserEventReceiver { + rx: BroadcastStream>, + st: Arc, + uid: u64, +} + +impl Drop for UserEventReceiver { + fn drop(&mut self) { + tracing::debug!(%self.uid, "user disconnected"); + if let Ok(mut map) = self.st.event.user_listeners.lock() { + if let Some(tx) = map.get_mut(&self.uid) { + if tx.receiver_count() == 1 { + map.remove(&self.uid); + } + } + } + } +} + +impl Stream for UserEventReceiver { + type Item = Result, BroadcastStreamRecvError>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.rx.poll_next_unpin(cx) + } +} + +pub async fn handle_ws(st: Arc, ws: &mut WebSocket) -> Result { + let (ws_tx, ws_rx) = ws.split(); + let mut ws_rx = ws_rx.map(|ret| ret.and_then(|msg| msg.into_text()).map_err(|_| StreamEnded)); + let mut ws_tx = WsSenderWrapper { + inner: ws_tx, + config: &st.config, + }; + + let uid = { + let payload = tokio::time::timeout(st.config.server.ws_auth_timeout_sec, ws_rx.next()) + .await + .context("authentication timeout")? + .ok_or(StreamEnded)??; + let auth = serde_json::from_str::>(&payload)?; + st.verify_signed_data(&auth)?; + + st.conn + .lock() + .unwrap() + .query_row( + r" + SELECT `uid` + FROM `user` + WHERE `userkey` = ? + ", + params![auth.signee.user], + |row| row.get::<_, u64>(0), + ) + .optional()? + .context("invalid user")? + }; + + tracing::debug!(%uid, "user connected"); + + let event_rx = { + let rx = match st.event.user_listeners.lock().unwrap().entry(uid) { + Entry::Occupied(ent) => ent.get().subscribe(), + Entry::Vacant(ent) => { + let (tx, rx) = broadcast::channel(st.config.server.ws_event_queue_len); + ent.insert(tx); + rx + } + }; + UserEventReceiver { + rx: rx.into(), + st: st.clone(), + uid, + } + }; + + let mut stream = stream_select!(ws_rx.map(Either::Left), event_rx.map(Either::Right)); + loop { + match stream.next().await.ok_or(StreamEnded)? { + Either::Left(msg) => match serde_json::from_str::(&msg?)? {}, + Either::Right(ret) => { + let msg = match &ret { + Ok(chat) => Outgoing::Chat(chat), + Err(BroadcastStreamRecvError::Lagged(_)) => Outgoing::Lagged, + }; + // TODO: Concurrent send. + ws_tx.send(&msg).await?; + } + } + } +} diff --git a/blahd/src/main.rs b/blahd/src/main.rs index 9ae898f..fb158dc 100644 --- a/blahd/src/main.rs +++ b/blahd/src/main.rs @@ -1,14 +1,12 @@ -use std::collections::hash_map::Entry; -use std::collections::HashMap; -use std::convert::Infallible; use std::path::PathBuf; use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime}; use anyhow::{Context, Result}; -use axum::extract::{Path, Query, State}; +use axum::extract::ws; +use axum::extract::{Path, Query, State, WebSocketUpgrade}; use axum::http::{header, StatusCode}; -use axum::response::{sse, IntoResponse}; +use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; use axum::{Json, Router}; use axum_extra::extract::WithRejection; @@ -21,14 +19,13 @@ use ed25519_dalek::SIGNATURE_LENGTH; use middleware::{ApiError, OptionalAuth, SignedJson}; use rusqlite::{named_params, params, OptionalExtension, Row}; use serde::{Deserialize, Serialize}; -use tokio::sync::broadcast; -use tokio_stream::StreamExt; use utils::ExpiringSet; use uuid::Uuid; #[macro_use] mod middleware; mod config; +mod event; mod utils; /// Blah Chat Server @@ -84,8 +81,8 @@ fn main() -> Result<()> { #[derive(Debug)] struct AppState { conn: Mutex, - room_listeners: Mutex>>>, used_nonces: Mutex>, + event: event::State, config: Config, } @@ -101,10 +98,10 @@ impl AppState { .context("failed to initialize database")?; Ok(Self { conn: Mutex::new(conn), - room_listeners: Mutex::new(HashMap::new()), used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs( config.server.timestamp_tolerance_secs, ))), + event: event::State::default(), config, }) @@ -152,11 +149,11 @@ async fn main_async(st: AppState) -> Result<()> { let st = Arc::new(st); let app = Router::new() + .route("/ws", get(handle_ws)) .route("/room/create", post(room_create)) .route("/room/:ruuid", get(room_get_metadata)) // NB. Sync with `feed_url` and `next_url` generation. .route("/room/:ruuid/feed.json", get(room_get_feed)) - .route("/room/:ruuid/event", get(room_event)) .route("/room/:ruuid/item", get(room_get_item).post(room_post_item)) .route("/room/:ruuid/admin", post(room_admin)) .with_state(st.clone()) @@ -179,6 +176,24 @@ async fn main_async(st: AppState) -> Result<()> { Ok(()) } +async fn handle_ws(State(st): ArcState, ws: WebSocketUpgrade) -> Response { + ws.on_upgrade(move |mut socket| async move { + match event::handle_ws(st, &mut socket).await { + Ok(never) => match never {}, + Err(err) if err.is::() => {} + Err(err) => { + tracing::debug!(%err, "ws error"); + let _: Result<_, _> = socket + .send(ws::Message::Close(Some(ws::CloseFrame { + code: ws::close_code::ERROR, + reason: err.to_string().into(), + }))) + .await; + } + } + }) +} + async fn room_create( st: ArcState, SignedJson(params): SignedJson, @@ -505,7 +520,7 @@ async fn room_post_item( )); } - let (rid, cid) = { + let (cid, txs) = { let conn = st.conn.lock().unwrap(); let Some((rid, uid)) = conn .query_row( @@ -550,73 +565,38 @@ async fn room_post_item( }, |row| row.get::<_, u64>(0), )?; - (rid, cid) + + // FIXME: Optimize this to not traverses over all members. + let mut stmt = conn.prepare( + r" + SELECT `uid` + FROM `room_member` + WHERE `rid` = :rid + ", + )?; + let listeners = st.event.user_listeners.lock().unwrap(); + let txs = stmt + .query_map(params![rid], |row| row.get::<_, u64>(0))? + .filter_map(|ret| match ret { + Ok(uid) => listeners.get(&uid).map(|tx| Ok(tx.clone())), + Err(err) => Some(Err(err)), + }) + .collect::, _>>()?; + + (cid, txs) }; - let mut listeners = st.room_listeners.lock().unwrap(); - if let Some(tx) = listeners.get(&rid) { - if tx.send(Arc::new(chat)).is_err() { - // Clean up because all receivers died. - listeners.remove(&rid); - } - } - - Ok(Json(cid)) -} - -async fn room_event( - st: ArcState, - Path(ruuid): Path, - // TODO: There is actually no way to add headers via `EventSource` in client side. - // But this API is kinda temporary and need a better replacement anyway. - // So just only support public room for now. - OptionalAuth(user): OptionalAuth, -) -> Result { - let rid = get_room_if_readable(&st.conn.lock().unwrap(), ruuid, user.as_ref(), |row| { - row.get::<_, u64>(0) - })?; - - let rx = match st.room_listeners.lock().unwrap().entry(rid) { - Entry::Occupied(ent) => ent.get().subscribe(), - Entry::Vacant(ent) => { - let (tx, rx) = broadcast::channel(st.config.server.event_queue_len); - ent.insert(tx); - rx - } - }; - - // Do clean up when this stream is closed. - struct CleanOnDrop { - st: Arc, - rid: u64, - } - impl Drop for CleanOnDrop { - fn drop(&mut self) { - if let Ok(mut listeners) = self.st.room_listeners.lock() { - if let Some(tx) = listeners.get(&self.rid) { - if tx.receiver_count() == 0 { - listeners.remove(&self.rid); - } - } + if !txs.is_empty() { + tracing::debug!("broadcasting event to {} clients", txs.len()); + let chat = Arc::new(chat); + for tx in txs { + if let Err(err) = tx.send(chat.clone()) { + tracing::debug!(%err, "failed to broadcast event"); } } } - let _guard = CleanOnDrop { st: st.0, rid }; - - let stream = tokio_stream::wrappers::BroadcastStream::new(rx).filter_map(move |ret| { - let _guard = &_guard; - // On stream closure or lagging, close the current stream so client can retry. - let item = ret.ok()?; - let evt = sse::Event::default() - .json_data(&*item) - .expect("serialization cannot fail"); - Some(Ok::<_, Infallible>(evt)) - }); - // NB. Send an empty event immediately to trigger client ready event. - let first_event = sse::Event::default().comment(""); - let stream = futures_util::stream::iter(Some(Ok(first_event))).chain(stream); - Ok(sse::Sse::new(stream).keep_alive(sse::KeepAlive::default())) + Ok(Json(cid)) } async fn room_admin( diff --git a/blahd/src/middleware.rs b/blahd/src/middleware.rs index b1ecd3a..d9a49ff 100644 --- a/blahd/src/middleware.rs +++ b/blahd/src/middleware.rs @@ -1,3 +1,4 @@ +use std::fmt; use std::sync::Arc; use axum::extract::rejection::{JsonRejection, PathRejection, QueryRejection}; @@ -22,6 +23,14 @@ pub struct ApiError { pub message: String, } +impl fmt::Display for ApiError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.message) + } +} + +impl std::error::Error for ApiError {} + macro_rules! error_response { ($status:expr, $code:literal, $msg:literal $(, $msg_args:expr)* $(,)?) => { $crate::middleware::ApiError { diff --git a/pages/main.js b/pages/main.js index 55db477..15db9e0 100644 --- a/pages/main.js +++ b/pages/main.js @@ -7,7 +7,7 @@ const joinRoomBtn = document.querySelector('#join-room'); let roomUrl = ''; let roomUuid = null; -let feed = null; +let ws = null; let keypair = null; let defaultConfig = {}; @@ -166,16 +166,13 @@ function escapeHtml(text) { } async function connectRoom(url) { - if (url === '' || url == roomUrl) return; + if (url === '' || url == roomUrl || keypair === null) return; const match = url.match(/^https?:\/\/.*\/([a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12})\/?/); if (match === null) { log('invalid room url'); return; } - if (feed !== null) { - feed.close(); - } roomUrl = url; roomUuid = match[1]; @@ -210,18 +207,40 @@ async function connectRoom(url) { // TODO: There is a time window where events would be lost. - feed = new EventSource(`${url}/event`); - feed.onopen = (_) => { + await connectWs(); +} + +async function connectWs() { + if (ws !== null) { + ws.close(); + } + const wsUrl = new URL(roomUrl); + wsUrl.protocol = wsUrl.protocol == 'http:' ? 'ws:' : 'wss:'; + wsUrl.pathname = '/ws'; + ws = new WebSocket(wsUrl); + ws.onopen = async (_) => { + const auth = await signData({ typ: 'auth' }); + await ws.send(auth); log('listening on events'); } - feed.onerror = (e) => { + ws.onclose = (e) => { console.error(e); - log('event listener error'); + log(`ws closed (code=${e.code}): ${e.reason}`); }; - feed.onmessage = async (e) => { - console.log('feed event', e.data); - const chat = JSON.parse(e.data); - showChatMsg(chat); + ws.onerror = (e) => { + console.error(e); + log(`ws error: ${e.error}`); + }; + ws.onmessage = async (e) => { + console.log('ws event', e.data); + const msg = JSON.parse(e.data); + if (msg.chat !== undefined) { + showChatMsg(msg.chat); + } else if (msg.lagged !== undefined) { + log('some events are dropped because of queue overflow') + } else { + log(`unknown ws message: ${e.data}`); + } }; } @@ -236,6 +255,7 @@ async function joinRoom() { user: await getUserPubkey(), }); log('joined room'); + await connectWs(); } catch (e) { console.error(e); log(`failed to join room: ${e}`);