mirror of
https://github.com/Blah-IM/blahrs.git
synced 2025-05-01 00:31:09 +00:00
Use proper rejection type for Auth
middleware
This commit is contained in:
parent
99d1311d63
commit
d3c3961298
2 changed files with 67 additions and 29 deletions
|
@ -18,7 +18,7 @@ use blah::types::{
|
|||
use config::Config;
|
||||
use database::Database;
|
||||
use ed25519_dalek::SIGNATURE_LENGTH;
|
||||
use middleware::{ApiError, OptionalAuth, SignedJson};
|
||||
use middleware::{ApiError, MaybeAuth, ResultExt as _, SignedJson};
|
||||
use parking_lot::Mutex;
|
||||
use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -216,7 +216,7 @@ enum ListRoomFilter {
|
|||
async fn room_list(
|
||||
st: ArcState,
|
||||
WithRejection(params, _): WithRejection<Query<ListRoomParams>, ApiError>,
|
||||
OptionalAuth(user): OptionalAuth,
|
||||
auth: MaybeAuth,
|
||||
) -> Result<Json<RoomList>, ApiError> {
|
||||
let pagination = Pagination {
|
||||
skip_token: params.skip_token,
|
||||
|
@ -287,13 +287,7 @@ async fn room_list(
|
|||
},
|
||||
),
|
||||
ListRoomFilter::Joined => {
|
||||
let Some(user) = user else {
|
||||
return Err(error_response!(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"unauthorized",
|
||||
"missing Authorization header for listing joined rooms",
|
||||
));
|
||||
};
|
||||
let user = auth?.0;
|
||||
query(
|
||||
r"
|
||||
SELECT
|
||||
|
@ -438,11 +432,11 @@ async fn room_get_item(
|
|||
st: ArcState,
|
||||
WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>,
|
||||
WithRejection(Query(pagination), _): WithRejection<Query<Pagination>, ApiError>,
|
||||
OptionalAuth(user): OptionalAuth,
|
||||
auth: MaybeAuth,
|
||||
) -> Result<Json<RoomItems>, ApiError> {
|
||||
let (items, skip_token) = {
|
||||
let conn = st.db.get();
|
||||
get_room_if_readable(&conn, ruuid, user.as_ref(), |_row| Ok(()))?;
|
||||
get_room_if_readable(&conn, ruuid, auth.into_optional()?.as_ref(), |_row| Ok(()))?;
|
||||
query_room_items(&st, &conn, ruuid, pagination)?
|
||||
};
|
||||
let items = items.into_iter().map(|(_, item)| item).collect();
|
||||
|
@ -455,9 +449,10 @@ async fn room_get_item(
|
|||
async fn room_get_metadata(
|
||||
st: ArcState,
|
||||
WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>,
|
||||
OptionalAuth(user): OptionalAuth,
|
||||
auth: MaybeAuth,
|
||||
) -> Result<Json<RoomMetadata>, ApiError> {
|
||||
let (title, attrs) = get_room_if_readable(&st.db.get(), ruuid, user.as_ref(), |row| {
|
||||
let (title, attrs) =
|
||||
get_room_if_readable(&st.db.get(), ruuid, auth.into_optional()?.as_ref(), |row| {
|
||||
Ok((
|
||||
row.get::<_, String>("title")?,
|
||||
row.get::<_, RoomAttrs>("attrs")?,
|
||||
|
|
|
@ -109,36 +109,79 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Extractor for optional verified JSON authorization header.
|
||||
#[derive(Debug)]
|
||||
pub struct OptionalAuth(pub Option<UserKey>);
|
||||
pub enum AuthRejection {
|
||||
None,
|
||||
Invalid(ApiError),
|
||||
}
|
||||
|
||||
impl From<AuthRejection> for ApiError {
|
||||
fn from(rej: AuthRejection) -> Self {
|
||||
match rej {
|
||||
AuthRejection::None => error_response!(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"unauthorized",
|
||||
"missing authorization header"
|
||||
),
|
||||
AuthRejection::Invalid(err) => err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for AuthRejection {
|
||||
fn into_response(self) -> Response {
|
||||
ApiError::from(self).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ResultExt {
|
||||
fn into_optional(self) -> Result<Option<UserKey>, ApiError>;
|
||||
}
|
||||
|
||||
impl ResultExt for Result<Auth, AuthRejection> {
|
||||
fn into_optional(self) -> Result<Option<UserKey>, ApiError> {
|
||||
match self {
|
||||
Ok(auth) => Ok(Some(auth.0)),
|
||||
Err(AuthRejection::None) => Ok(None),
|
||||
Err(AuthRejection::Invalid(err)) => Err(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type MaybeAuth = Result<Auth, AuthRejection>;
|
||||
|
||||
/// Extractor for verified JSON authorization header.
|
||||
#[derive(Debug)]
|
||||
pub struct Auth(pub UserKey);
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for OptionalAuth
|
||||
impl<S> FromRequestParts<S> for Auth
|
||||
where
|
||||
S: Send + Sync,
|
||||
Arc<AppState>: FromRef<S>,
|
||||
{
|
||||
type Rejection = ApiError;
|
||||
type Rejection = AuthRejection;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut request::Parts,
|
||||
state: &S,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
let Some(auth) = parts.headers.get(header::AUTHORIZATION) else {
|
||||
return Ok(Self(None));
|
||||
};
|
||||
let auth = parts
|
||||
.headers
|
||||
.get(header::AUTHORIZATION)
|
||||
.ok_or(AuthRejection::None)?;
|
||||
|
||||
let st = <Arc<AppState>>::from_ref(state);
|
||||
let data =
|
||||
serde_json::from_slice::<WithSig<AuthPayload>>(auth.as_bytes()).map_err(|err| {
|
||||
error_response!(
|
||||
AuthRejection::Invalid(error_response!(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"deserialization",
|
||||
"invalid authorization header: {err}",
|
||||
)
|
||||
))
|
||||
})?;
|
||||
st.verify_signed_data(&data)?;
|
||||
Ok(Self(Some(data.signee.user)))
|
||||
st.verify_signed_data(&data)
|
||||
.map_err(AuthRejection::Invalid)?;
|
||||
Ok(Self(data.signee.user))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue