refactor(config): split into subsections and verify on parsing

This commit is contained in:
oxalica 2024-09-13 07:20:48 -04:00
parent 93d1589730
commit 2775068e49
6 changed files with 123 additions and 77 deletions

View file

@ -5,6 +5,7 @@ use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use anyhow::{bail, Context as _, Result};
use axum::extract::ws::{Message, WebSocket};
@ -14,12 +15,12 @@ use futures_util::stream::SplitSink;
use futures_util::{stream_select, SinkExt as _, Stream, StreamExt};
use parking_lot::Mutex;
use rusqlite::{params, OptionalExtension};
use serde::{Deserialize, Serialize};
use serde::{de, Deserialize, Serialize};
use serde_inline_default::serde_inline_default;
use tokio::sync::broadcast;
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tokio_stream::wrappers::BroadcastStream;
use crate::config::ServerConfig;
use crate::AppState;
#[derive(Debug, Deserialize)]
@ -35,6 +36,31 @@ pub enum Outgoing<'a> {
Lagged,
}
#[serde_inline_default]
#[derive(Debug, Clone, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct Config {
#[serde(deserialize_with = "de_duration_sec")]
pub auth_timeout_sec: Duration,
#[serde(deserialize_with = "de_duration_sec")]
pub send_timeout_sec: Duration,
pub event_queue_len: usize,
}
impl Default for Config {
fn default() -> Self {
Self {
auth_timeout_sec: Duration::from_secs(15),
send_timeout_sec: Duration::from_secs(15),
event_queue_len: 1024,
}
}
}
fn de_duration_sec<'de, D: de::Deserializer<'de>>(de: D) -> Result<Duration, D::Error> {
<u64>::deserialize(de).map(Duration::from_secs)
}
#[derive(Debug, Default)]
pub struct State {
pub user_listeners: Mutex<HashMap<u64, UserEventSender>>,
@ -53,14 +79,14 @@ impl std::error::Error for StreamEnded {}
struct WsSenderWrapper<'ws, 'c> {
inner: SplitSink<&'ws mut WebSocket, Message>,
config: &'c ServerConfig,
config: &'c Config,
}
impl WsSenderWrapper<'_, '_> {
async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> {
let data = serde_json::to_string(&msg).expect("serialization cannot fail");
let fut = tokio::time::timeout(
self.config.ws_send_timeout_sec,
self.config.send_timeout_sec,
self.inner.send(Message::Text(data)),
);
match fut.await {
@ -101,15 +127,16 @@ impl Stream for UserEventReceiver {
}
pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallible> {
let config = &st.config.ws;
let (ws_tx, ws_rx) = ws.split();
let mut ws_rx = ws_rx.map(|ret| ret.and_then(|msg| msg.into_text()).map_err(|_| StreamEnded));
let mut ws_tx = WsSenderWrapper {
inner: ws_tx,
config: &st.config,
config,
};
let uid = {
let payload = tokio::time::timeout(st.config.ws_auth_timeout_sec, ws_rx.next())
let payload = tokio::time::timeout(config.auth_timeout_sec, ws_rx.next())
.await
.context("authentication timeout")?
.ok_or(StreamEnded)??;
@ -137,7 +164,7 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib
let rx = match st.event.user_listeners.lock().entry(uid) {
Entry::Occupied(ent) => ent.get().subscribe(),
Entry::Vacant(ent) => {
let (tx, rx) = broadcast::channel(st.config.ws_event_queue_len);
let (tx, rx) = broadcast::channel(config.event_queue_len);
ent.insert(tx);
rx
}