mirror of
https://github.com/Blah-IM/blahrs.git
synced 2025-05-01 00:31:09 +00:00
refactor(blahd): reorg and use consistent handler names
This commit is contained in:
parent
4e8124cda6
commit
97c0cf5844
4 changed files with 135 additions and 128 deletions
|
@ -8,7 +8,9 @@ use std::task::{Context, Poll};
|
|||
use std::time::Duration;
|
||||
|
||||
use anyhow::{bail, Context as _, Result};
|
||||
use axum::extract::ws::{Message, WebSocket};
|
||||
use axum::extract::ws::{close_code, CloseFrame, Message, WebSocket};
|
||||
use axum::extract::WebSocketUpgrade;
|
||||
use axum::response::Response;
|
||||
use blah_types::msg::{AuthPayload, SignedChatMsg};
|
||||
use blah_types::server::ClientEvent;
|
||||
use blah_types::Signed;
|
||||
|
@ -23,7 +25,7 @@ use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
|||
use tokio_stream::wrappers::BroadcastStream;
|
||||
|
||||
use crate::database::TransactionOps;
|
||||
use crate::AppState;
|
||||
use crate::{AppState, ArcState};
|
||||
|
||||
// We a borrowed type rather than an owned type.
|
||||
// So redefine it. Not sure if there is a better way.
|
||||
|
@ -127,7 +129,30 @@ impl Stream for UserEventReceiver {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallible> {
|
||||
// TODO: Authenticate via HTTP query?
|
||||
pub async fn get_ws(st: ArcState, ws: WebSocketUpgrade) -> Response {
|
||||
ws.on_upgrade(move |mut socket| async move {
|
||||
match handle_ws(st.0, &mut socket).await {
|
||||
#[allow(
|
||||
unreachable_patterns,
|
||||
reason = "compatibility before min_exhaustive_patterns"
|
||||
)]
|
||||
Ok(never) => match never {},
|
||||
Err(err) if err.is::<StreamEnded>() => {}
|
||||
Err(err) => {
|
||||
tracing::debug!(%err, "ws error");
|
||||
let _: Result<_, _> = socket
|
||||
.send(Message::Close(Some(CloseFrame {
|
||||
code: close_code::ERROR,
|
||||
reason: err.to_string().into(),
|
||||
})))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallible> {
|
||||
let config = &st.config.ws;
|
||||
let (ws_tx, ws_rx) = ws.split();
|
||||
let mut ws_rx = ws_rx.map(|ret| match ret {
|
||||
|
|
|
@ -3,15 +3,19 @@ use std::fmt;
|
|||
use std::num::NonZero;
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::http::header;
|
||||
use axum::extract::{OriginalUri, Path, Query};
|
||||
use axum::http::{header, StatusCode};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::Json;
|
||||
use blah_types::msg::{SignedChatMsgWithId, WithMsgId};
|
||||
use blah_types::msg::{RoomAttrs, SignedChatMsgWithId, WithMsgId};
|
||||
use blah_types::Id;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use url::Url;
|
||||
|
||||
use crate::database::TransactionOps;
|
||||
use crate::id::timestamp_of_id;
|
||||
use crate::middleware::ETag;
|
||||
use crate::{query_room_msgs, ApiError, ArcState, Pagination, HEADER_PUBLIC_NO_CACHE};
|
||||
|
||||
const JSON_FEED_MIME: &str = "application/feed+json";
|
||||
const ATOM_FEED_MIME: &str = "application/atom+xml";
|
||||
|
@ -214,3 +218,65 @@ impl FeedType for AtomFeed {
|
|||
([(header::CONTENT_TYPE, ATOM_FEED_MIME)], body).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_room_feed<FT: FeedType>(
|
||||
st: ArcState,
|
||||
ETag(etag): ETag<Id>,
|
||||
OriginalUri(req_uri): OriginalUri,
|
||||
Path(rid): Path<Id>,
|
||||
Query(mut pagination): Query<Pagination>,
|
||||
) -> Result<Response, ApiError> {
|
||||
let self_url = st
|
||||
.config
|
||||
.base_url
|
||||
.join(req_uri.path())
|
||||
.expect("base_url can be a base");
|
||||
|
||||
pagination.top = Some(
|
||||
pagination
|
||||
.effective_page_len(&st)
|
||||
.min(st.config.feed.max_page_len),
|
||||
);
|
||||
|
||||
let (title, msgs, skip_token) = st.db.with_read(|txn| {
|
||||
let (attrs, title) = txn.get_room_having(rid, RoomAttrs::PUBLIC_READABLE)?;
|
||||
// Sanity check.
|
||||
assert!(!attrs.contains(RoomAttrs::PEER_CHAT));
|
||||
let title = title.expect("public room must have title");
|
||||
let (msgs, skip_token) = query_room_msgs(&st, txn, rid, pagination)?;
|
||||
Ok((title, msgs, skip_token))
|
||||
})?;
|
||||
|
||||
// Use `Id(0)` as the tag for an empty list.
|
||||
let ret_etag = msgs.first().map_or(Id(0), |msg| msg.cid);
|
||||
if etag == Some(ret_etag) {
|
||||
return Ok(StatusCode::NOT_MODIFIED.into_response());
|
||||
}
|
||||
|
||||
let next_url = skip_token.map(|skip_token| {
|
||||
let next_params = Pagination {
|
||||
skip_token: Some(skip_token),
|
||||
top: pagination.top,
|
||||
until_token: None,
|
||||
};
|
||||
let mut next_url = self_url.clone();
|
||||
{
|
||||
let mut query = next_url.query_pairs_mut();
|
||||
let ser = serde_urlencoded::Serializer::new(&mut query);
|
||||
next_params
|
||||
.serialize(ser)
|
||||
.expect("serialization cannot fail");
|
||||
query.finish();
|
||||
}
|
||||
next_url
|
||||
});
|
||||
|
||||
let resp = FT::to_feed_response(FeedData {
|
||||
rid,
|
||||
title,
|
||||
msgs,
|
||||
self_url,
|
||||
next_url,
|
||||
});
|
||||
Ok(([HEADER_PUBLIC_NO_CACHE], ETag(Some(ret_etag)), resp).into_response())
|
||||
}
|
||||
|
|
151
blahd/src/lib.rs
151
blahd/src/lib.rs
|
@ -4,8 +4,7 @@ use std::time::Duration;
|
|||
|
||||
use anyhow::Result;
|
||||
use axum::body::Bytes;
|
||||
use axum::extract::{ws, OriginalUri};
|
||||
use axum::extract::{Path, Query, State, WebSocketUpgrade};
|
||||
use axum::extract::{Path, Query, State};
|
||||
use axum::http::{header, HeaderName, HeaderValue, StatusCode};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::routing::{get, post};
|
||||
|
@ -14,7 +13,7 @@ use axum_extra::extract::WithRejection as R;
|
|||
use blah_types::msg::{
|
||||
ChatPayload, CreateGroup, CreatePeerChat, CreateRoomPayload, DeleteRoomPayload,
|
||||
MemberPermission, RoomAdminOp, RoomAdminPayload, RoomAttrs, ServerPermission,
|
||||
SignedChatMsgWithId, UserRegisterPayload,
|
||||
SignedChatMsgWithId,
|
||||
};
|
||||
use blah_types::server::{
|
||||
ErrorResponseWithChallenge, RoomList, RoomMember, RoomMemberList, RoomMetadata, RoomMsgs,
|
||||
|
@ -23,7 +22,6 @@ use blah_types::server::{
|
|||
use blah_types::{get_timestamp, Id, Signed, UserKey};
|
||||
use data_encoding::BASE64_NOPAD;
|
||||
use database::{Transaction, TransactionOps};
|
||||
use feed::FeedData;
|
||||
use id::IdExt;
|
||||
use middleware::{Auth, ETag, MaybeAuth, ResultExt as _, SignedJson};
|
||||
use parking_lot::Mutex;
|
||||
|
@ -150,19 +148,27 @@ impl AppState {
|
|||
type ArcState = State<Arc<AppState>>;
|
||||
|
||||
pub fn router(st: Arc<AppState>) -> Router {
|
||||
// NB. User consistent handler naming: `<method>_<path>[_<details>]`.
|
||||
// Use prefix `list` for GET with pagination.
|
||||
//
|
||||
// One route per line.
|
||||
#[rustfmt::skip]
|
||||
let router = Router::new()
|
||||
.route("/server", get(handle_server_metadata))
|
||||
.route("/ws", get(handle_ws))
|
||||
.route("/user/me", get(user_get).post(user_register))
|
||||
.route("/room", get(room_list))
|
||||
.route("/room/create", post(room_create))
|
||||
.route("/room/:rid", get(room_get_metadata).delete(room_delete))
|
||||
.route("/room/:rid/feed.json", get(room_get_feed::<feed::JsonFeed>))
|
||||
.route("/room/:rid/feed.atom", get(room_get_feed::<feed::AtomFeed>))
|
||||
.route("/room/:rid/msg", get(room_msg_list).post(room_msg_post))
|
||||
.route("/room/:rid/msg/:cid/seen", post(room_msg_mark_seen))
|
||||
.route("/room/:rid/admin", post(room_admin))
|
||||
.route("/room/:rid/member", get(room_member_list))
|
||||
.route("/server", get(get_server_metadata))
|
||||
.route("/ws", get(event::get_ws))
|
||||
.route("/user/me", get(get_user).post(register::post_user))
|
||||
.route("/room", get(list_room))
|
||||
// TODO: Maybe just POST on `/room`?
|
||||
.route("/room/create", post(post_room_create))
|
||||
.route("/room/:rid", get(get_room).delete(delete_room))
|
||||
.route("/room/:rid/feed.json", get(feed::get_room_feed::<feed::JsonFeed>))
|
||||
.route("/room/:rid/feed.atom", get(feed::get_room_feed::<feed::AtomFeed>))
|
||||
.route("/room/:rid/msg", get(list_room_msg).post(post_room_msg))
|
||||
.route("/room/:rid/msg/:cid/seen", post(post_room_msg_seen))
|
||||
.route("/room/:rid/admin", post(post_room_admin))
|
||||
.route("/room/:rid/member", get(list_room_member));
|
||||
|
||||
let router = router
|
||||
.layer(tower_http::limit::RequestBodyLimitLayer::new(
|
||||
st.config.max_request_len,
|
||||
))
|
||||
|
@ -184,7 +190,7 @@ pub fn router(st: Arc<AppState>) -> Router {
|
|||
|
||||
type RE<T> = R<T, ApiError>;
|
||||
|
||||
async fn handle_server_metadata(State(st): ArcState) -> Response {
|
||||
async fn get_server_metadata(State(st): ArcState) -> Response {
|
||||
let (json, etag) = st.server_metadata.clone();
|
||||
let headers = [
|
||||
(
|
||||
|
@ -196,29 +202,7 @@ async fn handle_server_metadata(State(st): ArcState) -> Response {
|
|||
(headers, etag, json).into_response()
|
||||
}
|
||||
|
||||
async fn handle_ws(State(st): ArcState, ws: WebSocketUpgrade) -> Response {
|
||||
ws.on_upgrade(move |mut socket| async move {
|
||||
match event::handle_ws(st, &mut socket).await {
|
||||
#[allow(
|
||||
unreachable_patterns,
|
||||
reason = "compatibility before min_exhaustive_patterns"
|
||||
)]
|
||||
Ok(never) => match never {},
|
||||
Err(err) if err.is::<event::StreamEnded>() => {}
|
||||
Err(err) => {
|
||||
tracing::debug!(%err, "ws error");
|
||||
let _: Result<_, _> = socket
|
||||
.send(ws::Message::Close(Some(ws::CloseFrame {
|
||||
code: ws::close_code::ERROR,
|
||||
reason: err.to_string().into(),
|
||||
})))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn user_get(State(st): ArcState, auth: MaybeAuth) -> Response {
|
||||
async fn get_user(State(st): ArcState, auth: MaybeAuth) -> Response {
|
||||
let ret = (|| {
|
||||
match auth.into_optional()? {
|
||||
None => None,
|
||||
|
@ -243,13 +227,6 @@ async fn user_get(State(st): ArcState, auth: MaybeAuth) -> Response {
|
|||
}
|
||||
}
|
||||
|
||||
async fn user_register(
|
||||
State(st): ArcState,
|
||||
SignedJson(msg): SignedJson<UserRegisterPayload>,
|
||||
) -> Result<StatusCode, ApiError> {
|
||||
register::user_register(&st, msg).await
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
struct ListRoomParams {
|
||||
|
@ -272,7 +249,7 @@ enum ListRoomFilter {
|
|||
Unseen,
|
||||
}
|
||||
|
||||
async fn room_list(
|
||||
async fn list_room(
|
||||
st: ArcState,
|
||||
params: RE<Query<ListRoomParams>>,
|
||||
auth: MaybeAuth,
|
||||
|
@ -302,7 +279,7 @@ async fn room_list(
|
|||
Ok(Json(RoomList { rooms, skip_token }))
|
||||
}
|
||||
|
||||
async fn room_create(
|
||||
async fn post_room_create(
|
||||
st: ArcState,
|
||||
SignedJson(params): SignedJson<CreateRoomPayload>,
|
||||
) -> Result<Json<Id>, ApiError> {
|
||||
|
@ -390,7 +367,7 @@ impl Pagination {
|
|||
}
|
||||
}
|
||||
|
||||
async fn room_msg_list(
|
||||
async fn list_room_msg(
|
||||
st: ArcState,
|
||||
R(Path(rid), _): RE<Path<Id>>,
|
||||
R(Query(pagination), _): RE<Query<Pagination>>,
|
||||
|
@ -407,7 +384,7 @@ async fn room_msg_list(
|
|||
Ok(Json(RoomMsgs { msgs, skip_token }))
|
||||
}
|
||||
|
||||
async fn room_get_metadata(
|
||||
async fn get_room(
|
||||
st: ArcState,
|
||||
R(Path(rid), _): RE<Path<Id>>,
|
||||
auth: MaybeAuth,
|
||||
|
@ -438,68 +415,6 @@ async fn room_get_metadata(
|
|||
}))
|
||||
}
|
||||
|
||||
async fn room_get_feed<FT: feed::FeedType>(
|
||||
st: ArcState,
|
||||
ETag(etag): ETag<Id>,
|
||||
R(OriginalUri(req_uri), _): RE<OriginalUri>,
|
||||
R(Path(rid), _): RE<Path<Id>>,
|
||||
R(Query(mut pagination), _): RE<Query<Pagination>>,
|
||||
) -> Result<Response, ApiError> {
|
||||
let self_url = st
|
||||
.config
|
||||
.base_url
|
||||
.join(req_uri.path())
|
||||
.expect("base_url can be a base");
|
||||
|
||||
pagination.top = Some(
|
||||
pagination
|
||||
.effective_page_len(&st)
|
||||
.min(st.config.feed.max_page_len),
|
||||
);
|
||||
|
||||
let (title, msgs, skip_token) = st.db.with_read(|txn| {
|
||||
let (attrs, title) = txn.get_room_having(rid, RoomAttrs::PUBLIC_READABLE)?;
|
||||
// Sanity check.
|
||||
assert!(!attrs.contains(RoomAttrs::PEER_CHAT));
|
||||
let title = title.expect("public room must have title");
|
||||
let (msgs, skip_token) = query_room_msgs(&st, txn, rid, pagination)?;
|
||||
Ok((title, msgs, skip_token))
|
||||
})?;
|
||||
|
||||
// Use `Id(0)` as the tag for an empty list.
|
||||
let ret_etag = msgs.first().map_or(Id(0), |msg| msg.cid);
|
||||
if etag == Some(ret_etag) {
|
||||
return Ok(StatusCode::NOT_MODIFIED.into_response());
|
||||
}
|
||||
|
||||
let next_url = skip_token.map(|skip_token| {
|
||||
let next_params = Pagination {
|
||||
skip_token: Some(skip_token),
|
||||
top: pagination.top,
|
||||
until_token: None,
|
||||
};
|
||||
let mut next_url = self_url.clone();
|
||||
{
|
||||
let mut query = next_url.query_pairs_mut();
|
||||
let ser = serde_urlencoded::Serializer::new(&mut query);
|
||||
next_params
|
||||
.serialize(ser)
|
||||
.expect("serialization cannot fail");
|
||||
query.finish();
|
||||
}
|
||||
next_url
|
||||
});
|
||||
|
||||
let resp = FT::to_feed_response(FeedData {
|
||||
rid,
|
||||
title,
|
||||
msgs,
|
||||
self_url,
|
||||
next_url,
|
||||
});
|
||||
Ok(([HEADER_PUBLIC_NO_CACHE], ETag(Some(ret_etag)), resp).into_response())
|
||||
}
|
||||
|
||||
/// Get room messages with pagination parameters,
|
||||
/// return a page of messages and the next `skip_token` if this is not the last page.
|
||||
fn query_room_msgs(
|
||||
|
@ -520,7 +435,7 @@ fn query_room_msgs(
|
|||
Ok((msgs, skip_token))
|
||||
}
|
||||
|
||||
async fn room_msg_post(
|
||||
async fn post_room_msg(
|
||||
st: ArcState,
|
||||
R(Path(rid), _): RE<Path<Id>>,
|
||||
SignedJson(chat): SignedJson<ChatPayload>,
|
||||
|
@ -562,7 +477,7 @@ async fn room_msg_post(
|
|||
Ok(Json(cid))
|
||||
}
|
||||
|
||||
async fn room_admin(
|
||||
async fn post_room_admin(
|
||||
st: ArcState,
|
||||
R(Path(rid), _): RE<Path<Id>>,
|
||||
SignedJson(op): SignedJson<RoomAdminPayload>,
|
||||
|
@ -622,7 +537,7 @@ async fn room_leave(st: &AppState, rid: Id, user: &UserKey) -> Result<(), ApiErr
|
|||
})
|
||||
}
|
||||
|
||||
async fn room_delete(
|
||||
async fn delete_room(
|
||||
st: ArcState,
|
||||
R(Path(rid), _): RE<Path<Id>>,
|
||||
SignedJson(op): SignedJson<DeleteRoomPayload>,
|
||||
|
@ -640,7 +555,7 @@ async fn room_delete(
|
|||
})
|
||||
}
|
||||
|
||||
async fn room_msg_mark_seen(
|
||||
async fn post_room_msg_seen(
|
||||
st: ArcState,
|
||||
R(Path((rid, cid)), _): RE<Path<(Id, i64)>>,
|
||||
Auth(user): Auth,
|
||||
|
@ -655,7 +570,7 @@ async fn room_msg_mark_seen(
|
|||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
async fn room_member_list(
|
||||
async fn list_room_member(
|
||||
st: ArcState,
|
||||
R(Path(rid), _): RE<Path<Id>>,
|
||||
R(Query(pagination), _): RE<Query<Pagination>>,
|
||||
|
|
|
@ -3,10 +3,10 @@ use std::time::Duration;
|
|||
|
||||
use anyhow::{anyhow, ensure};
|
||||
use axum::http::StatusCode;
|
||||
use blah_types::get_timestamp;
|
||||
use blah_types::identity::{IdUrl, UserIdentityDesc};
|
||||
use blah_types::msg::{UserRegisterChallengeResponse, UserRegisterPayload};
|
||||
use blah_types::server::UserRegisterChallenge;
|
||||
use blah_types::{get_timestamp, Signed};
|
||||
use http_body_util::BodyExt;
|
||||
use parking_lot::Mutex;
|
||||
use rand::rngs::OsRng;
|
||||
|
@ -15,8 +15,9 @@ use serde::Deserialize;
|
|||
use sha2::{Digest, Sha256};
|
||||
|
||||
use crate::database::TransactionOps;
|
||||
use crate::middleware::SignedJson;
|
||||
use crate::utils::Instant;
|
||||
use crate::{ApiError, AppState, SERVER_AND_VERSION};
|
||||
use crate::{ApiError, ArcState, SERVER_AND_VERSION};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
|
||||
#[serde(default, deny_unknown_fields)]
|
||||
|
@ -159,9 +160,9 @@ impl State {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn user_register(
|
||||
st: &AppState,
|
||||
msg: Signed<UserRegisterPayload>,
|
||||
pub async fn post_user(
|
||||
axum::extract::State(st): ArcState,
|
||||
SignedJson(msg): SignedJson<UserRegisterPayload>,
|
||||
) -> Result<StatusCode, ApiError> {
|
||||
if !st.config.register.enable_public {
|
||||
return Err(ApiError::Disabled("public registration is disabled"));
|
||||
|
|
Loading…
Add table
Reference in a new issue