Switch from event stream to WebSocket for events

This commit is contained in:
oxalica 2024-09-02 21:33:05 -04:00
parent 5fadffef4d
commit 77216aa0f8
9 changed files with 386 additions and 123 deletions

88
Cargo.lock generated
View file

@ -130,6 +130,7 @@ checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
"base64 0.21.7",
"bytes", "bytes",
"futures-util", "futures-util",
"http", "http",
@ -148,8 +149,10 @@ dependencies = [
"serde_json", "serde_json",
"serde_path_to_error", "serde_path_to_error",
"serde_urlencoded", "serde_urlencoded",
"sha1",
"sync_wrapper 1.0.1", "sync_wrapper 1.0.1",
"tokio", "tokio",
"tokio-tungstenite",
"tower", "tower",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
@ -214,6 +217,12 @@ dependencies = [
"rustc-demangle", "rustc-demangle",
] ]
[[package]]
name = "base64"
version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]] [[package]]
name = "base64" name = "base64"
version = "0.22.1" version = "0.22.1"
@ -479,6 +488,12 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "data-encoding"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2"
[[package]] [[package]]
name = "der" name = "der"
version = "0.7.9" version = "0.7.9"
@ -650,6 +665,7 @@ checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-macro", "futures-macro",
"futures-sink",
"futures-task", "futures-task",
"pin-project-lite", "pin-project-lite",
"pin-utils", "pin-utils",
@ -1246,7 +1262,7 @@ version = "0.12.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63" checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63"
dependencies = [ dependencies = [
"base64", "base64 0.22.1",
"bytes", "bytes",
"encoding_rs", "encoding_rs",
"futures-core", "futures-core",
@ -1360,7 +1376,7 @@ version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425"
dependencies = [ dependencies = [
"base64", "base64 0.22.1",
"rustls-pki-types", "rustls-pki-types",
] ]
@ -1530,6 +1546,17 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "sha1"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]] [[package]]
name = "sha2" name = "sha2"
version = "0.10.8" version = "0.10.8"
@ -1678,6 +1705,26 @@ dependencies = [
"windows-sys 0.59.0", "windows-sys 0.59.0",
] ]
[[package]]
name = "thiserror"
version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "thread_local" name = "thread_local"
version = "1.1.8" version = "1.1.8"
@ -1763,6 +1810,18 @@ dependencies = [
"tokio-util", "tokio-util",
] ]
[[package]]
name = "tokio-tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.7.11" version = "0.7.11"
@ -1884,6 +1943,25 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
dependencies = [
"byteorder",
"bytes",
"data-encoding",
"http",
"httparse",
"log",
"rand",
"sha1",
"thiserror",
"url",
"utf-8",
]
[[package]] [[package]]
name = "typenum" name = "typenum"
version = "1.17.0" version = "1.17.0"
@ -1928,6 +2006,12 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]] [[package]]
name = "utf8-width" name = "utf8-width"
version = "0.1.7" version = "0.1.7"

View file

@ -5,7 +5,7 @@ edition = "2021"
[dependencies] [dependencies]
anyhow = "1" anyhow = "1"
axum = { version = "0.7", features = ["tokio"] } axum = { version = "0.7", features = ["ws"] }
axum-extra = "0.9" axum-extra = "0.9"
clap = { version = "4", features = ["derive"] } clap = { version = "4", features = ["derive"] }
ed25519-dalek = "2" ed25519-dalek = "2"
@ -17,7 +17,7 @@ sd-notify = "0.4"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde-aux = "4" serde-aux = "4"
serde_json = "1" serde_json = "1"
tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync"] } tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync", "time"] }
tokio-stream = { version = "0.1", features = ["sync"] } tokio-stream = { version = "0.1", features = ["sync"] }
tower-http = { version = "0.5", features = ["cors", "limit"] } tower-http = { version = "0.5", features = ["cors", "limit"] }
tracing = "0.1" tracing = "0.1"

View file

@ -26,8 +26,16 @@ max_page_len = 1024
# Maximum request body length in bytes. # Maximum request body length in bytes.
max_request_len = 4096 max_request_len = 4096
# Maximum length of a single event queue.
event_queue_len = 1024
# The maximum timestamp tolerance in seconds for request validation. # The maximum timestamp tolerance in seconds for request validation.
timestamp_tolerance_secs = 90 timestamp_tolerance_secs = 90
# The max waiting time for the first authentication message for websocket.
ws_auth_timeout_sec = 15
# The max waiting time for outgoing message to be received for websocket.
ws_send_timeout_sec = 15
# Maximum number of pending events a single user can have.
# If events overflow the pending buffer, older events will be dropped and
# client will be notified.
ws_event_queue_len = 1024

View file

@ -4,6 +4,18 @@ info:
version: 0.0.1 version: 0.0.1
paths: paths:
/ws:
get:
summary: WebSocket endpoint.
description: |
Once connection, client must send a JSON text message of type
`WithSig<AuthPayload>` for authentication.
If server does not close it immediately, it means success.
Then server will send JSON text messages on events that user are
interested in (eg. chat from joined rooms).
The message has type `Outgoing` in `blahd/src/ws.rs`.
/room: /room:
get: get:
summary: Get room metadata summary: Get room metadata
@ -133,34 +145,6 @@ paths:
application/json: application/json:
$ref: '#/components/schemas/ApiError' $ref: '#/components/schemas/ApiError'
/room/{ruuid}/event:
get:
summary: Get an event stream for future new items.
description: |
This is a temporary interface, before a better notification system
(post notifications? websocket?) is implemented.
headers:
Authorization:
description: Proof of membership for private rooms.
required: false
schema:
$ret: WithSig<AuthPayload>
responses:
200:
content:
text/event-stream:
x-description: An event stream, each event is a JSON with type WithSig<ChatPayload>
400:
description: Body is invalid or fails the verification.
content:
application/json:
$ref: '#/components/schemas/ApiError'
404:
description: Room not found.
content:
application/json:
$ref: '#/components/schemas/ApiError'
/room/{ruuid}/admin: /room/{ruuid}/admin:
post: post:
summary: Room management summary: Room management

View file

