refactor(blahd,types): hoist more types into types crate

This commit is contained in:
oxalica 2024-10-04 22:31:42 -04:00
parent 719c19dc64
commit 4e8124cda6
5 changed files with 129 additions and 74 deletions

View file

@ -1,11 +1,49 @@
//! Data types and constants for Chat Server interaction. //! Data types and constants for Chat Server interaction.
use std::fmt;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use url::Url; use url::Url;
use crate::msg::{Id, MemberPermission, RoomAttrs, SignedChatMsgWithId}; use crate::msg::{Id, MemberPermission, RoomAttrs, SignedChatMsg, SignedChatMsgWithId};
use crate::PubKey; use crate::PubKey;
/// The response object returned as body on HTTP error status.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ErrorResponse<S = String> {
/// The error object.
pub error: ErrorObject<S>,
}
/// The response object of `/_blah/user/me` endpoint on HTTP error status.
/// It contains additional registration information.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ErrorResponseWithChallenge<S> {
/// The error object.
pub error: ErrorObject<S>,
/// The challenge metadata returned by the `/_blah/user/me` endpoint for registration.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub register_challenge: Option<UserRegisterChallenge>,
}
/// The error object.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ErrorObject<S = String> {
/// A machine-readable error code string.
pub code: S,
/// A human-readable error message.
pub message: S,
}
impl<S: fmt::Display> fmt::Display for ErrorObject<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "api error ({}): {}", self.code, self.message)
}
}
impl<S: fmt::Display + fmt::Debug> std::error::Error for ErrorObject<S> {}
/// Metadata about the version and capabilities of a Chat Server. /// Metadata about the version and capabilities of a Chat Server.
/// ///
/// It should be relatively stable and do not change very often. /// It should be relatively stable and do not change very often.
@ -46,6 +84,17 @@ pub enum UserRegisterChallenge {
Unknown, Unknown,
} }
/// Response to list rooms.
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RoomList {
/// Result list of rooms.
pub rooms: Vec<RoomMetadata>,
/// The skip-token to fetch the next page.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub skip_token: Option<Id>,
}
/// The metadata of a room.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RoomMetadata { pub struct RoomMetadata {
/// Room id. /// Room id.
@ -75,3 +124,52 @@ pub struct RoomMetadata {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub peer_user: Option<PubKey>, pub peer_user: Option<PubKey>,
} }
/// Response to list room msgs.
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RoomMsgs {
/// Result list of msgs.
pub msgs: Vec<SignedChatMsgWithId>,
/// The skip-token to fetch the next page.
#[serde(skip_serializing_if = "Option::is_none")]
pub skip_token: Option<Id>,
}
/// Response to list room members.
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RoomMemberList {
/// Result list of members.
pub members: Vec<RoomMember>,
/// The skip-token to fetch the next page.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub skip_token: Option<Id>,
}
/// The description of a room member.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RoomMember {
/// The identity key of the member user.
pub id_key: PubKey,
/// The user permission in the room.
pub permission: MemberPermission,
/// The user's last seen message `cid` in the room.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub last_seen_cid: Option<Id>,
}
/// A server-to-client event.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ServerEvent {
/// A message from a joined room.
// FIXME: Include cid.
Msg(SignedChatMsg),
/// The receiver is too slow to receive and some events and are dropped.
// FIXME: Should we indefinitely buffer them or just disconnect the client instead?
Lagged,
}
/// A client-to-server event.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ClientEvent {}

View file

