mirror of
https://github.com/Blah-IM/blahrs.git
synced 2025-05-01 00:31:09 +00:00
Switch from event stream to WebSocket for events
This commit is contained in:
parent
5fadffef4d
commit
77216aa0f8
9 changed files with 386 additions and 123 deletions
88
Cargo.lock
generated
88
Cargo.lock
generated
|
@ -130,6 +130,7 @@ checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf"
|
|||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
"base64 0.21.7",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http",
|
||||
|
@ -148,8 +149,10 @@ dependencies = [
|
|||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sha1",
|
||||
"sync_wrapper 1.0.1",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
|
@ -214,6 +217,12 @@ dependencies = [
|
|||
"rustc-demangle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.21.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.22.1"
|
||||
|
@ -479,6 +488,12 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2"
|
||||
|
||||
[[package]]
|
||||
name = "der"
|
||||
version = "0.7.9"
|
||||
|
@ -650,6 +665,7 @@ checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
|
|||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-macro",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
|
@ -1246,7 +1262,7 @@ version = "0.12.7"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"encoding_rs",
|
||||
"futures-core",
|
||||
|
@ -1360,7 +1376,7 @@ version = "2.1.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"base64 0.22.1",
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
|
@ -1530,6 +1546,17 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1"
|
||||
version = "0.10.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.10.8"
|
||||
|
@ -1678,6 +1705,26 @@ dependencies = [
|
|||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.63"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.63"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thread_local"
|
||||
version = "1.1.8"
|
||||
|
@ -1763,6 +1810,18 @@ dependencies = [
|
|||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.11"
|
||||
|
@ -1884,6 +1943,25 @@ version = "0.2.5"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand",
|
||||
"sha1",
|
||||
"thiserror",
|
||||
"url",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.17.0"
|
||||
|
@ -1928,6 +2006,12 @@ dependencies = [
|
|||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "utf-8"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "utf8-width"
|
||||
version = "0.1.7"
|
||||
|
|
|
@ -5,7 +5,7 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
axum = { version = "0.7", features = ["tokio"] }
|
||||
axum = { version = "0.7", features = ["ws"] }
|
||||
axum-extra = "0.9"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
ed25519-dalek = "2"
|
||||
|
@ -17,7 +17,7 @@ sd-notify = "0.4"
|
|||
serde = { version = "1", features = ["derive"] }
|
||||
serde-aux = "4"
|
||||
serde_json = "1"
|
||||
tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync"] }
|
||||
tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync", "time"] }
|
||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||
tower-http = { version = "0.5", features = ["cors", "limit"] }
|
||||
tracing = "0.1"
|
||||
|
|
|
@ -26,8 +26,16 @@ max_page_len = 1024
|
|||
# Maximum request body length in bytes.
|
||||
max_request_len = 4096
|
||||
|
||||
# Maximum length of a single event queue.
|
||||
event_queue_len = 1024
|
||||
|
||||
# The maximum timestamp tolerance in seconds for request validation.
|
||||
timestamp_tolerance_secs = 90
|
||||
|
||||
# The max waiting time for the first authentication message for websocket.
|
||||
ws_auth_timeout_sec = 15
|
||||
|
||||
# The max waiting time for outgoing message to be received for websocket.
|
||||
ws_send_timeout_sec = 15
|
||||
|
||||
# Maximum number of pending events a single user can have.
|
||||
# If events overflow the pending buffer, older events will be dropped and
|
||||
# client will be notified.
|
||||
ws_event_queue_len = 1024
|
||||
|
|
|
@ -4,6 +4,18 @@ info:
|
|||
version: 0.0.1
|
||||
|
||||
paths:
|
||||
/ws:
|
||||
get:
|
||||
summary: WebSocket endpoint.
|
||||
description: |
|
||||
Once connection, client must send a JSON text message of type
|
||||
`WithSig<AuthPayload>` for authentication.
|
||||
If server does not close it immediately, it means success.
|
||||
|
||||
Then server will send JSON text messages on events that user are
|
||||
interested in (eg. chat from joined rooms).
|
||||
The message has type `Outgoing` in `blahd/src/ws.rs`.
|
||||
|
||||
/room:
|
||||
get:
|
||||
summary: Get room metadata
|
||||
|
@ -133,34 +145,6 @@ paths:
|
|||
application/json:
|
||||
$ref: '#/components/schemas/ApiError'
|
||||
|
||||
/room/{ruuid}/event:
|
||||
get:
|
||||
summary: Get an event stream for future new items.
|
||||
description: |
|
||||
This is a temporary interface, before a better notification system
|
||||
(post notifications? websocket?) is implemented.
|
||||
headers:
|
||||
Authorization:
|
||||
description: Proof of membership for private rooms.
|
||||
required: false
|
||||
schema:
|
||||
$ret: WithSig<AuthPayload>
|
||||
responses:
|
||||
200:
|
||||
content:
|
||||
text/event-stream:
|
||||
x-description: An event stream, each event is a JSON with type WithSig<ChatPayload>
|
||||
400:
|
||||
description: Body is invalid or fails the verification.
|
||||
content:
|
||||
application/json:
|
||||
$ref: '#/components/schemas/ApiError'
|
||||
404:
|
||||
description: Room not found.
|
||||
content:
|
||||
application/json:
|
||||
$ref: '#/components/schemas/ApiError'
|
||||
|
||||
/room/{ruuid}/admin:
|
||||
post:
|
||||
summary: Room management
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{ensure, Result};
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use serde_inline_default::serde_inline_default;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
|
@ -30,11 +31,22 @@ pub struct ServerConfig {
|
|||
pub max_page_len: usize,
|
||||
#[serde_inline_default(4096)] // 4KiB
|
||||
pub max_request_len: usize,
|
||||
#[serde_inline_default(1024)]
|
||||
pub event_queue_len: usize,
|
||||
|
||||
#[serde_inline_default(90)]
|
||||
pub timestamp_tolerance_secs: u64,
|
||||
|
||||
#[serde_inline_default(Duration::from_secs(15))]
|
||||
#[serde(deserialize_with = "de_duration_sec")]
|
||||
pub ws_auth_timeout_sec: Duration,
|
||||
#[serde_inline_default(Duration::from_secs(15))]
|
||||
#[serde(deserialize_with = "de_duration_sec")]
|
||||
pub ws_send_timeout_sec: Duration,
|
||||
#[serde_inline_default(1024)]
|
||||
pub ws_event_queue_len: usize,
|
||||
}
|
||||
|
||||
fn de_duration_sec<'de, D: Deserializer<'de>>(de: D) -> Result<Duration, D::Error> {
|
||||
<u64>::deserialize(de).map(Duration::from_secs)
|
||||
}
|
||||
|
||||
impl Config {
|
||||
|
|
166
blahd/src/event.rs
Normal file
166
blahd/src/event.rs
Normal file
|
@ -0,0 +1,166 @@
|
|||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::Infallible;
|
||||
use std::fmt;
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use anyhow::{bail, Context as _, Result};
|
||||
use axum::extract::ws::{Message, WebSocket};
|
||||
use blah::types::{AuthPayload, ChatItem, WithSig};
|
||||
use futures_util::future::Either;
|
||||
use futures_util::stream::SplitSink;
|
||||
use futures_util::{stream_select, SinkExt as _, Stream, StreamExt};
|
||||
use rusqlite::{params, OptionalExtension};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::broadcast;
|
||||
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub enum Incoming {}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Outgoing<'a> {
|
||||
/// A chat message from a joined room.
|
||||
Chat(&'a ChatItem),
|
||||
/// The receiver is too slow to receive and some events and are dropped.
|
||||
// FIXME: Should we indefinitely buffer them or just disconnect the client instead?
|
||||
Lagged,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct State {
|
||||
pub user_listeners: Mutex<HashMap<u64, UserEventSender>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StreamEnded;
|
||||
|
||||
impl fmt::Display for StreamEnded {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str("stream unexpectedly ended")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for StreamEnded {}
|
||||
|
||||
struct WsSenderWrapper<'ws, 'c> {
|
||||
inner: SplitSink<&'ws mut WebSocket, Message>,
|
||||
config: &'c crate::config::Config,
|
||||
}
|
||||
|
||||
impl WsSenderWrapper<'_, '_> {
|
||||
async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> {
|
||||
let data = serde_json::to_string(&msg).expect("serialization cannot fail");
|
||||
let fut = tokio::time::timeout(
|
||||
self.config.server.ws_send_timeout_sec,
|
||||
self.inner.send(Message::Text(data)),
|
||||
);
|
||||
match fut.await {
|
||||
Ok(Ok(())) => Ok(()),
|
||||
Ok(Err(_send_err)) => Err(StreamEnded.into()),
|
||||
Err(_elapsed) => bail!("send timeout"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type UserEventSender = broadcast::Sender<Arc<ChatItem>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UserEventReceiver {
|
||||
rx: BroadcastStream<Arc<ChatItem>>,
|
||||
st: Arc<AppState>,
|
||||
uid: u64,
|
||||
}
|
||||
|
||||
impl Drop for UserEventReceiver {
|
||||
fn drop(&mut self) {
|
||||
tracing::debug!(%self.uid, "user disconnected");
|
||||
if let Ok(mut map) = self.st.event.user_listeners.lock() {
|
||||
if let Some(tx) = map.get_mut(&self.uid) {
|
||||
if tx.receiver_count() == 1 {
|
||||
map.remove(&self.uid);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for UserEventReceiver {
|
||||
type Item = Result<Arc<ChatItem>, BroadcastStreamRecvError>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.rx.poll_next_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallible> {
|
||||
let (ws_tx, ws_rx) = ws.split();
|
||||
let mut ws_rx = ws_rx.map(|ret| ret.and_then(|msg| msg.into_text()).map_err(|_| StreamEnded));
|
||||
let mut ws_tx = WsSenderWrapper {
|
||||
inner: ws_tx,
|
||||
config: &st.config,
|
||||
};
|
||||
|
||||
let uid = {
|
||||
let payload = tokio::time::timeout(st.config.server.ws_auth_timeout_sec, ws_rx.next())
|
||||
.await
|
||||
.context("authentication timeout")?
|
||||
.ok_or(StreamEnded)??;
|
||||
let auth = serde_json::from_str::<WithSig<AuthPayload>>(&payload)?;
|
||||
st.verify_signed_data(&auth)?;
|
||||
|
||||
st.conn
|
||||
.lock()
|
||||
.unwrap()
|
||||
.query_row(
|
||||
r"
|
||||
SELECT `uid`
|
||||
FROM `user`
|
||||
WHERE `userkey` = ?
|
||||
",
|
||||
params![auth.signee.user],
|
||||
|row| row.get::<_, u64>(0),
|
||||
)
|
||||
.optional()?
|
||||
.context("invalid user")?
|
||||
};
|
||||
|
||||
tracing::debug!(%uid, "user connected");
|
||||
|
||||
let event_rx = {
|
||||
let rx = match st.event.user_listeners.lock().unwrap().entry(uid) {
|
||||
Entry::Occupied(ent) => ent.get().subscribe(),
|
||||
Entry::Vacant(ent) => {
|
||||
let (tx, rx) = broadcast::channel(st.config.server.ws_event_queue_len);
|
||||
ent.insert(tx);
|
||||
rx
|
||||
}
|
||||
};
|
||||
UserEventReceiver {
|
||||
rx: rx.into(),
|
||||
st: st.clone(),
|
||||
uid,
|
||||
}
|
||||
};
|
||||
|
||||
let mut stream = stream_select!(ws_rx.map(Either::Left), event_rx.map(Either::Right));
|
||||
loop {
|
||||
match stream.next().await.ok_or(StreamEnded)? {
|
||||
Either::Left(msg) => match serde_json::from_str::<Incoming>(&msg?)? {},
|
||||
Either::Right(ret) => {
|
||||
let msg = match &ret {
|
||||
Ok(chat) => Outgoing::Chat(chat),
|
||||
Err(BroadcastStreamRecvError::Lagged(_)) => Outgoing::Lagged,
|
||||
};
|
||||
// TODO: Concurrent send.
|
||||
ws_tx.send(&msg).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,14 +1,12 @@
|
|||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::Infallible;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use axum::extract::{Path, Query, State};
|
||||
use axum::extract::ws;
|
||||
use axum::extract::{Path, Query, State, WebSocketUpgrade};
|
||||
use axum::http::{header, StatusCode};
|
||||
use axum::response::{sse, IntoResponse};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
use axum_extra::extract::WithRejection;
|
||||
|
@ -21,14 +19,13 @@ use ed25519_dalek::SIGNATURE_LENGTH;
|
|||
use middleware::{ApiError, OptionalAuth, SignedJson};
|
||||
use rusqlite::{named_params, params, OptionalExtension, Row};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::broadcast;
|
||||
use tokio_stream::StreamExt;
|
||||
use utils::ExpiringSet;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[macro_use]
|
||||
mod middleware;
|
||||
mod config;
|
||||
mod event;
|
||||
mod utils;
|
||||
|
||||
/// Blah Chat Server
|
||||
|
@ -84,8 +81,8 @@ fn main() -> Result<()> {
|
|||
#[derive(Debug)]
|
||||
struct AppState {
|
||||
conn: Mutex<rusqlite::Connection>,
|
||||
room_listeners: Mutex<HashMap<u64, broadcast::Sender<Arc<ChatItem>>>>,
|
||||
used_nonces: Mutex<ExpiringSet<u32>>,
|
||||
event: event::State,
|
||||
|
||||
config: Config,
|
||||
}
|
||||
|
@ -101,10 +98,10 @@ impl AppState {
|
|||
.context("failed to initialize database")?;
|
||||
Ok(Self {
|
||||
conn: Mutex::new(conn),
|
||||
room_listeners: Mutex::new(HashMap::new()),
|
||||
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(
|
||||
config.server.timestamp_tolerance_secs,
|
||||
))),
|
||||
event: event::State::default(),
|
||||
|
||||
config,
|
||||
})
|
||||
|
@ -152,11 +149,11 @@ async fn main_async(st: AppState) -> Result<()> {
|
|||
let st = Arc::new(st);
|
||||
|
||||
let app = Router::new()
|
||||
.route("/ws", get(handle_ws))
|
||||
.route("/room/create", post(room_create))
|
||||
.route("/room/:ruuid", get(room_get_metadata))
|
||||
// NB. Sync with `feed_url` and `next_url` generation.
|
||||
.route("/room/:ruuid/feed.json", get(room_get_feed))
|
||||
.route("/room/:ruuid/event", get(room_event))
|
||||
.route("/room/:ruuid/item", get(room_get_item).post(room_post_item))
|
||||
.route("/room/:ruuid/admin", post(room_admin))
|
||||
.with_state(st.clone())
|
||||
|
@ -179,6 +176,24 @@ async fn main_async(st: AppState) -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_ws(State(st): ArcState, ws: WebSocketUpgrade) -> Response {
|
||||
ws.on_upgrade(move |mut socket| async move {
|
||||
match event::handle_ws(st, &mut socket).await {
|
||||
Ok(never) => match never {},
|
||||
Err(err) if err.is::<event::StreamEnded>() => {}
|
||||
Err(err) => {
|
||||
tracing::debug!(%err, "ws error");
|
||||
let _: Result<_, _> = socket
|
||||
.send(ws::Message::Close(Some(ws::CloseFrame {
|
||||
code: ws::close_code::ERROR,
|
||||
reason: err.to_string().into(),
|
||||
})))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn room_create(
|
||||
st: ArcState,
|
||||
SignedJson(params): SignedJson<CreateRoomPayload>,
|
||||
|
@ -505,7 +520,7 @@ async fn room_post_item(
|
|||
));
|
||||
}
|
||||
|
||||
let (rid, cid) = {
|
||||
let (cid, txs) = {
|
||||
let conn = st.conn.lock().unwrap();
|
||||
let Some((rid, uid)) = conn
|
||||
.query_row(
|
||||
|
@ -550,75 +565,40 @@ async fn room_post_item(
|
|||
},
|
||||
|row| row.get::<_, u64>(0),
|
||||
)?;
|
||||
(rid, cid)
|
||||
|
||||
// FIXME: Optimize this to not traverses over all members.
|
||||
let mut stmt = conn.prepare(
|
||||
r"
|
||||
SELECT `uid`
|
||||
FROM `room_member`
|
||||
WHERE `rid` = :rid
|
||||
",
|
||||
)?;
|
||||
let listeners = st.event.user_listeners.lock().unwrap();
|
||||
let txs = stmt
|
||||
.query_map(params![rid], |row| row.get::<_, u64>(0))?
|
||||
.filter_map(|ret| match ret {
|
||||
Ok(uid) => listeners.get(&uid).map(|tx| Ok(tx.clone())),
|
||||
Err(err) => Some(Err(err)),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
(cid, txs)
|
||||
};
|
||||
|
||||
let mut listeners = st.room_listeners.lock().unwrap();
|
||||
if let Some(tx) = listeners.get(&rid) {
|
||||
if tx.send(Arc::new(chat)).is_err() {
|
||||
// Clean up because all receivers died.
|
||||
listeners.remove(&rid);
|
||||
if !txs.is_empty() {
|
||||
tracing::debug!("broadcasting event to {} clients", txs.len());
|
||||
let chat = Arc::new(chat);
|
||||
for tx in txs {
|
||||
if let Err(err) = tx.send(chat.clone()) {
|
||||
tracing::debug!(%err, "failed to broadcast event");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(cid))
|
||||
}
|
||||
|
||||
async fn room_event(
|
||||
st: ArcState,
|
||||
Path(ruuid): Path<Uuid>,
|
||||
// TODO: There is actually no way to add headers via `EventSource` in client side.
|
||||
// But this API is kinda temporary and need a better replacement anyway.
|
||||
// So just only support public room for now.
|
||||
OptionalAuth(user): OptionalAuth,
|
||||
) -> Result<impl IntoResponse, ApiError> {
|
||||
let rid = get_room_if_readable(&st.conn.lock().unwrap(), ruuid, user.as_ref(), |row| {
|
||||
row.get::<_, u64>(0)
|
||||
})?;
|
||||
|
||||
let rx = match st.room_listeners.lock().unwrap().entry(rid) {
|
||||
Entry::Occupied(ent) => ent.get().subscribe(),
|
||||
Entry::Vacant(ent) => {
|
||||
let (tx, rx) = broadcast::channel(st.config.server.event_queue_len);
|
||||
ent.insert(tx);
|
||||
rx
|
||||
}
|
||||
};
|
||||
|
||||
// Do clean up when this stream is closed.
|
||||
struct CleanOnDrop {
|
||||
st: Arc<AppState>,
|
||||
rid: u64,
|
||||
}
|
||||
impl Drop for CleanOnDrop {
|
||||
fn drop(&mut self) {
|
||||
if let Ok(mut listeners) = self.st.room_listeners.lock() {
|
||||
if let Some(tx) = listeners.get(&self.rid) {
|
||||
if tx.receiver_count() == 0 {
|
||||
listeners.remove(&self.rid);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _guard = CleanOnDrop { st: st.0, rid };
|
||||
|
||||
let stream = tokio_stream::wrappers::BroadcastStream::new(rx).filter_map(move |ret| {
|
||||
let _guard = &_guard;
|
||||
// On stream closure or lagging, close the current stream so client can retry.
|
||||
let item = ret.ok()?;
|
||||
let evt = sse::Event::default()
|
||||
.json_data(&*item)
|
||||
.expect("serialization cannot fail");
|
||||
Some(Ok::<_, Infallible>(evt))
|
||||
});
|
||||
// NB. Send an empty event immediately to trigger client ready event.
|
||||
let first_event = sse::Event::default().comment("");
|
||||
let stream = futures_util::stream::iter(Some(Ok(first_event))).chain(stream);
|
||||
Ok(sse::Sse::new(stream).keep_alive(sse::KeepAlive::default()))
|
||||
}
|
||||
|
||||
async fn room_admin(
|
||||
st: ArcState,
|
||||
Path(ruuid): Path<Uuid>,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::rejection::{JsonRejection, PathRejection, QueryRejection};
|
||||
|
@ -22,6 +23,14 @@ pub struct ApiError {
|
|||
pub message: String,
|
||||
}
|
||||
|
||||
impl fmt::Display for ApiError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(&self.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ApiError {}
|
||||
|
||||
macro_rules! error_response {
|
||||
($status:expr, $code:literal, $msg:literal $(, $msg_args:expr)* $(,)?) => {
|
||||
$crate::middleware::ApiError {
|
||||
|
|
|
@ -7,7 +7,7 @@ const joinRoomBtn = document.querySelector('#join-room');
|
|||
|
||||
let roomUrl = '';
|
||||
let roomUuid = null;
|
||||
let feed = null;
|
||||
let ws = null;
|
||||
let keypair = null;
|
||||
let defaultConfig = {};
|
||||
|
||||
|
@ -166,16 +166,13 @@ function escapeHtml(text) {
|
|||
}
|
||||
|
||||
async function connectRoom(url) {
|
||||
if (url === '' || url == roomUrl) return;
|
||||
if (url === '' || url == roomUrl || keypair === null) return;
|
||||
const match = url.match(/^https?:\/\/.*\/([a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12})\/?/);
|
||||
if (match === null) {
|
||||
log('invalid room url');
|
||||
return;
|
||||
}
|
||||
|
||||
if (feed !== null) {
|
||||
feed.close();
|
||||
}
|
||||
roomUrl = url;
|
||||
roomUuid = match[1];
|
||||
|
||||
|
@ -210,18 +207,40 @@ async function connectRoom(url) {
|
|||
|
||||
// TODO: There is a time window where events would be lost.
|
||||
|
||||
feed = new EventSource(`${url}/event`);
|
||||
feed.onopen = (_) => {
|
||||
await connectWs();
|
||||
}
|
||||
|
||||
async function connectWs() {
|
||||
if (ws !== null) {
|
||||
ws.close();
|
||||
}
|
||||
const wsUrl = new URL(roomUrl);
|
||||
wsUrl.protocol = wsUrl.protocol == 'http:' ? 'ws:' : 'wss:';
|
||||
wsUrl.pathname = '/ws';
|
||||
ws = new WebSocket(wsUrl);
|
||||
ws.onopen = async (_) => {
|
||||
const auth = await signData({ typ: 'auth' });
|
||||
await ws.send(auth);
|
||||
log('listening on events');
|
||||
}
|
||||
feed.onerror = (e) => {
|
||||
ws.onclose = (e) => {
|
||||
console.error(e);
|
||||
log('event listener error');
|
||||
log(`ws closed (code=${e.code}): ${e.reason}`);
|
||||
};
|
||||
feed.onmessage = async (e) => {
|
||||
console.log('feed event', e.data);
|
||||
const chat = JSON.parse(e.data);
|
||||
showChatMsg(chat);
|
||||
ws.onerror = (e) => {
|
||||
console.error(e);
|
||||
log(`ws error: ${e.error}`);
|
||||
};
|
||||
ws.onmessage = async (e) => {
|
||||
console.log('ws event', e.data);
|
||||
const msg = JSON.parse(e.data);
|
||||
if (msg.chat !== undefined) {
|
||||
showChatMsg(msg.chat);
|
||||
} else if (msg.lagged !== undefined) {
|
||||
log('some events are dropped because of queue overflow')
|
||||
} else {
|
||||
log(`unknown ws message: ${e.data}`);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -236,6 +255,7 @@ async function joinRoom() {
|
|||
user: await getUserPubkey(),
|
||||
});
|
||||
log('joined room');
|
||||
await connectWs();
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
log(`failed to join room: ${e}`);
|
||||
|
|
Loading…
Add table
Reference in a new issue