@ -1,7 +1,8 @@
use std::path::PathBuf; use std::path::PathBuf;
use std::time::Duration;
use anyhow::{ensure, Result}; use anyhow::{ensure, Result};
use serde::Deserialize; use serde::{Deserialize, Deserializer};
use serde_inline_default::serde_inline_default; use serde_inline_default::serde_inline_default;
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
@ -30,11 +31,22 @@ pub struct ServerConfig {
pub max_page_len: usize, pub max_page_len: usize,
#[serde_inline_default(4096)] // 4KiB #[serde_inline_default(4096)] // 4KiB
pub max_request_len: usize, pub max_request_len: usize,
#[serde_inline_default(1024)]
pub event_queue_len: usize,
#[serde_inline_default(90)] #[serde_inline_default(90)]
pub timestamp_tolerance_secs: u64, pub timestamp_tolerance_secs: u64,
#[serde_inline_default(Duration::from_secs(15))]
#[serde(deserialize_with = "de_duration_sec")]
pub ws_auth_timeout_sec: Duration,
#[serde_inline_default(Duration::from_secs(15))]
#[serde(deserialize_with = "de_duration_sec")]
pub ws_send_timeout_sec: Duration,
#[serde_inline_default(1024)]
pub ws_event_queue_len: usize,
}
fn de_duration_sec<'de, D: Deserializer<'de>>(de: D) -> Result<Duration, D::Error> {
<u64>::deserialize(de).map(Duration::from_secs)
} }
impl Config { impl Config {

166
blahd/src/event.rs Normal file
View file

@ -0,0 +1,166 @@
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::convert::Infallible;
use std::fmt;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use anyhow::{bail, Context as _, Result};
use axum::extract::ws::{Message, WebSocket};
use blah::types::{AuthPayload, ChatItem, WithSig};
use futures_util::future::Either;
use futures_util::stream::SplitSink;
use futures_util::{stream_select, SinkExt as _, Stream, StreamExt};
use rusqlite::{params, OptionalExtension};
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast;
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tokio_stream::wrappers::BroadcastStream;
use crate::AppState;
#[derive(Debug, Deserialize)]
pub enum Incoming {}
#[derive(Debug, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum Outgoing<'a> {
/// A chat message from a joined room.
Chat(&'a ChatItem),
/// The receiver is too slow to receive and some events and are dropped.
// FIXME: Should we indefinitely buffer them or just disconnect the client instead?
Lagged,
}
#[derive(Debug, Default)]
pub struct State {
pub user_listeners: Mutex<HashMap<u64, UserEventSender>>,
}
#[derive(Debug)]
pub struct StreamEnded;
impl fmt::Display for StreamEnded {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("stream unexpectedly ended")
}
}
impl std::error::Error for StreamEnded {}
struct WsSenderWrapper<'ws, 'c> {
inner: SplitSink<&'ws mut WebSocket, Message>,
config: &'c crate::config::Config,
}
impl WsSenderWrapper<'_, '_> {
async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> {
let data = serde_json::to_string(&msg).expect("serialization cannot fail");
let fut = tokio::time::timeout(
self.config.server.ws_send_timeout_sec,
self.inner.send(Message::Text(data)),
);
match fut.await {
Ok(Ok(())) => Ok(()),
Ok(Err(_send_err)) => Err(StreamEnded.into()),
Err(_elapsed) => bail!("send timeout"),
}
}
}
type UserEventSender = broadcast::Sender<Arc<ChatItem>>;
#[derive(Debug)]
struct UserEventReceiver {
rx: BroadcastStream<Arc<ChatItem>>,
st: Arc<AppState>,
uid: u64,
}
impl Drop for UserEventReceiver {
fn drop(&mut self) {
tracing::debug!(%self.uid, "user disconnected");
if let Ok(mut map) = self.st.event.user_listeners.lock() {
if let Some(tx) = map.get_mut(&self.uid) {
if tx.receiver_count() == 1 {
map.remove(&self.uid);
}
}
}
}
}
impl Stream for UserEventReceiver {
type Item = Result<Arc<ChatItem>, BroadcastStreamRecvError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx.poll_next_unpin(cx)
}
}
pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallible> {
let (ws_tx, ws_rx) = ws.split();
let mut ws_rx = ws_rx.map(|ret| ret.and_then(|msg| msg.into_text()).map_err(|_| StreamEnded));
let mut ws_tx = WsSenderWrapper {
inner: ws_tx,
config: &st.config,
};
let uid = {
let payload = tokio::time::timeout(st.config.server.ws_auth_timeout_sec, ws_rx.next())
.await
.context("authentication timeout")?
.ok_or(StreamEnded)??;
let auth = serde_json::from_str::<WithSig<AuthPayload>>(&payload)?;
st.verify_signed_data(&auth)?;
st.conn
.lock()
.unwrap()
.query_row(
r"
SELECT `uid`
FROM `user`
WHERE `userkey` = ?
",
params![auth.signee.user],
|row| row.get::<_, u64>(0),
)
.optional()?
.context("invalid user")?
};
tracing::debug!(%uid, "user connected");
let event_rx = {
let rx = match st.event.user_listeners.lock().unwrap().entry(uid) {
Entry::Occupied(ent) => ent.get().subscribe(),
Entry::Vacant(ent) => {
let (tx, rx) = broadcast::channel(st.config.server.ws_event_queue_len);
ent.insert(tx);
rx
}
};
UserEventReceiver {
rx: rx.into(),
st: st.clone(),
uid,
}
};
let mut stream = stream_select!(ws_rx.map(Either::Left), event_rx.map(Either::Right));
loop {
match stream.next().await.ok_or(StreamEnded)? {
Either::Left(msg) => match serde_json::from_str::<Incoming>(&msg?)? {},
Either::Right(ret) => {
let msg = match &ret {
Ok(chat) => Outgoing::Chat(chat),
Err(BroadcastStreamRecvError::Lagged(_)) => Outgoing::Lagged,
};
// TODO: Concurrent send.
ws_tx.send(&msg).await?;
}
}
}
}

View file

@ -1,14 +1,12 @@
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::convert::Infallible;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime}; use std::time::{Duration, SystemTime};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use axum::extract::{Path, Query, State}; use axum::extract::ws;
use axum::extract::{Path, Query, State, WebSocketUpgrade};
use axum::http::{header, StatusCode}; use axum::http::{header, StatusCode};
use axum::response::{sse, IntoResponse}; use axum::response::{IntoResponse, Response};
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{Json, Router};
use axum_extra::extract::WithRejection; use axum_extra::extract::WithRejection;
@ -21,14 +19,13 @@ use ed25519_dalek::SIGNATURE_LENGTH;
use middleware::{ApiError, OptionalAuth, SignedJson}; use middleware::{ApiError, OptionalAuth, SignedJson};
use rusqlite::{named_params, params, OptionalExtension, Row}; use rusqlite::{named_params, params, OptionalExtension, Row};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::broadcast;
use tokio_stream::StreamExt;
use utils::ExpiringSet; use utils::ExpiringSet;
use uuid::Uuid; use uuid::Uuid;
#[macro_use] #[macro_use]
mod middleware; mod middleware;
mod config; mod config;
mod event;
mod utils; mod utils;
/// Blah Chat Server /// Blah Chat Server
@ -84,8 +81,8 @@ fn main() -> Result<()> {
#[derive(Debug)] #[derive(Debug)]
struct AppState { struct AppState {
conn: Mutex<rusqlite::Connection>, conn: Mutex<rusqlite::Connection>,
room_listeners: Mutex<HashMap<u64, broadcast::Sender<Arc<ChatItem>>>>,
used_nonces: Mutex<ExpiringSet<u32>>, used_nonces: Mutex<ExpiringSet<u32>>,
event: event::State,
config: Config, config: Config,
} }
@ -101,10 +98,10 @@ impl AppState {
.context("failed to initialize database")?; .context("failed to initialize database")?;
Ok(Self { Ok(Self {
conn: Mutex::new(conn), conn: Mutex::new(conn),
room_listeners: Mutex::new(HashMap::new()),
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs( used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(
config.server.timestamp_tolerance_secs, config.server.timestamp_tolerance_secs,
))), ))),
event: event::State::default(),
config, config,
}) })
@ -152,11 +149,11 @@ async fn main_async(st: AppState) -> Result<()> {
let st = Arc::new(st); let st = Arc::new(st);
let app = Router::new() let app = Router::new()
.route("/ws", get(handle_ws))
.route("/room/create", post(room_create)) .route("/room/create", post(room_create))
.route("/room/:ruuid", get(room_get_metadata)) .route("/room/:ruuid", get(room_get_metadata))
// NB. Sync with `feed_url` and `next_url` generation. // NB. Sync with `feed_url` and `next_url` generation.
.route("/room/:ruuid/feed.json", get(room_get_feed)) .route("/room/:ruuid/feed.json", get(room_get_feed))
.route("/room/:ruuid/event", get(room_event))
.route("/room/:ruuid/item", get(room_get_item).post(room_post_item)) .route("/room/:ruuid/item", get(room_get_item).post(room_post_item))
.route("/room/:ruuid/admin", post(room_admin)) .route("/room/:ruuid/admin", post(room_admin))
.with_state(st.clone()) .with_state(st.clone())
@ -179,6 +176,24 @@ async fn main_async(st: AppState) -> Result<()> {
Ok(()) Ok(())
} }
async fn handle_ws(State(st): ArcState, ws: WebSocketUpgrade) -> Response {
ws.on_upgrade(move |mut socket| async move {
match event::handle_ws(st, &mut socket).await {
Ok(never) => match never {},
Err(err) if err.is::<event::StreamEnded>() => {}
Err(err) => {
tracing::debug!(%err, "ws error");
let _: Result<_, _> = socket
.send(ws::Message::Close(Some(ws::CloseFrame {
code: ws::close_code::ERROR,
reason: err.to_string().into(),
})))
.await;
}
}
})
}
async fn room_create( async fn room_create(
st: ArcState, st: ArcState,
SignedJson(params): SignedJson<CreateRoomPayload>, SignedJson(params): SignedJson<CreateRoomPayload>,
@ -505,7 +520,7 @@ async fn room_post_item(
)); ));
} }
let (rid, cid) = { let (cid, txs) = {
let conn = st.conn.lock().unwrap(); let conn = st.conn.lock().unwrap();
let Some((rid, uid)) = conn let Some((rid, uid)) = conn
.query_row( .query_row(
@ -550,75 +565,40 @@ async fn room_post_item(
}, },
|row| row.get::<_, u64>(0), |row| row.get::<_, u64>(0),
)?; )?;
(rid, cid)
// FIXME: Optimize this to not traverses over all members.
let mut stmt = conn.prepare(
r"
SELECT `uid`
FROM `room_member`
WHERE `rid` = :rid
",
)?;
let listeners = st.event.user_listeners.lock().unwrap();
let txs = stmt
.query_map(params![rid], |row| row.get::<_, u64>(0))?
.filter_map(|ret| match ret {
Ok(uid) => listeners.get(&uid).map(|tx| Ok(tx.clone())),
Err(err) => Some(Err(err)),
})
.collect::<Result<Vec<_>, _>>()?;
(cid, txs)
}; };
let mut listeners = st.room_listeners.lock().unwrap(); if !txs.is_empty() {
if let Some(tx) = listeners.get(&rid) { tracing::debug!("broadcasting event to {} clients", txs.len());
if tx.send(Arc::new(chat)).is_err() { let chat = Arc::new(chat);
// Clean up because all receivers died. for tx in txs {
listeners.remove(&rid); if let Err(err) = tx.send(chat.clone()) {
tracing::debug!(%err, "failed to broadcast event");
}
} }
} }
Ok(Json(cid)) Ok(Json(cid))
} }
async fn room_event(
st: ArcState,
Path(ruuid): Path<Uuid>,
// TODO: There is actually no way to add headers via `EventSource` in client side.
// But this API is kinda temporary and need a better replacement anyway.
// So just only support public room for now.
OptionalAuth(user): OptionalAuth,
) -> Result<impl IntoResponse, ApiError> {
let rid = get_room_if_readable(&st.conn.lock().unwrap(), ruuid, user.as_ref(), |row| {
row.get::<_, u64>(0)
})?;
let rx = match st.room_listeners.lock().unwrap().entry(rid) {
Entry::Occupied(ent) => ent.get().subscribe(),
Entry::Vacant(ent) => {
let (tx, rx) = broadcast::channel(st.config.server.event_queue_len);
ent.insert(tx);
rx
}
};
// Do clean up when this stream is closed.
struct CleanOnDrop {
st: Arc<AppState>,
rid: u64,
}
impl Drop for CleanOnDrop {
fn drop(&mut self) {
if let Ok(mut listeners) = self.st.room_listeners.lock() {
if let Some(tx) = listeners.get(&self.rid) {
if tx.receiver_count() == 0 {
listeners.remove(&self.rid);
}
}
}
}
}
let _guard = CleanOnDrop { st: st.0, rid };
let stream = tokio_stream::wrappers::BroadcastStream::new(rx).filter_map(move |ret| {
let _guard = &_guard;
// On stream closure or lagging, close the current stream so client can retry.
let item = ret.ok()?;
let evt = sse::Event::default()
.json_data(&*item)
.expect("serialization cannot fail");
Some(Ok::<_, Infallible>(evt))
});
// NB. Send an empty event immediately to trigger client ready event.
let first_event = sse::Event::default().comment("");
let stream = futures_util::stream::iter(Some(Ok(first_event))).chain(stream);
Ok(sse::Sse::new(stream).keep_alive(sse::KeepAlive::default()))
}
async fn room_admin( async fn room_admin(
st: ArcState, st: ArcState,
Path(ruuid): Path<Uuid>, Path(ruuid): Path<Uuid>,

View file

@ -1,3 +1,4 @@
use std::fmt;
use std::sync::Arc; use std::sync::Arc;
use axum::extract::rejection::{JsonRejection, PathRejection, QueryRejection}; use axum::extract::rejection::{JsonRejection, PathRejection, QueryRejection};
@ -22,6 +23,14 @@ pub struct ApiError {
pub message: String, pub message: String,
} }
impl fmt::Display for ApiError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.message)
}
}
impl std::error::Error for ApiError {}
macro_rules! error_response { macro_rules! error_response {
($status:expr, $code:literal, $msg:literal $(, $msg_args:expr)* $(,)?) => { ($status:expr, $code:literal, $msg:literal $(, $msg_args:expr)* $(,)?) => {
$crate::middleware::ApiError { $crate::middleware::ApiError {

View file

@ -7,7 +7,7 @@ const joinRoomBtn = document.querySelector('#join-room');
let roomUrl = ''; let roomUrl = '';
let roomUuid = null; let roomUuid = null;
let feed = null; let ws = null;
let keypair = null; let keypair = null;
let defaultConfig = {}; let defaultConfig = {};
@ -166,16 +166,13 @@ function escapeHtml(text) {
} }
async function connectRoom(url) { async function connectRoom(url) {
if (url === '' || url == roomUrl) return; if (url === '' || url == roomUrl || keypair === null) return;
const match = url.match(/^https?:\/\/.*\/([a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12})\/?/); const match = url.match(/^https?:\/\/.*\/([a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12})\/?/);
if (match === null) { if (match === null) {
log('invalid room url'); log('invalid room url');
return; return;
} }
if (feed !== null) {
feed.close();
}
roomUrl = url; roomUrl = url;
roomUuid = match[1]; roomUuid = match[1];
@ -210,18 +207,40 @@ async function connectRoom(url) {
// TODO: There is a time window where events would be lost. // TODO: There is a time window where events would be lost.
feed = new EventSource(`${url}/event`); await connectWs();
feed.onopen = (_) => { }
async function connectWs() {
if (ws !== null) {
ws.close();
}
const wsUrl = new URL(roomUrl);
wsUrl.protocol = wsUrl.protocol == 'http:' ? 'ws:' : 'wss:';
wsUrl.pathname = '/ws';
ws = new WebSocket(wsUrl);
ws.onopen = async (_) => {
const auth = await signData({ typ: 'auth' });
await ws.send(auth);
log('listening on events'); log('listening on events');
} }
feed.onerror = (e) => { ws.onclose = (e) => {
console.error(e); console.error(e);
log('event listener error'); log(`ws closed (code=${e.code}): ${e.reason}`);
}; };
feed.onmessage = async (e) => { ws.onerror = (e) => {
console.log('feed event', e.data); console.error(e);
const chat = JSON.parse(e.data); log(`ws error: ${e.error}`);
showChatMsg(chat); };
ws.onmessage = async (e) => {
console.log('ws event', e.data);
const msg = JSON.parse(e.data);
if (msg.chat !== undefined) {
showChatMsg(msg.chat);
} else if (msg.lagged !== undefined) {
log('some events are dropped because of queue overflow')
} else {
log(`unknown ws message: ${e.data}`);
}
}; };
} }
@ -236,6 +255,7 @@ async function joinRoom() {
user: await getUserPubkey(), user: await getUserPubkey(),
}); });
log('joined room'); log('joined room');
await connectWs();
} catch (e) { } catch (e) {
console.error(e); console.error(e);
log(`failed to join room: ${e}`); log(`failed to join room: ${e}`);