From 2775068e49e64a448faa7a6c3c8061e7e7ac295e Mon Sep 17 00:00:00 2001 From: oxalica Date: Fri, 13 Sep 2024 07:20:48 -0400 Subject: [PATCH] refactor(config): split into subsections and verify on parsing --- blahd/config.example.toml | 8 ++-- blahd/src/bin/blahd.rs | 1 - blahd/src/config.rs | 91 +++++++++++++-------------------------- blahd/src/database.rs | 26 ++++++++++- blahd/src/event.rs | 41 +++++++++++++++--- blahd/src/lib.rs | 33 +++++++++++++- 6 files changed, 123 insertions(+), 77 deletions(-) diff --git a/blahd/config.example.toml b/blahd/config.example.toml index 2242e67..359e599 100644 --- a/blahd/config.example.toml +++ b/blahd/config.example.toml @@ -42,13 +42,15 @@ max_request_len = 4096 # The maximum timestamp tolerance in seconds for request validation. timestamp_tolerance_secs = 90 +[server.ws] + # The max waiting time for the first authentication message for websocket. -ws_auth_timeout_sec = 15 +auth_timeout_sec = 15 # The max waiting time for outgoing message to be received for websocket. -ws_send_timeout_sec = 15 +send_timeout_sec = 15 # Maximum number of pending events a single user can have. # If events overflow the pending buffer, older events will be dropped and # client will be notified. -ws_event_queue_len = 1024 +event_queue_len = 1024 diff --git a/blahd/src/bin/blahd.rs b/blahd/src/bin/blahd.rs index 9a5e1a8..c40cb42 100644 --- a/blahd/src/bin/blahd.rs +++ b/blahd/src/bin/blahd.rs @@ -33,7 +33,6 @@ fn main() -> Result<()> { fn parse_config(path: &std::path::Path) -> Result { let src = std::fs::read_to_string(path)?; let config = toml::from_str::(&src)?; - config.validate()?; Ok(config) } diff --git a/blahd/src/config.rs b/blahd/src/config.rs index 0c6aef4..644c787 100644 --- a/blahd/src/config.rs +++ b/blahd/src/config.rs @@ -1,33 +1,17 @@ -use std::num::NonZeroUsize; -use std::path::PathBuf; -use std::time::Duration; - -use anyhow::{ensure, Result}; -use serde::{Deserialize, Deserializer, Serialize}; +use serde::{Deserialize, Serialize}; use serde_constant::ConstBool; -use serde_inline_default::serde_inline_default; -use url::Url; + +use crate::{database, ServerConfig}; #[derive(Debug, Clone, Deserialize)] #[serde(deny_unknown_fields)] pub struct Config { - pub database: DatabaseConfig, + #[serde(default)] + pub database: database::Config, pub listen: ListenConfig, pub server: ServerConfig, } -#[serde_inline_default] -#[derive(Debug, Clone, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct DatabaseConfig { - #[serde_inline_default(false)] - pub in_memory: bool, - #[serde_inline_default("/var/lib/blahd/db.sqlite".into())] - pub path: PathBuf, - #[serde_inline_default(true)] - pub create: bool, -} - #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] pub enum ListenConfig { @@ -35,52 +19,35 @@ pub enum ListenConfig { Systemd(ConstBool), } -#[serde_inline_default] -#[derive(Debug, Clone, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct ServerConfig { - pub base_url: Url, - - #[serde_inline_default(1024.try_into().expect("not zero"))] - pub max_page_len: NonZeroUsize, - #[serde_inline_default(4096)] // 4KiB - pub max_request_len: usize, - - #[serde_inline_default(90)] - pub timestamp_tolerance_secs: u64, - - #[serde_inline_default(Duration::from_secs(15))] - #[serde(deserialize_with = "de_duration_sec")] - pub ws_auth_timeout_sec: Duration, - #[serde_inline_default(Duration::from_secs(15))] - #[serde(deserialize_with = "de_duration_sec")] - pub ws_send_timeout_sec: Duration, - #[serde_inline_default(1024)] - pub ws_event_queue_len: usize, -} - -fn de_duration_sec<'de, D: Deserializer<'de>>(de: D) -> Result { - ::deserialize(de).map(Duration::from_secs) -} - -impl Config { - pub fn validate(&self) -> Result<()> { - ensure!( - !self.server.base_url.cannot_be_a_base(), - "base_url must be able to be a base", - ); - Ok(()) - } -} - #[cfg(test)] mod tests { use super::*; #[test] - fn example_config() { + fn example() { let src = std::fs::read_to_string("config.example.toml").unwrap(); - let config = toml::from_str::(&src).unwrap(); - config.validate().unwrap(); + let _config = toml::from_str::(&src).unwrap(); + } + + #[test] + fn minimal_address() { + let src = r#" +[server] +base_url = "http://localhost" +[listen] +address = "localhost:8080" + "#; + let _config = toml::from_str::(src).unwrap(); + } + + #[test] + fn minimal_systemd() { + let src = r#" +[server] +base_url = "http://localhost" +[listen] +systemd = true + "#; + let _config = toml::from_str::(src).unwrap(); } } diff --git a/blahd/src/database.rs b/blahd/src/database.rs index 5cbdd93..5c83a81 100644 --- a/blahd/src/database.rs +++ b/blahd/src/database.rs @@ -1,10 +1,13 @@ use std::ops::DerefMut; +use std::path::PathBuf; use anyhow::{ensure, Context, Result}; use parking_lot::Mutex; use rusqlite::{params, Connection, OpenFlags}; +use serde::Deserialize; +use serde_inline_default::serde_inline_default; -use crate::config::DatabaseConfig; +const DEFAULT_DATABASE_PATH: &str = "/var/lib/blahd/db.sqlite"; static INIT_SQL: &str = include_str!("../schema.sql"); @@ -12,6 +15,25 @@ static INIT_SQL: &str = include_str!("../schema.sql"); // `echo -n 'blahd-database-0' | sha256sum | head -c5` || version const APPLICATION_ID: i32 = 0xd9e_8404; +#[serde_inline_default] +#[derive(Debug, Clone, Deserialize)] +#[serde(default, deny_unknown_fields)] +pub struct Config { + pub in_memory: bool, + pub path: PathBuf, + pub create: bool, +} + +impl Default for Config { + fn default() -> Self { + Self { + in_memory: false, + path: DEFAULT_DATABASE_PATH.into(), + create: true, + } + } +} + #[derive(Debug)] pub struct Database { conn: Mutex, @@ -25,7 +47,7 @@ impl Database { Ok(Self { conn: conn.into() }) } - pub fn open(config: &DatabaseConfig) -> Result { + pub fn open(config: &Config) -> Result { let mut conn = if config.in_memory { Connection::open_in_memory().context("failed to open in-memory database")? } else { diff --git a/blahd/src/event.rs b/blahd/src/event.rs index db8a126..d787201 100644 --- a/blahd/src/event.rs +++ b/blahd/src/event.rs @@ -5,6 +5,7 @@ use std::fmt; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use std::time::Duration; use anyhow::{bail, Context as _, Result}; use axum::extract::ws::{Message, WebSocket}; @@ -14,12 +15,12 @@ use futures_util::stream::SplitSink; use futures_util::{stream_select, SinkExt as _, Stream, StreamExt}; use parking_lot::Mutex; use rusqlite::{params, OptionalExtension}; -use serde::{Deserialize, Serialize}; +use serde::{de, Deserialize, Serialize}; +use serde_inline_default::serde_inline_default; use tokio::sync::broadcast; use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use tokio_stream::wrappers::BroadcastStream; -use crate::config::ServerConfig; use crate::AppState; #[derive(Debug, Deserialize)] @@ -35,6 +36,31 @@ pub enum Outgoing<'a> { Lagged, } +#[serde_inline_default] +#[derive(Debug, Clone, Deserialize)] +#[serde(default, deny_unknown_fields)] +pub struct Config { + #[serde(deserialize_with = "de_duration_sec")] + pub auth_timeout_sec: Duration, + #[serde(deserialize_with = "de_duration_sec")] + pub send_timeout_sec: Duration, + pub event_queue_len: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + auth_timeout_sec: Duration::from_secs(15), + send_timeout_sec: Duration::from_secs(15), + event_queue_len: 1024, + } + } +} + +fn de_duration_sec<'de, D: de::Deserializer<'de>>(de: D) -> Result { + ::deserialize(de).map(Duration::from_secs) +} + #[derive(Debug, Default)] pub struct State { pub user_listeners: Mutex>, @@ -53,14 +79,14 @@ impl std::error::Error for StreamEnded {} struct WsSenderWrapper<'ws, 'c> { inner: SplitSink<&'ws mut WebSocket, Message>, - config: &'c ServerConfig, + config: &'c Config, } impl WsSenderWrapper<'_, '_> { async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> { let data = serde_json::to_string(&msg).expect("serialization cannot fail"); let fut = tokio::time::timeout( - self.config.ws_send_timeout_sec, + self.config.send_timeout_sec, self.inner.send(Message::Text(data)), ); match fut.await { @@ -101,15 +127,16 @@ impl Stream for UserEventReceiver { } pub async fn handle_ws(st: Arc, ws: &mut WebSocket) -> Result { + let config = &st.config.ws; let (ws_tx, ws_rx) = ws.split(); let mut ws_rx = ws_rx.map(|ret| ret.and_then(|msg| msg.into_text()).map_err(|_| StreamEnded)); let mut ws_tx = WsSenderWrapper { inner: ws_tx, - config: &st.config, + config, }; let uid = { - let payload = tokio::time::timeout(st.config.ws_auth_timeout_sec, ws_rx.next()) + let payload = tokio::time::timeout(config.auth_timeout_sec, ws_rx.next()) .await .context("authentication timeout")? .ok_or(StreamEnded)??; @@ -137,7 +164,7 @@ pub async fn handle_ws(st: Arc, ws: &mut WebSocket) -> Result ent.get().subscribe(), Entry::Vacant(ent) => { - let (tx, rx) = broadcast::channel(st.config.ws_event_queue_len); + let (tx, rx) = broadcast::channel(config.event_queue_len); ent.insert(tx); rx } diff --git a/blahd/src/lib.rs b/blahd/src/lib.rs index edcfcd3..60f45fe 100644 --- a/blahd/src/lib.rs +++ b/blahd/src/lib.rs @@ -15,13 +15,13 @@ use blah_types::{ RoomAdminPayload, RoomAttrs, RoomMetadata, ServerPermission, Signed, SignedChatMsg, Signee, UserKey, WithMsgId, }; -use config::ServerConfig; use ed25519_dalek::SIGNATURE_LENGTH; use id::IdExt; use middleware::{Auth, MaybeAuth, ResultExt as _, SignedJson}; use parking_lot::Mutex; use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql}; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; +use serde_inline_default::serde_inline_default; use url::Url; use utils::ExpiringSet; @@ -36,6 +36,35 @@ mod utils; pub use database::Database; pub use middleware::ApiError; +#[serde_inline_default] +#[derive(Debug, Clone, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct ServerConfig { + #[serde(deserialize_with = "de_base_url")] + pub base_url: Url, + + #[serde_inline_default(1024.try_into().expect("not zero"))] + pub max_page_len: NonZeroUsize, + #[serde_inline_default(4096)] // 4KiB + pub max_request_len: usize, + + #[serde_inline_default(90)] + pub timestamp_tolerance_secs: u64, + + #[serde(default)] + pub ws: event::Config, +} + +fn de_base_url<'de, D: Deserializer<'de>>(de: D) -> Result { + let url = Url::deserialize(de)?; + if url.cannot_be_a_base() { + return Err(serde::de::Error::custom( + "base_url must be able to be a base", + )); + } + Ok(url) +} + // Locks must be grabbed in the field order. #[derive(Debug)] pub struct AppState {