mirror of
https://github.com/Blah-IM/blahrs.git
synced 2025-05-01 08:41:09 +00:00
Use proper rejection type for Auth
middleware
This commit is contained in:
parent
99d1311d63
commit
d3c3961298
2 changed files with 67 additions and 29 deletions
|
@ -18,7 +18,7 @@ use blah::types::{
|
||||||
use config::Config;
|
use config::Config;
|
||||||
use database::Database;
|
use database::Database;
|
||||||
use ed25519_dalek::SIGNATURE_LENGTH;
|
use ed25519_dalek::SIGNATURE_LENGTH;
|
||||||
use middleware::{ApiError, OptionalAuth, SignedJson};
|
use middleware::{ApiError, MaybeAuth, ResultExt as _, SignedJson};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql};
|
use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -216,7 +216,7 @@ enum ListRoomFilter {
|
||||||
async fn room_list(
|
async fn room_list(
|
||||||
st: ArcState,
|
st: ArcState,
|
||||||
WithRejection(params, _): WithRejection<Query<ListRoomParams>, ApiError>,
|
WithRejection(params, _): WithRejection<Query<ListRoomParams>, ApiError>,
|
||||||
OptionalAuth(user): OptionalAuth,
|
auth: MaybeAuth,
|
||||||
) -> Result<Json<RoomList>, ApiError> {
|
) -> Result<Json<RoomList>, ApiError> {
|
||||||
let pagination = Pagination {
|
let pagination = Pagination {
|
||||||
skip_token: params.skip_token,
|
skip_token: params.skip_token,
|
||||||
|
@ -287,13 +287,7 @@ async fn room_list(
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
ListRoomFilter::Joined => {
|
ListRoomFilter::Joined => {
|
||||||
let Some(user) = user else {
|
let user = auth?.0;
|
||||||
return Err(error_response!(
|
|
||||||
StatusCode::UNAUTHORIZED,
|
|
||||||
"unauthorized",
|
|
||||||
"missing Authorization header for listing joined rooms",
|
|
||||||
));
|
|
||||||
};
|
|
||||||
query(
|
query(
|
||||||
r"
|
r"
|
||||||
SELECT
|
SELECT
|
||||||
|
@ -438,11 +432,11 @@ async fn room_get_item(
|
||||||
st: ArcState,
|
st: ArcState,
|
||||||
WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>,
|
WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>,
|
||||||
WithRejection(Query(pagination), _): WithRejection<Query<Pagination>, ApiError>,
|
WithRejection(Query(pagination), _): WithRejection<Query<Pagination>, ApiError>,
|
||||||
OptionalAuth(user): OptionalAuth,
|
auth: MaybeAuth,
|
||||||
) -> Result<Json<RoomItems>, ApiError> {
|
) -> Result<Json<RoomItems>, ApiError> {
|
||||||
let (items, skip_token) = {
|
let (items, skip_token) = {
|
||||||
let conn = st.db.get();
|
let conn = st.db.get();
|
||||||
get_room_if_readable(&conn, ruuid, user.as_ref(), |_row| Ok(()))?;
|
get_room_if_readable(&conn, ruuid, auth.into_optional()?.as_ref(), |_row| Ok(()))?;
|
||||||
query_room_items(&st, &conn, ruuid, pagination)?
|
query_room_items(&st, &conn, ruuid, pagination)?
|
||||||
};
|
};
|
||||||
let items = items.into_iter().map(|(_, item)| item).collect();
|
let items = items.into_iter().map(|(_, item)| item).collect();
|
||||||
|
@ -455,14 +449,15 @@ async fn room_get_item(
|
||||||
async fn room_get_metadata(
|
async fn room_get_metadata(
|
||||||
st: ArcState,
|
st: ArcState,
|
||||||
WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>,
|
WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>,
|
||||||
OptionalAuth(user): OptionalAuth,
|
auth: MaybeAuth,
|
||||||
) -> Result<Json<RoomMetadata>, ApiError> {
|
) -> Result<Json<RoomMetadata>, ApiError> {
|
||||||
let (title, attrs) = get_room_if_readable(&st.db.get(), ruuid, user.as_ref(), |row| {
|
let (title, attrs) =
|
||||||
Ok((
|
get_room_if_readable(&st.db.get(), ruuid, auth.into_optional()?.as_ref(), |row| {
|
||||||
row.get::<_, String>("title")?,
|
Ok((
|
||||||
row.get::<_, RoomAttrs>("attrs")?,
|
row.get::<_, String>("title")?,
|
||||||
))
|
row.get::<_, RoomAttrs>("attrs")?,
|
||||||
})?;
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
Ok(Json(RoomMetadata {
|
Ok(Json(RoomMetadata {
|
||||||
ruuid,
|
ruuid,
|
||||||
|
|
|
@ -109,36 +109,79 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extractor for optional verified JSON authorization header.
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct OptionalAuth(pub Option<UserKey>);
|
pub enum AuthRejection {
|
||||||
|
None,
|
||||||
|
Invalid(ApiError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<AuthRejection> for ApiError {
|
||||||
|
fn from(rej: AuthRejection) -> Self {
|
||||||
|
match rej {
|
||||||
|
AuthRejection::None => error_response!(
|
||||||
|
StatusCode::UNAUTHORIZED,
|
||||||
|
"unauthorized",
|
||||||
|
"missing authorization header"
|
||||||
|
),
|
||||||
|
AuthRejection::Invalid(err) => err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for AuthRejection {
|
||||||
|
fn into_response(self) -> Response {
|
||||||
|
ApiError::from(self).into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait ResultExt {
|
||||||
|
fn into_optional(self) -> Result<Option<UserKey>, ApiError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResultExt for Result<Auth, AuthRejection> {
|
||||||
|
fn into_optional(self) -> Result<Option<UserKey>, ApiError> {
|
||||||
|
match self {
|
||||||
|
Ok(auth) => Ok(Some(auth.0)),
|
||||||
|
Err(AuthRejection::None) => Ok(None),
|
||||||
|
Err(AuthRejection::Invalid(err)) => Err(err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type MaybeAuth = Result<Auth, AuthRejection>;
|
||||||
|
|
||||||
|
/// Extractor for verified JSON authorization header.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Auth(pub UserKey);
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<S> FromRequestParts<S> for OptionalAuth
|
impl<S> FromRequestParts<S> for Auth
|
||||||
where
|
where
|
||||||
S: Send + Sync,
|
S: Send + Sync,
|
||||||
Arc<AppState>: FromRef<S>,
|
Arc<AppState>: FromRef<S>,
|
||||||
{
|
{
|
||||||
type Rejection = ApiError;
|
type Rejection = AuthRejection;
|
||||||
|
|
||||||
async fn from_request_parts(
|
async fn from_request_parts(
|
||||||
parts: &mut request::Parts,
|
parts: &mut request::Parts,
|
||||||
state: &S,
|
state: &S,
|
||||||
) -> Result<Self, Self::Rejection> {
|
) -> Result<Self, Self::Rejection> {
|
||||||
let Some(auth) = parts.headers.get(header::AUTHORIZATION) else {
|
let auth = parts
|
||||||
return Ok(Self(None));
|
.headers
|
||||||
};
|
.get(header::AUTHORIZATION)
|
||||||
|
.ok_or(AuthRejection::None)?;
|
||||||
|
|
||||||
let st = <Arc<AppState>>::from_ref(state);
|
let st = <Arc<AppState>>::from_ref(state);
|
||||||
let data =
|
let data =
|
||||||
serde_json::from_slice::<WithSig<AuthPayload>>(auth.as_bytes()).map_err(|err| {
|
serde_json::from_slice::<WithSig<AuthPayload>>(auth.as_bytes()).map_err(|err| {
|
||||||
error_response!(
|
AuthRejection::Invalid(error_response!(
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
"deserialization",
|
"deserialization",
|
||||||
"invalid authorization header: {err}",
|
"invalid authorization header: {err}",
|
||||||
)
|
))
|
||||||
})?;
|
})?;
|
||||||
st.verify_signed_data(&data)?;
|
st.verify_signed_data(&data)
|
||||||
Ok(Self(Some(data.signee.user)))
|
.map_err(AuthRejection::Invalid)?;
|
||||||
|
Ok(Self(data.signee.user))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue