mirror of
https://github.com/Blah-IM/blahrs.git
synced 2025-07-02 12:35:33 +00:00
refactor(config): split into subsections and verify on parsing
This commit is contained in:
parent
93d1589730
commit
2775068e49
6 changed files with 123 additions and 77 deletions
|
@ -5,6 +5,7 @@ use std::fmt;
|
|||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{bail, Context as _, Result};
|
||||
use axum::extract::ws::{Message, WebSocket};
|
||||
|
@ -14,12 +15,12 @@ use futures_util::stream::SplitSink;
|
|||
use futures_util::{stream_select, SinkExt as _, Stream, StreamExt};
|
||||
use parking_lot::Mutex;
|
||||
use rusqlite::{params, OptionalExtension};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{de, Deserialize, Serialize};
|
||||
use serde_inline_default::serde_inline_default;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
|
||||
use crate::config::ServerConfig;
|
||||
use crate::AppState;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
@ -35,6 +36,31 @@ pub enum Outgoing<'a> {
|
|||
Lagged,
|
||||
}
|
||||
|
||||
#[serde_inline_default]
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(default, deny_unknown_fields)]
|
||||
pub struct Config {
|
||||
#[serde(deserialize_with = "de_duration_sec")]
|
||||
pub auth_timeout_sec: Duration,
|
||||
#[serde(deserialize_with = "de_duration_sec")]
|
||||
pub send_timeout_sec: Duration,
|
||||
pub event_queue_len: usize,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
auth_timeout_sec: Duration::from_secs(15),
|
||||
send_timeout_sec: Duration::from_secs(15),
|
||||
event_queue_len: 1024,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn de_duration_sec<'de, D: de::Deserializer<'de>>(de: D) -> Result<Duration, D::Error> {
|
||||
<u64>::deserialize(de).map(Duration::from_secs)
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct State {
|
||||
pub user_listeners: Mutex<HashMap<u64, UserEventSender>>,
|
||||
|
@ -53,14 +79,14 @@ impl std::error::Error for StreamEnded {}
|
|||
|
||||
struct WsSenderWrapper<'ws, 'c> {
|
||||
inner: SplitSink<&'ws mut WebSocket, Message>,
|
||||
config: &'c ServerConfig,
|
||||
config: &'c Config,
|
||||
}
|
||||
|
||||
impl WsSenderWrapper<'_, '_> {
|
||||
async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> {
|
||||
let data = serde_json::to_string(&msg).expect("serialization cannot fail");
|
||||
let fut = tokio::time::timeout(
|
||||
self.config.ws_send_timeout_sec,
|
||||
self.config.send_timeout_sec,
|
||||
self.inner.send(Message::Text(data)),
|
||||
);
|
||||
match fut.await {
|
||||
|
@ -101,15 +127,16 @@ impl Stream for UserEventReceiver {
|
|||
}
|
||||
|
||||
pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallible> {
|
||||
let config = &st.config.ws;
|
||||
let (ws_tx, ws_rx) = ws.split();
|
||||
let mut ws_rx = ws_rx.map(|ret| ret.and_then(|msg| msg.into_text()).map_err(|_| StreamEnded));
|
||||
let mut ws_tx = WsSenderWrapper {
|
||||
inner: ws_tx,
|
||||
config: &st.config,
|
||||
config,
|
||||
};
|
||||
|
||||
let uid = {
|
||||
let payload = tokio::time::timeout(st.config.ws_auth_timeout_sec, ws_rx.next())
|
||||
let payload = tokio::time::timeout(config.auth_timeout_sec, ws_rx.next())
|
||||
.await
|
||||
.context("authentication timeout")?
|
||||
.ok_or(StreamEnded)??;
|
||||
|
@ -137,7 +164,7 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib
|
|||
let rx = match st.event.user_listeners.lock().entry(uid) {
|
||||
Entry::Occupied(ent) => ent.get().subscribe(),
|
||||
Entry::Vacant(ent) => {
|
||||
let (tx, rx) = broadcast::channel(st.config.ws_event_queue_len);
|
||||
let (tx, rx) = broadcast::channel(config.event_queue_len);
|
||||
ent.insert(tx);
|
||||
rx
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue