mirror of
https://github.com/Blah-IM/blahrs.git
synced 2025-05-01 08:41:09 +00:00
Switch from event stream to WebSocket for events
This commit is contained in:
parent
5fadffef4d
commit
77216aa0f8
9 changed files with 386 additions and 123 deletions
88
Cargo.lock
generated
88
Cargo.lock
generated
|
@ -130,6 +130,7 @@ checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum-core",
|
"axum-core",
|
||||||
|
"base64 0.21.7",
|
||||||
"bytes",
|
"bytes",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"http",
|
"http",
|
||||||
|
@ -148,8 +149,10 @@ dependencies = [
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"serde_path_to_error",
|
"serde_path_to_error",
|
||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
|
"sha1",
|
||||||
"sync_wrapper 1.0.1",
|
"sync_wrapper 1.0.1",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-tungstenite",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
|
@ -214,6 +217,12 @@ dependencies = [
|
||||||
"rustc-demangle",
|
"rustc-demangle",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "base64"
|
||||||
|
version = "0.21.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "base64"
|
name = "base64"
|
||||||
version = "0.22.1"
|
version = "0.22.1"
|
||||||
|
@ -479,6 +488,12 @@ dependencies = [
|
||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "data-encoding"
|
||||||
|
version = "2.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "der"
|
name = "der"
|
||||||
version = "0.7.9"
|
version = "0.7.9"
|
||||||
|
@ -650,6 +665,7 @@ checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-macro",
|
"futures-macro",
|
||||||
|
"futures-sink",
|
||||||
"futures-task",
|
"futures-task",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"pin-utils",
|
"pin-utils",
|
||||||
|
@ -1246,7 +1262,7 @@ version = "0.12.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63"
|
checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"base64",
|
"base64 0.22.1",
|
||||||
"bytes",
|
"bytes",
|
||||||
"encoding_rs",
|
"encoding_rs",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
@ -1360,7 +1376,7 @@ version = "2.1.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425"
|
checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"base64",
|
"base64 0.22.1",
|
||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1530,6 +1546,17 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sha1"
|
||||||
|
version = "0.10.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"cpufeatures",
|
||||||
|
"digest",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sha2"
|
name = "sha2"
|
||||||
version = "0.10.8"
|
version = "0.10.8"
|
||||||
|
@ -1678,6 +1705,26 @@ dependencies = [
|
||||||
"windows-sys 0.59.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror"
|
||||||
|
version = "1.0.63"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
|
||||||
|
dependencies = [
|
||||||
|
"thiserror-impl",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror-impl"
|
||||||
|
version = "1.0.63"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thread_local"
|
name = "thread_local"
|
||||||
version = "1.1.8"
|
version = "1.1.8"
|
||||||
|
@ -1763,6 +1810,18 @@ dependencies = [
|
||||||
"tokio-util",
|
"tokio-util",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-tungstenite"
|
||||||
|
version = "0.21.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
|
||||||
|
dependencies = [
|
||||||
|
"futures-util",
|
||||||
|
"log",
|
||||||
|
"tokio",
|
||||||
|
"tungstenite",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-util"
|
name = "tokio-util"
|
||||||
version = "0.7.11"
|
version = "0.7.11"
|
||||||
|
@ -1884,6 +1943,25 @@ version = "0.2.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tungstenite"
|
||||||
|
version = "0.21.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
"bytes",
|
||||||
|
"data-encoding",
|
||||||
|
"http",
|
||||||
|
"httparse",
|
||||||
|
"log",
|
||||||
|
"rand",
|
||||||
|
"sha1",
|
||||||
|
"thiserror",
|
||||||
|
"url",
|
||||||
|
"utf-8",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typenum"
|
name = "typenum"
|
||||||
version = "1.17.0"
|
version = "1.17.0"
|
||||||
|
@ -1928,6 +2006,12 @@ dependencies = [
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utf-8"
|
||||||
|
version = "0.7.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "utf8-width"
|
name = "utf8-width"
|
||||||
version = "0.1.7"
|
version = "0.1.7"
|
||||||
|
|
|
@ -5,7 +5,7 @@ edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
axum = { version = "0.7", features = ["tokio"] }
|
axum = { version = "0.7", features = ["ws"] }
|
||||||
axum-extra = "0.9"
|
axum-extra = "0.9"
|
||||||
clap = { version = "4", features = ["derive"] }
|
clap = { version = "4", features = ["derive"] }
|
||||||
ed25519-dalek = "2"
|
ed25519-dalek = "2"
|
||||||
|
@ -17,7 +17,7 @@ sd-notify = "0.4"
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde-aux = "4"
|
serde-aux = "4"
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync"] }
|
tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync", "time"] }
|
||||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||||
tower-http = { version = "0.5", features = ["cors", "limit"] }
|
tower-http = { version = "0.5", features = ["cors", "limit"] }
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
|
|
|
@ -26,8 +26,16 @@ max_page_len = 1024
|
||||||
# Maximum request body length in bytes.
|
# Maximum request body length in bytes.
|
||||||
max_request_len = 4096
|
max_request_len = 4096
|
||||||
|
|
||||||
# Maximum length of a single event queue.
|
|
||||||
event_queue_len = 1024
|
|
||||||
|
|
||||||
# The maximum timestamp tolerance in seconds for request validation.
|
# The maximum timestamp tolerance in seconds for request validation.
|
||||||
timestamp_tolerance_secs = 90
|
timestamp_tolerance_secs = 90
|
||||||
|
|
||||||
|
# The max waiting time for the first authentication message for websocket.
|
||||||
|
ws_auth_timeout_sec = 15
|
||||||
|
|
||||||
|
# The max waiting time for outgoing message to be received for websocket.
|
||||||
|
ws_send_timeout_sec = 15
|
||||||
|
|
||||||
|
# Maximum number of pending events a single user can have.
|
||||||
|
# If events overflow the pending buffer, older events will be dropped and
|
||||||
|
# client will be notified.
|
||||||
|
ws_event_queue_len = 1024
|
||||||
|
|
|
@ -4,6 +4,18 @@ info:
|
||||||
version: 0.0.1
|
version: 0.0.1
|
||||||
|
|
||||||
paths:
|
paths:
|
||||||
|
/ws:
|
||||||
|
get:
|
||||||
|
summary: WebSocket endpoint.
|
||||||
|
description: |
|
||||||
|
Once connection, client must send a JSON text message of type
|
||||||
|
`WithSig<AuthPayload>` for authentication.
|
||||||
|
If server does not close it immediately, it means success.
|
||||||
|
|
||||||
|
Then server will send JSON text messages on events that user are
|
||||||
|
interested in (eg. chat from joined rooms).
|
||||||
|
The message has type `Outgoing` in `blahd/src/ws.rs`.
|
||||||
|
|
||||||
/room:
|
/room:
|
||||||
get:
|
get:
|
||||||
summary: Get room metadata
|
summary: Get room metadata
|
||||||
|
@ -133,34 +145,6 @@ paths:
|
||||||
application/json:
|
application/json:
|
||||||
$ref: '#/components/schemas/ApiError'
|
$ref: '#/components/schemas/ApiError'
|
||||||
|
|
||||||
/room/{ruuid}/event:
|
|
||||||
get:
|
|
||||||
summary: Get an event stream for future new items.
|
|
||||||
description: |
|
|
||||||
This is a temporary interface, before a better notification system
|
|
||||||
(post notifications? websocket?) is implemented.
|
|
||||||
headers:
|
|
||||||
Authorization:
|
|
||||||
description: Proof of membership for private rooms.
|
|
||||||
required: false
|
|
||||||
schema:
|
|
||||||
$ret: WithSig<AuthPayload>
|
|
||||||
responses:
|
|
||||||
200:
|
|
||||||
content:
|
|
||||||
text/event-stream:
|
|
||||||
x-description: An event stream, each event is a JSON with type WithSig<ChatPayload>
|
|
||||||
400:
|
|
||||||
description: Body is invalid or fails the verification.
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
$ref: '#/components/schemas/ApiError'
|
|
||||||
404:
|
|
||||||
description: Room not found.
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
$ref: '#/components/schemas/ApiError'
|
|
||||||
|
|
||||||
/room/{ruuid}/admin:
|
/room/{ruuid}/admin:
|
||||||
post:
|
post:
|
||||||
summary: Room management
|
summary: Room management
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use anyhow::{ensure, Result};
|
use anyhow::{ensure, Result};
|
||||||
use serde::Deserialize;
|
use serde::{Deserialize, Deserializer};
|
||||||
use serde_inline_default::serde_inline_default;
|
use serde_inline_default::serde_inline_default;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
@ -30,11 +31,22 @@ pub struct ServerConfig {
|
||||||
pub max_page_len: usize,
|
pub max_page_len: usize,
|
||||||
#[serde_inline_default(4096)] // 4KiB
|
#[serde_inline_default(4096)] // 4KiB
|
||||||
pub max_request_len: usize,
|
pub max_request_len: usize,
|
||||||
#[serde_inline_default(1024)]
|
|
||||||
pub event_queue_len: usize,
|
|
||||||
|
|
||||||
#[serde_inline_default(90)]
|
#[serde_inline_default(90)]
|
||||||
pub timestamp_tolerance_secs: u64,
|
pub timestamp_tolerance_secs: u64,
|
||||||
|
|
||||||
|
#[serde_inline_default(Duration::from_secs(15))]
|
||||||
|
#[serde(deserialize_with = "de_duration_sec")]
|
||||||
|
pub ws_auth_timeout_sec: Duration,
|
||||||
|
#[serde_inline_default(Duration::from_secs(15))]
|
||||||
|
#[serde(deserialize_with = "de_duration_sec")]
|
||||||
|
pub ws_send_timeout_sec: Duration,
|
||||||
|
#[serde_inline_default(1024)]
|
||||||
|
pub ws_event_queue_len: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn de_duration_sec<'de, D: Deserializer<'de>>(de: D) -> Result<Duration, D::Error> {
|
||||||
|
<u64>::deserialize(de).map(Duration::from_secs)
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
|
|
166
blahd/src/event.rs
Normal file
166
blahd/src/event.rs
Normal file
|
@ -0,0 +1,166 @@
|
||||||
|
use std::collections::hash_map::Entry;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::convert::Infallible;
|
||||||
|
use std::fmt;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
|
||||||
|
use anyhow::{bail, Context as _, Result};
|
||||||
|
use axum::extract::ws::{Message, WebSocket};
|
||||||
|
use blah::types::{AuthPayload, ChatItem, WithSig};
|
||||||
|
use futures_util::future::Either;
|
||||||
|
use futures_util::stream::SplitSink;
|
||||||
|
use futures_util::{stream_select, SinkExt as _, Stream, StreamExt};
|
||||||
|
use rusqlite::{params, OptionalExtension};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tokio::sync::broadcast;
|
||||||
|
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
||||||
|
use tokio_stream::wrappers::BroadcastStream;
|
||||||
|
|
||||||
|
use crate::AppState;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub enum Incoming {}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum Outgoing<'a> {
|
||||||
|
/// A chat message from a joined room.
|
||||||
|
Chat(&'a ChatItem),
|
||||||
|
/// The receiver is too slow to receive and some events and are dropped.
|
||||||
|
// FIXME: Should we indefinitely buffer them or just disconnect the client instead?
|
||||||
|
Lagged,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct State {
|
||||||
|
pub user_listeners: Mutex<HashMap<u64, UserEventSender>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct StreamEnded;
|
||||||
|
|
||||||
|
impl fmt::Display for StreamEnded {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.write_str("stream unexpectedly ended")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for StreamEnded {}
|
||||||
|
|
||||||
|
struct WsSenderWrapper<'ws, 'c> {
|
||||||
|
inner: SplitSink<&'ws mut WebSocket, Message>,
|
||||||
|
config: &'c crate::config::Config,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WsSenderWrapper<'_, '_> {
|
||||||
|
async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> {
|
||||||
|
let data = serde_json::to_string(&msg).expect("serialization cannot fail");
|
||||||
|
let fut = tokio::time::timeout(
|
||||||
|
self.config.server.ws_send_timeout_sec,
|
||||||
|
self.inner.send(Message::Text(data)),
|
||||||
|
);
|
||||||
|
match fut.await {
|
||||||
|
Ok(Ok(())) => Ok(()),
|
||||||
|
Ok(Err(_send_err)) => Err(StreamEnded.into()),
|
||||||
|
Err(_elapsed) => bail!("send timeout"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserEventSender = broadcast::Sender<Arc<ChatItem>>;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct UserEventReceiver {
|
||||||
|
rx: BroadcastStream<Arc<ChatItem>>,
|
||||||
|
st: Arc<AppState>,
|
||||||
|
uid: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for UserEventReceiver {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
tracing::debug!(%self.uid, "user disconnected");
|
||||||
|
if let Ok(mut map) = self.st.event.user_listeners.lock() {
|
||||||
|
if let Some(tx) = map.get_mut(&self.uid) {
|
||||||
|
if tx.receiver_count() == 1 {
|
||||||
|
map.remove(&self.uid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Stream for UserEventReceiver {
|
||||||
|
type Item = Result<Arc<ChatItem>, BroadcastStreamRecvError>;
|
||||||
|
|
||||||
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
self.rx.poll_next_unpin(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallible> {
|
||||||
|
let (ws_tx, ws_rx) = ws.split();
|
||||||
|
let mut ws_rx = ws_rx.map(|ret| ret.and_then(|msg| msg.into_text()).map_err(|_| StreamEnded));
|
||||||
|
let mut ws_tx = WsSenderWrapper {
|
||||||
|
inner: ws_tx,
|
||||||
|
config: &st.config,
|
||||||
|
};
|
||||||
|
|
||||||
|
let uid = {
|
||||||
|
let payload = tokio::time::timeout(st.config.server.ws_auth_timeout_sec, ws_rx.next())
|
||||||
|
.await
|
||||||
|
.context("authentication timeout")?
|
||||||
|
.ok_or(StreamEnded)??;
|
||||||
|
let auth = serde_json::from_str::<WithSig<AuthPayload>>(&payload)?;
|
||||||
|
st.verify_signed_data(&auth)?;
|
||||||
|
|
||||||
|
st.conn
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.query_row(
|
||||||
|
r"
|
||||||
|
SELECT `uid`
|
||||||
|
FROM `user`
|
||||||
|
WHERE `userkey` = ?
|
||||||
|
",
|
||||||
|
params![auth.signee.user],
|
||||||
|
|row| row.get::<_, u64>(0),
|
||||||
|
)
|
||||||
|
.optional()?
|
||||||
|
.context("invalid user")?
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::debug!(%uid, "user connected");
|
||||||
|
|
||||||
|
let event_rx = {
|
||||||
|
let rx = match st.event.user_listeners.lock().unwrap().entry(uid) {
|
||||||
|
Entry::Occupied(ent) => ent.get().subscribe(),
|
||||||
|
Entry::Vacant(ent) => {
|
||||||
|
let (tx, rx) = broadcast::channel(st.config.server.ws_event_queue_len);
|
||||||
|
ent.insert(tx);
|
||||||
|
rx
|
||||||
|
}
|
||||||
|
};
|
||||||
|
UserEventReceiver {
|
||||||
|
rx: rx.into(),
|
||||||
|
st: st.clone(),
|
||||||
|
uid,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut stream = stream_select!(ws_rx.map(Either::Left), event_rx.map(Either::Right));
|
||||||
|
loop {
|
||||||
|
match stream.next().await.ok_or(StreamEnded)? {
|
||||||
|
Either::Left(msg) => match serde_json::from_str::<Incoming>(&msg?)? {},
|
||||||
|
Either::Right(ret) => {
|
||||||
|
let msg = match &ret {
|
||||||
|
Ok(chat) => Outgoing::Chat(chat),
|
||||||
|
Err(BroadcastStreamRecvError::Lagged(_)) => Outgoing::Lagged,
|
||||||
|
};
|
||||||
|
// TODO: Concurrent send.
|
||||||
|
ws_tx.send(&msg).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,14 +1,12 @@
|
||||||
use std::collections::hash_map::Entry;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::convert::Infallible;
|
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
use std::time::{Duration, SystemTime};
|
use std::time::{Duration, SystemTime};
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use axum::extract::{Path, Query, State};
|
use axum::extract::ws;
|
||||||
|
use axum::extract::{Path, Query, State, WebSocketUpgrade};
|
||||||
use axum::http::{header, StatusCode};
|
use axum::http::{header, StatusCode};
|
||||||
use axum::response::{sse, IntoResponse};
|
use axum::response::{IntoResponse, Response};
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
use axum_extra::extract::WithRejection;
|
use axum_extra::extract::WithRejection;
|
||||||
|
@ -21,14 +19,13 @@ use ed25519_dalek::SIGNATURE_LENGTH;
|
||||||
use middleware::{ApiError, OptionalAuth, SignedJson};
|
use middleware::{ApiError, OptionalAuth, SignedJson};
|
||||||
use rusqlite::{named_params, params, OptionalExtension, Row};
|
use rusqlite::{named_params, params, OptionalExtension, Row};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tokio::sync::broadcast;
|
|
||||||
use tokio_stream::StreamExt;
|
|
||||||
use utils::ExpiringSet;
|
use utils::ExpiringSet;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
mod middleware;
|
mod middleware;
|
||||||
mod config;
|
mod config;
|
||||||
|
mod event;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
/// Blah Chat Server
|
/// Blah Chat Server
|
||||||
|
@ -84,8 +81,8 @@ fn main() -> Result<()> {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct AppState {
|
struct AppState {
|
||||||
conn: Mutex<rusqlite::Connection>,
|
conn: Mutex<rusqlite::Connection>,
|
||||||
room_listeners: Mutex<HashMap<u64, broadcast::Sender<Arc<ChatItem>>>>,
|
|
||||||
used_nonces: Mutex<ExpiringSet<u32>>,
|
used_nonces: Mutex<ExpiringSet<u32>>,
|
||||||
|
event: event::State,
|
||||||
|
|
||||||
config: Config,
|
config: Config,
|
||||||
}
|
}
|
||||||
|
@ -101,10 +98,10 @@ impl AppState {
|
||||||
.context("failed to initialize database")?;
|
.context("failed to initialize database")?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
conn: Mutex::new(conn),
|
conn: Mutex::new(conn),
|
||||||
room_listeners: Mutex::new(HashMap::new()),
|
|
||||||
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(
|
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(
|
||||||
config.server.timestamp_tolerance_secs,
|
config.server.timestamp_tolerance_secs,
|
||||||
))),
|
))),
|
||||||
|
event: event::State::default(),
|
||||||
|
|
||||||
config,
|
config,
|
||||||
})
|
})
|
||||||
|
@ -152,11 +149,11 @@ async fn main_async(st: AppState) -> Result<()> {
|
||||||
let st = Arc::new(st);
|
let st = Arc::new(st);
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
|
.route("/ws", get(handle_ws))
|
||||||
.route("/room/create", post(room_create))
|
.route("/room/create", post(room_create))
|
||||||
.route("/room/:ruuid", get(room_get_metadata))
|
.route("/room/:ruuid", get(room_get_metadata))
|
||||||
// NB. Sync with `feed_url` and `next_url` generation.
|
// NB. Sync with `feed_url` and `next_url` generation.
|
||||||
.route("/room/:ruuid/feed.json", get(room_get_feed))
|
.route("/room/:ruuid/feed.json", get(room_get_feed))
|
||||||
.route("/room/:ruuid/event", get(room_event))
|
|
||||||
.route("/room/:ruuid/item", get(room_get_item).post(room_post_item))
|
.route("/room/:ruuid/item", get(room_get_item).post(room_post_item))
|
||||||
.route("/room/:ruuid/admin", post(room_admin))
|
.route("/room/:ruuid/admin", post(room_admin))
|
||||||
.with_state(st.clone())
|
.with_state(st.clone())
|
||||||
|
@ -179,6 +176,24 @@ async fn main_async(st: AppState) -> Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn handle_ws(State(st): ArcState, ws: WebSocketUpgrade) -> Response {
|
||||||
|
ws.on_upgrade(move |mut socket| async move {
|
||||||
|
match event::handle_ws(st, &mut socket).await {
|
||||||
|
Ok(never) => match never {},
|
||||||
|
Err(err) if err.is::<event::StreamEnded>() => {}
|
||||||
|
Err(err) => {
|
||||||
|
tracing::debug!(%err, "ws error");
|
||||||
|
let _: Result<_, _> = socket
|
||||||
|
.send(ws::Message::Close(Some(ws::CloseFrame {
|
||||||
|
code: ws::close_code::ERROR,
|
||||||
|
reason: err.to_string().into(),
|
||||||
|
})))
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
async fn room_create(
|
async fn room_create(
|
||||||
st: ArcState,
|
st: ArcState,
|
||||||
SignedJson(params): SignedJson<CreateRoomPayload>,
|
SignedJson(params): SignedJson<CreateRoomPayload>,
|
||||||
|
@ -505,7 +520,7 @@ async fn room_post_item(
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let (rid, cid) = {
|
let (cid, txs) = {
|
||||||
let conn = st.conn.lock().unwrap();
|
let conn = st.conn.lock().unwrap();
|
||||||
let Some((rid, uid)) = conn
|
let Some((rid, uid)) = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
|
@ -550,75 +565,40 @@ async fn room_post_item(
|
||||||
},
|
},
|
||||||
|row| row.get::<_, u64>(0),
|
|row| row.get::<_, u64>(0),
|
||||||
)?;
|
)?;
|
||||||
(rid, cid)
|
|
||||||
|
// FIXME: Optimize this to not traverses over all members.
|
||||||
|
let mut stmt = conn.prepare(
|
||||||
|
r"
|
||||||
|
SELECT `uid`
|
||||||
|
FROM `room_member`
|
||||||
|
WHERE `rid` = :rid
|
||||||
|
",
|
||||||
|
)?;
|
||||||
|
let listeners = st.event.user_listeners.lock().unwrap();
|
||||||
|
let txs = stmt
|
||||||
|
.query_map(params![rid], |row| row.get::<_, u64>(0))?
|
||||||
|
.filter_map(|ret| match ret {
|
||||||
|
Ok(uid) => listeners.get(&uid).map(|tx| Ok(tx.clone())),
|
||||||
|
Err(err) => Some(Err(err)),
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
|
(cid, txs)
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut listeners = st.room_listeners.lock().unwrap();
|
if !txs.is_empty() {
|
||||||
if let Some(tx) = listeners.get(&rid) {
|
tracing::debug!("broadcasting event to {} clients", txs.len());
|
||||||
if tx.send(Arc::new(chat)).is_err() {
|
let chat = Arc::new(chat);
|
||||||
// Clean up because all receivers died.
|
for tx in txs {
|
||||||
listeners.remove(&rid);
|
if let Err(err) = tx.send(chat.clone()) {
|
||||||
|
tracing::debug!(%err, "failed to broadcast event");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Json(cid))
|
Ok(Json(cid))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn room_event(
|
|
||||||
st: ArcState,
|
|
||||||
Path(ruuid): Path<Uuid>,
|
|
||||||
// TODO: There is actually no way to add headers via `EventSource` in client side.
|
|
||||||
// But this API is kinda temporary and need a better replacement anyway.
|
|
||||||
// So just only support public room for now.
|
|
||||||
OptionalAuth(user): OptionalAuth,
|
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
|
||||||
let rid = get_room_if_readable(&st.conn.lock().unwrap(), ruuid, user.as_ref(), |row| {
|
|
||||||
row.get::<_, u64>(0)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let rx = match st.room_listeners.lock().unwrap().entry(rid) {
|
|
||||||
Entry::Occupied(ent) => ent.get().subscribe(),
|
|
||||||
Entry::Vacant(ent) => {
|
|
||||||
let (tx, rx) = broadcast::channel(st.config.server.event_queue_len);
|
|
||||||
ent.insert(tx);
|
|
||||||
rx
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Do clean up when this stream is closed.
|
|
||||||
struct CleanOnDrop {
|
|
||||||
st: Arc<AppState>,
|
|
||||||
rid: u64,
|
|
||||||
}
|
|
||||||
impl Drop for CleanOnDrop {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
if let Ok(mut listeners) = self.st.room_listeners.lock() {
|
|
||||||
if let Some(tx) = listeners.get(&self.rid) {
|
|
||||||
if tx.receiver_count() == 0 {
|
|
||||||
listeners.remove(&self.rid);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let _guard = CleanOnDrop { st: st.0, rid };
|
|
||||||
|
|
||||||
let stream = tokio_stream::wrappers::BroadcastStream::new(rx).filter_map(move |ret| {
|
|
||||||
let _guard = &_guard;
|
|
||||||
// On stream closure or lagging, close the current stream so client can retry.
|
|
||||||
let item = ret.ok()?;
|
|
||||||
let evt = sse::Event::default()
|
|
||||||
.json_data(&*item)
|
|
||||||
.expect("serialization cannot fail");
|
|
||||||
Some(Ok::<_, Infallible>(evt))
|
|
||||||
});
|
|
||||||
// NB. Send an empty event immediately to trigger client ready event.
|
|
||||||
let first_event = sse::Event::default().comment("");
|
|
||||||
let stream = futures_util::stream::iter(Some(Ok(first_event))).chain(stream);
|
|
||||||
Ok(sse::Sse::new(stream).keep_alive(sse::KeepAlive::default()))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn room_admin(
|
async fn room_admin(
|
||||||
st: ArcState,
|
st: ArcState,
|
||||||
Path(ruuid): Path<Uuid>,
|
Path(ruuid): Path<Uuid>,
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
use std::fmt;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use axum::extract::rejection::{JsonRejection, PathRejection, QueryRejection};
|
use axum::extract::rejection::{JsonRejection, PathRejection, QueryRejection};
|
||||||
|
@ -22,6 +23,14 @@ pub struct ApiError {
|
||||||
pub message: String,
|
pub message: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for ApiError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.write_str(&self.message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for ApiError {}
|
||||||
|
|
||||||
macro_rules! error_response {
|
macro_rules! error_response {
|
||||||
($status:expr, $code:literal, $msg:literal $(, $msg_args:expr)* $(,)?) => {
|
($status:expr, $code:literal, $msg:literal $(, $msg_args:expr)* $(,)?) => {
|
||||||
$crate::middleware::ApiError {
|
$crate::middleware::ApiError {
|
||||||
|
|
|
@ -7,7 +7,7 @@ const joinRoomBtn = document.querySelector('#join-room');
|
||||||
|
|
||||||
let roomUrl = '';
|
let roomUrl = '';
|
||||||
let roomUuid = null;
|
let roomUuid = null;
|
||||||
let feed = null;
|
let ws = null;
|
||||||
let keypair = null;
|
let keypair = null;
|
||||||
let defaultConfig = {};
|
let defaultConfig = {};
|
||||||
|
|
||||||
|
@ -166,16 +166,13 @@ function escapeHtml(text) {
|
||||||
}
|
}
|
||||||
|
|
||||||
async function connectRoom(url) {
|
async function connectRoom(url) {
|
||||||
if (url === '' || url == roomUrl) return;
|
if (url === '' || url == roomUrl || keypair === null) return;
|
||||||
const match = url.match(/^https?:\/\/.*\/([a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12})\/?/);
|
const match = url.match(/^https?:\/\/.*\/([a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12})\/?/);
|
||||||
if (match === null) {
|
if (match === null) {
|
||||||
log('invalid room url');
|
log('invalid room url');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (feed !== null) {
|
|
||||||
feed.close();
|
|
||||||
}
|
|
||||||
roomUrl = url;
|
roomUrl = url;
|
||||||
roomUuid = match[1];
|
roomUuid = match[1];
|
||||||
|
|
||||||
|
@ -210,18 +207,40 @@ async function connectRoom(url) {
|
||||||
|
|
||||||
// TODO: There is a time window where events would be lost.
|
// TODO: There is a time window where events would be lost.
|
||||||
|
|
||||||
feed = new EventSource(`${url}/event`);
|
await connectWs();
|
||||||
feed.onopen = (_) => {
|
}
|
||||||
|
|
||||||
|
async function connectWs() {
|
||||||
|
if (ws !== null) {
|
||||||
|
ws.close();
|
||||||
|
}
|
||||||
|
const wsUrl = new URL(roomUrl);
|
||||||
|
wsUrl.protocol = wsUrl.protocol == 'http:' ? 'ws:' : 'wss:';
|
||||||
|
wsUrl.pathname = '/ws';
|
||||||
|
ws = new WebSocket(wsUrl);
|
||||||
|
ws.onopen = async (_) => {
|
||||||
|
const auth = await signData({ typ: 'auth' });
|
||||||
|
await ws.send(auth);
|
||||||
log('listening on events');
|
log('listening on events');
|
||||||
}
|
}
|
||||||
feed.onerror = (e) => {
|
ws.onclose = (e) => {
|
||||||
console.error(e);
|
console.error(e);
|
||||||
log('event listener error');
|
log(`ws closed (code=${e.code}): ${e.reason}`);
|
||||||
};
|
};
|
||||||
feed.onmessage = async (e) => {
|
ws.onerror = (e) => {
|
||||||
console.log('feed event', e.data);
|
console.error(e);
|
||||||
const chat = JSON.parse(e.data);
|
log(`ws error: ${e.error}`);
|
||||||
showChatMsg(chat);
|
};
|
||||||
|
ws.onmessage = async (e) => {
|
||||||
|
console.log('ws event', e.data);
|
||||||
|
const msg = JSON.parse(e.data);
|
||||||
|
if (msg.chat !== undefined) {
|
||||||
|
showChatMsg(msg.chat);
|
||||||
|
} else if (msg.lagged !== undefined) {
|
||||||
|
log('some events are dropped because of queue overflow')
|
||||||
|
} else {
|
||||||
|
log(`unknown ws message: ${e.data}`);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -236,6 +255,7 @@ async function joinRoom() {
|
||||||
user: await getUserPubkey(),
|
user: await getUserPubkey(),
|
||||||
});
|
});
|
||||||
log('joined room');
|
log('joined room');
|
||||||
|
await connectWs();
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error(e);
|
console.error(e);
|
||||||
log(`failed to join room: ${e}`);
|
log(`failed to join room: ${e}`);
|
||||||
|
|
Loading…
Add table
Reference in a new issue