From 883fac02ae8c40e51eff54d38ff89eead7e5fdbd Mon Sep 17 00:00:00 2001 From: oxalica Date: Sun, 22 Sep 2024 04:42:52 -0400 Subject: [PATCH] test: add for WS --- Cargo.lock | 35 ++++++++++++- blahd/Cargo.toml | 1 + blahd/tests/webapi.rs | 112 ++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 142 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b6ecae8..b22c22a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -167,7 +167,7 @@ dependencies = [ "sha1", "sync_wrapper 1.0.1", "tokio", - "tokio-tungstenite", + "tokio-tungstenite 0.21.0", "tower", "tower-layer", "tower-service", @@ -337,6 +337,7 @@ dependencies = [ "tempfile", "tokio", "tokio-stream", + "tokio-tungstenite 0.24.0", "toml", "tower-http", "tracing", @@ -2438,7 +2439,19 @@ dependencies = [ "futures-util", "log", "tokio", - "tungstenite", + "tungstenite 0.21.0", +] + +[[package]] +name = "tokio-tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.24.0", ] [[package]] @@ -2615,6 +2628,24 @@ dependencies = [ "utf-8", ] +[[package]] +name = "tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "utf-8", +] + [[package]] name = "typenum" version = "1.17.0" diff --git a/blahd/Cargo.toml b/blahd/Cargo.toml index cb4eabf..19a6b30 100644 --- a/blahd/Cargo.toml +++ b/blahd/Cargo.toml @@ -47,6 +47,7 @@ reqwest = { version = "0.12.7", features = ["json"] } rstest = { version = "0.22.0", default-features = false } scopeguard = "1.2.0" tempfile = "3.12.0" +tokio-tungstenite = "0.24.0" [lints] workspace = true diff --git a/blahd/tests/webapi.rs b/blahd/tests/webapi.rs index cb34ddc..ab40f05 100644 --- a/blahd/tests/webapi.rs +++ b/blahd/tests/webapi.rs @@ -19,7 +19,7 @@ use blah_types::{ use blahd::{ApiError, AppState, Database, RoomList, RoomMsgs}; use ed25519_dalek::SigningKey; use futures_util::future::BoxFuture; -use futures_util::TryFutureExt; +use futures_util::{SinkExt, Stream, StreamExt, TryFutureExt}; use parking_lot::Mutex; use reqwest::{header, Method, StatusCode}; use rstest::{fixture, rstest}; @@ -34,12 +34,16 @@ const LOCALHOST: &str = "localhost"; const REGISTER_DIFFICULTY: u8 = 1; const TIME_TOLERANCE: Duration = Duration::from_millis(100); +const WS_CONNECT_TIMEOUT: Duration = Duration::from_millis(1500); const CONFIG: fn(u16) -> String = |port| { format!( r#" base_url="http://{LOCALHOST}:{port}" +[ws] +auth_timeout_sec = 1 + [register] enable_public = true difficulty = {REGISTER_DIFFICULTY} @@ -114,6 +118,14 @@ impl fmt::Display for ApiErrorWithHeaders { impl std::error::Error for ApiErrorWithHeaders {} +// TODO: Hoist this into types crate. +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum WsEvent { + // TODO: Include cid? + Msg(SignedChatMsg), +} + #[derive(Debug)] struct Server { port: u16, @@ -122,11 +134,34 @@ struct Server { impl Server { fn url(&self, rhs: impl fmt::Display) -> String { - format!("{}/_blah{}", self.domain(), rhs) + format!("http://{}/_blah{}", self.domain(), rhs) } fn domain(&self) -> String { - format!("http://{}:{}", LOCALHOST, self.port) + format!("{}:{}", LOCALHOST, self.port) + } + + async fn connect_ws( + &self, + auth_user: Option<&User>, + ) -> Result> + Unpin> { + let url = format!("ws://{}/_blah/ws", self.domain()); + let (mut ws, _) = tokio_tungstenite::connect_async(url).await.unwrap(); + if let Some(user) = auth_user { + ws.send(tokio_tungstenite::tungstenite::Message::Text(auth(user))) + .await + .unwrap(); + } + Ok(ws + .map(|ret| { + let wsmsg = ret?; + if wsmsg.is_close() { + return Ok(None); + } + let event = serde_json::from_slice::(&wsmsg.into_data())?; + Ok(Some(event)) + }) + .filter_map(|ret| std::future::ready(ret.transpose()))) } fn request( @@ -943,7 +978,7 @@ async fn register(server: Server) { register_fast(&req) .await .expect_api_err(StatusCode::BAD_REQUEST, "invalid_server_url"); - req.server_url = server.domain().parse().unwrap(); + req.server_url = format!("http://{}", server.domain()).parse().unwrap(); register_fast(&req) .await @@ -1078,3 +1113,72 @@ async fn register(server: Server) { .await .unwrap(); } + +#[rstest] +#[tokio::test] +async fn event(server: Server) { + let rid1 = server + .create_room(&ALICE, RoomAttrs::PUBLIC_JOINABLE, "room1") + .await + .unwrap(); + + { + let mut ws = server.connect_ws(None).await.unwrap(); + let msg = tokio::time::timeout(WS_CONNECT_TIMEOUT, ws.next()) + .await + .unwrap(); + assert!(msg.is_none(), "auth should timeout"); + } + + { + let mut ws = server.connect_ws(Some(&CAROL)).await.unwrap(); + assert!( + ws.next().await.is_none(), + "should close unauthorized connection", + ); + } + + // Ok. + let mut ws = server.connect_ws(Some(&ALICE)).await.unwrap(); + // TODO: Synchronize with the server so that following msgs will be received. + + // Should receive msgs from self-post. + { + let chat = server.post_chat(rid1, &ALICE, "alice1").await.unwrap(); + let got = ws.next().await.unwrap().unwrap(); + assert_eq!(got, WsEvent::Msg(chat.msg)); + } + + // Should receive msgs from other user. + { + server + .join_room(rid1, &BOB, MemberPermission::MAX_SELF_ADD) + .await + .unwrap(); + let chat = server.post_chat(rid1, &BOB, "bob1").await.unwrap(); + let got = ws.next().await.unwrap().unwrap(); + assert_eq!(got, WsEvent::Msg(chat.msg)); + } + + // Should receive msgs from new room. + let rid2 = server + .create_room(&ALICE, RoomAttrs::PUBLIC_JOINABLE, "room2") + .await + .unwrap(); + { + let chat = server.post_chat(rid2, &ALICE, "alice2").await.unwrap(); + let got = ws.next().await.unwrap().unwrap(); + assert_eq!(got, WsEvent::Msg(chat.msg)); + } + + // Each streams should receive each message once. + { + let mut ws2 = server.connect_ws(Some(&ALICE)).await.unwrap(); + + let chat = server.post_chat(rid1, &ALICE, "alice1").await.unwrap(); + let got1 = ws.next().await.unwrap().unwrap(); + assert_eq!(got1, WsEvent::Msg(chat.msg.clone())); + let got2 = ws2.next().await.unwrap().unwrap(); + assert_eq!(got2, WsEvent::Msg(chat.msg)); + } +}