mirror of
https://github.com/Blah-IM/blahrs.git
synced 2025-05-01 00:31:09 +00:00
refactor(config): split into subsections and verify on parsing
This commit is contained in:
parent
93d1589730
commit
2775068e49
6 changed files with 123 additions and 77 deletions
|
@ -42,13 +42,15 @@ max_request_len = 4096
|
|||
# The maximum timestamp tolerance in seconds for request validation.
|
||||
timestamp_tolerance_secs = 90
|
||||
|
||||
[server.ws]
|
||||
|
||||
# The max waiting time for the first authentication message for websocket.
|
||||
ws_auth_timeout_sec = 15
|
||||
auth_timeout_sec = 15
|
||||
|
||||
# The max waiting time for outgoing message to be received for websocket.
|
||||
ws_send_timeout_sec = 15
|
||||
send_timeout_sec = 15
|
||||
|
||||
# Maximum number of pending events a single user can have.
|
||||
# If events overflow the pending buffer, older events will be dropped and
|
||||
# client will be notified.
|
||||
ws_event_queue_len = 1024
|
||||
event_queue_len = 1024
|
||||
|
|
|
@ -33,7 +33,6 @@ fn main() -> Result<()> {
|
|||
fn parse_config(path: &std::path::Path) -> Result<Config> {
|
||||
let src = std::fs::read_to_string(path)?;
|
||||
let config = toml::from_str::<Config>(&src)?;
|
||||
config.validate()?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,33 +1,17 @@
|
|||
use std::num::NonZeroUsize;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{ensure, Result};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_constant::ConstBool;
|
||||
use serde_inline_default::serde_inline_default;
|
||||
use url::Url;
|
||||
|
||||
use crate::{database, ServerConfig};
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct Config {
|
||||
pub database: DatabaseConfig,
|
||||
#[serde(default)]
|
||||
pub database: database::Config,
|
||||
pub listen: ListenConfig,
|
||||
pub server: ServerConfig,
|
||||
}
|
||||
|
||||
#[serde_inline_default]
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct DatabaseConfig {
|
||||
#[serde_inline_default(false)]
|
||||
pub in_memory: bool,
|
||||
#[serde_inline_default("/var/lib/blahd/db.sqlite".into())]
|
||||
pub path: PathBuf,
|
||||
#[serde_inline_default(true)]
|
||||
pub create: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ListenConfig {
|
||||
|
@ -35,52 +19,35 @@ pub enum ListenConfig {
|
|||
Systemd(ConstBool<true>),
|
||||
}
|
||||
|
||||
#[serde_inline_default]
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct ServerConfig {
|
||||
pub base_url: Url,
|
||||
|
||||
#[serde_inline_default(1024.try_into().expect("not zero"))]
|
||||
pub max_page_len: NonZeroUsize,
|
||||
#[serde_inline_default(4096)] // 4KiB
|
||||
pub max_request_len: usize,
|
||||
|
||||
#[serde_inline_default(90)]
|
||||
pub timestamp_tolerance_secs: u64,
|
||||
|
||||
#[serde_inline_default(Duration::from_secs(15))]
|
||||
#[serde(deserialize_with = "de_duration_sec")]
|
||||
pub ws_auth_timeout_sec: Duration,
|
||||
#[serde_inline_default(Duration::from_secs(15))]
|
||||
#[serde(deserialize_with = "de_duration_sec")]
|
||||
pub ws_send_timeout_sec: Duration,
|
||||
#[serde_inline_default(1024)]
|
||||
pub ws_event_queue_len: usize,
|
||||
}
|
||||
|
||||
fn de_duration_sec<'de, D: Deserializer<'de>>(de: D) -> Result<Duration, D::Error> {
|
||||
<u64>::deserialize(de).map(Duration::from_secs)
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
ensure!(
|
||||
!self.server.base_url.cannot_be_a_base(),
|
||||
"base_url must be able to be a base",
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn example_config() {
|
||||
fn example() {
|
||||
let src = std::fs::read_to_string("config.example.toml").unwrap();
|
||||
let config = toml::from_str::<Config>(&src).unwrap();
|
||||
config.validate().unwrap();
|
||||
let _config = toml::from_str::<Config>(&src).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn minimal_address() {
|
||||
let src = r#"
|
||||
[server]
|
||||
base_url = "http://localhost"
|
||||
[listen]
|
||||
address = "localhost:8080"
|
||||
"#;
|
||||
let _config = toml::from_str::<Config>(src).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn minimal_systemd() {
|
||||
let src = r#"
|
||||
[server]
|
||||
base_url = "http://localhost"
|
||||
[listen]
|
||||
systemd = true
|
||||
"#;
|
||||
let _config = toml::from_str::<Config>(src).unwrap();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
use std::ops::DerefMut;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::{ensure, Context, Result};
|
||||
use parking_lot::Mutex;
|
||||
use rusqlite::{params, Connection, OpenFlags};
|
||||
use serde::Deserialize;
|
||||
use serde_inline_default::serde_inline_default;
|
||||
|
||||
use crate::config::DatabaseConfig;
|
||||
const DEFAULT_DATABASE_PATH: &str = "/var/lib/blahd/db.sqlite";
|
||||
|
||||
static INIT_SQL: &str = include_str!("../schema.sql");
|
||||
|
||||
|
@ -12,6 +15,25 @@ static INIT_SQL: &str = include_str!("../schema.sql");
|
|||
// `echo -n 'blahd-database-0' | sha256sum | head -c5` || version
|
||||
const APPLICATION_ID: i32 = 0xd9e_8404;
|
||||
|
||||
#[serde_inline_default]
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(default, deny_unknown_fields)]
|
||||
pub struct Config {
|
||||
pub in_memory: bool,
|
||||
pub path: PathBuf,
|
||||
pub create: bool,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
in_memory: false,
|
||||
path: DEFAULT_DATABASE_PATH.into(),
|
||||
create: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Database {
|
||||
conn: Mutex<Connection>,
|
||||
|
@ -25,7 +47,7 @@ impl Database {
|
|||
Ok(Self { conn: conn.into() })
|
||||
}
|
||||
|
||||
pub fn open(config: &DatabaseConfig) -> Result<Self> {
|
||||
pub fn open(config: &Config) -> Result<Self> {
|
||||
let mut conn = if config.in_memory {
|
||||
Connection::open_in_memory().context("failed to open in-memory database")?
|
||||
} else {
|
||||
|
|
|
@ -5,6 +5,7 @@ use std::fmt;
|
|||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{bail, Context as _, Result};
|
||||
use axum::extract::ws::{Message, WebSocket};
|
||||
|
@ -14,12 +15,12 @@ use futures_util::stream::SplitSink;
|
|||
use futures_util::{stream_select, SinkExt as _, Stream, StreamExt};
|
||||
use parking_lot::Mutex;
|
||||
use rusqlite::{params, OptionalExtension};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{de, Deserialize, Serialize};
|
||||
use serde_inline_default::serde_inline_default;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
|
||||
use crate::config::ServerConfig;
|
||||
use crate::AppState;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
@ -35,6 +36,31 @@ pub enum Outgoing<'a> {
|
|||
Lagged,
|
||||
}
|
||||
|
||||
#[serde_inline_default]
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(default, deny_unknown_fields)]
|
||||
pub struct Config {
|
||||
#[serde(deserialize_with = "de_duration_sec")]
|
||||
pub auth_timeout_sec: Duration,
|
||||
#[serde(deserialize_with = "de_duration_sec")]
|
||||
pub send_timeout_sec: Duration,
|
||||
pub event_queue_len: usize,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
auth_timeout_sec: Duration::from_secs(15),
|
||||
send_timeout_sec: Duration::from_secs(15),
|
||||
event_queue_len: 1024,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn de_duration_sec<'de, D: de::Deserializer<'de>>(de: D) -> Result<Duration, D::Error> {
|
||||
<u64>::deserialize(de).map(Duration::from_secs)
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct State {
|
||||
pub user_listeners: Mutex<HashMap<u64, UserEventSender>>,
|
||||
|
@ -53,14 +79,14 @@ impl std::error::Error for StreamEnded {}
|
|||
|
||||
struct WsSenderWrapper<'ws, 'c> {
|
||||
inner: SplitSink<&'ws mut WebSocket, Message>,
|
||||
config: &'c ServerConfig,
|
||||
config: &'c Config,
|
||||
}
|
||||
|
||||
impl WsSenderWrapper<'_, '_> {
|
||||
async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> {
|
||||
let data = serde_json::to_string(&msg).expect("serialization cannot fail");
|
||||
let fut = tokio::time::timeout(
|
||||
self.config.ws_send_timeout_sec,
|
||||
self.config.send_timeout_sec,
|
||||
self.inner.send(Message::Text(data)),
|
||||
);
|
||||
match fut.await {
|
||||
|
@ -101,15 +127,16 @@ impl Stream for UserEventReceiver {
|
|||
}
|
||||
|
||||
pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallible> {
|
||||
let config = &st.config.ws;
|
||||
let (ws_tx, ws_rx) = ws.split();
|
||||
let mut ws_rx = ws_rx.map(|ret| ret.and_then(|msg| msg.into_text()).map_err(|_| StreamEnded));
|
||||
let mut ws_tx = WsSenderWrapper {
|
||||
inner: ws_tx,
|
||||
config: &st.config,
|
||||
config,
|
||||
};
|
||||
|
||||
let uid = {
|
||||
let payload = tokio::time::timeout(st.config.ws_auth_timeout_sec, ws_rx.next())
|
||||
let payload = tokio::time::timeout(config.auth_timeout_sec, ws_rx.next())
|
||||
.await
|
||||
.context("authentication timeout")?
|
||||
.ok_or(StreamEnded)??;
|
||||
|
@ -137,7 +164,7 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib
|
|||
let rx = match st.event.user_listeners.lock().entry(uid) {
|
||||
Entry::Occupied(ent) => ent.get().subscribe(),
|
||||
Entry::Vacant(ent) => {
|
||||
let (tx, rx) = broadcast::channel(st.config.ws_event_queue_len);
|
||||
let (tx, rx) = broadcast::channel(config.event_queue_len);
|
||||
ent.insert(tx);
|
||||
rx
|
||||
}
|
||||
|
|
|
@ -15,13 +15,13 @@ use blah_types::{
|
|||
RoomAdminPayload, RoomAttrs, RoomMetadata, ServerPermission, Signed, SignedChatMsg, Signee,
|
||||
UserKey, WithMsgId,
|
||||
};
|
||||
use config::ServerConfig;
|
||||
use ed25519_dalek::SIGNATURE_LENGTH;
|
||||
use id::IdExt;
|
||||
use middleware::{Auth, MaybeAuth, ResultExt as _, SignedJson};
|
||||
use parking_lot::Mutex;
|
||||
use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use serde_inline_default::serde_inline_default;
|
||||
use url::Url;
|
||||
use utils::ExpiringSet;
|
||||
|
||||
|
@ -36,6 +36,35 @@ mod utils;
|
|||
pub use database::Database;
|
||||
pub use middleware::ApiError;
|
||||
|
||||
#[serde_inline_default]
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct ServerConfig {
|
||||
#[serde(deserialize_with = "de_base_url")]
|
||||
pub base_url: Url,
|
||||
|
||||
#[serde_inline_default(1024.try_into().expect("not zero"))]
|
||||
pub max_page_len: NonZeroUsize,
|
||||
#[serde_inline_default(4096)] // 4KiB
|
||||
pub max_request_len: usize,
|
||||
|
||||
#[serde_inline_default(90)]
|
||||
pub timestamp_tolerance_secs: u64,
|
||||
|
||||
#[serde(default)]
|
||||
pub ws: event::Config,
|
||||
}
|
||||
|
||||
fn de_base_url<'de, D: Deserializer<'de>>(de: D) -> Result<Url, D::Error> {
|
||||
let url = Url::deserialize(de)?;
|
||||
if url.cannot_be_a_base() {
|
||||
return Err(serde::de::Error::custom(
|
||||
"base_url must be able to be a base",
|
||||
));
|
||||
}
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
// Locks must be grabbed in the field order.
|
||||
#[derive(Debug)]
|
||||
pub struct AppState {
|
||||
|
|
Loading…
Add table
Reference in a new issue