mirror of
				https://github.com/Blah-IM/blahrs.git
				synced 2025-10-29 18:11:36 +00:00 
			
		
		
		
	Use proper rejection type for Auth middleware
				
					
				
			This commit is contained in:
		
							parent
							
								
									99d1311d63
								
							
						
					
					
						commit
						d3c3961298
					
				
					 2 changed files with 67 additions and 29 deletions
				
			
		|  | @ -18,7 +18,7 @@ use blah::types::{ | |||
| use config::Config; | ||||
| use database::Database; | ||||
| use ed25519_dalek::SIGNATURE_LENGTH; | ||||
| use middleware::{ApiError, OptionalAuth, SignedJson}; | ||||
| use middleware::{ApiError, MaybeAuth, ResultExt as _, SignedJson}; | ||||
| use parking_lot::Mutex; | ||||
| use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql}; | ||||
| use serde::{Deserialize, Serialize}; | ||||
|  | @ -216,7 +216,7 @@ enum ListRoomFilter { | |||
| async fn room_list( | ||||
|     st: ArcState, | ||||
|     WithRejection(params, _): WithRejection<Query<ListRoomParams>, ApiError>, | ||||
|     OptionalAuth(user): OptionalAuth, | ||||
|     auth: MaybeAuth, | ||||
| ) -> Result<Json<RoomList>, ApiError> { | ||||
|     let pagination = Pagination { | ||||
|         skip_token: params.skip_token, | ||||
|  | @ -287,13 +287,7 @@ async fn room_list( | |||
|             }, | ||||
|         ), | ||||
|         ListRoomFilter::Joined => { | ||||
|             let Some(user) = user else { | ||||
|                 return Err(error_response!( | ||||
|                     StatusCode::UNAUTHORIZED, | ||||
|                     "unauthorized", | ||||
|                     "missing Authorization header for listing joined rooms", | ||||
|                 )); | ||||
|             }; | ||||
|             let user = auth?.0; | ||||
|             query( | ||||
|                 r" | ||||
|                 SELECT | ||||
|  | @ -438,11 +432,11 @@ async fn room_get_item( | |||
|     st: ArcState, | ||||
|     WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>, | ||||
|     WithRejection(Query(pagination), _): WithRejection<Query<Pagination>, ApiError>, | ||||
|     OptionalAuth(user): OptionalAuth, | ||||
|     auth: MaybeAuth, | ||||
| ) -> Result<Json<RoomItems>, ApiError> { | ||||
|     let (items, skip_token) = { | ||||
|         let conn = st.db.get(); | ||||
|         get_room_if_readable(&conn, ruuid, user.as_ref(), |_row| Ok(()))?; | ||||
|         get_room_if_readable(&conn, ruuid, auth.into_optional()?.as_ref(), |_row| Ok(()))?; | ||||
|         query_room_items(&st, &conn, ruuid, pagination)? | ||||
|     }; | ||||
|     let items = items.into_iter().map(|(_, item)| item).collect(); | ||||
|  | @ -455,14 +449,15 @@ async fn room_get_item( | |||
| async fn room_get_metadata( | ||||
|     st: ArcState, | ||||
|     WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>, | ||||
|     OptionalAuth(user): OptionalAuth, | ||||
|     auth: MaybeAuth, | ||||
| ) -> Result<Json<RoomMetadata>, ApiError> { | ||||
|     let (title, attrs) = get_room_if_readable(&st.db.get(), ruuid, user.as_ref(), |row| { | ||||
|         Ok(( | ||||
|             row.get::<_, String>("title")?, | ||||
|             row.get::<_, RoomAttrs>("attrs")?, | ||||
|         )) | ||||
|     })?; | ||||
|     let (title, attrs) = | ||||
|         get_room_if_readable(&st.db.get(), ruuid, auth.into_optional()?.as_ref(), |row| { | ||||
|             Ok(( | ||||
|                 row.get::<_, String>("title")?, | ||||
|                 row.get::<_, RoomAttrs>("attrs")?, | ||||
|             )) | ||||
|         })?; | ||||
| 
 | ||||
|     Ok(Json(RoomMetadata { | ||||
|         ruuid, | ||||
|  |  | |||
|  | @ -109,36 +109,79 @@ where | |||
|     } | ||||
| } | ||||
| 
 | ||||
| /// Extractor for optional verified JSON authorization header.
 | ||||
| #[derive(Debug)] | ||||
| pub struct OptionalAuth(pub Option<UserKey>); | ||||
| pub enum AuthRejection { | ||||
|     None, | ||||
|     Invalid(ApiError), | ||||
| } | ||||
| 
 | ||||
| impl From<AuthRejection> for ApiError { | ||||
|     fn from(rej: AuthRejection) -> Self { | ||||
|         match rej { | ||||
|             AuthRejection::None => error_response!( | ||||
|                 StatusCode::UNAUTHORIZED, | ||||
|                 "unauthorized", | ||||
|                 "missing authorization header" | ||||
|             ), | ||||
|             AuthRejection::Invalid(err) => err, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl IntoResponse for AuthRejection { | ||||
|     fn into_response(self) -> Response { | ||||
|         ApiError::from(self).into_response() | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| pub trait ResultExt { | ||||
|     fn into_optional(self) -> Result<Option<UserKey>, ApiError>; | ||||
| } | ||||
| 
 | ||||
| impl ResultExt for Result<Auth, AuthRejection> { | ||||
|     fn into_optional(self) -> Result<Option<UserKey>, ApiError> { | ||||
|         match self { | ||||
|             Ok(auth) => Ok(Some(auth.0)), | ||||
|             Err(AuthRejection::None) => Ok(None), | ||||
|             Err(AuthRejection::Invalid(err)) => Err(err), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| pub type MaybeAuth = Result<Auth, AuthRejection>; | ||||
| 
 | ||||
| /// Extractor for verified JSON authorization header.
 | ||||
| #[derive(Debug)] | ||||
| pub struct Auth(pub UserKey); | ||||
| 
 | ||||
| #[async_trait] | ||||
| impl<S> FromRequestParts<S> for OptionalAuth | ||||
| impl<S> FromRequestParts<S> for Auth | ||||
| where | ||||
|     S: Send + Sync, | ||||
|     Arc<AppState>: FromRef<S>, | ||||
| { | ||||
|     type Rejection = ApiError; | ||||
|     type Rejection = AuthRejection; | ||||
| 
 | ||||
|     async fn from_request_parts( | ||||
|         parts: &mut request::Parts, | ||||
|         state: &S, | ||||
|     ) -> Result<Self, Self::Rejection> { | ||||
|         let Some(auth) = parts.headers.get(header::AUTHORIZATION) else { | ||||
|             return Ok(Self(None)); | ||||
|         }; | ||||
|         let auth = parts | ||||
|             .headers | ||||
|             .get(header::AUTHORIZATION) | ||||
|             .ok_or(AuthRejection::None)?; | ||||
| 
 | ||||
|         let st = <Arc<AppState>>::from_ref(state); | ||||
|         let data = | ||||
|             serde_json::from_slice::<WithSig<AuthPayload>>(auth.as_bytes()).map_err(|err| { | ||||
|                 error_response!( | ||||
|                 AuthRejection::Invalid(error_response!( | ||||
|                     StatusCode::BAD_REQUEST, | ||||
|                     "deserialization", | ||||
|                     "invalid authorization header: {err}", | ||||
|                 ) | ||||
|                 )) | ||||
|             })?; | ||||
|         st.verify_signed_data(&data)?; | ||||
|         Ok(Self(Some(data.signee.user))) | ||||
|         st.verify_signed_data(&data) | ||||
|             .map_err(AuthRejection::Invalid)?; | ||||
|         Ok(Self(data.signee.user)) | ||||
|     } | ||||
| } | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 oxalica
						oxalica