mirror of
https://github.com/Blah-IM/blahrs.git
synced 2025-05-01 08:41:09 +00:00
193 lines
5.1 KiB
Rust
193 lines
5.1 KiB
Rust
use std::backtrace::Backtrace;
|
|
use std::fmt;
|
|
use std::sync::Arc;
|
|
|
|
use axum::extract::rejection::{JsonRejection, PathRejection, QueryRejection};
|
|
use axum::extract::{FromRef, FromRequest, FromRequestParts, Request};
|
|
use axum::http::{header, request, StatusCode};
|
|
use axum::response::{IntoResponse, Response};
|
|
use axum::{async_trait, Json};
|
|
use blah_types::{AuthPayload, Signed, UserKey};
|
|
use serde::de::DeserializeOwned;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::AppState;
|
|
|
|
/// Error response body for json endpoints.
|
|
///
|
|
/// Mostly following: <https://learn.microsoft.com/en-us/graph/errors>
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct ApiError {
|
|
#[serde(skip, default)]
|
|
pub status: StatusCode,
|
|
pub code: String,
|
|
pub message: String,
|
|
}
|
|
|
|
impl fmt::Display for ApiError {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
write!(
|
|
f,
|
|
"api error status={} code={}: {}",
|
|
self.status, self.code, self.message,
|
|
)
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for ApiError {}
|
|
|
|
macro_rules! error_response {
|
|
($status:expr, $code:literal, $msg:literal $(, $msg_args:expr)* $(,)?) => {
|
|
$crate::middleware::ApiError {
|
|
status: $status,
|
|
code: $code.to_owned(),
|
|
message: ::std::format!($msg $(, $msg_args)*),
|
|
}
|
|
};
|
|
}
|
|
|
|
impl IntoResponse for ApiError {
|
|
fn into_response(self) -> Response {
|
|
#[derive(Serialize)]
|
|
struct Resp<'a> {
|
|
error: &'a ApiError,
|
|
}
|
|
let mut resp = Json(Resp { error: &self }).into_response();
|
|
*resp.status_mut() = self.status;
|
|
resp
|
|
}
|
|
}
|
|
|
|
macro_rules! define_from_deser_rejection {
|
|
($($ty:ty, $name:literal;)*) => {
|
|
$(
|
|
impl From<$ty> for ApiError {
|
|
fn from(rej: $ty) -> Self {
|
|
tracing::debug!(?rej, "rejected");
|
|
error_response!(
|
|
StatusCode::BAD_REQUEST,
|
|
"deserialization",
|
|
"invalid {}: {}",
|
|
$name,
|
|
rej,
|
|
)
|
|
}
|
|
}
|
|
)*
|
|
};
|
|
}
|
|
|
|
define_from_deser_rejection! {
|
|
JsonRejection, "json";
|
|
QueryRejection, "query";
|
|
PathRejection, "path";
|
|
}
|
|
|
|
impl From<rusqlite::Error> for ApiError {
|
|
fn from(err: rusqlite::Error) -> Self {
|
|
tracing::error!(%err, backtrace = %Backtrace::force_capture(), "database error");
|
|
error_response!(
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
"server_error",
|
|
"internal server error",
|
|
)
|
|
}
|
|
}
|
|
|
|
/// Extractor for verified JSON payload.
|
|
#[derive(Debug)]
|
|
pub struct SignedJson<T>(pub Signed<T>);
|
|
|
|
#[async_trait]
|
|
impl<S, T> FromRequest<S> for SignedJson<T>
|
|
where
|
|
S: Send + Sync,
|
|
T: Serialize + DeserializeOwned,
|
|
Arc<AppState>: FromRef<S>,
|
|
{
|
|
type Rejection = ApiError;
|
|
|
|
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
|
|
let Json(data) = <Json<Signed<T>> as FromRequest<S>>::from_request(req, state).await?;
|
|
let st = <Arc<AppState>>::from_ref(state);
|
|
st.verify_signed_data(&data)?;
|
|
Ok(Self(data))
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum AuthRejection {
|
|
None,
|
|
Invalid(ApiError),
|
|
}
|
|
|
|
impl From<AuthRejection> for ApiError {
|
|
fn from(rej: AuthRejection) -> Self {
|
|
match rej {
|
|
AuthRejection::None => error_response!(
|
|
StatusCode::UNAUTHORIZED,
|
|
"unauthorized",
|
|
"missing authorization header"
|
|
),
|
|
AuthRejection::Invalid(err) => err,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl IntoResponse for AuthRejection {
|
|
fn into_response(self) -> Response {
|
|
ApiError::from(self).into_response()
|
|
}
|
|
}
|
|
|
|
pub trait ResultExt {
|
|
fn into_optional(self) -> Result<Option<UserKey>, ApiError>;
|
|
}
|
|
|
|
impl ResultExt for Result<Auth, AuthRejection> {
|
|
fn into_optional(self) -> Result<Option<UserKey>, ApiError> {
|
|
match self {
|
|
Ok(auth) => Ok(Some(auth.0)),
|
|
Err(AuthRejection::None) => Ok(None),
|
|
Err(AuthRejection::Invalid(err)) => Err(err),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub type MaybeAuth = Result<Auth, AuthRejection>;
|
|
|
|
/// Extractor for verified JSON authorization header.
|
|
#[derive(Debug)]
|
|
pub struct Auth(pub UserKey);
|
|
|
|
#[async_trait]
|
|
impl<S> FromRequestParts<S> for Auth
|
|
where
|
|
S: Send + Sync,
|
|
Arc<AppState>: FromRef<S>,
|
|
{
|
|
type Rejection = AuthRejection;
|
|
|
|
async fn from_request_parts(
|
|
parts: &mut request::Parts,
|
|
state: &S,
|
|
) -> Result<Self, Self::Rejection> {
|
|
let auth = parts
|
|
.headers
|
|
.get(header::AUTHORIZATION)
|
|
.ok_or(AuthRejection::None)?;
|
|
|
|
let st = <Arc<AppState>>::from_ref(state);
|
|
let data =
|
|
serde_json::from_slice::<Signed<AuthPayload>>(auth.as_bytes()).map_err(|err| {
|
|
AuthRejection::Invalid(error_response!(
|
|
StatusCode::BAD_REQUEST,
|
|
"deserialization",
|
|
"invalid authorization header: {err}",
|
|
))
|
|
})?;
|
|
st.verify_signed_data(&data)
|
|
.map_err(AuthRejection::Invalid)?;
|
|
Ok(Self(data.signee.user))
|
|
}
|
|
}
|