mirror of
https://github.com/Blah-IM/blahrs.git
synced 2025-05-01 08:41:09 +00:00
Add tests for room create/read and fix incorrect foreign keys
This commit is contained in:
parent
c5263c607c
commit
ff5b7e60a7
6 changed files with 224 additions and 28 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -307,6 +307,7 @@ dependencies = [
|
||||||
"hex",
|
"hex",
|
||||||
"humantime",
|
"humantime",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
|
"rand",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"rstest",
|
"rstest",
|
||||||
"rusqlite",
|
"rusqlite",
|
||||||
|
|
|
@ -30,6 +30,7 @@ url = { version = "2.5.2", features = ["serde"] }
|
||||||
blah = { path = "..", features = ["rusqlite"] }
|
blah = { path = "..", features = ["rusqlite"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
rand = "0.8.5"
|
||||||
reqwest = { version = "0.12.7", features = ["json"] }
|
reqwest = { version = "0.12.7", features = ["json"] }
|
||||||
rstest = { version = "0.22.0", default-features = false }
|
rstest = { version = "0.22.0", default-features = false }
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,8 @@ CREATE TABLE IF NOT EXISTS `room_member` (
|
||||||
`rid` INTEGER NOT NULL REFERENCES `room` ON DELETE CASCADE,
|
`rid` INTEGER NOT NULL REFERENCES `room` ON DELETE CASCADE,
|
||||||
`uid` INTEGER NOT NULL REFERENCES `user` ON DELETE RESTRICT,
|
`uid` INTEGER NOT NULL REFERENCES `user` ON DELETE RESTRICT,
|
||||||
`permission` INTEGER NOT NULL,
|
`permission` INTEGER NOT NULL,
|
||||||
`last_seen_cid` INTEGER NOT NULL REFERENCES `room_item` (`cid`) ON DELETE NO ACTION
|
-- Optionally references `room_item`(`cid`).
|
||||||
DEFAULT 0,
|
`last_seen_cid` INTEGER NOT NULL DEFAULT 0,
|
||||||
PRIMARY KEY (`rid`, `uid`)
|
PRIMARY KEY (`rid`, `uid`)
|
||||||
) STRICT;
|
) STRICT;
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ use blah::types::{
|
||||||
use config::ServerConfig;
|
use config::ServerConfig;
|
||||||
use ed25519_dalek::SIGNATURE_LENGTH;
|
use ed25519_dalek::SIGNATURE_LENGTH;
|
||||||
use id::IdExt;
|
use id::IdExt;
|
||||||
use middleware::{ApiError, Auth, MaybeAuth, ResultExt as _, SignedJson};
|
use middleware::{Auth, MaybeAuth, ResultExt as _, SignedJson};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql};
|
use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -33,6 +33,7 @@ mod id;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
pub use database::Database;
|
pub use database::Database;
|
||||||
|
pub use middleware::ApiError;
|
||||||
|
|
||||||
// Locks must be grabbed in the field order.
|
// Locks must be grabbed in the field order.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|
|
@ -8,24 +8,28 @@ use axum::response::{IntoResponse, Response};
|
||||||
use axum::{async_trait, Json};
|
use axum::{async_trait, Json};
|
||||||
use blah::types::{AuthPayload, UserKey, WithSig};
|
use blah::types::{AuthPayload, UserKey, WithSig};
|
||||||
use serde::de::DeserializeOwned;
|
use serde::de::DeserializeOwned;
|
||||||
use serde::Serialize;
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
|
|
||||||
/// Error response body for json endpoints.
|
/// Error response body for json endpoints.
|
||||||
///
|
///
|
||||||
/// Mostly following: <https://learn.microsoft.com/en-us/graph/errors>
|
/// Mostly following: <https://learn.microsoft.com/en-us/graph/errors>
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct ApiError {
|
pub struct ApiError {
|
||||||
#[serde(skip)]
|
#[serde(skip, default)]
|
||||||
pub status: StatusCode,
|
pub status: StatusCode,
|
||||||
pub code: &'static str,
|
pub code: String,
|
||||||
pub message: String,
|
pub message: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for ApiError {
|
impl fmt::Display for ApiError {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.write_str(&self.message)
|
write!(
|
||||||
|
f,
|
||||||
|
"api error status={} code={}: {}",
|
||||||
|
self.status, self.code, self.message,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,7 +39,7 @@ macro_rules! error_response {
|
||||||
($status:expr, $code:literal, $msg:literal $(, $msg_args:expr)* $(,)?) => {
|
($status:expr, $code:literal, $msg:literal $(, $msg_args:expr)* $(,)?) => {
|
||||||
$crate::middleware::ApiError {
|
$crate::middleware::ApiError {
|
||||||
status: $status,
|
status: $status,
|
||||||
code: $code,
|
code: $code.to_owned(),
|
||||||
message: ::std::format!($msg $(, $msg_args)*),
|
message: ::std::format!($msg $(, $msg_args)*),
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -2,32 +2,113 @@
|
||||||
#![allow(clippy::unwrap_used)]
|
#![allow(clippy::unwrap_used)]
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::future::IntoFuture;
|
use std::future::IntoFuture;
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, LazyLock};
|
||||||
|
|
||||||
use blahd::{AppState, Database, RoomList};
|
use anyhow::Result;
|
||||||
use reqwest::Client;
|
use blah::types::{
|
||||||
|
get_timestamp, AuthPayload, CreateRoomPayload, Id, MemberPermission, RoomAttrs, RoomMember,
|
||||||
|
RoomMemberList, ServerPermission, UserKey, WithSig,
|
||||||
|
};
|
||||||
|
use blahd::{ApiError, AppState, Database, RoomList, RoomMetadata};
|
||||||
|
use ed25519_dalek::SigningKey;
|
||||||
|
use rand::RngCore;
|
||||||
|
use reqwest::{header, Method, StatusCode};
|
||||||
use rstest::{fixture, rstest};
|
use rstest::{fixture, rstest};
|
||||||
use rusqlite::Connection;
|
use rusqlite::{params, Connection};
|
||||||
|
use serde::de::DeserializeOwned;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
// Avoid name resolution.
|
// Avoid name resolution.
|
||||||
const LOCALHOST: &str = "127.0.0.1";
|
const LOCALHOST: &str = "127.0.0.1";
|
||||||
|
|
||||||
|
static ALICE_PRIV: LazyLock<SigningKey> = LazyLock::new(|| SigningKey::from_bytes(&[b'A'; 32]));
|
||||||
|
static ALICE: LazyLock<UserKey> = LazyLock::new(|| UserKey(ALICE_PRIV.verifying_key().to_bytes()));
|
||||||
|
static BOB_PRIV: LazyLock<SigningKey> = LazyLock::new(|| SigningKey::from_bytes(&[b'B'; 32]));
|
||||||
|
// static BOB: LazyLock<UserKey> = LazyLock::new(|| UserKey(BOB_PRIV.verifying_key().to_bytes()));
|
||||||
|
|
||||||
|
fn mock_rng() -> impl RngCore {
|
||||||
|
rand::rngs::mock::StepRng::new(9, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
trait ResultExt {
|
||||||
|
fn expect_api_err(self, status: StatusCode, code: &str);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: fmt::Debug> ResultExt for Result<T> {
|
||||||
|
#[track_caller]
|
||||||
|
fn expect_api_err(self, status: StatusCode, code: &str) {
|
||||||
|
let err = self.unwrap_err().downcast::<ApiError>().unwrap();
|
||||||
|
assert_eq!(err.status, status);
|
||||||
|
assert_eq!(err.code, code);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Server {
|
struct Server {
|
||||||
port: u16,
|
port: u16,
|
||||||
|
client: reqwest::Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Server {
|
impl Server {
|
||||||
fn url(&self, rhs: impl fmt::Display) -> String {
|
fn url(&self, rhs: impl fmt::Display) -> String {
|
||||||
format!("http://{}:{}{}", LOCALHOST, self.port, rhs)
|
format!("http://{}:{}{}", LOCALHOST, self.port, rhs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn request<Req: Serialize, Resp: DeserializeOwned>(
|
||||||
|
&self,
|
||||||
|
method: Method,
|
||||||
|
url: impl fmt::Display,
|
||||||
|
auth: Option<&str>,
|
||||||
|
body: Option<Req>,
|
||||||
|
) -> Result<Option<Resp>> {
|
||||||
|
let mut b = self.client.request(method, self.url(url));
|
||||||
|
if let Some(auth) = auth {
|
||||||
|
b = b.header(header::AUTHORIZATION, auth);
|
||||||
|
}
|
||||||
|
if let Some(body) = &body {
|
||||||
|
b = b.json(body);
|
||||||
|
}
|
||||||
|
let resp = b.send().await?;
|
||||||
|
let status = resp.status();
|
||||||
|
let resp_str = resp.text().await?;
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct Resp {
|
||||||
|
error: ApiError,
|
||||||
|
}
|
||||||
|
let Resp { mut error } = serde_json::from_str(&resp_str)?;
|
||||||
|
error.status = status;
|
||||||
|
Err(error.into())
|
||||||
|
} else if resp_str.is_empty() {
|
||||||
|
Ok(None)
|
||||||
|
} else {
|
||||||
|
Ok(Some(serde_json::from_str(&resp_str)?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get<Resp: DeserializeOwned>(
|
||||||
|
&self,
|
||||||
|
url: impl fmt::Display,
|
||||||
|
auth: Option<&str>,
|
||||||
|
) -> Result<Resp> {
|
||||||
|
Ok(self
|
||||||
|
.request(Method::GET, url, auth, None::<()>)
|
||||||
|
.await?
|
||||||
|
.unwrap())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[fixture]
|
#[fixture]
|
||||||
fn server() -> Server {
|
fn server() -> Server {
|
||||||
let mut conn = Connection::open_in_memory().unwrap();
|
let mut conn = Connection::open_in_memory().unwrap();
|
||||||
Database::maybe_init(&mut conn).unwrap();
|
Database::maybe_init(&mut conn).unwrap();
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO `user` (`userkey`, `permission`) VALUES (?, ?)",
|
||||||
|
params![*ALICE, ServerPermission::ALL],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
let db = Database::from_raw(conn).unwrap();
|
let db = Database::from_raw(conn).unwrap();
|
||||||
|
|
||||||
// Use std's to avoid async, since we need no name resolution.
|
// Use std's to avoid async, since we need no name resolution.
|
||||||
|
@ -42,28 +123,136 @@ fn server() -> Server {
|
||||||
let router = blahd::router(Arc::new(st));
|
let router = blahd::router(Arc::new(st));
|
||||||
|
|
||||||
tokio::spawn(axum::serve(listener, router).into_future());
|
tokio::spawn(axum::serve(listener, router).into_future());
|
||||||
Server { port }
|
let client = reqwest::ClientBuilder::new().no_proxy().build().unwrap();
|
||||||
}
|
Server { port, client }
|
||||||
|
|
||||||
#[fixture]
|
|
||||||
fn client() -> Client {
|
|
||||||
Client::new()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[rstest]
|
#[rstest]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn smoke(client: Client, server: Server) {
|
async fn smoke(server: Server) {
|
||||||
let got = client
|
let got: RoomList = server.get("/room?filter=public", None).await.unwrap();
|
||||||
.get(server.url("/room?filter=public"))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.json::<RoomList>()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let exp = RoomList {
|
let exp = RoomList {
|
||||||
rooms: Vec::new(),
|
rooms: Vec::new(),
|
||||||
skip_token: None,
|
skip_token: None,
|
||||||
};
|
};
|
||||||
assert_eq!(got, exp);
|
assert_eq!(got, exp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn sign<T: Serialize>(key: &SigningKey, rng: &mut impl RngCore, payload: T) -> WithSig<T> {
|
||||||
|
WithSig::sign(key, get_timestamp(), rng, payload).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth(key: &SigningKey, rng: &mut impl RngCore) -> String {
|
||||||
|
serde_json::to_string(&sign(key, rng, AuthPayload {})).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_room(
|
||||||
|
server: &Server,
|
||||||
|
key: &SigningKey,
|
||||||
|
rng: &mut impl RngCore,
|
||||||
|
attrs: RoomAttrs,
|
||||||
|
title: impl fmt::Display,
|
||||||
|
) -> Result<Id> {
|
||||||
|
let req = sign(
|
||||||
|
key,
|
||||||
|
rng,
|
||||||
|
CreateRoomPayload {
|
||||||
|
attrs,
|
||||||
|
members: RoomMemberList(vec![RoomMember {
|
||||||
|
permission: MemberPermission::ALL,
|
||||||
|
user: UserKey(key.verifying_key().to_bytes()),
|
||||||
|
}]),
|
||||||
|
title: title.to_string(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
Ok(server
|
||||||
|
.request(Method::POST, "/room/create", None, Some(&req))
|
||||||
|
.await?
|
||||||
|
.unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rstest]
|
||||||
|
#[case::public(true)]
|
||||||
|
#[case::private(false)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn room_create_get(server: Server, #[case] public: bool) {
|
||||||
|
let rng = &mut mock_rng();
|
||||||
|
let mut room_meta = RoomMetadata {
|
||||||
|
rid: Id(0),
|
||||||
|
title: "test room".into(),
|
||||||
|
attrs: if public {
|
||||||
|
RoomAttrs::PUBLIC_READABLE | RoomAttrs::PUBLIC_JOINABLE
|
||||||
|
} else {
|
||||||
|
RoomAttrs::empty()
|
||||||
|
},
|
||||||
|
last_chat: None,
|
||||||
|
last_seen_cid: None,
|
||||||
|
unseen_cnt: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Alice has permission.
|
||||||
|
let rid = create_room(&server, &ALICE_PRIV, rng, room_meta.attrs, &room_meta.title)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
room_meta.rid = rid;
|
||||||
|
|
||||||
|
// Bob has no permission.
|
||||||
|
create_room(
|
||||||
|
&server,
|
||||||
|
&BOB_PRIV,
|
||||||
|
rng,
|
||||||
|
room_meta.attrs,
|
||||||
|
room_meta.title.clone(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect_api_err(StatusCode::FORBIDDEN, "permission_denied");
|
||||||
|
|
||||||
|
// Alice can always access it.
|
||||||
|
let got_meta = server
|
||||||
|
.get::<RoomMetadata>(format!("/room/{rid}"), Some(&auth(&ALICE_PRIV, rng)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(got_meta, room_meta);
|
||||||
|
|
||||||
|
// Bob or public can access it when it is public.
|
||||||
|
for auth in [None, Some(auth(&BOB_PRIV, rng))] {
|
||||||
|
let resp = server
|
||||||
|
.get::<RoomMetadata>(format!("/room/{rid}"), auth.as_deref())
|
||||||
|
.await;
|
||||||
|
if public {
|
||||||
|
assert_eq!(resp.unwrap(), room_meta);
|
||||||
|
} else {
|
||||||
|
resp.expect_api_err(StatusCode::NOT_FOUND, "not_found");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The room appears in public list only when it is public.
|
||||||
|
let expect_list = |has: bool| RoomList {
|
||||||
|
rooms: has.then(|| room_meta.clone()).into_iter().collect(),
|
||||||
|
skip_token: None,
|
||||||
|
};
|
||||||
|
assert_eq!(
|
||||||
|
server
|
||||||
|
.get::<RoomList>("/room?filter=public", None)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
expect_list(public),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Joined rooms endpoint always require authentication.
|
||||||
|
server
|
||||||
|
.get::<RoomList>("/room?filter=joined", None)
|
||||||
|
.await
|
||||||
|
.expect_api_err(StatusCode::UNAUTHORIZED, "unauthorized");
|
||||||
|
let got_joined = server
|
||||||
|
.get::<RoomList>("/room?filter=joined", Some(&auth(&ALICE_PRIV, rng)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(got_joined, expect_list(true));
|
||||||
|
|
||||||
|
let got_joined = server
|
||||||
|
.get::<RoomList>("/room?filter=joined", Some(&auth(&BOB_PRIV, rng)))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(got_joined, expect_list(false));
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue