diff --git a/blahd/src/event.rs b/blahd/src/event.rs index c7e3cd4..db38adb 100644 --- a/blahd/src/event.rs +++ b/blahd/src/event.rs @@ -8,7 +8,9 @@ use std::task::{Context, Poll}; use std::time::Duration; use anyhow::{bail, Context as _, Result}; -use axum::extract::ws::{Message, WebSocket}; +use axum::extract::ws::{close_code, CloseFrame, Message, WebSocket}; +use axum::extract::WebSocketUpgrade; +use axum::response::Response; use blah_types::msg::{AuthPayload, SignedChatMsg}; use blah_types::server::ClientEvent; use blah_types::Signed; @@ -23,7 +25,7 @@ use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use tokio_stream::wrappers::BroadcastStream; use crate::database::TransactionOps; -use crate::AppState; +use crate::{AppState, ArcState}; // We a borrowed type rather than an owned type. // So redefine it. Not sure if there is a better way. @@ -127,7 +129,30 @@ impl Stream for UserEventReceiver { } } -pub async fn handle_ws(st: Arc, ws: &mut WebSocket) -> Result { +// TODO: Authenticate via HTTP query? +pub async fn get_ws(st: ArcState, ws: WebSocketUpgrade) -> Response { + ws.on_upgrade(move |mut socket| async move { + match handle_ws(st.0, &mut socket).await { + #[allow( + unreachable_patterns, + reason = "compatibility before min_exhaustive_patterns" + )] + Ok(never) => match never {}, + Err(err) if err.is::() => {} + Err(err) => { + tracing::debug!(%err, "ws error"); + let _: Result<_, _> = socket + .send(Message::Close(Some(CloseFrame { + code: close_code::ERROR, + reason: err.to_string().into(), + }))) + .await; + } + } + }) +} + +async fn handle_ws(st: Arc, ws: &mut WebSocket) -> Result { let config = &st.config.ws; let (ws_tx, ws_rx) = ws.split(); let mut ws_rx = ws_rx.map(|ret| match ret { diff --git a/blahd/src/feed.rs b/blahd/src/feed.rs index d359ef8..ff91471 100644 --- a/blahd/src/feed.rs +++ b/blahd/src/feed.rs @@ -3,15 +3,19 @@ use std::fmt; use std::num::NonZero; use std::time::Duration; -use axum::http::header; +use axum::extract::{OriginalUri, Path, Query}; +use axum::http::{header, StatusCode}; use axum::response::{IntoResponse, Response}; use axum::Json; -use blah_types::msg::{SignedChatMsgWithId, WithMsgId}; +use blah_types::msg::{RoomAttrs, SignedChatMsgWithId, WithMsgId}; use blah_types::Id; use serde::{Deserialize, Serialize}; use url::Url; +use crate::database::TransactionOps; use crate::id::timestamp_of_id; +use crate::middleware::ETag; +use crate::{query_room_msgs, ApiError, ArcState, Pagination, HEADER_PUBLIC_NO_CACHE}; const JSON_FEED_MIME: &str = "application/feed+json"; const ATOM_FEED_MIME: &str = "application/atom+xml"; @@ -214,3 +218,65 @@ impl FeedType for AtomFeed { ([(header::CONTENT_TYPE, ATOM_FEED_MIME)], body).into_response() } } + +pub async fn get_room_feed( + st: ArcState, + ETag(etag): ETag, + OriginalUri(req_uri): OriginalUri, + Path(rid): Path, + Query(mut pagination): Query, +) -> Result { + let self_url = st + .config + .base_url + .join(req_uri.path()) + .expect("base_url can be a base"); + + pagination.top = Some( + pagination + .effective_page_len(&st) + .min(st.config.feed.max_page_len), + ); + + let (title, msgs, skip_token) = st.db.with_read(|txn| { + let (attrs, title) = txn.get_room_having(rid, RoomAttrs::PUBLIC_READABLE)?; + // Sanity check. + assert!(!attrs.contains(RoomAttrs::PEER_CHAT)); + let title = title.expect("public room must have title"); + let (msgs, skip_token) = query_room_msgs(&st, txn, rid, pagination)?; + Ok((title, msgs, skip_token)) + })?; + + // Use `Id(0)` as the tag for an empty list. + let ret_etag = msgs.first().map_or(Id(0), |msg| msg.cid); + if etag == Some(ret_etag) { + return Ok(StatusCode::NOT_MODIFIED.into_response()); + } + + let next_url = skip_token.map(|skip_token| { + let next_params = Pagination { + skip_token: Some(skip_token), + top: pagination.top, + until_token: None, + }; + let mut next_url = self_url.clone(); + { + let mut query = next_url.query_pairs_mut(); + let ser = serde_urlencoded::Serializer::new(&mut query); + next_params + .serialize(ser) + .expect("serialization cannot fail"); + query.finish(); + } + next_url + }); + + let resp = FT::to_feed_response(FeedData { + rid, + title, + msgs, + self_url, + next_url, + }); + Ok(([HEADER_PUBLIC_NO_CACHE], ETag(Some(ret_etag)), resp).into_response()) +} diff --git a/blahd/src/lib.rs b/blahd/src/lib.rs index b7f2556..01a6154 100644 --- a/blahd/src/lib.rs +++ b/blahd/src/lib.rs @@ -4,8 +4,7 @@ use std::time::Duration; use anyhow::Result; use axum::body::Bytes; -use axum::extract::{ws, OriginalUri}; -use axum::extract::{Path, Query, State, WebSocketUpgrade}; +use axum::extract::{Path, Query, State}; use axum::http::{header, HeaderName, HeaderValue, StatusCode}; use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; @@ -14,7 +13,7 @@ use axum_extra::extract::WithRejection as R; use blah_types::msg::{ ChatPayload, CreateGroup, CreatePeerChat, CreateRoomPayload, DeleteRoomPayload, MemberPermission, RoomAdminOp, RoomAdminPayload, RoomAttrs, ServerPermission, - SignedChatMsgWithId, UserRegisterPayload, + SignedChatMsgWithId, }; use blah_types::server::{ ErrorResponseWithChallenge, RoomList, RoomMember, RoomMemberList, RoomMetadata, RoomMsgs, @@ -23,7 +22,6 @@ use blah_types::server::{ use blah_types::{get_timestamp, Id, Signed, UserKey}; use data_encoding::BASE64_NOPAD; use database::{Transaction, TransactionOps}; -use feed::FeedData; use id::IdExt; use middleware::{Auth, ETag, MaybeAuth, ResultExt as _, SignedJson}; use parking_lot::Mutex; @@ -150,19 +148,27 @@ impl AppState { type ArcState = State>; pub fn router(st: Arc) -> Router { + // NB. User consistent handler naming: `_[_
]`. + // Use prefix `list` for GET with pagination. + // + // One route per line. + #[rustfmt::skip] let router = Router::new() - .route("/server", get(handle_server_metadata)) - .route("/ws", get(handle_ws)) - .route("/user/me", get(user_get).post(user_register)) - .route("/room", get(room_list)) - .route("/room/create", post(room_create)) - .route("/room/:rid", get(room_get_metadata).delete(room_delete)) - .route("/room/:rid/feed.json", get(room_get_feed::)) - .route("/room/:rid/feed.atom", get(room_get_feed::)) - .route("/room/:rid/msg", get(room_msg_list).post(room_msg_post)) - .route("/room/:rid/msg/:cid/seen", post(room_msg_mark_seen)) - .route("/room/:rid/admin", post(room_admin)) - .route("/room/:rid/member", get(room_member_list)) + .route("/server", get(get_server_metadata)) + .route("/ws", get(event::get_ws)) + .route("/user/me", get(get_user).post(register::post_user)) + .route("/room", get(list_room)) + // TODO: Maybe just POST on `/room`? + .route("/room/create", post(post_room_create)) + .route("/room/:rid", get(get_room).delete(delete_room)) + .route("/room/:rid/feed.json", get(feed::get_room_feed::)) + .route("/room/:rid/feed.atom", get(feed::get_room_feed::)) + .route("/room/:rid/msg", get(list_room_msg).post(post_room_msg)) + .route("/room/:rid/msg/:cid/seen", post(post_room_msg_seen)) + .route("/room/:rid/admin", post(post_room_admin)) + .route("/room/:rid/member", get(list_room_member)); + + let router = router .layer(tower_http::limit::RequestBodyLimitLayer::new( st.config.max_request_len, )) @@ -184,7 +190,7 @@ pub fn router(st: Arc) -> Router { type RE = R; -async fn handle_server_metadata(State(st): ArcState) -> Response { +async fn get_server_metadata(State(st): ArcState) -> Response { let (json, etag) = st.server_metadata.clone(); let headers = [ ( @@ -196,29 +202,7 @@ async fn handle_server_metadata(State(st): ArcState) -> Response { (headers, etag, json).into_response() } -async fn handle_ws(State(st): ArcState, ws: WebSocketUpgrade) -> Response { - ws.on_upgrade(move |mut socket| async move { - match event::handle_ws(st, &mut socket).await { - #[allow( - unreachable_patterns, - reason = "compatibility before min_exhaustive_patterns" - )] - Ok(never) => match never {}, - Err(err) if err.is::() => {} - Err(err) => { - tracing::debug!(%err, "ws error"); - let _: Result<_, _> = socket - .send(ws::Message::Close(Some(ws::CloseFrame { - code: ws::close_code::ERROR, - reason: err.to_string().into(), - }))) - .await; - } - } - }) -} - -async fn user_get(State(st): ArcState, auth: MaybeAuth) -> Response { +async fn get_user(State(st): ArcState, auth: MaybeAuth) -> Response { let ret = (|| { match auth.into_optional()? { None => None, @@ -243,13 +227,6 @@ async fn user_get(State(st): ArcState, auth: MaybeAuth) -> Response { } } -async fn user_register( - State(st): ArcState, - SignedJson(msg): SignedJson, -) -> Result { - register::user_register(&st, msg).await -} - #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields, rename_all = "camelCase")] struct ListRoomParams { @@ -272,7 +249,7 @@ enum ListRoomFilter { Unseen, } -async fn room_list( +async fn list_room( st: ArcState, params: RE>, auth: MaybeAuth, @@ -302,7 +279,7 @@ async fn room_list( Ok(Json(RoomList { rooms, skip_token })) } -async fn room_create( +async fn post_room_create( st: ArcState, SignedJson(params): SignedJson, ) -> Result, ApiError> { @@ -390,7 +367,7 @@ impl Pagination { } } -async fn room_msg_list( +async fn list_room_msg( st: ArcState, R(Path(rid), _): RE>, R(Query(pagination), _): RE>, @@ -407,7 +384,7 @@ async fn room_msg_list( Ok(Json(RoomMsgs { msgs, skip_token })) } -async fn room_get_metadata( +async fn get_room( st: ArcState, R(Path(rid), _): RE>, auth: MaybeAuth, @@ -438,68 +415,6 @@ async fn room_get_metadata( })) } -async fn room_get_feed( - st: ArcState, - ETag(etag): ETag, - R(OriginalUri(req_uri), _): RE, - R(Path(rid), _): RE>, - R(Query(mut pagination), _): RE>, -) -> Result { - let self_url = st - .config - .base_url - .join(req_uri.path()) - .expect("base_url can be a base"); - - pagination.top = Some( - pagination - .effective_page_len(&st) - .min(st.config.feed.max_page_len), - ); - - let (title, msgs, skip_token) = st.db.with_read(|txn| { - let (attrs, title) = txn.get_room_having(rid, RoomAttrs::PUBLIC_READABLE)?; - // Sanity check. - assert!(!attrs.contains(RoomAttrs::PEER_CHAT)); - let title = title.expect("public room must have title"); - let (msgs, skip_token) = query_room_msgs(&st, txn, rid, pagination)?; - Ok((title, msgs, skip_token)) - })?; - - // Use `Id(0)` as the tag for an empty list. - let ret_etag = msgs.first().map_or(Id(0), |msg| msg.cid); - if etag == Some(ret_etag) { - return Ok(StatusCode::NOT_MODIFIED.into_response()); - } - - let next_url = skip_token.map(|skip_token| { - let next_params = Pagination { - skip_token: Some(skip_token), - top: pagination.top, - until_token: None, - }; - let mut next_url = self_url.clone(); - { - let mut query = next_url.query_pairs_mut(); - let ser = serde_urlencoded::Serializer::new(&mut query); - next_params - .serialize(ser) - .expect("serialization cannot fail"); - query.finish(); - } - next_url - }); - - let resp = FT::to_feed_response(FeedData { - rid, - title, - msgs, - self_url, - next_url, - }); - Ok(([HEADER_PUBLIC_NO_CACHE], ETag(Some(ret_etag)), resp).into_response()) -} - /// Get room messages with pagination parameters, /// return a page of messages and the next `skip_token` if this is not the last page. fn query_room_msgs( @@ -520,7 +435,7 @@ fn query_room_msgs( Ok((msgs, skip_token)) } -async fn room_msg_post( +async fn post_room_msg( st: ArcState, R(Path(rid), _): RE>, SignedJson(chat): SignedJson, @@ -562,7 +477,7 @@ async fn room_msg_post( Ok(Json(cid)) } -async fn room_admin( +async fn post_room_admin( st: ArcState, R(Path(rid), _): RE>, SignedJson(op): SignedJson, @@ -622,7 +537,7 @@ async fn room_leave(st: &AppState, rid: Id, user: &UserKey) -> Result<(), ApiErr }) } -async fn room_delete( +async fn delete_room( st: ArcState, R(Path(rid), _): RE>, SignedJson(op): SignedJson, @@ -640,7 +555,7 @@ async fn room_delete( }) } -async fn room_msg_mark_seen( +async fn post_room_msg_seen( st: ArcState, R(Path((rid, cid)), _): RE>, Auth(user): Auth, @@ -655,7 +570,7 @@ async fn room_msg_mark_seen( Ok(StatusCode::NO_CONTENT) } -async fn room_member_list( +async fn list_room_member( st: ArcState, R(Path(rid), _): RE>, R(Query(pagination), _): RE>, diff --git a/blahd/src/register.rs b/blahd/src/register.rs index 915440a..83284af 100644 --- a/blahd/src/register.rs +++ b/blahd/src/register.rs @@ -3,10 +3,10 @@ use std::time::Duration; use anyhow::{anyhow, ensure}; use axum::http::StatusCode; +use blah_types::get_timestamp; use blah_types::identity::{IdUrl, UserIdentityDesc}; use blah_types::msg::{UserRegisterChallengeResponse, UserRegisterPayload}; use blah_types::server::UserRegisterChallenge; -use blah_types::{get_timestamp, Signed}; use http_body_util::BodyExt; use parking_lot::Mutex; use rand::rngs::OsRng; @@ -15,8 +15,9 @@ use serde::Deserialize; use sha2::{Digest, Sha256}; use crate::database::TransactionOps; +use crate::middleware::SignedJson; use crate::utils::Instant; -use crate::{ApiError, AppState, SERVER_AND_VERSION}; +use crate::{ApiError, ArcState, SERVER_AND_VERSION}; #[derive(Debug, Clone, PartialEq, Eq, Deserialize)] #[serde(default, deny_unknown_fields)] @@ -159,9 +160,9 @@ impl State { } } -pub async fn user_register( - st: &AppState, - msg: Signed, +pub async fn post_user( + axum::extract::State(st): ArcState, + SignedJson(msg): SignedJson, ) -> Result { if !st.config.register.enable_public { return Err(ApiError::Disabled("public registration is disabled"));