refactor(config): split into subsections and verify on parsing

This commit is contained in:
oxalica 2024-09-13 07:20:48 -04:00
parent 93d1589730
commit 2775068e49
6 changed files with 123 additions and 77 deletions

View file

@ -42,13 +42,15 @@ max_request_len = 4096
# The maximum timestamp tolerance in seconds for request validation.
timestamp_tolerance_secs = 90
[server.ws]
# The max waiting time for the first authentication message for websocket.
ws_auth_timeout_sec = 15
auth_timeout_sec = 15
# The max waiting time for outgoing message to be received for websocket.
ws_send_timeout_sec = 15
send_timeout_sec = 15
# Maximum number of pending events a single user can have.
# If events overflow the pending buffer, older events will be dropped and
# client will be notified.
ws_event_queue_len = 1024
event_queue_len = 1024

View file

@ -33,7 +33,6 @@ fn main() -> Result<()> {
fn parse_config(path: &std::path::Path) -> Result<Config> {
let src = std::fs::read_to_string(path)?;
let config = toml::from_str::<Config>(&src)?;
config.validate()?;
Ok(config)
}

View file

@ -1,33 +1,17 @@
use std::num::NonZeroUsize;
use std::path::PathBuf;
use std::time::Duration;
use anyhow::{ensure, Result};
use serde::{Deserialize, Deserializer, Serialize};
use serde::{Deserialize, Serialize};
use serde_constant::ConstBool;
use serde_inline_default::serde_inline_default;
use url::Url;
use crate::{database, ServerConfig};
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Config {
pub database: DatabaseConfig,
#[serde(default)]
pub database: database::Config,
pub listen: ListenConfig,
pub server: ServerConfig,
}
#[serde_inline_default]
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct DatabaseConfig {
#[serde_inline_default(false)]
pub in_memory: bool,
#[serde_inline_default("/var/lib/blahd/db.sqlite".into())]
pub path: PathBuf,
#[serde_inline_default(true)]
pub create: bool,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ListenConfig {
@ -35,52 +19,35 @@ pub enum ListenConfig {
Systemd(ConstBool<true>),
}
#[serde_inline_default]
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ServerConfig {
pub base_url: Url,
#[serde_inline_default(1024.try_into().expect("not zero"))]
pub max_page_len: NonZeroUsize,
#[serde_inline_default(4096)] // 4KiB
pub max_request_len: usize,
#[serde_inline_default(90)]
pub timestamp_tolerance_secs: u64,
#[serde_inline_default(Duration::from_secs(15))]
#[serde(deserialize_with = "de_duration_sec")]
pub ws_auth_timeout_sec: Duration,
#[serde_inline_default(Duration::from_secs(15))]
#[serde(deserialize_with = "de_duration_sec")]
pub ws_send_timeout_sec: Duration,
#[serde_inline_default(1024)]
pub ws_event_queue_len: usize,
}
fn de_duration_sec<'de, D: Deserializer<'de>>(de: D) -> Result<Duration, D::Error> {
<u64>::deserialize(de).map(Duration::from_secs)
}
impl Config {
pub fn validate(&self) -> Result<()> {
ensure!(
!self.server.base_url.cannot_be_a_base(),
"base_url must be able to be a base",
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn example_config() {
fn example() {
let src = std::fs::read_to_string("config.example.toml").unwrap();
let config = toml::from_str::<Config>(&src).unwrap();
config.validate().unwrap();
let _config = toml::from_str::<Config>(&src).unwrap();
}
#[test]
fn minimal_address() {
let src = r#"
[server]
base_url = "http://localhost"
[listen]
address = "localhost:8080"
"#;
let _config = toml::from_str::<Config>(src).unwrap();
}
#[test]
fn minimal_systemd() {
let src = r#"
[server]
base_url = "http://localhost"
[listen]
systemd = true
"#;
let _config = toml::from_str::<Config>(src).unwrap();
}
}

View file

@ -1,10 +1,13 @@
use std::ops::DerefMut;
use std::path::PathBuf;
use anyhow::{ensure, Context, Result};
use parking_lot::Mutex;
use rusqlite::{params, Connection, OpenFlags};
use serde::Deserialize;
use serde_inline_default::serde_inline_default;
use crate::config::DatabaseConfig;
const DEFAULT_DATABASE_PATH: &str = "/var/lib/blahd/db.sqlite";
static INIT_SQL: &str = include_str!("../schema.sql");
@ -12,6 +15,25 @@ static INIT_SQL: &str = include_str!("../schema.sql");
// `echo -n 'blahd-database-0' | sha256sum | head -c5` || version
const APPLICATION_ID: i32 = 0xd9e_8404;
#[serde_inline_default]
#[derive(Debug, Clone, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct Config {
pub in_memory: bool,
pub path: PathBuf,
pub create: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
in_memory: false,
path: DEFAULT_DATABASE_PATH.into(),
create: true,
}
}
}
#[derive(Debug)]
pub struct Database {
conn: Mutex<Connection>,
@ -25,7 +47,7 @@ impl Database {
Ok(Self { conn: conn.into() })
}
pub fn open(config: &DatabaseConfig) -> Result<Self> {
pub fn open(config: &Config) -> Result<Self> {
let mut conn = if config.in_memory {
Connection::open_in_memory().context("failed to open in-memory database")?
} else {

View file

@ -5,6 +5,7 @@ use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use anyhow::{bail, Context as _, Result};
use axum::extract::ws::{Message, WebSocket};
@ -14,12 +15,12 @@ use futures_util::stream::SplitSink;
use futures_util::{stream_select, SinkExt as _, Stream, StreamExt};
use parking_lot::Mutex;
use rusqlite::{params, OptionalExtension};
use serde::{Deserialize, Serialize};
use serde::{de, Deserialize, Serialize};
use serde_inline_default::serde_inline_default;
use tokio::sync::broadcast;
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tokio_stream::wrappers::BroadcastStream;
use crate::config::ServerConfig;
use crate::AppState;
#[derive(Debug, Deserialize)]
@ -35,6 +36,31 @@ pub enum Outgoing<'a> {
Lagged,
}
#[serde_inline_default]
#[derive(Debug, Clone, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct Config {
#[serde(deserialize_with = "de_duration_sec")]
pub auth_timeout_sec: Duration,
#[serde(deserialize_with = "de_duration_sec")]
pub send_timeout_sec: Duration,
pub event_queue_len: usize,
}
impl Default for Config {
fn default() -> Self {
Self {
auth_timeout_sec: Duration::from_secs(15),
send_timeout_sec: Duration::from_secs(15),
event_queue_len: 1024,
}
}
}
fn de_duration_sec<'de, D: de::Deserializer<'de>>(de: D) -> Result<Duration, D::Error> {
<u64>::deserialize(de).map(Duration::from_secs)
}
#[derive(Debug, Default)]
pub struct State {
pub user_listeners: Mutex<HashMap<u64, UserEventSender>>,
@ -53,14 +79,14 @@ impl std::error::Error for StreamEnded {}
struct WsSenderWrapper<'ws, 'c> {
inner: SplitSink<&'ws mut WebSocket, Message>,
config: &'c ServerConfig,
config: &'c Config,
}
impl WsSenderWrapper<'_, '_> {
async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> {
let data = serde_json::to_string(&msg).expect("serialization cannot fail");
let fut = tokio::time::timeout(
self.config.ws_send_timeout_sec,
self.config.send_timeout_sec,
self.inner.send(Message::Text(data)),
);
match fut.await {
@ -101,15 +127,16 @@ impl Stream for UserEventReceiver {
}
pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallible> {
let config = &st.config.ws;
let (ws_tx, ws_rx) = ws.split();
let mut ws_rx = ws_rx.map(|ret| ret.and_then(|msg| msg.into_text()).map_err(|_| StreamEnded));
let mut ws_tx = WsSenderWrapper {
inner: ws_tx,
config: &st.config,
config,
};
let uid = {
let payload = tokio::time::timeout(st.config.ws_auth_timeout_sec, ws_rx.next())
let payload = tokio::time::timeout(config.auth_timeout_sec, ws_rx.next())
.await
.context("authentication timeout")?
.ok_or(StreamEnded)??;
@ -137,7 +164,7 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib
let rx = match st.event.user_listeners.lock().entry(uid) {
Entry::Occupied(ent) => ent.get().subscribe(),
Entry::Vacant(ent) => {
let (tx, rx) = broadcast::channel(st.config.ws_event_queue_len);
let (tx, rx) = broadcast::channel(config.event_queue_len);
ent.insert(tx);
rx
}

View file

@ -15,13 +15,13 @@ use blah_types::{
RoomAdminPayload, RoomAttrs, RoomMetadata, ServerPermission, Signed, SignedChatMsg, Signee,
UserKey, WithMsgId,
};
use config::ServerConfig;
use ed25519_dalek::SIGNATURE_LENGTH;
use id::IdExt;
use middleware::{Auth, MaybeAuth, ResultExt as _, SignedJson};
use parking_lot::Mutex;
use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql};
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use serde_inline_default::serde_inline_default;
use url::Url;
use utils::ExpiringSet;
@ -36,6 +36,35 @@ mod utils;
pub use database::Database;
pub use middleware::ApiError;
#[serde_inline_default]
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ServerConfig {
#[serde(deserialize_with = "de_base_url")]
pub base_url: Url,
#[serde_inline_default(1024.try_into().expect("not zero"))]
pub max_page_len: NonZeroUsize,
#[serde_inline_default(4096)] // 4KiB
pub max_request_len: usize,
#[serde_inline_default(90)]
pub timestamp_tolerance_secs: u64,
#[serde(default)]
pub ws: event::Config,
}
fn de_base_url<'de, D: Deserializer<'de>>(de: D) -> Result<Url, D::Error> {
let url = Url::deserialize(de)?;
if url.cannot_be_a_base() {
return Err(serde::de::Error::custom(
"base_url must be able to be a base",
));
}
Ok(url)
}
// Locks must be grabbed in the field order.
#[derive(Debug)]
pub struct AppState {