test: add more tests for register verification

This commit is contained in:
oxalica 2024-09-22 10:54:34 -04:00
parent 7ab1d4a935
commit 2fe8dfdab7
3 changed files with 260 additions and 45 deletions

View file

@ -36,7 +36,7 @@ mod id;
mod register;
mod utils;
pub use database::Database;
pub use database::{Config as DatabaseConfig, Database};
pub use middleware::ApiError;
#[serde_inline_default]
@ -202,7 +202,7 @@ async fn user_get(
async fn user_register(
State(st): ArcState,
SignedJson(msg): SignedJson<UserRegisterPayload>,
) -> impl IntoResponse {
) -> Result<StatusCode, ApiError> {
register::user_register(&st, msg).await
}

View file

@ -7,7 +7,7 @@ use std::future::{Future, IntoFuture};
use std::sync::{Arc, LazyLock};
use std::time::{Duration, Instant};
use anyhow::Result;
use anyhow::{Context, Result};
use axum::http::HeaderMap;
use blah_types::identity::{IdUrl, UserActKeyDesc, UserIdentityDesc, UserProfile};
use blah_types::{
@ -190,7 +190,8 @@ impl Server {
struct Resp {
error: ApiError,
}
let Resp { mut error } = serde_json::from_str(&resp_str)?;
let Resp { mut error } = serde_json::from_str(&resp_str)
.with_context(|| format!("failed to parse response {resp_str:?}"))?;
error.status = status;
Err(ApiErrorWithHeaders { error, headers }.into())
} else if resp_str.is_empty() {
@ -320,6 +321,34 @@ impl Server {
Ok(WithMsgId { cid, msg })
}
}
async fn get_me(&self, auth_user: Option<&User>) -> Result<(), Option<(u32, u8)>> {
let auth = auth_user.map(auth);
match self
.request::<(), NoContent>(Method::GET, "/user/me", auth.as_deref(), None)
.await
{
Ok(None) => Ok(()),
Err(err) => {
let err = err.downcast::<ApiErrorWithHeaders>().unwrap();
assert_eq!(err.error.status, StatusCode::NOT_FOUND);
if !err.headers.contains_key(X_BLAH_NONCE) {
return Err(None);
}
let challenge_nonce = err.headers[X_BLAH_NONCE]
.to_str()
.unwrap()
.parse::<u32>()
.unwrap();
let difficulty = err.headers[X_BLAH_DIFFICULTY]
.to_str()
.unwrap()
.parse::<u8>()
.unwrap();
Err(Some((challenge_nonce, difficulty)))
}
}
}
}
#[fixture]
@ -359,15 +388,19 @@ fn server() -> Server {
}
}
let db = Database::from_raw(conn).unwrap();
server_with(db, &CONFIG)
}
// TODO: Testing config is hard to build because it does have a `Default` impl.
#[track_caller]
fn server_with(db: Database, config: &dyn Fn(u16) -> String) -> Server {
// Use std's to avoid async, since we need no name resolution.
let listener = std::net::TcpListener::bind(format!("{LOCALHOST}:0")).unwrap();
listener.set_nonblocking(true).unwrap();
let port = listener.local_addr().unwrap().port();
let listener = TcpListener::from_std(listener).unwrap();
// TODO: Testing config is hard to build because it does have a `Default` impl.
let config = toml::from_str(&CONFIG(port)).unwrap();
let config = toml::from_str(&config(port)).unwrap();
let st = AppState::new(db, config);
let router = blahd::router(Arc::new(st));
@ -909,7 +942,7 @@ async fn delete_room(server: Server, #[case] peer_chat: bool, #[case] alice_dele
#[rstest]
#[tokio::test]
async fn register(server: Server) {
async fn register_flow(server: Server) {
let rid = server
.create_room(
&ALICE,
@ -919,44 +952,22 @@ async fn register(server: Server) {
.await
.unwrap();
let get_me = |user: Option<&User>| {
let auth = user.map(auth);
server
.request::<(), ()>(Method::GET, "/user/me", auth.as_deref(), None)
.map_ok(|_| ())
.map_err(|err| {
let err = err.downcast::<ApiErrorWithHeaders>().unwrap();
assert_eq!(err.error.status, StatusCode::NOT_FOUND);
let challenge_nonce = err.headers[X_BLAH_NONCE]
.to_str()
.unwrap()
.parse::<u32>()
.unwrap();
let difficulty = err.headers[X_BLAH_DIFFICULTY]
.to_str()
.unwrap()
.parse::<u8>()
.unwrap();
(challenge_nonce, difficulty)
})
};
// Alice is registered.
get_me(Some(&ALICE)).await.unwrap();
server.get_me(Some(&ALICE)).await.unwrap();
// Carol is not registered.
let (challenge_nonce, diff) = get_me(Some(&CAROL)).await.unwrap_err();
let (challenge_nonce, diff) = server.get_me(Some(&CAROL)).await.unwrap_err().unwrap();
assert_eq!(diff, REGISTER_DIFFICULTY);
// Without token.
let ret2 = get_me(None).await.unwrap_err();
let ret2 = server.get_me(None).await.unwrap_err().unwrap();
assert_eq!(ret2, (challenge_nonce, diff));
let mut req = UserRegisterPayload {
id_key: CAROL.pubkeys.id_key.clone(),
// Invalid values.
server_url: "http://localhost".parse().unwrap(),
id_url: "http://.".parse().unwrap(),
id_url: "http://com.".parse().unwrap(),
challenge_nonce: challenge_nonce - 1,
};
let register = |req: Signed<UserRegisterPayload>| {
@ -1080,6 +1091,7 @@ async fn register(server: Server) {
}
.sign_msg(&CAROL.pubkeys.id_key, &CAROL.id_priv)
.unwrap();
// Incorrect URL, without port.
let profile = sign_profile("https://localhost".parse().unwrap());
UserIdentityDesc {
id_key: CAROL.pubkeys.id_key.clone(),
@ -1095,7 +1107,7 @@ async fn register(server: Server) {
.expect_api_err(StatusCode::UNAUTHORIZED, "invalid_id_description");
// Still not registered.
get_me(Some(&CAROL)).await.unwrap_err();
server.get_me(Some(&CAROL)).await.unwrap_err();
server
.join_room(rid, &CAROL, MemberPermission::MAX_SELF_ADD)
.await
@ -1107,13 +1119,73 @@ async fn register(server: Server) {
register(sign_with_difficulty(&req, true)).await.unwrap();
// Registered now.
get_me(Some(&CAROL)).await.unwrap();
server.get_me(Some(&CAROL)).await.unwrap();
server
.join_room(rid, &CAROL, MemberPermission::MAX_SELF_ADD)
.await
.unwrap();
}
#[rstest]
#[case::disabled(false, true, true, true)]
#[case::no_http(true, false, true, true)]
#[case::no_port(true, true, false, true)]
#[case::no_single_label(true, true, true, false)]
#[tokio::test]
async fn register_config(
#[case] enabled: bool,
#[case] allow_http: bool,
#[case] allow_port: bool,
#[case] allow_single_label: bool,
) {
let config = |port| {
format!(
r#"
base_url="http://{LOCALHOST}:{port}"
[register]
enable_public = {enabled}
unsafe_allow_id_url_http = {allow_http}
unsafe_allow_id_url_custom_port = {allow_port}
unsafe_allow_id_url_single_label = {allow_single_label}
"#
)
};
let db_config = blahd::DatabaseConfig {
in_memory: true,
..Default::default()
};
let server = server_with(Database::open(&db_config).unwrap(), &config);
// Returns challenge headers only if registration is enabled.
let hdrs = server.get_me(Some(&CAROL)).await.unwrap_err();
if enabled {
hdrs.unwrap();
} else {
assert_eq!(hdrs, None);
}
let server_url = format!("http://{}", server.domain());
let req = server.sign(
&CAROL,
UserRegisterPayload {
id_key: CAROL.pubkeys.id_key.clone(),
// Unused values.
id_url: server_url.parse().unwrap(),
server_url: server_url.parse().unwrap(),
challenge_nonce: 0,
},
);
let ret = server
.request::<_, ()>(Method::POST, "/user/me", None, Some(req))
.await;
if !enabled {
ret.expect_api_err(StatusCode::FORBIDDEN, "disabled");
} else {
ret.expect_api_err(StatusCode::BAD_REQUEST, "invalid_id_url");
}
}
#[rstest]
#[tokio::test]
async fn event(server: Server) {