refactor(blahd): reorg and use consistent handler names

This commit is contained in:
oxalica 2024-10-04 23:10:06 -04:00
parent 4e8124cda6
commit 97c0cf5844
4 changed files with 135 additions and 128 deletions

View file

@ -8,7 +8,9 @@ use std::task::{Context, Poll};
use std::time::Duration;
use anyhow::{bail, Context as _, Result};
use axum::extract::ws::{Message, WebSocket};
use axum::extract::ws::{close_code, CloseFrame, Message, WebSocket};
use axum::extract::WebSocketUpgrade;
use axum::response::Response;
use blah_types::msg::{AuthPayload, SignedChatMsg};
use blah_types::server::ClientEvent;
use blah_types::Signed;
@ -23,7 +25,7 @@ use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tokio_stream::wrappers::BroadcastStream;
use crate::database::TransactionOps;
use crate::AppState;
use crate::{AppState, ArcState};
// We a borrowed type rather than an owned type.
// So redefine it. Not sure if there is a better way.
@ -127,7 +129,30 @@ impl Stream for UserEventReceiver {
}
}
pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallible> {
// TODO: Authenticate via HTTP query?
pub async fn get_ws(st: ArcState, ws: WebSocketUpgrade) -> Response {
ws.on_upgrade(move |mut socket| async move {
match handle_ws(st.0, &mut socket).await {
#[allow(
unreachable_patterns,
reason = "compatibility before min_exhaustive_patterns"
)]
Ok(never) => match never {},
Err(err) if err.is::<StreamEnded>() => {}
Err(err) => {
tracing::debug!(%err, "ws error");
let _: Result<_, _> = socket
.send(Message::Close(Some(CloseFrame {
code: close_code::ERROR,
reason: err.to_string().into(),
})))
.await;
}
}
})
}
async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallible> {
let config = &st.config.ws;
let (ws_tx, ws_rx) = ws.split();
let mut ws_rx = ws_rx.map(|ret| match ret {

View file

@ -3,15 +3,19 @@ use std::fmt;
use std::num::NonZero;
use std::time::Duration;
use axum::http::header;
use axum::extract::{OriginalUri, Path, Query};
use axum::http::{header, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::Json;
use blah_types::msg::{SignedChatMsgWithId, WithMsgId};
use blah_types::msg::{RoomAttrs, SignedChatMsgWithId, WithMsgId};
use blah_types::Id;
use serde::{Deserialize, Serialize};
use url::Url;
use crate::database::TransactionOps;
use crate::id::timestamp_of_id;
use crate::middleware::ETag;
use crate::{query_room_msgs, ApiError, ArcState, Pagination, HEADER_PUBLIC_NO_CACHE};
const JSON_FEED_MIME: &str = "application/feed+json";
const ATOM_FEED_MIME: &str = "application/atom+xml";
@ -214,3 +218,65 @@ impl FeedType for AtomFeed {
([(header::CONTENT_TYPE, ATOM_FEED_MIME)], body).into_response()
}
}
pub async fn get_room_feed<FT: FeedType>(
st: ArcState,
ETag(etag): ETag<Id>,
OriginalUri(req_uri): OriginalUri,
Path(rid): Path<Id>,
Query(mut pagination): Query<Pagination>,
) -> Result<Response, ApiError> {
let self_url = st
.config
.base_url
.join(req_uri.path())
.expect("base_url can be a base");
pagination.top = Some(
pagination
.effective_page_len(&st)
.min(st.config.feed.max_page_len),
);
let (title, msgs, skip_token) = st.db.with_read(|txn| {
let (attrs, title) = txn.get_room_having(rid, RoomAttrs::PUBLIC_READABLE)?;
// Sanity check.
assert!(!attrs.contains(RoomAttrs::PEER_CHAT));
let title = title.expect("public room must have title");
let (msgs, skip_token) = query_room_msgs(&st, txn, rid, pagination)?;
Ok((title, msgs, skip_token))
})?;
// Use `Id(0)` as the tag for an empty list.
let ret_etag = msgs.first().map_or(Id(0), |msg| msg.cid);
if etag == Some(ret_etag) {
return Ok(StatusCode::NOT_MODIFIED.into_response());
}
let next_url = skip_token.map(|skip_token| {
let next_params = Pagination {
skip_token: Some(skip_token),
top: pagination.top,
until_token: None,
};
let mut next_url = self_url.clone();
{
let mut query = next_url.query_pairs_mut();
let ser = serde_urlencoded::Serializer::new(&mut query);
next_params
.serialize(ser)
.expect("serialization cannot fail");
query.finish();
}
next_url
});
let resp = FT::to_feed_response(FeedData {
rid,
title,
msgs,
self_url,
next_url,
});
Ok(([HEADER_PUBLIC_NO_CACHE], ETag(Some(ret_etag)), resp).into_response())
}

View file

@ -4,8 +4,7 @@ use std::time::Duration;
use anyhow::Result;
use axum::body::Bytes;
use axum::extract::{ws, OriginalUri};
use axum::extract::{Path, Query, State, WebSocketUpgrade};
use axum::extract::{Path, Query, State};
use axum::http::{header, HeaderName, HeaderValue, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
@ -14,7 +13,7 @@ use axum_extra::extract::WithRejection as R;
use blah_types::msg::{
ChatPayload, CreateGroup, CreatePeerChat, CreateRoomPayload, DeleteRoomPayload,
MemberPermission, RoomAdminOp, RoomAdminPayload, RoomAttrs, ServerPermission,
SignedChatMsgWithId, UserRegisterPayload,
SignedChatMsgWithId,
};
use blah_types::server::{
ErrorResponseWithChallenge, RoomList, RoomMember, RoomMemberList, RoomMetadata, RoomMsgs,
@ -23,7 +22,6 @@ use blah_types::server::{
use blah_types::{get_timestamp, Id, Signed, UserKey};
use data_encoding::BASE64_NOPAD;
use database::{Transaction, TransactionOps};
use feed::FeedData;
use id::IdExt;
use middleware::{Auth, ETag, MaybeAuth, ResultExt as _, SignedJson};
use parking_lot::Mutex;
@ -150,19 +148,27 @@ impl AppState {
type ArcState = State<Arc<AppState>>;
pub fn router(st: Arc<AppState>) -> Router {
// NB. User consistent handler naming: `<method>_<path>[_<details>]`.
// Use prefix `list` for GET with pagination.
//
// One route per line.
#[rustfmt::skip]
let router = Router::new()
.route("/server", get(handle_server_metadata))
.route("/ws", get(handle_ws))
.route("/user/me", get(user_get).post(user_register))
.route("/room", get(room_list))
.route("/room/create", post(room_create))
.route("/room/:rid", get(room_get_metadata).delete(room_delete))
.route("/room/:rid/feed.json", get(room_get_feed::<feed::JsonFeed>))
.route("/room/:rid/feed.atom", get(room_get_feed::<feed::AtomFeed>))
.route("/room/:rid/msg", get(room_msg_list).post(room_msg_post))
.route("/room/:rid/msg/:cid/seen", post(room_msg_mark_seen))
.route("/room/:rid/admin", post(room_admin))
.route("/room/:rid/member", get(room_member_list))
.route("/server", get(get_server_metadata))
.route("/ws", get(event::get_ws))
.route("/user/me", get(get_user).post(register::post_user))
.route("/room", get(list_room))
// TODO: Maybe just POST on `/room`?
.route("/room/create", post(post_room_create))
.route("/room/:rid", get(get_room).delete(delete_room))
.route("/room/:rid/feed.json", get(feed::get_room_feed::<feed::JsonFeed>))
.route("/room/:rid/feed.atom", get(feed::get_room_feed::<feed::AtomFeed>))
.route("/room/:rid/msg", get(list_room_msg).post(post_room_msg))
.route("/room/:rid/msg/:cid/seen", post(post_room_msg_seen))
.route("/room/:rid/admin", post(post_room_admin))
.route("/room/:rid/member", get(list_room_member));
let router = router
.layer(tower_http::limit::RequestBodyLimitLayer::new(
st.config.max_request_len,
))
@ -184,7 +190,7 @@ pub fn router(st: Arc<AppState>) -> Router {
type RE<T> = R<T, ApiError>;
async fn handle_server_metadata(State(st): ArcState) -> Response {
async fn get_server_metadata(State(st): ArcState) -> Response {
let (json, etag) = st.server_metadata.clone();
let headers = [
(
@ -196,29 +202,7 @@ async fn handle_server_metadata(State(st): ArcState) -> Response {
(headers, etag, json).into_response()
}
async fn handle_ws(State(st): ArcState, ws: WebSocketUpgrade) -> Response {
ws.on_upgrade(move |mut socket| async move {
match event::handle_ws(st, &mut socket).await {
#[allow(
unreachable_patterns,
reason = "compatibility before min_exhaustive_patterns"
)]
Ok(never) => match never {},
Err(err) if err.is::<event::StreamEnded>() => {}
Err(err) => {
tracing::debug!(%err, "ws error");
let _: Result<_, _> = socket
.send(ws::Message::Close(Some(ws::CloseFrame {
code: ws::close_code::ERROR,
reason: err.to_string().into(),
})))
.await;
}
}
})
}
async fn user_get(State(st): ArcState, auth: MaybeAuth) -> Response {
async fn get_user(State(st): ArcState, auth: MaybeAuth) -> Response {
let ret = (|| {
match auth.into_optional()? {
None => None,
@ -243,13 +227,6 @@ async fn user_get(State(st): ArcState, auth: MaybeAuth) -> Response {
}
}
async fn user_register(
State(st): ArcState,
SignedJson(msg): SignedJson<UserRegisterPayload>,
) -> Result<StatusCode, ApiError> {
register::user_register(&st, msg).await
}
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
struct ListRoomParams {
@ -272,7 +249,7 @@ enum ListRoomFilter {
Unseen,
}
async fn room_list(
async fn list_room(
st: ArcState,
params: RE<Query<ListRoomParams>>,
auth: MaybeAuth,
@ -302,7 +279,7 @@ async fn room_list(
Ok(Json(RoomList { rooms, skip_token }))
}
async fn room_create(
async fn post_room_create(
st: ArcState,
SignedJson(params): SignedJson<CreateRoomPayload>,
) -> Result<Json<Id>, ApiError> {
@ -390,7 +367,7 @@ impl Pagination {
}
}
async fn room_msg_list(
async fn list_room_msg(
st: ArcState,
R(Path(rid), _): RE<Path<Id>>,
R(Query(pagination), _): RE<Query<Pagination>>,
@ -407,7 +384,7 @@ async fn room_msg_list(
Ok(Json(RoomMsgs { msgs, skip_token }))
}
async fn room_get_metadata(
async fn get_room(
st: ArcState,
R(Path(rid), _): RE<Path<Id>>,
auth: MaybeAuth,
@ -438,68 +415,6 @@ async fn room_get_metadata(
}))
}
async fn room_get_feed<FT: feed::FeedType>(
st: ArcState,
ETag(etag): ETag<Id>,
R(OriginalUri(req_uri), _): RE<OriginalUri>,
R(Path(rid), _): RE<Path<Id>>,
R(Query(mut pagination), _): RE<Query<Pagination>>,
) -> Result<Response, ApiError> {
let self_url = st
.config
.base_url
.join(req_uri.path())
.expect("base_url can be a base");
pagination.top = Some(
pagination
.effective_page_len(&st)
.min(st.config.feed.max_page_len),
);
let (title, msgs, skip_token) = st.db.with_read(|txn| {
let (attrs, title) = txn.get_room_having(rid, RoomAttrs::PUBLIC_READABLE)?;
// Sanity check.
assert!(!attrs.contains(RoomAttrs::PEER_CHAT));
let title = title.expect("public room must have title");
let (msgs, skip_token) = query_room_msgs(&st, txn, rid, pagination)?;
Ok((title, msgs, skip_token))
})?;
// Use `Id(0)` as the tag for an empty list.
let ret_etag = msgs.first().map_or(Id(0), |msg| msg.cid);
if etag == Some(ret_etag) {
return Ok(StatusCode::NOT_MODIFIED.into_response());
}
let next_url = skip_token.map(|skip_token| {
let next_params = Pagination {
skip_token: Some(skip_token),
top: pagination.top,
until_token: None,
};
let mut next_url = self_url.clone();
{
let mut query = next_url.query_pairs_mut();
let ser = serde_urlencoded::Serializer::new(&mut query);
next_params
.serialize(ser)
.expect("serialization cannot fail");
query.finish();
}
next_url
});
let resp = FT::to_feed_response(FeedData {
rid,
title,
msgs,
self_url,
next_url,
});
Ok(([HEADER_PUBLIC_NO_CACHE], ETag(Some(ret_etag)), resp).into_response())
}
/// Get room messages with pagination parameters,
/// return a page of messages and the next `skip_token` if this is not the last page.
fn query_room_msgs(
@ -520,7 +435,7 @@ fn query_room_msgs(
Ok((msgs, skip_token))
}
async fn room_msg_post(
async fn post_room_msg(
st: ArcState,
R(Path(rid), _): RE<Path<Id>>,
SignedJson(chat): SignedJson<ChatPayload>,
@ -562,7 +477,7 @@ async fn room_msg_post(
Ok(Json(cid))
}
async fn room_admin(
async fn post_room_admin(
st: ArcState,
R(Path(rid), _): RE<Path<Id>>,
SignedJson(op): SignedJson<RoomAdminPayload>,
@ -622,7 +537,7 @@ async fn room_leave(st: &AppState, rid: Id, user: &UserKey) -> Result<(), ApiErr
})
}
async fn room_delete(
async fn delete_room(
st: ArcState,
R(Path(rid), _): RE<Path<Id>>,
SignedJson(op): SignedJson<DeleteRoomPayload>,
@ -640,7 +555,7 @@ async fn room_delete(
})
}
async fn room_msg_mark_seen(
async fn post_room_msg_seen(
st: ArcState,
R(Path((rid, cid)), _): RE<Path<(Id, i64)>>,
Auth(user): Auth,
@ -655,7 +570,7 @@ async fn room_msg_mark_seen(
Ok(StatusCode::NO_CONTENT)
}
async fn room_member_list(
async fn list_room_member(
st: ArcState,
R(Path(rid), _): RE<Path<Id>>,
R(Query(pagination), _): RE<Query<Pagination>>,

View file

@ -3,10 +3,10 @@ use std::time::Duration;
use anyhow::{anyhow, ensure};
use axum::http::StatusCode;
use blah_types::get_timestamp;
use blah_types::identity::{IdUrl, UserIdentityDesc};
use blah_types::msg::{UserRegisterChallengeResponse, UserRegisterPayload};
use blah_types::server::UserRegisterChallenge;
use blah_types::{get_timestamp, Signed};
use http_body_util::BodyExt;
use parking_lot::Mutex;
use rand::rngs::OsRng;
@ -15,8 +15,9 @@ use serde::Deserialize;
use sha2::{Digest, Sha256};
use crate::database::TransactionOps;
use crate::middleware::SignedJson;
use crate::utils::Instant;
use crate::{ApiError, AppState, SERVER_AND_VERSION};
use crate::{ApiError, ArcState, SERVER_AND_VERSION};
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
#[serde(default, deny_unknown_fields)]
@ -159,9 +160,9 @@ impl State {
}
}
pub async fn user_register(
st: &AppState,
msg: Signed<UserRegisterPayload>,
pub async fn post_user(
axum::extract::State(st): ArcState,
SignedJson(msg): SignedJson<UserRegisterPayload>,
) -> Result<StatusCode, ApiError> {
if !st.config.register.enable_public {
return Err(ApiError::Disabled("public registration is disabled"));