Use proper rejection type for Auth middleware

This commit is contained in:
oxalica 2024-09-03 19:48:50 -04:00
parent 99d1311d63
commit d3c3961298
2 changed files with 67 additions and 29 deletions

View file

@ -18,7 +18,7 @@ use blah::types::{
use config::Config; use config::Config;
use database::Database; use database::Database;
use ed25519_dalek::SIGNATURE_LENGTH; use ed25519_dalek::SIGNATURE_LENGTH;
use middleware::{ApiError, OptionalAuth, SignedJson}; use middleware::{ApiError, MaybeAuth, ResultExt as _, SignedJson};
use parking_lot::Mutex; use parking_lot::Mutex;
use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql}; use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -216,7 +216,7 @@ enum ListRoomFilter {
async fn room_list( async fn room_list(
st: ArcState, st: ArcState,
WithRejection(params, _): WithRejection<Query<ListRoomParams>, ApiError>, WithRejection(params, _): WithRejection<Query<ListRoomParams>, ApiError>,
OptionalAuth(user): OptionalAuth, auth: MaybeAuth,
) -> Result<Json<RoomList>, ApiError> { ) -> Result<Json<RoomList>, ApiError> {
let pagination = Pagination { let pagination = Pagination {
skip_token: params.skip_token, skip_token: params.skip_token,
@ -287,13 +287,7 @@ async fn room_list(
}, },
), ),
ListRoomFilter::Joined => { ListRoomFilter::Joined => {
let Some(user) = user else { let user = auth?.0;
return Err(error_response!(
StatusCode::UNAUTHORIZED,
"unauthorized",
"missing Authorization header for listing joined rooms",
));
};
query( query(
r" r"
SELECT SELECT
@ -438,11 +432,11 @@ async fn room_get_item(
st: ArcState, st: ArcState,
WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>, WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>,
WithRejection(Query(pagination), _): WithRejection<Query<Pagination>, ApiError>, WithRejection(Query(pagination), _): WithRejection<Query<Pagination>, ApiError>,
OptionalAuth(user): OptionalAuth, auth: MaybeAuth,
) -> Result<Json<RoomItems>, ApiError> { ) -> Result<Json<RoomItems>, ApiError> {
let (items, skip_token) = { let (items, skip_token) = {
let conn = st.db.get(); let conn = st.db.get();
get_room_if_readable(&conn, ruuid, user.as_ref(), |_row| Ok(()))?; get_room_if_readable(&conn, ruuid, auth.into_optional()?.as_ref(), |_row| Ok(()))?;
query_room_items(&st, &conn, ruuid, pagination)? query_room_items(&st, &conn, ruuid, pagination)?
}; };
let items = items.into_iter().map(|(_, item)| item).collect(); let items = items.into_iter().map(|(_, item)| item).collect();
@ -455,9 +449,10 @@ async fn room_get_item(
async fn room_get_metadata( async fn room_get_metadata(
st: ArcState, st: ArcState,
WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>, WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>,
OptionalAuth(user): OptionalAuth, auth: MaybeAuth,
) -> Result<Json<RoomMetadata>, ApiError> { ) -> Result<Json<RoomMetadata>, ApiError> {
let (title, attrs) = get_room_if_readable(&st.db.get(), ruuid, user.as_ref(), |row| { let (title, attrs) =
get_room_if_readable(&st.db.get(), ruuid, auth.into_optional()?.as_ref(), |row| {
Ok(( Ok((
row.get::<_, String>("title")?, row.get::<_, String>("title")?,
row.get::<_, RoomAttrs>("attrs")?, row.get::<_, RoomAttrs>("attrs")?,

View file

@ -109,36 +109,79 @@ where
} }
} }
/// Extractor for optional verified JSON authorization header.
#[derive(Debug)] #[derive(Debug)]
pub struct OptionalAuth(pub Option<UserKey>); pub enum AuthRejection {
None,
Invalid(ApiError),
}
impl From<AuthRejection> for ApiError {
fn from(rej: AuthRejection) -> Self {
match rej {
AuthRejection::None => error_response!(
StatusCode::UNAUTHORIZED,
"unauthorized",
"missing authorization header"
),
AuthRejection::Invalid(err) => err,
}
}
}
impl IntoResponse for AuthRejection {
fn into_response(self) -> Response {
ApiError::from(self).into_response()
}
}
pub trait ResultExt {
fn into_optional(self) -> Result<Option<UserKey>, ApiError>;
}
impl ResultExt for Result<Auth, AuthRejection> {
fn into_optional(self) -> Result<Option<UserKey>, ApiError> {
match self {
Ok(auth) => Ok(Some(auth.0)),
Err(AuthRejection::None) => Ok(None),
Err(AuthRejection::Invalid(err)) => Err(err),
}
}
}
pub type MaybeAuth = Result<Auth, AuthRejection>;
/// Extractor for verified JSON authorization header.
#[derive(Debug)]
pub struct Auth(pub UserKey);
#[async_trait] #[async_trait]
impl<S> FromRequestParts<S> for OptionalAuth impl<S> FromRequestParts<S> for Auth
where where
S: Send + Sync, S: Send + Sync,
Arc<AppState>: FromRef<S>, Arc<AppState>: FromRef<S>,
{ {
type Rejection = ApiError; type Rejection = AuthRejection;
async fn from_request_parts( async fn from_request_parts(
parts: &mut request::Parts, parts: &mut request::Parts,
state: &S, state: &S,
) -> Result<Self, Self::Rejection> { ) -> Result<Self, Self::Rejection> {
let Some(auth) = parts.headers.get(header::AUTHORIZATION) else { let auth = parts
return Ok(Self(None)); .headers
}; .get(header::AUTHORIZATION)
.ok_or(AuthRejection::None)?;
let st = <Arc<AppState>>::from_ref(state); let st = <Arc<AppState>>::from_ref(state);
let data = let data =
serde_json::from_slice::<WithSig<AuthPayload>>(auth.as_bytes()).map_err(|err| { serde_json::from_slice::<WithSig<AuthPayload>>(auth.as_bytes()).map_err(|err| {
error_response!( AuthRejection::Invalid(error_response!(
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
"deserialization", "deserialization",
"invalid authorization header: {err}", "invalid authorization header: {err}",
) ))
})?; })?;
st.verify_signed_data(&data)?; st.verify_signed_data(&data)
Ok(Self(Some(data.signee.user))) .map_err(AuthRejection::Invalid)?;
Ok(Self(data.signee.user))
} }
} }