diff --git a/blahd/src/main.rs b/blahd/src/main.rs index 75079ef..7ae5515 100644 --- a/blahd/src/main.rs +++ b/blahd/src/main.rs @@ -18,7 +18,7 @@ use blah::types::{ use config::Config; use database::Database; use ed25519_dalek::SIGNATURE_LENGTH; -use middleware::{ApiError, OptionalAuth, SignedJson}; +use middleware::{ApiError, MaybeAuth, ResultExt as _, SignedJson}; use parking_lot::Mutex; use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql}; use serde::{Deserialize, Serialize}; @@ -216,7 +216,7 @@ enum ListRoomFilter { async fn room_list( st: ArcState, WithRejection(params, _): WithRejection, ApiError>, - OptionalAuth(user): OptionalAuth, + auth: MaybeAuth, ) -> Result, ApiError> { let pagination = Pagination { skip_token: params.skip_token, @@ -287,13 +287,7 @@ async fn room_list( }, ), ListRoomFilter::Joined => { - let Some(user) = user else { - return Err(error_response!( - StatusCode::UNAUTHORIZED, - "unauthorized", - "missing Authorization header for listing joined rooms", - )); - }; + let user = auth?.0; query( r" SELECT @@ -438,11 +432,11 @@ async fn room_get_item( st: ArcState, WithRejection(Path(ruuid), _): WithRejection, ApiError>, WithRejection(Query(pagination), _): WithRejection, ApiError>, - OptionalAuth(user): OptionalAuth, + auth: MaybeAuth, ) -> Result, ApiError> { let (items, skip_token) = { let conn = st.db.get(); - get_room_if_readable(&conn, ruuid, user.as_ref(), |_row| Ok(()))?; + get_room_if_readable(&conn, ruuid, auth.into_optional()?.as_ref(), |_row| Ok(()))?; query_room_items(&st, &conn, ruuid, pagination)? }; let items = items.into_iter().map(|(_, item)| item).collect(); @@ -455,14 +449,15 @@ async fn room_get_item( async fn room_get_metadata( st: ArcState, WithRejection(Path(ruuid), _): WithRejection, ApiError>, - OptionalAuth(user): OptionalAuth, + auth: MaybeAuth, ) -> Result, ApiError> { - let (title, attrs) = get_room_if_readable(&st.db.get(), ruuid, user.as_ref(), |row| { - Ok(( - row.get::<_, String>("title")?, - row.get::<_, RoomAttrs>("attrs")?, - )) - })?; + let (title, attrs) = + get_room_if_readable(&st.db.get(), ruuid, auth.into_optional()?.as_ref(), |row| { + Ok(( + row.get::<_, String>("title")?, + row.get::<_, RoomAttrs>("attrs")?, + )) + })?; Ok(Json(RoomMetadata { ruuid, diff --git a/blahd/src/middleware.rs b/blahd/src/middleware.rs index d9a49ff..325ef34 100644 --- a/blahd/src/middleware.rs +++ b/blahd/src/middleware.rs @@ -109,36 +109,79 @@ where } } -/// Extractor for optional verified JSON authorization header. #[derive(Debug)] -pub struct OptionalAuth(pub Option); +pub enum AuthRejection { + None, + Invalid(ApiError), +} + +impl From for ApiError { + fn from(rej: AuthRejection) -> Self { + match rej { + AuthRejection::None => error_response!( + StatusCode::UNAUTHORIZED, + "unauthorized", + "missing authorization header" + ), + AuthRejection::Invalid(err) => err, + } + } +} + +impl IntoResponse for AuthRejection { + fn into_response(self) -> Response { + ApiError::from(self).into_response() + } +} + +pub trait ResultExt { + fn into_optional(self) -> Result, ApiError>; +} + +impl ResultExt for Result { + fn into_optional(self) -> Result, ApiError> { + match self { + Ok(auth) => Ok(Some(auth.0)), + Err(AuthRejection::None) => Ok(None), + Err(AuthRejection::Invalid(err)) => Err(err), + } + } +} + +pub type MaybeAuth = Result; + +/// Extractor for verified JSON authorization header. +#[derive(Debug)] +pub struct Auth(pub UserKey); #[async_trait] -impl FromRequestParts for OptionalAuth +impl FromRequestParts for Auth where S: Send + Sync, Arc: FromRef, { - type Rejection = ApiError; + type Rejection = AuthRejection; async fn from_request_parts( parts: &mut request::Parts, state: &S, ) -> Result { - let Some(auth) = parts.headers.get(header::AUTHORIZATION) else { - return Ok(Self(None)); - }; + let auth = parts + .headers + .get(header::AUTHORIZATION) + .ok_or(AuthRejection::None)?; let st = >::from_ref(state); let data = serde_json::from_slice::>(auth.as_bytes()).map_err(|err| { - error_response!( + AuthRejection::Invalid(error_response!( StatusCode::BAD_REQUEST, "deserialization", "invalid authorization header: {err}", - ) + )) })?; - st.verify_signed_data(&data)?; - Ok(Self(Some(data.signee.user))) + st.verify_signed_data(&data) + .map_err(AuthRejection::Invalid)?; + Ok(Self(data.signee.user)) } }