@ -10,6 +10,7 @@ use std::time::Duration;
use anyhow::{bail, Context as _, Result}; use anyhow::{bail, Context as _, Result};
use axum::extract::ws::{Message, WebSocket}; use axum::extract::ws::{Message, WebSocket};
use blah_types::msg::{AuthPayload, SignedChatMsg}; use blah_types::msg::{AuthPayload, SignedChatMsg};
use blah_types::server::ClientEvent;
use blah_types::Signed; use blah_types::Signed;
use futures_util::future::Either; use futures_util::future::Either;
use futures_util::stream::SplitSink; use futures_util::stream::SplitSink;
@ -24,12 +25,11 @@ use tokio_stream::wrappers::BroadcastStream;
use crate::database::TransactionOps; use crate::database::TransactionOps;
use crate::AppState; use crate::AppState;
#[derive(Debug, Deserialize)] // We a borrowed type rather than an owned type.
pub enum Incoming {} // So redefine it. Not sure if there is a better way.
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum Outgoing<'a> { enum ServerEvent<'a> {
/// A message from a joined room. /// A message from a joined room.
Msg(&'a SignedChatMsg), Msg(&'a SignedChatMsg),
/// The receiver is too slow to receive and some events and are dropped. /// The receiver is too slow to receive and some events and are dropped.
@ -84,7 +84,7 @@ struct WsSenderWrapper<'ws, 'c> {
} }
impl WsSenderWrapper<'_, '_> { impl WsSenderWrapper<'_, '_> {
async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> { async fn send(&mut self, msg: &ServerEvent<'_>) -> Result<()> {
let data = serde_json::to_string(&msg).expect("serialization cannot fail"); let data = serde_json::to_string(&msg).expect("serialization cannot fail");
let fut = tokio::time::timeout( let fut = tokio::time::timeout(
self.config.send_timeout_sec, self.config.send_timeout_sec,
@ -173,11 +173,11 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib
let mut stream = stream_select!(ws_rx.map(Either::Left), event_rx.map(Either::Right)); let mut stream = stream_select!(ws_rx.map(Either::Left), event_rx.map(Either::Right));
loop { loop {
match stream.next().await.ok_or(StreamEnded)? { match stream.next().await.ok_or(StreamEnded)? {
Either::Left(msg) => match serde_json::from_str::<Incoming>(&msg?)? {}, Either::Left(msg) => match serde_json::from_str::<ClientEvent>(&msg?)? {},
Either::Right(ret) => { Either::Right(ret) => {
let msg = match &ret { let msg = match &ret {
Ok(chat) => Outgoing::Msg(chat), Ok(chat) => ServerEvent::Msg(chat),
Err(BroadcastStreamRecvError::Lagged(_)) => Outgoing::Lagged, Err(BroadcastStreamRecvError::Lagged(_)) => ServerEvent::Lagged,
}; };
// TODO: Concurrent send. // TODO: Concurrent send.
ws_tx.send(&msg).await?; ws_tx.send(&msg).await?;

View file

@ -16,13 +16,16 @@ use blah_types::msg::{
MemberPermission, RoomAdminOp, RoomAdminPayload, RoomAttrs, ServerPermission, MemberPermission, RoomAdminOp, RoomAdminPayload, RoomAttrs, ServerPermission,
SignedChatMsgWithId, UserRegisterPayload, SignedChatMsgWithId, UserRegisterPayload,
}; };
use blah_types::server::{RoomMetadata, ServerCapabilities, ServerMetadata, UserRegisterChallenge}; use blah_types::server::{
use blah_types::{get_timestamp, Id, PubKey, Signed, UserKey}; ErrorResponseWithChallenge, RoomList, RoomMember, RoomMemberList, RoomMetadata, RoomMsgs,
ServerCapabilities, ServerMetadata,
};
use blah_types::{get_timestamp, Id, Signed, UserKey};
use data_encoding::BASE64_NOPAD; use data_encoding::BASE64_NOPAD;
use database::{Transaction, TransactionOps}; use database::{Transaction, TransactionOps};
use feed::FeedData; use feed::FeedData;
use id::IdExt; use id::IdExt;
use middleware::{Auth, ETag, MaybeAuth, RawApiError, ResultExt as _, SignedJson}; use middleware::{Auth, ETag, MaybeAuth, ResultExt as _, SignedJson};
use parking_lot::Mutex; use parking_lot::Mutex;
use serde::{Deserialize, Deserializer, Serialize}; use serde::{Deserialize, Deserializer, Serialize};
use serde_inline_default::serde_inline_default; use serde_inline_default::serde_inline_default;
@ -224,14 +227,6 @@ async fn user_get(State(st): ArcState, auth: MaybeAuth) -> Response {
.ok_or(ApiError::UserNotFound) .ok_or(ApiError::UserNotFound)
})(); })();
// TODO: Hoist this into types crate.
#[derive(Serialize)]
struct ErrResp<'a> {
error: RawApiError<'a>,
#[serde(skip_serializing_if = "Option::is_none")]
register_challenge: Option<UserRegisterChallenge>,
}
match ret { match ret {
Ok(_) => StatusCode::NO_CONTENT.into_response(), Ok(_) => StatusCode::NO_CONTENT.into_response(),
Err(err) => { Err(err) => {
@ -239,7 +234,7 @@ async fn user_get(State(st): ArcState, auth: MaybeAuth) -> Response {
if status != StatusCode::NOT_FOUND { if status != StatusCode::NOT_FOUND {
return err.into_response(); return err.into_response();
} }
let resp = Json(ErrResp { let resp = Json(ErrorResponseWithChallenge {
error: raw_err, error: raw_err,
register_challenge: st.register.challenge(&st.config.register), register_challenge: st.register.challenge(&st.config.register),
}); });
@ -255,13 +250,6 @@ async fn user_register(
register::user_register(&st, msg).await register::user_register(&st, msg).await
} }
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RoomList {
pub rooms: Vec<RoomMetadata>,
#[serde(skip_serializing_if = "Option::is_none")]
pub skip_token: Option<Id>,
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "camelCase")] #[serde(deny_unknown_fields, rename_all = "camelCase")]
struct ListRoomParams { struct ListRoomParams {
@ -402,13 +390,6 @@ impl Pagination {
} }
} }
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RoomMsgs {
pub msgs: Vec<SignedChatMsgWithId>,
#[serde(skip_serializing_if = "Option::is_none")]
pub skip_token: Option<Id>,
}
async fn room_msg_list( async fn room_msg_list(
st: ArcState, st: ArcState,
R(Path(rid), _): RE<Path<Id>>, R(Path(rid), _): RE<Path<Id>>,
@ -674,22 +655,6 @@ async fn room_msg_mark_seen(
Ok(StatusCode::NO_CONTENT) Ok(StatusCode::NO_CONTENT)
} }
// TODO: Hoist these into types crate.
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RoomMemberList {
pub members: Vec<RoomMember>,
#[serde(skip_serializing_if = "Option::is_none")]
pub skip_token: Option<Id>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RoomMember {
pub id_key: PubKey,
pub permission: MemberPermission,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub last_seen_cid: Option<Id>,
}
async fn room_member_list( async fn room_member_list(
st: ArcState, st: ArcState,
R(Path(rid), _): RE<Path<Id>>, R(Path(rid), _): RE<Path<Id>>,

View file

@ -10,6 +10,7 @@ use axum::http::{header, request, HeaderValue, StatusCode};
use axum::response::{IntoResponse, IntoResponseParts, Response, ResponseParts}; use axum::response::{IntoResponse, IntoResponseParts, Response, ResponseParts};
use axum::{async_trait, Json}; use axum::{async_trait, Json};
use blah_types::msg::AuthPayload; use blah_types::msg::AuthPayload;
use blah_types::server::ErrorObject;
use blah_types::{Signed, UserKey}; use blah_types::{Signed, UserKey};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::Serialize; use serde::Serialize;
@ -45,7 +46,7 @@ macro_rules! define_api_error {
)* )*
} }
}; };
(status, RawApiError { code, message }) (status, ErrorObject { code, message })
} }
} }
@ -76,11 +77,7 @@ pub enum ApiError {
} }
#[derive(Debug, Serialize)] pub type RawApiError<'a> = ErrorObject<&'a str>;
pub struct RawApiError<'a> {
pub code: &'a str,
pub message: &'a str,
}
macro_rules! api_ensure { macro_rules! api_ensure {
($assertion:expr, $msg:literal $(,)?) => { ($assertion:expr, $msg:literal $(,)?) => {

View file

@ -15,9 +15,12 @@ use blah_types::msg::{
SignedChatMsg, SignedChatMsgWithId, UserRegisterChallengeResponse, UserRegisterPayload, SignedChatMsg, SignedChatMsgWithId, UserRegisterChallengeResponse, UserRegisterPayload,
WithMsgId, WithMsgId,
}; };
use blah_types::server::{RoomMetadata, ServerMetadata, UserRegisterChallenge}; use blah_types::server::{
RoomList, RoomMember, RoomMemberList, RoomMetadata, RoomMsgs, ServerEvent, ServerMetadata,
UserRegisterChallenge,
};
use blah_types::{Id, SignExt, Signed, UserKey}; use blah_types::{Id, SignExt, Signed, UserKey};
use blahd::{AppState, Database, RoomList, RoomMember, RoomMemberList, RoomMsgs}; use blahd::{AppState, Database};
use ed25519_dalek::SigningKey; use ed25519_dalek::SigningKey;
use expect_test::expect; use expect_test::expect;
use futures_util::future::BoxFuture; use futures_util::future::BoxFuture;
@ -140,14 +143,6 @@ impl fmt::Display for ApiError {
impl std::error::Error for ApiError {} impl std::error::Error for ApiError {}
// TODO: Hoist this into types crate.
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum WsEvent {
// TODO: Include cid?
Msg(SignedChatMsg),
}
#[derive(Debug)] #[derive(Debug)]
struct Server { struct Server {
port: u16, port: u16,
@ -166,7 +161,7 @@ impl Server {
async fn connect_ws( async fn connect_ws(
&self, &self,
auth_user: Option<&User>, auth_user: Option<&User>,
) -> Result<impl Stream<Item = Result<WsEvent>> + Unpin> { ) -> Result<impl Stream<Item = Result<ServerEvent>> + Unpin> {
let url = format!("ws://{}/_blah/ws", self.domain()); let url = format!("ws://{}/_blah/ws", self.domain());
let (mut ws, _) = tokio_tungstenite::connect_async(url).await.unwrap(); let (mut ws, _) = tokio_tungstenite::connect_async(url).await.unwrap();
if let Some(user) = auth_user { if let Some(user) = auth_user {
@ -180,7 +175,7 @@ impl Server {
if wsmsg.is_close() { if wsmsg.is_close() {
return Ok(None); return Ok(None);
} }
let event = serde_json::from_slice::<WsEvent>(&wsmsg.into_data())?; let event = serde_json::from_slice::<ServerEvent>(&wsmsg.into_data())?;
Ok(Some(event)) Ok(Some(event))
}) })
.filter_map(|ret| std::future::ready(ret.transpose()))) .filter_map(|ret| std::future::ready(ret.transpose())))
@ -1530,7 +1525,7 @@ async fn event(server: Server) {
{ {
let chat = server.post_chat(rid1, &ALICE, "alice1").await.unwrap(); let chat = server.post_chat(rid1, &ALICE, "alice1").await.unwrap();
let got = ws.next().await.unwrap().unwrap(); let got = ws.next().await.unwrap().unwrap();
assert_eq!(got, WsEvent::Msg(chat.msg)); assert_eq!(got, ServerEvent::Msg(chat.msg));
} }
// Should receive msgs from other user. // Should receive msgs from other user.
@ -1541,7 +1536,7 @@ async fn event(server: Server) {
.unwrap(); .unwrap();
let chat = server.post_chat(rid1, &BOB, "bob1").await.unwrap(); let chat = server.post_chat(rid1, &BOB, "bob1").await.unwrap();
let got = ws.next().await.unwrap().unwrap(); let got = ws.next().await.unwrap().unwrap();
assert_eq!(got, WsEvent::Msg(chat.msg)); assert_eq!(got, ServerEvent::Msg(chat.msg));
} }
// Should receive msgs from new room. // Should receive msgs from new room.
@ -1552,7 +1547,7 @@ async fn event(server: Server) {
{ {
let chat = server.post_chat(rid2, &ALICE, "alice2").await.unwrap(); let chat = server.post_chat(rid2, &ALICE, "alice2").await.unwrap();
let got = ws.next().await.unwrap().unwrap(); let got = ws.next().await.unwrap().unwrap();
assert_eq!(got, WsEvent::Msg(chat.msg)); assert_eq!(got, ServerEvent::Msg(chat.msg));
} }
// Each streams should receive each message once. // Each streams should receive each message once.
@ -1561,9 +1556,9 @@ async fn event(server: Server) {
let chat = server.post_chat(rid1, &ALICE, "alice1").await.unwrap(); let chat = server.post_chat(rid1, &ALICE, "alice1").await.unwrap();
let got1 = ws.next().await.unwrap().unwrap(); let got1 = ws.next().await.unwrap().unwrap();
assert_eq!(got1, WsEvent::Msg(chat.msg.clone())); assert_eq!(got1, ServerEvent::Msg(chat.msg.clone()));
let got2 = ws2.next().await.unwrap().unwrap(); let got2 = ws2.next().await.unwrap().unwrap();
assert_eq!(got2, WsEvent::Msg(chat.msg)); assert_eq!(got2, ServerEvent::Msg(chat.msg));
} }
} }