Define error response format and refactor error handling

This commit is contained in:
oxalica 2024-08-30 23:54:16 -04:00
parent 4ceffe3f31
commit 4937502d4c
6 changed files with 341 additions and 195 deletions

23
Cargo.lock generated
View file

@ -177,6 +177,28 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "axum-extra"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0be6ea09c9b96cb5076af0de2e383bd2bc0c18f827cf1967bdd353e0b910d733"
dependencies = [
"axum",
"axum-core",
"bytes",
"futures-util",
"http",
"http-body",
"http-body-util",
"mime",
"pin-project-lite",
"serde",
"tower",
"tower-layer",
"tower-service",
"tracing",
]
[[package]] [[package]]
name = "backtrace" name = "backtrace"
version = "0.3.73" version = "0.3.73"
@ -257,6 +279,7 @@ version = "0.0.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"axum", "axum",
"axum-extra",
"blah", "blah",
"clap", "clap",
"ed25519-dalek", "ed25519-dalek",

View file

@ -6,6 +6,7 @@ edition = "2021"
[dependencies] [dependencies]
anyhow = "1" anyhow = "1"
axum = { version = "0.7", features = ["tokio"] } axum = { version = "0.7", features = ["tokio"] }
axum-extra = "0.9"
clap = { version = "4", features = ["derive"] } clap = { version = "4", features = ["derive"] }
ed25519-dalek = "2" ed25519-dalek = "2"
futures-util = "0.3" futures-util = "0.3"

View file

@ -33,6 +33,9 @@ paths:
description: UUID of the newly created room (ruuid). description: UUID of the newly created room (ruuid).
403: 403:
description: The user does not have permission to create room. description: The user does not have permission to create room.
content:
application/json:
$ref: '#/components/schemas/ApiError'
/room/{ruuid}/feed.json: /room/{ruuid}/feed.json:
get: get:
@ -45,6 +48,9 @@ paths:
$ref: 'https://www.jsonfeed.org/version/1.1/' $ref: 'https://www.jsonfeed.org/version/1.1/'
404: 404:
description: Room does not exist or is private. description: Room does not exist or is private.
content:
application/json:
$ref: '#/components/schemas/ApiError'
/room/{ruuid}/item: /room/{ruuid}/item:
get: get:
@ -68,6 +74,12 @@ paths:
content: content:
application/json: application/json:
x-description: TODO x-description: TODO
404:
description: |
Room does not exist or the user does not have permission to read it.
content:
application/json:
$ref: '#/components/schemas/ApiError'
post: post:
summary: Post a chat in room {ruuid} summary: Post a chat in room {ruuid}
@ -94,10 +106,15 @@ paths:
description: Created chat id (cid). description: Created chat id (cid).
400: 400:
description: Body is invalid or fails the verification. description: Body is invalid or fails the verification.
content:
application/json:
$ref: '#/components/schemas/ApiError'
403: 403:
description: The user does not have permission to post in this room. description: |
404: The user does not have permission to post in this room, or the room does not exist.
description: Room not found. content:
application/json:
$ref: '#/components/schemas/ApiError'
/room/{ruuid}/event: /room/{ruuid}/event:
get: get:
@ -118,5 +135,24 @@ paths:
x-description: An event stream, each event is a JSON with type WithSig<ChatPayload> x-description: An event stream, each event is a JSON with type WithSig<ChatPayload>
400: 400:
description: Body is invalid or fails the verification. description: Body is invalid or fails the verification.
content:
application/json:
$ref: '#/components/schemas/ApiError'
404: 404:
description: Room not found. description: Room not found.
content:
application/json:
$ref: '#/components/schemas/ApiError'
components:
schemas:
ApiError:
type: object
properties:
error:
type: object
properties:
code:
type: string
message:
type: string

View file

@ -6,18 +6,19 @@ use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime}; use std::time::{Duration, SystemTime};
use anyhow::{ensure, Context, Result}; use anyhow::{ensure, Context, Result};
use axum::extract::{FromRef, FromRequest, FromRequestParts, Path, Query, Request, State}; use axum::extract::{Path, Query, State};
use axum::http::{header, request, StatusCode}; use axum::http::{header, StatusCode};
use axum::response::{sse, IntoResponse, Response}; use axum::response::{sse, IntoResponse};
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{async_trait, Json, Router}; use axum::{Json, Router};
use axum_extra::extract::WithRejection;
use blah::types::{ use blah::types::{
AuthPayload, ChatItem, ChatPayload, CreateRoomPayload, MemberPermission, RoomAttrs, ChatItem, ChatPayload, CreateRoomPayload, MemberPermission, RoomAttrs, ServerPermission,
ServerPermission, Signee, UserKey, WithSig, Signee, UserKey, WithSig,
}; };
use ed25519_dalek::SIGNATURE_LENGTH; use ed25519_dalek::SIGNATURE_LENGTH;
use middleware::{ApiError, OptionalAuth, SignedJson};
use rusqlite::{named_params, params, OptionalExtension, Row}; use rusqlite::{named_params, params, OptionalExtension, Row};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::broadcast; use tokio::sync::broadcast;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
@ -29,6 +30,8 @@ const EVENT_QUEUE_LEN: usize = 1024;
const MAX_BODY_LEN: usize = 4 << 10; // 4KiB const MAX_BODY_LEN: usize = 4 << 10; // 4KiB
const TIMESTAMP_TOLERENCE: u64 = 90; const TIMESTAMP_TOLERENCE: u64 = 90;
#[macro_use]
mod middleware;
mod utils; mod utils;
#[derive(Debug, clap::Parser)] #[derive(Debug, clap::Parser)]
@ -93,24 +96,38 @@ impl AppState {
}) })
} }
fn verify_signed_data<T: Serialize>(&self, data: &WithSig<T>) -> Result<()> { fn verify_signed_data<T: Serialize>(&self, data: &WithSig<T>) -> Result<(), ApiError> {
data.verify().context("unsigned payload")?; let Ok(()) = data.verify() else {
return Err(error_response!(
StatusCode::BAD_REQUEST,
"invalid_signature",
"signature verification failed"
));
};
let timestamp_diff = SystemTime::now() let timestamp_diff = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH) .duration_since(SystemTime::UNIX_EPOCH)
.expect("after UNIX epoch") .expect("after UNIX epoch")
.as_secs() .as_secs()
.abs_diff(data.signee.timestamp); .abs_diff(data.signee.timestamp);
ensure!( if timestamp_diff > TIMESTAMP_TOLERENCE {
timestamp_diff <= TIMESTAMP_TOLERENCE, return Err(error_response!(
StatusCode::BAD_REQUEST,
"invalid_timestamp",
"invalid timestamp, off by {timestamp_diff}s" "invalid timestamp, off by {timestamp_diff}s"
); ));
ensure!( }
self.used_nonces if !self
.used_nonces
.lock() .lock()
.unwrap() .unwrap()
.try_insert(data.signee.nonce), .try_insert(data.signee.nonce)
{
return Err(error_response!(
StatusCode::BAD_REQUEST,
"duplicated_nonce",
"duplicated nonce", "duplicated nonce",
); ));
}
Ok(()) Ok(())
} }
} }
@ -125,9 +142,9 @@ async fn main_async(opt: Cli, st: AppState) -> Result<()> {
.route("/room/:ruuid/event", get(room_event)) .route("/room/:ruuid/event", get(room_event))
.route("/room/:ruuid/item", get(room_get_item).post(room_post_item)) .route("/room/:ruuid/item", get(room_get_item).post(room_post_item))
.with_state(Arc::new(st)) .with_state(Arc::new(st))
.layer(tower_http::limit::RequestBodyLimitLayer::new(MAX_BODY_LEN))
// NB. This comes at last (outmost layer), so inner errors will still be wraped with // NB. This comes at last (outmost layer), so inner errors will still be wraped with
// correct CORS headers. // correct CORS headers.
.layer(tower_http::limit::RequestBodyLimitLayer::new(MAX_BODY_LEN))
.layer(tower_http::cors::CorsLayer::permissive()); .layer(tower_http::cors::CorsLayer::permissive());
let listener = tokio::net::TcpListener::bind(&opt.listen) let listener = tokio::net::TcpListener::bind(&opt.listen)
@ -143,26 +160,20 @@ async fn main_async(opt: Cli, st: AppState) -> Result<()> {
Ok(()) Ok(())
} }
fn from_db_error(err: rusqlite::Error) -> StatusCode {
match err {
rusqlite::Error::QueryReturnedNoRows => StatusCode::NOT_FOUND,
err => {
tracing::error!(%err, "database error");
StatusCode::INTERNAL_SERVER_ERROR
}
}
}
async fn room_create( async fn room_create(
st: ArcState, st: ArcState,
SignedJson(params): SignedJson<CreateRoomPayload>, SignedJson(params): SignedJson<CreateRoomPayload>,
) -> Result<Json<Uuid>, StatusCode> { ) -> Result<Json<Uuid>, ApiError> {
let members = &params.signee.payload.members.0; let members = &params.signee.payload.members.0;
if !members if !members
.iter() .iter()
.any(|m| m.user == params.signee.user && m.permission == MemberPermission::ALL) .any(|m| m.user == params.signee.user && m.permission == MemberPermission::ALL)
{ {
return Err(StatusCode::BAD_REQUEST); return Err(error_response!(
StatusCode::BAD_REQUEST,
"deserialization",
"invalid initial members",
));
} }
let mut conn = st.conn.lock().unwrap(); let mut conn = st.conn.lock().unwrap();
@ -179,15 +190,17 @@ async fn room_create(
Ok(perm.contains(ServerPermission::CREATE_ROOM)) Ok(perm.contains(ServerPermission::CREATE_ROOM))
}, },
) )
.optional() .optional()?
.map_err(from_db_error)?
else { else {
return Err(StatusCode::FORBIDDEN); return Err(error_response!(
StatusCode::FORBIDDEN,
"permission_denied",
"user does not have permission to create room",
));
}; };
let ruuid = Uuid::new_v4(); let ruuid = Uuid::new_v4();
(|| {
let txn = conn.transaction()?; let txn = conn.transaction()?;
let rid = txn.query_row( let rid = txn.query_row(
r" r"
@ -227,9 +240,6 @@ async fn room_create(
drop(insert_member); drop(insert_member);
drop(insert_user); drop(insert_user);
txn.commit()?; txn.commit()?;
Ok(())
})()
.map_err(from_db_error)?;
Ok(Json(ruuid)) Ok(Json(ruuid))
} }
@ -246,13 +256,12 @@ struct GetRoomItemParams {
async fn room_get_item( async fn room_get_item(
st: ArcState, st: ArcState,
Path(ruuid): Path<Uuid>, WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>,
params: Query<GetRoomItemParams>, WithRejection(params, _): WithRejection<Query<GetRoomItemParams>, ApiError>,
OptionalAuth(user): OptionalAuth, OptionalAuth(user): OptionalAuth,
) -> Result<impl IntoResponse, StatusCode> { ) -> Result<impl IntoResponse, ApiError> {
let (room_meta, items) = let (room_meta, items) =
query_room_items(&st.conn.lock().unwrap(), ruuid, user.as_ref(), &params) query_room_items(&st.conn.lock().unwrap(), ruuid, user.as_ref(), &params)?;
.map_err(from_db_error)?;
// TODO: This format is to-be-decided. Or do we even need this interface other than // TODO: This format is to-be-decided. Or do we even need this interface other than
// `feed.json`? // `feed.json`?
@ -261,11 +270,10 @@ async fn room_get_item(
async fn room_get_feed( async fn room_get_feed(
st: ArcState, st: ArcState,
Path(ruuid): Path<Uuid>, WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>,
params: Query<GetRoomItemParams>, params: Query<GetRoomItemParams>,
) -> Result<impl IntoResponse, StatusCode> { ) -> Result<impl IntoResponse, ApiError> {
let (room_meta, items) = let (room_meta, items) = query_room_items(&st.conn.lock().unwrap(), ruuid, None, &params)?;
query_room_items(&st.conn.lock().unwrap(), ruuid, None, &params).map_err(from_db_error)?;
let items = items let items = items
.into_iter() .into_iter()
@ -352,7 +360,7 @@ fn get_room_if_readable<T>(
ruuid: Uuid, ruuid: Uuid,
user: Option<&UserKey>, user: Option<&UserKey>,
f: impl FnOnce(&Row<'_>) -> rusqlite::Result<T>, f: impl FnOnce(&Row<'_>) -> rusqlite::Result<T>,
) -> rusqlite::Result<T> { ) -> Result<T, ApiError> {
conn.query_row( conn.query_row(
r" r"
SELECT `rid`, `title`, `attrs` SELECT `rid`, `title`, `attrs`
@ -372,6 +380,8 @@ fn get_room_if_readable<T>(
}, },
f, f,
) )
.optional()?
.ok_or_else(|| error_response!(StatusCode::NOT_FOUND, "not_found", "room not found"))
} }
fn query_room_items( fn query_room_items(
@ -379,7 +389,7 @@ fn query_room_items(
ruuid: Uuid, ruuid: Uuid,
user: Option<&UserKey>, user: Option<&UserKey>,
params: &GetRoomItemParams, params: &GetRoomItemParams,
) -> rusqlite::Result<(RoomMetadata, Vec<(u64, ChatItem)>)> { ) -> Result<(RoomMetadata, Vec<(u64, ChatItem)>), ApiError> {
let (rid, title, attrs) = get_room_if_readable(conn, ruuid, user, |row| { let (rid, title, attrs) = get_room_if_readable(conn, ruuid, user, |row| {
Ok(( Ok((
row.get::<_, u64>("rid")?, row.get::<_, u64>("rid")?,
@ -430,76 +440,17 @@ fn query_room_items(
Ok((room_meta, items)) Ok((room_meta, items))
} }
/// Extractor for verified JSON payload.
#[derive(Debug)]
struct SignedJson<T>(WithSig<T>);
#[async_trait]
impl<S, T> FromRequest<S> for SignedJson<T>
where
S: Send + Sync,
T: Serialize + DeserializeOwned,
Arc<AppState>: FromRef<S>,
{
type Rejection = Response;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let Json(data) = <Json<WithSig<T>> as FromRequest<S>>::from_request(req, state)
.await
.map_err(|err| err.into_response())?;
let st = <Arc<AppState>>::from_ref(state);
st.verify_signed_data(&data).map_err(|err| {
tracing::debug!(%err, "unsigned payload");
StatusCode::BAD_REQUEST.into_response()
})?;
Ok(Self(data))
}
}
/// Extractor for optional verified JSON authorization header.
#[derive(Debug)]
struct OptionalAuth(Option<UserKey>);
#[async_trait]
impl<S> FromRequestParts<S> for OptionalAuth
where
S: Send + Sync,
Arc<AppState>: FromRef<S>,
{
type Rejection = StatusCode;
async fn from_request_parts(
parts: &mut request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let Some(auth) = parts.headers.get(header::AUTHORIZATION) else {
return Ok(Self(None));
};
let st = <Arc<AppState>>::from_ref(state);
let ret = serde_json::from_slice::<WithSig<AuthPayload>>(auth.as_bytes())
.context("invalid JSON")
.and_then(|data| {
st.verify_signed_data(&data)?;
Ok(data.signee.user)
});
match ret {
Ok(user) => Ok(Self(Some(user))),
Err(err) => {
tracing::debug!(%err, "invalid authorization");
Err(StatusCode::BAD_REQUEST)
}
}
}
}
async fn room_post_item( async fn room_post_item(
st: ArcState, st: ArcState,
Path(ruuid): Path<Uuid>, Path(ruuid): Path<Uuid>,
SignedJson(chat): SignedJson<ChatPayload>, SignedJson(chat): SignedJson<ChatPayload>,
) -> Result<Json<u64>, StatusCode> { ) -> Result<Json<u64>, ApiError> {
if ruuid != chat.signee.payload.room { if ruuid != chat.signee.payload.room {
return Err(StatusCode::BAD_REQUEST); return Err(error_response!(
StatusCode::BAD_REQUEST,
"invalid_request",
"URI and payload room id mismatch",
));
} }
let (rid, cid) = { let (rid, cid) = {
@ -522,15 +473,16 @@ async fn room_post_item(
}, },
|row| Ok((row.get::<_, u64>("rid")?, row.get::<_, u64>("uid")?)), |row| Ok((row.get::<_, u64>("rid")?, row.get::<_, u64>("uid")?)),
) )
.optional() .optional()?
.map_err(from_db_error)?
else { else {
tracing::debug!("rejected post: unpermitted user {}", chat.signee.user); return Err(error_response!(
return Err(StatusCode::FORBIDDEN); StatusCode::FORBIDDEN,
"permission_denied",
"the user does not have permission to post in this room",
));
}; };
let cid = conn let cid = conn.query_row(
.query_row(
r" r"
INSERT INTO `room_item` (`rid`, `uid`, `timestamp`, `nonce`, `sig`, `rich_text`) INSERT INTO `room_item` (`rid`, `uid`, `timestamp`, `nonce`, `sig`, `rich_text`)
VALUES (:rid, :uid, :timestamp, :nonce, :sig, :rich_text) VALUES (:rid, :uid, :timestamp, :nonce, :sig, :rich_text)
@ -545,12 +497,10 @@ async fn room_post_item(
":sig": chat.sig, ":sig": chat.sig,
}, },
|row| row.get::<_, u64>(0), |row| row.get::<_, u64>(0),
) )?;
.map_err(from_db_error)?;
(rid, cid) (rid, cid)
}; };
{
let mut listeners = st.room_listeners.lock().unwrap(); let mut listeners = st.room_listeners.lock().unwrap();
if let Some(tx) = listeners.get(&rid) { if let Some(tx) = listeners.get(&rid) {
if tx.send(Arc::new(chat)).is_err() { if tx.send(Arc::new(chat)).is_err() {
@ -558,7 +508,6 @@ async fn room_post_item(
listeners.remove(&rid); listeners.remove(&rid);
} }
} }
}
Ok(Json(cid)) Ok(Json(cid))
} }
@ -570,11 +519,10 @@ async fn room_event(
// But this API is kinda temporary and need a better replacement anyway. // But this API is kinda temporary and need a better replacement anyway.
// So just only support public room for now. // So just only support public room for now.
OptionalAuth(user): OptionalAuth, OptionalAuth(user): OptionalAuth,
) -> Result<impl IntoResponse, StatusCode> { ) -> Result<impl IntoResponse, ApiError> {
let rid = get_room_if_readable(&st.conn.lock().unwrap(), ruuid, user.as_ref(), |row| { let rid = get_room_if_readable(&st.conn.lock().unwrap(), ruuid, user.as_ref(), |row| {
row.get::<_, u64>(0) row.get::<_, u64>(0)
}) })?;
.map_err(from_db_error)?;
let rx = match st.room_listeners.lock().unwrap().entry(rid) { let rx = match st.room_listeners.lock().unwrap().entry(rid) {
Entry::Occupied(ent) => ent.get().subscribe(), Entry::Occupied(ent) => ent.get().subscribe(),

135
blahd/src/middleware.rs Normal file
View file

@ -0,0 +1,135 @@
use std::sync::Arc;
use axum::extract::rejection::{JsonRejection, PathRejection, QueryRejection};
use axum::extract::{FromRef, FromRequest, FromRequestParts, Request};
use axum::http::{header, request, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::{async_trait, Json};
use blah::types::{AuthPayload, UserKey, WithSig};
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::AppState;
/// Error response body for json endpoints.
///
/// Mostly following: <https://learn.microsoft.com/en-us/graph/errors>
#[derive(Debug, Serialize)]
pub struct ApiError {
#[serde(skip)]
pub status: StatusCode,
pub code: &'static str,
pub message: String,
}
macro_rules! error_response {
($status:expr, $code:literal, $msg:literal $(, $msg_args:expr)* $(,)?) => {
$crate::middleware::ApiError {
status: $status,
code: $code,
message: ::std::format!($msg $(, $msg_args)*),
}
};
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
#[derive(Serialize)]
struct Resp<'a> {
error: &'a ApiError,
}
let mut resp = Json(Resp { error: &self }).into_response();
*resp.status_mut() = self.status;
resp
}
}
macro_rules! define_from_deser_rejection {
($($ty:ty, $name:literal;)*) => {
$(
impl From<$ty> for ApiError {
fn from(rej: $ty) -> Self {
error_response!(
StatusCode::BAD_REQUEST,
"deserialization",
"invalid {}: {}",
$name,
rej,
)
}
}
)*
};
}
define_from_deser_rejection! {
JsonRejection, "json";
QueryRejection, "query";
PathRejection, "path";
}
impl From<rusqlite::Error> for ApiError {
fn from(err: rusqlite::Error) -> Self {
tracing::error!(%err, "database error");
error_response!(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"internal server error",
)
}
}
/// Extractor for verified JSON payload.
#[derive(Debug)]
pub struct SignedJson<T>(pub WithSig<T>);
#[async_trait]
impl<S, T> FromRequest<S> for SignedJson<T>
where
S: Send + Sync,
T: Serialize + DeserializeOwned,
Arc<AppState>: FromRef<S>,
{
type Rejection = ApiError;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let Json(data) = <Json<WithSig<T>> as FromRequest<S>>::from_request(req, state).await?;
let st = <Arc<AppState>>::from_ref(state);
st.verify_signed_data(&data)?;
Ok(Self(data))
}
}
/// Extractor for optional verified JSON authorization header.
#[derive(Debug)]
pub struct OptionalAuth(pub Option<UserKey>);
#[async_trait]
impl<S> FromRequestParts<S> for OptionalAuth
where
S: Send + Sync,
Arc<AppState>: FromRef<S>,
{
type Rejection = ApiError;
async fn from_request_parts(
parts: &mut request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
let Some(auth) = parts.headers.get(header::AUTHORIZATION) else {
return Ok(Self(None));
};
let st = <Arc<AppState>>::from_ref(state);
let data =
serde_json::from_slice::<WithSig<AuthPayload>>(auth.as_bytes()).map_err(|err| {
error_response!(
StatusCode::BAD_REQUEST,
"deserialization",
"invalid authorization header: {err}",
)
})?;
st.verify_signed_data(&data)?;
Ok(Self(Some(data.signee.user)))
}
}

View file

@ -177,12 +177,12 @@ async function connectRoom(url) {
}, },
}, },
) )
.then((resp) => { .then(async (resp) => {
if (!resp.ok) throw new Error(`status ${resp.status} ${resp.statusText}`); return [resp.status, await resp.json()];
return resp.json();
}) })
// TODO: This response format is to-be-decided. // TODO: This response format is to-be-decided.
.then(async (json) => { .then(async ([status, json]) => {
if (status !== 200) throw new Error(`status ${status}: ${json.error.message}`);
const [{ title }, items] = json const [{ title }, items] = json
document.title = `room: ${title}` document.title = `room: ${title}`
items.reverse(); items.reverse();
@ -256,7 +256,10 @@ async function postChat(text) {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
}, },
}); });
if (!resp.ok) throw new Error(`status ${resp.status} ${resp.statusText}`); if (!resp.ok) {
const errResp = await resp.json();
throw new Error(`status ${resp.status}: ${errResp.error.message}`);
}
chatInput.value = ''; chatInput.value = '';
} catch (e) { } catch (e) {
console.error(e); console.error(e);