refactor(webapi,types)!: make challenge type extensive

We may allow more challenge types other than PoW in the future, eg.
captcha. So make the relevent types more generic.

Now the challenge is returned in JSON response as a individual top-level
field `register_challenge` instead of in HTTP headers.
This commit is contained in:
oxalica 2024-10-01 05:26:00 -04:00
parent 364e517b7d
commit bc6e6c2056
11 changed files with 206 additions and 130 deletions

View file

@ -8,14 +8,14 @@ use std::sync::{Arc, LazyLock};
use std::time::{Duration, Instant};
use anyhow::{Context, Result};
use axum::http::HeaderMap;
use blah_types::identity::{IdUrl, UserActKeyDesc, UserIdentityDesc, UserProfile};
use blah_types::msg::{
AuthPayload, ChatPayload, CreateGroup, CreatePeerChat, CreateRoomPayload, DeleteRoomPayload,
MemberPermission, RichText, RoomAdminOp, RoomAdminPayload, RoomAttrs, ServerPermission,
SignedChatMsg, SignedChatMsgWithId, UserRegisterPayload, WithMsgId,
SignedChatMsg, SignedChatMsgWithId, UserRegisterChallengeResponse, UserRegisterPayload,
WithMsgId,
};
use blah_types::server::{RoomMetadata, ServerMetadata, X_BLAH_DIFFICULTY, X_BLAH_NONCE};
use blah_types::server::{RoomMetadata, ServerMetadata, UserRegisterChallenge};
use blah_types::{Id, SignExt, Signed, UserKey};
use blahd::{AppState, Database, RoomList, RoomMsgs};
use ed25519_dalek::SigningKey;
@ -52,11 +52,14 @@ max_page_len = 2
[register]
enable_public = true
difficulty = {REGISTER_DIFFICULTY}
request_timeout_secs = 1
unsafe_allow_id_url_http = true
unsafe_allow_id_url_custom_port = true
unsafe_allow_id_url_single_label = true
[register.challenge.pow]
difficulty = {REGISTER_DIFFICULTY}
nonce_rotate_secs = 60
"#
)
};
@ -98,7 +101,7 @@ trait ResultExt {
impl<T: fmt::Debug> ResultExt for Result<T> {
#[track_caller]
fn expect_api_err(self, status: StatusCode, code: &str) {
let err = self.unwrap_err().downcast::<ApiErrorWithHeaders>().unwrap();
let err = self.unwrap_err().downcast::<ApiError>().unwrap();
assert_eq!(
(err.status, &*err.code),
(status, code),
@ -108,7 +111,7 @@ impl<T: fmt::Debug> ResultExt for Result<T> {
#[track_caller]
fn expect_invalid_request(self, message: &str) {
let err = self.unwrap_err().downcast::<ApiErrorWithHeaders>().unwrap();
let err = self.unwrap_err().downcast::<ApiError>().unwrap();
assert_eq!(
(err.status, &*err.code, &*err.message),
(StatusCode::BAD_REQUEST, "invalid_request", message),
@ -118,14 +121,14 @@ impl<T: fmt::Debug> ResultExt for Result<T> {
}
#[derive(Debug)]
pub struct ApiErrorWithHeaders {
pub struct ApiError {
status: StatusCode,
code: String,
message: String,
headers: HeaderMap,
register_challenge: Option<UserRegisterChallenge>,
}
impl fmt::Display for ApiErrorWithHeaders {
impl fmt::Display for ApiError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
@ -135,7 +138,7 @@ impl fmt::Display for ApiErrorWithHeaders {
}
}
impl std::error::Error for ApiErrorWithHeaders {}
impl std::error::Error for ApiError {}
// TODO: Hoist this into types crate.
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
@ -201,13 +204,14 @@ impl Server {
async move {
let resp = b.send().await?;
let status = resp.status();
let headers = resp.headers().clone();
let resp_str = resp.text().await?;
if !status.is_success() {
#[derive(Deserialize)]
struct Resp {
error: RespErr,
#[serde(default)]
register_challenge: Option<UserRegisterChallenge>,
}
#[derive(Deserialize)]
struct RespErr {
@ -217,11 +221,11 @@ impl Server {
let resp = serde_json::from_str::<Resp>(&resp_str)
.with_context(|| format!("failed to parse response {resp_str:?}"))?;
Err(ApiErrorWithHeaders {
Err(ApiError {
status,
code: resp.error.code,
message: resp.error.message,
headers,
register_challenge: resp.register_challenge,
}
.into())
} else if resp_str.is_empty() {
@ -364,22 +368,15 @@ impl Server {
{
Ok(None) => Ok(()),
Err(err) => {
let err = err.downcast::<ApiErrorWithHeaders>().unwrap();
let err = err.downcast::<ApiError>().unwrap();
assert_eq!(err.status, StatusCode::NOT_FOUND);
if !err.headers.contains_key(X_BLAH_NONCE) {
return Err(None);
}
let challenge_nonce = err.headers[X_BLAH_NONCE]
.to_str()
.unwrap()
.parse::<u32>()
.unwrap();
let difficulty = err.headers[X_BLAH_DIFFICULTY]
.to_str()
.unwrap()
.parse::<u8>()
.unwrap();
Err(Some((challenge_nonce, difficulty)))
Err(match err.register_challenge {
Some(UserRegisterChallenge::Pow { nonce, difficulty }) => {
Some((nonce, difficulty))
}
Some(UserRegisterChallenge::Unknown) => unreachable!(),
None => None,
})
}
}
}
@ -1197,7 +1194,9 @@ async fn register_flow(server: Server) {
// Invalid values.
server_url: "http://localhost".parse().unwrap(),
id_url: "http://com.".parse().unwrap(),
challenge_nonce: challenge_nonce - 1,
challenge: Some(UserRegisterChallengeResponse::Pow {
nonce: challenge_nonce - 1,
}),
};
let register = |req: Signed<UserRegisterPayload>| {
server
@ -1262,7 +1261,9 @@ async fn register_flow(server: Server) {
register_fast(&req)
.await
.expect_invalid_request("invalid challenge nonce");
req.challenge_nonce += 1;
req.challenge = Some(UserRegisterChallengeResponse::Pow {
nonce: challenge_nonce,
});
register(sign_with_difficulty(&req, false))
.await
@ -1408,7 +1409,7 @@ unsafe_allow_id_url_single_label = {allow_single_label}
// Unused values.
id_url: server_url.parse().unwrap(),
server_url: server_url.parse().unwrap(),
challenge_nonce: 0,
challenge: None,
},
);
let ret = server
@ -1433,9 +1434,10 @@ async fn register_nonce() {
base_url="{BASE_URL}"
[register]
enable_public = true
unsafe_allow_id_url_http = true
[register.challenge.pow]
difficulty = 64 # Should fail the challenge if nonce matches.
nonce_rotate_secs = 10
unsafe_allow_id_url_http = true
"#
)
};
@ -1455,7 +1457,7 @@ unsafe_allow_id_url_http = true
id_key: CAROL.pubkeys.id_key.clone(),
server_url: BASE_URL.parse().unwrap(),
id_url: BASE_URL.parse().unwrap(),
challenge_nonce: nonce,
challenge: Some(UserRegisterChallengeResponse::Pow { nonce }),
}
.sign_msg(&CAROL.pubkeys.id_key, &CAROL.act_priv)
.unwrap();