mirror of
https://github.com/Blah-IM/blahrs.git
synced 2025-05-01 08:41:09 +00:00
Define error response format and refactor error handling
This commit is contained in:
parent
4ceffe3f31
commit
4937502d4c
6 changed files with 341 additions and 195 deletions
23
Cargo.lock
generated
23
Cargo.lock
generated
|
@ -177,6 +177,28 @@ dependencies = [
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum-extra"
|
||||||
|
version = "0.9.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0be6ea09c9b96cb5076af0de2e383bd2bc0c18f827cf1967bdd353e0b910d733"
|
||||||
|
dependencies = [
|
||||||
|
"axum",
|
||||||
|
"axum-core",
|
||||||
|
"bytes",
|
||||||
|
"futures-util",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"http-body-util",
|
||||||
|
"mime",
|
||||||
|
"pin-project-lite",
|
||||||
|
"serde",
|
||||||
|
"tower",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "backtrace"
|
name = "backtrace"
|
||||||
version = "0.3.73"
|
version = "0.3.73"
|
||||||
|
@ -257,6 +279,7 @@ version = "0.0.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"axum",
|
"axum",
|
||||||
|
"axum-extra",
|
||||||
"blah",
|
"blah",
|
||||||
"clap",
|
"clap",
|
||||||
"ed25519-dalek",
|
"ed25519-dalek",
|
||||||
|
|
|
@ -6,6 +6,7 @@ edition = "2021"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
axum = { version = "0.7", features = ["tokio"] }
|
axum = { version = "0.7", features = ["tokio"] }
|
||||||
|
axum-extra = "0.9"
|
||||||
clap = { version = "4", features = ["derive"] }
|
clap = { version = "4", features = ["derive"] }
|
||||||
ed25519-dalek = "2"
|
ed25519-dalek = "2"
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
|
|
|
@ -33,6 +33,9 @@ paths:
|
||||||
description: UUID of the newly created room (ruuid).
|
description: UUID of the newly created room (ruuid).
|
||||||
403:
|
403:
|
||||||
description: The user does not have permission to create room.
|
description: The user does not have permission to create room.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
$ref: '#/components/schemas/ApiError'
|
||||||
|
|
||||||
/room/{ruuid}/feed.json:
|
/room/{ruuid}/feed.json:
|
||||||
get:
|
get:
|
||||||
|
@ -45,6 +48,9 @@ paths:
|
||||||
$ref: 'https://www.jsonfeed.org/version/1.1/'
|
$ref: 'https://www.jsonfeed.org/version/1.1/'
|
||||||
404:
|
404:
|
||||||
description: Room does not exist or is private.
|
description: Room does not exist or is private.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
$ref: '#/components/schemas/ApiError'
|
||||||
|
|
||||||
/room/{ruuid}/item:
|
/room/{ruuid}/item:
|
||||||
get:
|
get:
|
||||||
|
@ -68,6 +74,12 @@ paths:
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
x-description: TODO
|
x-description: TODO
|
||||||
|
404:
|
||||||
|
description: |
|
||||||
|
Room does not exist or the user does not have permission to read it.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
$ref: '#/components/schemas/ApiError'
|
||||||
|
|
||||||
post:
|
post:
|
||||||
summary: Post a chat in room {ruuid}
|
summary: Post a chat in room {ruuid}
|
||||||
|
@ -94,10 +106,15 @@ paths:
|
||||||
description: Created chat id (cid).
|
description: Created chat id (cid).
|
||||||
400:
|
400:
|
||||||
description: Body is invalid or fails the verification.
|
description: Body is invalid or fails the verification.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
$ref: '#/components/schemas/ApiError'
|
||||||
403:
|
403:
|
||||||
description: The user does not have permission to post in this room.
|
description: |
|
||||||
404:
|
The user does not have permission to post in this room, or the room does not exist.
|
||||||
description: Room not found.
|
content:
|
||||||
|
application/json:
|
||||||
|
$ref: '#/components/schemas/ApiError'
|
||||||
|
|
||||||
/room/{ruuid}/event:
|
/room/{ruuid}/event:
|
||||||
get:
|
get:
|
||||||
|
@ -118,5 +135,24 @@ paths:
|
||||||
x-description: An event stream, each event is a JSON with type WithSig<ChatPayload>
|
x-description: An event stream, each event is a JSON with type WithSig<ChatPayload>
|
||||||
400:
|
400:
|
||||||
description: Body is invalid or fails the verification.
|
description: Body is invalid or fails the verification.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
$ref: '#/components/schemas/ApiError'
|
||||||
404:
|
404:
|
||||||
description: Room not found.
|
description: Room not found.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
$ref: '#/components/schemas/ApiError'
|
||||||
|
|
||||||
|
components:
|
||||||
|
schemas:
|
||||||
|
ApiError:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
error:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
code:
|
||||||
|
type: string
|
||||||
|
message:
|
||||||
|
type: string
|
||||||
|
|
|
@ -6,18 +6,19 @@ use std::sync::{Arc, Mutex};
|
||||||
use std::time::{Duration, SystemTime};
|
use std::time::{Duration, SystemTime};
|
||||||
|
|
||||||
use anyhow::{ensure, Context, Result};
|
use anyhow::{ensure, Context, Result};
|
||||||
use axum::extract::{FromRef, FromRequest, FromRequestParts, Path, Query, Request, State};
|
use axum::extract::{Path, Query, State};
|
||||||
use axum::http::{header, request, StatusCode};
|
use axum::http::{header, StatusCode};
|
||||||
use axum::response::{sse, IntoResponse, Response};
|
use axum::response::{sse, IntoResponse};
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{async_trait, Json, Router};
|
use axum::{Json, Router};
|
||||||
|
use axum_extra::extract::WithRejection;
|
||||||
use blah::types::{
|
use blah::types::{
|
||||||
AuthPayload, ChatItem, ChatPayload, CreateRoomPayload, MemberPermission, RoomAttrs,
|
ChatItem, ChatPayload, CreateRoomPayload, MemberPermission, RoomAttrs, ServerPermission,
|
||||||
ServerPermission, Signee, UserKey, WithSig,
|
Signee, UserKey, WithSig,
|
||||||
};
|
};
|
||||||
use ed25519_dalek::SIGNATURE_LENGTH;
|
use ed25519_dalek::SIGNATURE_LENGTH;
|
||||||
|
use middleware::{ApiError, OptionalAuth, SignedJson};
|
||||||
use rusqlite::{named_params, params, OptionalExtension, Row};
|
use rusqlite::{named_params, params, OptionalExtension, Row};
|
||||||
use serde::de::DeserializeOwned;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tokio::sync::broadcast;
|
use tokio::sync::broadcast;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
|
@ -29,6 +30,8 @@ const EVENT_QUEUE_LEN: usize = 1024;
|
||||||
const MAX_BODY_LEN: usize = 4 << 10; // 4KiB
|
const MAX_BODY_LEN: usize = 4 << 10; // 4KiB
|
||||||
const TIMESTAMP_TOLERENCE: u64 = 90;
|
const TIMESTAMP_TOLERENCE: u64 = 90;
|
||||||
|
|
||||||
|
#[macro_use]
|
||||||
|
mod middleware;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
#[derive(Debug, clap::Parser)]
|
#[derive(Debug, clap::Parser)]
|
||||||
|
@ -93,24 +96,38 @@ impl AppState {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn verify_signed_data<T: Serialize>(&self, data: &WithSig<T>) -> Result<()> {
|
fn verify_signed_data<T: Serialize>(&self, data: &WithSig<T>) -> Result<(), ApiError> {
|
||||||
data.verify().context("unsigned payload")?;
|
let Ok(()) = data.verify() else {
|
||||||
|
return Err(error_response!(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"invalid_signature",
|
||||||
|
"signature verification failed"
|
||||||
|
));
|
||||||
|
};
|
||||||
let timestamp_diff = SystemTime::now()
|
let timestamp_diff = SystemTime::now()
|
||||||
.duration_since(SystemTime::UNIX_EPOCH)
|
.duration_since(SystemTime::UNIX_EPOCH)
|
||||||
.expect("after UNIX epoch")
|
.expect("after UNIX epoch")
|
||||||
.as_secs()
|
.as_secs()
|
||||||
.abs_diff(data.signee.timestamp);
|
.abs_diff(data.signee.timestamp);
|
||||||
ensure!(
|
if timestamp_diff > TIMESTAMP_TOLERENCE {
|
||||||
timestamp_diff <= TIMESTAMP_TOLERENCE,
|
return Err(error_response!(
|
||||||
"invalid timestamp, off by {timestamp_diff}s"
|
StatusCode::BAD_REQUEST,
|
||||||
);
|
"invalid_timestamp",
|
||||||
ensure!(
|
"invalid timestamp, off by {timestamp_diff}s"
|
||||||
self.used_nonces
|
));
|
||||||
.lock()
|
}
|
||||||
.unwrap()
|
if !self
|
||||||
.try_insert(data.signee.nonce),
|
.used_nonces
|
||||||
"duplicated nonce",
|
.lock()
|
||||||
);
|
.unwrap()
|
||||||
|
.try_insert(data.signee.nonce)
|
||||||
|
{
|
||||||
|
return Err(error_response!(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"duplicated_nonce",
|
||||||
|
"duplicated nonce",
|
||||||
|
));
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -125,9 +142,9 @@ async fn main_async(opt: Cli, st: AppState) -> Result<()> {
|
||||||
.route("/room/:ruuid/event", get(room_event))
|
.route("/room/:ruuid/event", get(room_event))
|
||||||
.route("/room/:ruuid/item", get(room_get_item).post(room_post_item))
|
.route("/room/:ruuid/item", get(room_get_item).post(room_post_item))
|
||||||
.with_state(Arc::new(st))
|
.with_state(Arc::new(st))
|
||||||
.layer(tower_http::limit::RequestBodyLimitLayer::new(MAX_BODY_LEN))
|
|
||||||
// NB. This comes at last (outmost layer), so inner errors will still be wraped with
|
// NB. This comes at last (outmost layer), so inner errors will still be wraped with
|
||||||
// correct CORS headers.
|
// correct CORS headers.
|
||||||
|
.layer(tower_http::limit::RequestBodyLimitLayer::new(MAX_BODY_LEN))
|
||||||
.layer(tower_http::cors::CorsLayer::permissive());
|
.layer(tower_http::cors::CorsLayer::permissive());
|
||||||
|
|
||||||
let listener = tokio::net::TcpListener::bind(&opt.listen)
|
let listener = tokio::net::TcpListener::bind(&opt.listen)
|
||||||
|
@ -143,26 +160,20 @@ async fn main_async(opt: Cli, st: AppState) -> Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_db_error(err: rusqlite::Error) -> StatusCode {
|
|
||||||
match err {
|
|
||||||
rusqlite::Error::QueryReturnedNoRows => StatusCode::NOT_FOUND,
|
|
||||||
err => {
|
|
||||||
tracing::error!(%err, "database error");
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn room_create(
|
async fn room_create(
|
||||||
st: ArcState,
|
st: ArcState,
|
||||||
SignedJson(params): SignedJson<CreateRoomPayload>,
|
SignedJson(params): SignedJson<CreateRoomPayload>,
|
||||||
) -> Result<Json<Uuid>, StatusCode> {
|
) -> Result<Json<Uuid>, ApiError> {
|
||||||
let members = ¶ms.signee.payload.members.0;
|
let members = ¶ms.signee.payload.members.0;
|
||||||
if !members
|
if !members
|
||||||
.iter()
|
.iter()
|
||||||
.any(|m| m.user == params.signee.user && m.permission == MemberPermission::ALL)
|
.any(|m| m.user == params.signee.user && m.permission == MemberPermission::ALL)
|
||||||
{
|
{
|
||||||
return Err(StatusCode::BAD_REQUEST);
|
return Err(error_response!(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"deserialization",
|
||||||
|
"invalid initial members",
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut conn = st.conn.lock().unwrap();
|
let mut conn = st.conn.lock().unwrap();
|
||||||
|
@ -179,57 +190,56 @@ async fn room_create(
|
||||||
Ok(perm.contains(ServerPermission::CREATE_ROOM))
|
Ok(perm.contains(ServerPermission::CREATE_ROOM))
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.optional()
|
.optional()?
|
||||||
.map_err(from_db_error)?
|
|
||||||
else {
|
else {
|
||||||
return Err(StatusCode::FORBIDDEN);
|
return Err(error_response!(
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
"permission_denied",
|
||||||
|
"user does not have permission to create room",
|
||||||
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
let ruuid = Uuid::new_v4();
|
let ruuid = Uuid::new_v4();
|
||||||
|
|
||||||
(|| {
|
let txn = conn.transaction()?;
|
||||||
let txn = conn.transaction()?;
|
let rid = txn.query_row(
|
||||||
let rid = txn.query_row(
|
r"
|
||||||
r"
|
INSERT INTO `room` (`ruuid`, `title`)
|
||||||
INSERT INTO `room` (`ruuid`, `title`)
|
VALUES (:ruuid, :title)
|
||||||
VALUES (:ruuid, :title)
|
RETURNING `rid`
|
||||||
RETURNING `rid`
|
",
|
||||||
",
|
named_params! {
|
||||||
named_params! {
|
":ruuid": ruuid,
|
||||||
":ruuid": ruuid,
|
":title": params.signee.payload.title,
|
||||||
":title": params.signee.payload.title,
|
},
|
||||||
},
|
|row| row.get::<_, u64>(0),
|
||||||
|row| row.get::<_, u64>(0),
|
)?;
|
||||||
)?;
|
let mut insert_user = txn.prepare(
|
||||||
let mut insert_user = txn.prepare(
|
r"
|
||||||
r"
|
INSERT INTO `user` (`userkey`)
|
||||||
INSERT INTO `user` (`userkey`)
|
VALUES (?)
|
||||||
VALUES (?)
|
ON CONFLICT (`userkey`) DO NOTHING
|
||||||
ON CONFLICT (`userkey`) DO NOTHING
|
",
|
||||||
",
|
)?;
|
||||||
)?;
|
let mut insert_member = txn.prepare(
|
||||||
let mut insert_member = txn.prepare(
|
r"
|
||||||
r"
|
INSERT INTO `room_member` (`rid`, `uid`, `permission`)
|
||||||
INSERT INTO `room_member` (`rid`, `uid`, `permission`)
|
SELECT :rid, `uid`, :permission
|
||||||
SELECT :rid, `uid`, :permission
|
FROM `user`
|
||||||
FROM `user`
|
WHERE `userkey` = :userkey
|
||||||
WHERE `userkey` = :userkey
|
",
|
||||||
",
|
)?;
|
||||||
)?;
|
for member in members {
|
||||||
for member in members {
|
insert_user.execute(params![member.user])?;
|
||||||
insert_user.execute(params![member.user])?;
|
insert_member.execute(named_params! {
|
||||||
insert_member.execute(named_params! {
|
":rid": rid,
|
||||||
":rid": rid,
|
":userkey": member.user,
|
||||||
":userkey": member.user,
|
":permission": member.permission,
|
||||||
":permission": member.permission,
|
})?;
|
||||||
})?;
|
}
|
||||||
}
|
drop(insert_member);
|
||||||
drop(insert_member);
|
drop(insert_user);
|
||||||
drop(insert_user);
|
txn.commit()?;
|
||||||
txn.commit()?;
|
|
||||||
Ok(())
|
|
||||||
})()
|
|
||||||
.map_err(from_db_error)?;
|
|
||||||
|
|
||||||
Ok(Json(ruuid))
|
Ok(Json(ruuid))
|
||||||
}
|
}
|
||||||
|
@ -246,13 +256,12 @@ struct GetRoomItemParams {
|
||||||
|
|
||||||
async fn room_get_item(
|
async fn room_get_item(
|
||||||
st: ArcState,
|
st: ArcState,
|
||||||
Path(ruuid): Path<Uuid>,
|
WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>,
|
||||||
params: Query<GetRoomItemParams>,
|
WithRejection(params, _): WithRejection<Query<GetRoomItemParams>, ApiError>,
|
||||||
OptionalAuth(user): OptionalAuth,
|
OptionalAuth(user): OptionalAuth,
|
||||||
) -> Result<impl IntoResponse, StatusCode> {
|
) -> Result<impl IntoResponse, ApiError> {
|
||||||
let (room_meta, items) =
|
let (room_meta, items) =
|
||||||
query_room_items(&st.conn.lock().unwrap(), ruuid, user.as_ref(), ¶ms)
|
query_room_items(&st.conn.lock().unwrap(), ruuid, user.as_ref(), ¶ms)?;
|
||||||
.map_err(from_db_error)?;
|
|
||||||
|
|
||||||
// TODO: This format is to-be-decided. Or do we even need this interface other than
|
// TODO: This format is to-be-decided. Or do we even need this interface other than
|
||||||
// `feed.json`?
|
// `feed.json`?
|
||||||
|
@ -261,11 +270,10 @@ async fn room_get_item(
|
||||||
|
|
||||||
async fn room_get_feed(
|
async fn room_get_feed(
|
||||||
st: ArcState,
|
st: ArcState,
|
||||||
Path(ruuid): Path<Uuid>,
|
WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>,
|
||||||
params: Query<GetRoomItemParams>,
|
params: Query<GetRoomItemParams>,
|
||||||
) -> Result<impl IntoResponse, StatusCode> {
|
) -> Result<impl IntoResponse, ApiError> {
|
||||||
let (room_meta, items) =
|
let (room_meta, items) = query_room_items(&st.conn.lock().unwrap(), ruuid, None, ¶ms)?;
|
||||||
query_room_items(&st.conn.lock().unwrap(), ruuid, None, ¶ms).map_err(from_db_error)?;
|
|
||||||
|
|
||||||
let items = items
|
let items = items
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
@ -352,7 +360,7 @@ fn get_room_if_readable<T>(
|
||||||
ruuid: Uuid,
|
ruuid: Uuid,
|
||||||
user: Option<&UserKey>,
|
user: Option<&UserKey>,
|
||||||
f: impl FnOnce(&Row<'_>) -> rusqlite::Result<T>,
|
f: impl FnOnce(&Row<'_>) -> rusqlite::Result<T>,
|
||||||
) -> rusqlite::Result<T> {
|
) -> Result<T, ApiError> {
|
||||||
conn.query_row(
|
conn.query_row(
|
||||||
r"
|
r"
|
||||||
SELECT `rid`, `title`, `attrs`
|
SELECT `rid`, `title`, `attrs`
|
||||||
|
@ -372,6 +380,8 @@ fn get_room_if_readable<T>(
|
||||||
},
|
},
|
||||||
f,
|
f,
|
||||||
)
|
)
|
||||||
|
.optional()?
|
||||||
|
.ok_or_else(|| error_response!(StatusCode::NOT_FOUND, "not_found", "room not found"))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn query_room_items(
|
fn query_room_items(
|
||||||
|
@ -379,7 +389,7 @@ fn query_room_items(
|
||||||
ruuid: Uuid,
|
ruuid: Uuid,
|
||||||
user: Option<&UserKey>,
|
user: Option<&UserKey>,
|
||||||
params: &GetRoomItemParams,
|
params: &GetRoomItemParams,
|
||||||
) -> rusqlite::Result<(RoomMetadata, Vec<(u64, ChatItem)>)> {
|
) -> Result<(RoomMetadata, Vec<(u64, ChatItem)>), ApiError> {
|
||||||
let (rid, title, attrs) = get_room_if_readable(conn, ruuid, user, |row| {
|
let (rid, title, attrs) = get_room_if_readable(conn, ruuid, user, |row| {
|
||||||
Ok((
|
Ok((
|
||||||
row.get::<_, u64>("rid")?,
|
row.get::<_, u64>("rid")?,
|
||||||
|
@ -430,76 +440,17 @@ fn query_room_items(
|
||||||
Ok((room_meta, items))
|
Ok((room_meta, items))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extractor for verified JSON payload.
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct SignedJson<T>(WithSig<T>);
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl<S, T> FromRequest<S> for SignedJson<T>
|
|
||||||
where
|
|
||||||
S: Send + Sync,
|
|
||||||
T: Serialize + DeserializeOwned,
|
|
||||||
Arc<AppState>: FromRef<S>,
|
|
||||||
{
|
|
||||||
type Rejection = Response;
|
|
||||||
|
|
||||||
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
|
|
||||||
let Json(data) = <Json<WithSig<T>> as FromRequest<S>>::from_request(req, state)
|
|
||||||
.await
|
|
||||||
.map_err(|err| err.into_response())?;
|
|
||||||
let st = <Arc<AppState>>::from_ref(state);
|
|
||||||
st.verify_signed_data(&data).map_err(|err| {
|
|
||||||
tracing::debug!(%err, "unsigned payload");
|
|
||||||
StatusCode::BAD_REQUEST.into_response()
|
|
||||||
})?;
|
|
||||||
Ok(Self(data))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Extractor for optional verified JSON authorization header.
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct OptionalAuth(Option<UserKey>);
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl<S> FromRequestParts<S> for OptionalAuth
|
|
||||||
where
|
|
||||||
S: Send + Sync,
|
|
||||||
Arc<AppState>: FromRef<S>,
|
|
||||||
{
|
|
||||||
type Rejection = StatusCode;
|
|
||||||
|
|
||||||
async fn from_request_parts(
|
|
||||||
parts: &mut request::Parts,
|
|
||||||
state: &S,
|
|
||||||
) -> Result<Self, Self::Rejection> {
|
|
||||||
let Some(auth) = parts.headers.get(header::AUTHORIZATION) else {
|
|
||||||
return Ok(Self(None));
|
|
||||||
};
|
|
||||||
|
|
||||||
let st = <Arc<AppState>>::from_ref(state);
|
|
||||||
let ret = serde_json::from_slice::<WithSig<AuthPayload>>(auth.as_bytes())
|
|
||||||
.context("invalid JSON")
|
|
||||||
.and_then(|data| {
|
|
||||||
st.verify_signed_data(&data)?;
|
|
||||||
Ok(data.signee.user)
|
|
||||||
});
|
|
||||||
match ret {
|
|
||||||
Ok(user) => Ok(Self(Some(user))),
|
|
||||||
Err(err) => {
|
|
||||||
tracing::debug!(%err, "invalid authorization");
|
|
||||||
Err(StatusCode::BAD_REQUEST)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn room_post_item(
|
async fn room_post_item(
|
||||||
st: ArcState,
|
st: ArcState,
|
||||||
Path(ruuid): Path<Uuid>,
|
Path(ruuid): Path<Uuid>,
|
||||||
SignedJson(chat): SignedJson<ChatPayload>,
|
SignedJson(chat): SignedJson<ChatPayload>,
|
||||||
) -> Result<Json<u64>, StatusCode> {
|
) -> Result<Json<u64>, ApiError> {
|
||||||
if ruuid != chat.signee.payload.room {
|
if ruuid != chat.signee.payload.room {
|
||||||
return Err(StatusCode::BAD_REQUEST);
|
return Err(error_response!(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"invalid_request",
|
||||||
|
"URI and payload room id mismatch",
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let (rid, cid) = {
|
let (rid, cid) = {
|
||||||
|
@ -522,41 +473,39 @@ async fn room_post_item(
|
||||||
},
|
},
|
||||||
|row| Ok((row.get::<_, u64>("rid")?, row.get::<_, u64>("uid")?)),
|
|row| Ok((row.get::<_, u64>("rid")?, row.get::<_, u64>("uid")?)),
|
||||||
)
|
)
|
||||||
.optional()
|
.optional()?
|
||||||
.map_err(from_db_error)?
|
|
||||||
else {
|
else {
|
||||||
tracing::debug!("rejected post: unpermitted user {}", chat.signee.user);
|
return Err(error_response!(
|
||||||
return Err(StatusCode::FORBIDDEN);
|
StatusCode::FORBIDDEN,
|
||||||
|
"permission_denied",
|
||||||
|
"the user does not have permission to post in this room",
|
||||||
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
let cid = conn
|
let cid = conn.query_row(
|
||||||
.query_row(
|
r"
|
||||||
r"
|
INSERT INTO `room_item` (`rid`, `uid`, `timestamp`, `nonce`, `sig`, `rich_text`)
|
||||||
INSERT INTO `room_item` (`rid`, `uid`, `timestamp`, `nonce`, `sig`, `rich_text`)
|
VALUES (:rid, :uid, :timestamp, :nonce, :sig, :rich_text)
|
||||||
VALUES (:rid, :uid, :timestamp, :nonce, :sig, :rich_text)
|
RETURNING `cid`
|
||||||
RETURNING `cid`
|
",
|
||||||
",
|
named_params! {
|
||||||
named_params! {
|
":rid": rid,
|
||||||
":rid": rid,
|
":uid": uid,
|
||||||
":uid": uid,
|
":timestamp": chat.signee.timestamp,
|
||||||
":timestamp": chat.signee.timestamp,
|
":nonce": chat.signee.nonce,
|
||||||
":nonce": chat.signee.nonce,
|
":rich_text": &chat.signee.payload.rich_text,
|
||||||
":rich_text": &chat.signee.payload.rich_text,
|
":sig": chat.sig,
|
||||||
":sig": chat.sig,
|
},
|
||||||
},
|
|row| row.get::<_, u64>(0),
|
||||||
|row| row.get::<_, u64>(0),
|
)?;
|
||||||
)
|
|
||||||
.map_err(from_db_error)?;
|
|
||||||
(rid, cid)
|
(rid, cid)
|
||||||
};
|
};
|
||||||
|
|
||||||
{
|
let mut listeners = st.room_listeners.lock().unwrap();
|
||||||
let mut listeners = st.room_listeners.lock().unwrap();
|
if let Some(tx) = listeners.get(&rid) {
|
||||||
if let Some(tx) = listeners.get(&rid) {
|
if tx.send(Arc::new(chat)).is_err() {
|
||||||
if tx.send(Arc::new(chat)).is_err() {
|
// Clean up because all receivers died.
|
||||||
// Clean up because all receivers died.
|
listeners.remove(&rid);
|
||||||
listeners.remove(&rid);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -570,11 +519,10 @@ async fn room_event(
|
||||||
// But this API is kinda temporary and need a better replacement anyway.
|
// But this API is kinda temporary and need a better replacement anyway.
|
||||||
// So just only support public room for now.
|
// So just only support public room for now.
|
||||||
OptionalAuth(user): OptionalAuth,
|
OptionalAuth(user): OptionalAuth,
|
||||||
) -> Result<impl IntoResponse, StatusCode> {
|
) -> Result<impl IntoResponse, ApiError> {
|
||||||
let rid = get_room_if_readable(&st.conn.lock().unwrap(), ruuid, user.as_ref(), |row| {
|
let rid = get_room_if_readable(&st.conn.lock().unwrap(), ruuid, user.as_ref(), |row| {
|
||||||
row.get::<_, u64>(0)
|
row.get::<_, u64>(0)
|
||||||
})
|
})?;
|
||||||
.map_err(from_db_error)?;
|
|
||||||
|
|
||||||
let rx = match st.room_listeners.lock().unwrap().entry(rid) {
|
let rx = match st.room_listeners.lock().unwrap().entry(rid) {
|
||||||
Entry::Occupied(ent) => ent.get().subscribe(),
|
Entry::Occupied(ent) => ent.get().subscribe(),
|
||||||
|
|
135
blahd/src/middleware.rs
Normal file
135
blahd/src/middleware.rs
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::extract::rejection::{JsonRejection, PathRejection, QueryRejection};
|
||||||
|
use axum::extract::{FromRef, FromRequest, FromRequestParts, Request};
|
||||||
|
use axum::http::{header, request, StatusCode};
|
||||||
|
use axum::response::{IntoResponse, Response};
|
||||||
|
use axum::{async_trait, Json};
|
||||||
|
use blah::types::{AuthPayload, UserKey, WithSig};
|
||||||
|
use serde::de::DeserializeOwned;
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
use crate::AppState;
|
||||||
|
|
||||||
|
/// Error response body for json endpoints.
|
||||||
|
///
|
||||||
|
/// Mostly following: <https://learn.microsoft.com/en-us/graph/errors>
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct ApiError {
|
||||||
|
#[serde(skip)]
|
||||||
|
pub status: StatusCode,
|
||||||
|
pub code: &'static str,
|
||||||
|
pub message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! error_response {
|
||||||
|
($status:expr, $code:literal, $msg:literal $(, $msg_args:expr)* $(,)?) => {
|
||||||
|
$crate::middleware::ApiError {
|
||||||
|
status: $status,
|
||||||
|
code: $code,
|
||||||
|
message: ::std::format!($msg $(, $msg_args)*),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for ApiError {
|
||||||
|
fn into_response(self) -> Response {
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct Resp<'a> {
|
||||||
|
error: &'a ApiError,
|
||||||
|
}
|
||||||
|
let mut resp = Json(Resp { error: &self }).into_response();
|
||||||
|
*resp.status_mut() = self.status;
|
||||||
|
resp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! define_from_deser_rejection {
|
||||||
|
($($ty:ty, $name:literal;)*) => {
|
||||||
|
$(
|
||||||
|
impl From<$ty> for ApiError {
|
||||||
|
fn from(rej: $ty) -> Self {
|
||||||
|
error_response!(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"deserialization",
|
||||||
|
"invalid {}: {}",
|
||||||
|
$name,
|
||||||
|
rej,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)*
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
define_from_deser_rejection! {
|
||||||
|
JsonRejection, "json";
|
||||||
|
QueryRejection, "query";
|
||||||
|
PathRejection, "path";
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<rusqlite::Error> for ApiError {
|
||||||
|
fn from(err: rusqlite::Error) -> Self {
|
||||||
|
tracing::error!(%err, "database error");
|
||||||
|
error_response!(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
"server_error",
|
||||||
|
"internal server error",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extractor for verified JSON payload.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct SignedJson<T>(pub WithSig<T>);
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<S, T> FromRequest<S> for SignedJson<T>
|
||||||
|
where
|
||||||
|
S: Send + Sync,
|
||||||
|
T: Serialize + DeserializeOwned,
|
||||||
|
Arc<AppState>: FromRef<S>,
|
||||||
|
{
|
||||||
|
type Rejection = ApiError;
|
||||||
|
|
||||||
|
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
|
let Json(data) = <Json<WithSig<T>> as FromRequest<S>>::from_request(req, state).await?;
|
||||||
|
let st = <Arc<AppState>>::from_ref(state);
|
||||||
|
st.verify_signed_data(&data)?;
|
||||||
|
Ok(Self(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extractor for optional verified JSON authorization header.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct OptionalAuth(pub Option<UserKey>);
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<S> FromRequestParts<S> for OptionalAuth
|
||||||
|
where
|
||||||
|
S: Send + Sync,
|
||||||
|
Arc<AppState>: FromRef<S>,
|
||||||
|
{
|
||||||
|
type Rejection = ApiError;
|
||||||
|
|
||||||
|
async fn from_request_parts(
|
||||||
|
parts: &mut request::Parts,
|
||||||
|
state: &S,
|
||||||
|
) -> Result<Self, Self::Rejection> {
|
||||||
|
let Some(auth) = parts.headers.get(header::AUTHORIZATION) else {
|
||||||
|
return Ok(Self(None));
|
||||||
|
};
|
||||||
|
|
||||||
|
let st = <Arc<AppState>>::from_ref(state);
|
||||||
|
let data =
|
||||||
|
serde_json::from_slice::<WithSig<AuthPayload>>(auth.as_bytes()).map_err(|err| {
|
||||||
|
error_response!(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"deserialization",
|
||||||
|
"invalid authorization header: {err}",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
st.verify_signed_data(&data)?;
|
||||||
|
Ok(Self(Some(data.signee.user)))
|
||||||
|
}
|
||||||
|
}
|
|
@ -177,12 +177,12 @@ async function connectRoom(url) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.then((resp) => {
|
.then(async (resp) => {
|
||||||
if (!resp.ok) throw new Error(`status ${resp.status} ${resp.statusText}`);
|
return [resp.status, await resp.json()];
|
||||||
return resp.json();
|
|
||||||
})
|
})
|
||||||
// TODO: This response format is to-be-decided.
|
// TODO: This response format is to-be-decided.
|
||||||
.then(async (json) => {
|
.then(async ([status, json]) => {
|
||||||
|
if (status !== 200) throw new Error(`status ${status}: ${json.error.message}`);
|
||||||
const [{ title }, items] = json
|
const [{ title }, items] = json
|
||||||
document.title = `room: ${title}`
|
document.title = `room: ${title}`
|
||||||
items.reverse();
|
items.reverse();
|
||||||
|
@ -256,7 +256,10 @@ async function postChat(text) {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
if (!resp.ok) throw new Error(`status ${resp.status} ${resp.statusText}`);
|
if (!resp.ok) {
|
||||||
|
const errResp = await resp.json();
|
||||||
|
throw new Error(`status ${resp.status}: ${errResp.error.message}`);
|
||||||
|
}
|
||||||
chatInput.value = '';
|
chatInput.value = '';
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error(e);
|
console.error(e);
|
||||||
|
|
Loading…
Add table
Reference in a new issue