test: add for WS

This commit is contained in:
oxalica 2024-09-22 04:42:52 -04:00
parent 4bca196df3
commit 883fac02ae
3 changed files with 142 additions and 6 deletions

35
Cargo.lock generated
View file

@ -167,7 +167,7 @@ dependencies = [
"sha1",
"sync_wrapper 1.0.1",
"tokio",
"tokio-tungstenite",
"tokio-tungstenite 0.21.0",
"tower",
"tower-layer",
"tower-service",
@ -337,6 +337,7 @@ dependencies = [
"tempfile",
"tokio",
"tokio-stream",
"tokio-tungstenite 0.24.0",
"toml",
"tower-http",
"tracing",
@ -2438,7 +2439,19 @@ dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
"tungstenite 0.21.0",
]
[[package]]
name = "tokio-tungstenite"
version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite 0.24.0",
]
[[package]]
@ -2615,6 +2628,24 @@ dependencies = [
"utf-8",
]
[[package]]
name = "tungstenite"
version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a"
dependencies = [
"byteorder",
"bytes",
"data-encoding",
"http",
"httparse",
"log",
"rand",
"sha1",
"thiserror",
"utf-8",
]
[[package]]
name = "typenum"
version = "1.17.0"

View file

@ -47,6 +47,7 @@ reqwest = { version = "0.12.7", features = ["json"] }
rstest = { version = "0.22.0", default-features = false }
scopeguard = "1.2.0"
tempfile = "3.12.0"
tokio-tungstenite = "0.24.0"
[lints]
workspace = true

View file

@ -19,7 +19,7 @@ use blah_types::{
use blahd::{ApiError, AppState, Database, RoomList, RoomMsgs};
use ed25519_dalek::SigningKey;
use futures_util::future::BoxFuture;
use futures_util::TryFutureExt;
use futures_util::{SinkExt, Stream, StreamExt, TryFutureExt};
use parking_lot::Mutex;
use reqwest::{header, Method, StatusCode};
use rstest::{fixture, rstest};
@ -34,12 +34,16 @@ const LOCALHOST: &str = "localhost";
const REGISTER_DIFFICULTY: u8 = 1;
const TIME_TOLERANCE: Duration = Duration::from_millis(100);
const WS_CONNECT_TIMEOUT: Duration = Duration::from_millis(1500);
const CONFIG: fn(u16) -> String = |port| {
format!(
r#"
base_url="http://{LOCALHOST}:{port}"
[ws]
auth_timeout_sec = 1
[register]
enable_public = true
difficulty = {REGISTER_DIFFICULTY}
@ -114,6 +118,14 @@ impl fmt::Display for ApiErrorWithHeaders {
impl std::error::Error for ApiErrorWithHeaders {}
// TODO: Hoist this into types crate.
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum WsEvent {
// TODO: Include cid?
Msg(SignedChatMsg),
}
#[derive(Debug)]
struct Server {
port: u16,
@ -122,11 +134,34 @@ struct Server {
impl Server {
fn url(&self, rhs: impl fmt::Display) -> String {
format!("{}/_blah{}", self.domain(), rhs)
format!("http://{}/_blah{}", self.domain(), rhs)
}
fn domain(&self) -> String {
format!("http://{}:{}", LOCALHOST, self.port)
format!("{}:{}", LOCALHOST, self.port)
}
async fn connect_ws(
&self,
auth_user: Option<&User>,
) -> Result<impl Stream<Item = Result<WsEvent>> + Unpin> {
let url = format!("ws://{}/_blah/ws", self.domain());
let (mut ws, _) = tokio_tungstenite::connect_async(url).await.unwrap();
if let Some(user) = auth_user {
ws.send(tokio_tungstenite::tungstenite::Message::Text(auth(user)))
.await
.unwrap();
}
Ok(ws
.map(|ret| {
let wsmsg = ret?;
if wsmsg.is_close() {
return Ok(None);
}
let event = serde_json::from_slice::<WsEvent>(&wsmsg.into_data())?;
Ok(Some(event))
})
.filter_map(|ret| std::future::ready(ret.transpose())))
}
fn request<Req: Serialize, Resp: DeserializeOwned>(
@ -943,7 +978,7 @@ async fn register(server: Server) {
register_fast(&req)
.await
.expect_api_err(StatusCode::BAD_REQUEST, "invalid_server_url");
req.server_url = server.domain().parse().unwrap();
req.server_url = format!("http://{}", server.domain()).parse().unwrap();
register_fast(&req)
.await
@ -1078,3 +1113,72 @@ async fn register(server: Server) {
.await
.unwrap();
}
#[rstest]
#[tokio::test]
async fn event(server: Server) {
let rid1 = server
.create_room(&ALICE, RoomAttrs::PUBLIC_JOINABLE, "room1")
.await
.unwrap();
{
let mut ws = server.connect_ws(None).await.unwrap();
let msg = tokio::time::timeout(WS_CONNECT_TIMEOUT, ws.next())
.await
.unwrap();
assert!(msg.is_none(), "auth should timeout");
}
{
let mut ws = server.connect_ws(Some(&CAROL)).await.unwrap();
assert!(
ws.next().await.is_none(),
"should close unauthorized connection",
);
}
// Ok.
let mut ws = server.connect_ws(Some(&ALICE)).await.unwrap();
// TODO: Synchronize with the server so that following msgs will be received.
// Should receive msgs from self-post.
{
let chat = server.post_chat(rid1, &ALICE, "alice1").await.unwrap();
let got = ws.next().await.unwrap().unwrap();
assert_eq!(got, WsEvent::Msg(chat.msg));
}
// Should receive msgs from other user.
{
server
.join_room(rid1, &BOB, MemberPermission::MAX_SELF_ADD)
.await
.unwrap();
let chat = server.post_chat(rid1, &BOB, "bob1").await.unwrap();
let got = ws.next().await.unwrap().unwrap();
assert_eq!(got, WsEvent::Msg(chat.msg));
}
// Should receive msgs from new room.
let rid2 = server
.create_room(&ALICE, RoomAttrs::PUBLIC_JOINABLE, "room2")
.await
.unwrap();
{
let chat = server.post_chat(rid2, &ALICE, "alice2").await.unwrap();
let got = ws.next().await.unwrap().unwrap();
assert_eq!(got, WsEvent::Msg(chat.msg));
}
// Each streams should receive each message once.
{
let mut ws2 = server.connect_ws(Some(&ALICE)).await.unwrap();
let chat = server.post_chat(rid1, &ALICE, "alice1").await.unwrap();
let got1 = ws.next().await.unwrap().unwrap();
assert_eq!(got1, WsEvent::Msg(chat.msg.clone()));
let got2 = ws2.next().await.unwrap().unwrap();
assert_eq!(got2, WsEvent::Msg(chat.msg));
}
}