diff --git a/Cargo.lock b/Cargo.lock index 69c8bee..fec48e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -307,6 +307,7 @@ dependencies = [ "hex", "humantime", "parking_lot", + "rand", "reqwest", "rstest", "rusqlite", diff --git a/blahd/Cargo.toml b/blahd/Cargo.toml index a9bb7fa..508e5b3 100644 --- a/blahd/Cargo.toml +++ b/blahd/Cargo.toml @@ -30,6 +30,7 @@ url = { version = "2.5.2", features = ["serde"] } blah = { path = "..", features = ["rusqlite"] } [dev-dependencies] +rand = "0.8.5" reqwest = { version = "0.12.7", features = ["json"] } rstest = { version = "0.22.0", default-features = false } diff --git a/blahd/schema.sql b/blahd/schema.sql index 64765b7..9f6e4e2 100644 --- a/blahd/schema.sql +++ b/blahd/schema.sql @@ -17,8 +17,8 @@ CREATE TABLE IF NOT EXISTS `room_member` ( `rid` INTEGER NOT NULL REFERENCES `room` ON DELETE CASCADE, `uid` INTEGER NOT NULL REFERENCES `user` ON DELETE RESTRICT, `permission` INTEGER NOT NULL, - `last_seen_cid` INTEGER NOT NULL REFERENCES `room_item` (`cid`) ON DELETE NO ACTION - DEFAULT 0, + -- Optionally references `room_item`(`cid`). + `last_seen_cid` INTEGER NOT NULL DEFAULT 0, PRIMARY KEY (`rid`, `uid`) ) STRICT; diff --git a/blahd/src/lib.rs b/blahd/src/lib.rs index b281d1a..c1fea34 100644 --- a/blahd/src/lib.rs +++ b/blahd/src/lib.rs @@ -17,7 +17,7 @@ use blah::types::{ use config::ServerConfig; use ed25519_dalek::SIGNATURE_LENGTH; use id::IdExt; -use middleware::{ApiError, Auth, MaybeAuth, ResultExt as _, SignedJson}; +use middleware::{Auth, MaybeAuth, ResultExt as _, SignedJson}; use parking_lot::Mutex; use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql}; use serde::{Deserialize, Serialize}; @@ -33,6 +33,7 @@ mod id; mod utils; pub use database::Database; +pub use middleware::ApiError; // Locks must be grabbed in the field order. #[derive(Debug)] diff --git a/blahd/src/middleware.rs b/blahd/src/middleware.rs index 325ef34..a60bbf9 100644 --- a/blahd/src/middleware.rs +++ b/blahd/src/middleware.rs @@ -8,24 +8,28 @@ use axum::response::{IntoResponse, Response}; use axum::{async_trait, Json}; use blah::types::{AuthPayload, UserKey, WithSig}; use serde::de::DeserializeOwned; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use crate::AppState; /// Error response body for json endpoints. /// /// Mostly following: -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct ApiError { - #[serde(skip)] + #[serde(skip, default)] pub status: StatusCode, - pub code: &'static str, + pub code: String, pub message: String, } impl fmt::Display for ApiError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(&self.message) + write!( + f, + "api error status={} code={}: {}", + self.status, self.code, self.message, + ) } } @@ -35,7 +39,7 @@ macro_rules! error_response { ($status:expr, $code:literal, $msg:literal $(, $msg_args:expr)* $(,)?) => { $crate::middleware::ApiError { status: $status, - code: $code, + code: $code.to_owned(), message: ::std::format!($msg $(, $msg_args)*), } }; diff --git a/blahd/tests/basic.rs b/blahd/tests/basic.rs index 5bae4a6..f7b2712 100644 --- a/blahd/tests/basic.rs +++ b/blahd/tests/basic.rs @@ -2,32 +2,113 @@ #![allow(clippy::unwrap_used)] use std::fmt; use std::future::IntoFuture; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; -use blahd::{AppState, Database, RoomList}; -use reqwest::Client; +use anyhow::Result; +use blah::types::{ + get_timestamp, AuthPayload, CreateRoomPayload, Id, MemberPermission, RoomAttrs, RoomMember, + RoomMemberList, ServerPermission, UserKey, WithSig, +}; +use blahd::{ApiError, AppState, Database, RoomList, RoomMetadata}; +use ed25519_dalek::SigningKey; +use rand::RngCore; +use reqwest::{header, Method, StatusCode}; use rstest::{fixture, rstest}; -use rusqlite::Connection; +use rusqlite::{params, Connection}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; // Avoid name resolution. const LOCALHOST: &str = "127.0.0.1"; +static ALICE_PRIV: LazyLock = LazyLock::new(|| SigningKey::from_bytes(&[b'A'; 32])); +static ALICE: LazyLock = LazyLock::new(|| UserKey(ALICE_PRIV.verifying_key().to_bytes())); +static BOB_PRIV: LazyLock = LazyLock::new(|| SigningKey::from_bytes(&[b'B'; 32])); +// static BOB: LazyLock = LazyLock::new(|| UserKey(BOB_PRIV.verifying_key().to_bytes())); + +fn mock_rng() -> impl RngCore { + rand::rngs::mock::StepRng::new(9, 1) +} + +trait ResultExt { + fn expect_api_err(self, status: StatusCode, code: &str); +} + +impl ResultExt for Result { + #[track_caller] + fn expect_api_err(self, status: StatusCode, code: &str) { + let err = self.unwrap_err().downcast::().unwrap(); + assert_eq!(err.status, status); + assert_eq!(err.code, code); + } +} + #[derive(Debug)] struct Server { port: u16, + client: reqwest::Client, } impl Server { fn url(&self, rhs: impl fmt::Display) -> String { format!("http://{}:{}{}", LOCALHOST, self.port, rhs) } + + async fn request( + &self, + method: Method, + url: impl fmt::Display, + auth: Option<&str>, + body: Option, + ) -> Result> { + let mut b = self.client.request(method, self.url(url)); + if let Some(auth) = auth { + b = b.header(header::AUTHORIZATION, auth); + } + if let Some(body) = &body { + b = b.json(body); + } + let resp = b.send().await?; + let status = resp.status(); + let resp_str = resp.text().await?; + + if !status.is_success() { + #[derive(Deserialize)] + struct Resp { + error: ApiError, + } + let Resp { mut error } = serde_json::from_str(&resp_str)?; + error.status = status; + Err(error.into()) + } else if resp_str.is_empty() { + Ok(None) + } else { + Ok(Some(serde_json::from_str(&resp_str)?)) + } + } + + async fn get( + &self, + url: impl fmt::Display, + auth: Option<&str>, + ) -> Result { + Ok(self + .request(Method::GET, url, auth, None::<()>) + .await? + .unwrap()) + } } #[fixture] fn server() -> Server { let mut conn = Connection::open_in_memory().unwrap(); Database::maybe_init(&mut conn).unwrap(); + conn.execute( + "INSERT INTO `user` (`userkey`, `permission`) VALUES (?, ?)", + params![*ALICE, ServerPermission::ALL], + ) + .unwrap(); let db = Database::from_raw(conn).unwrap(); // Use std's to avoid async, since we need no name resolution. @@ -42,28 +123,136 @@ fn server() -> Server { let router = blahd::router(Arc::new(st)); tokio::spawn(axum::serve(listener, router).into_future()); - Server { port } -} - -#[fixture] -fn client() -> Client { - Client::new() + let client = reqwest::ClientBuilder::new().no_proxy().build().unwrap(); + Server { port, client } } #[rstest] #[tokio::test] -async fn smoke(client: Client, server: Server) { - let got = client - .get(server.url("/room?filter=public")) - .send() - .await - .unwrap() - .json::() - .await - .unwrap(); +async fn smoke(server: Server) { + let got: RoomList = server.get("/room?filter=public", None).await.unwrap(); let exp = RoomList { rooms: Vec::new(), skip_token: None, }; assert_eq!(got, exp); } + +fn sign(key: &SigningKey, rng: &mut impl RngCore, payload: T) -> WithSig { + WithSig::sign(key, get_timestamp(), rng, payload).unwrap() +} + +fn auth(key: &SigningKey, rng: &mut impl RngCore) -> String { + serde_json::to_string(&sign(key, rng, AuthPayload {})).unwrap() +} + +async fn create_room( + server: &Server, + key: &SigningKey, + rng: &mut impl RngCore, + attrs: RoomAttrs, + title: impl fmt::Display, +) -> Result { + let req = sign( + key, + rng, + CreateRoomPayload { + attrs, + members: RoomMemberList(vec![RoomMember { + permission: MemberPermission::ALL, + user: UserKey(key.verifying_key().to_bytes()), + }]), + title: title.to_string(), + }, + ); + Ok(server + .request(Method::POST, "/room/create", None, Some(&req)) + .await? + .unwrap()) +} + +#[rstest] +#[case::public(true)] +#[case::private(false)] +#[tokio::test] +async fn room_create_get(server: Server, #[case] public: bool) { + let rng = &mut mock_rng(); + let mut room_meta = RoomMetadata { + rid: Id(0), + title: "test room".into(), + attrs: if public { + RoomAttrs::PUBLIC_READABLE | RoomAttrs::PUBLIC_JOINABLE + } else { + RoomAttrs::empty() + }, + last_chat: None, + last_seen_cid: None, + unseen_cnt: None, + }; + + // Alice has permission. + let rid = create_room(&server, &ALICE_PRIV, rng, room_meta.attrs, &room_meta.title) + .await + .unwrap(); + room_meta.rid = rid; + + // Bob has no permission. + create_room( + &server, + &BOB_PRIV, + rng, + room_meta.attrs, + room_meta.title.clone(), + ) + .await + .expect_api_err(StatusCode::FORBIDDEN, "permission_denied"); + + // Alice can always access it. + let got_meta = server + .get::(format!("/room/{rid}"), Some(&auth(&ALICE_PRIV, rng))) + .await + .unwrap(); + assert_eq!(got_meta, room_meta); + + // Bob or public can access it when it is public. + for auth in [None, Some(auth(&BOB_PRIV, rng))] { + let resp = server + .get::(format!("/room/{rid}"), auth.as_deref()) + .await; + if public { + assert_eq!(resp.unwrap(), room_meta); + } else { + resp.expect_api_err(StatusCode::NOT_FOUND, "not_found"); + } + } + + // The room appears in public list only when it is public. + let expect_list = |has: bool| RoomList { + rooms: has.then(|| room_meta.clone()).into_iter().collect(), + skip_token: None, + }; + assert_eq!( + server + .get::("/room?filter=public", None) + .await + .unwrap(), + expect_list(public), + ); + + // Joined rooms endpoint always require authentication. + server + .get::("/room?filter=joined", None) + .await + .expect_api_err(StatusCode::UNAUTHORIZED, "unauthorized"); + let got_joined = server + .get::("/room?filter=joined", Some(&auth(&ALICE_PRIV, rng))) + .await + .unwrap(); + assert_eq!(got_joined, expect_list(true)); + + let got_joined = server + .get::("/room?filter=joined", Some(&auth(&BOB_PRIV, rng))) + .await + .unwrap(); + assert_eq!(got_joined, expect_list(false)); +}