diff --git a/.gitignore b/.gitignore index aed3a5a..2e1278f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ /target *.sqlite* *.key + +config.toml diff --git a/Cargo.lock b/Cargo.lock index a3ebc91..eb641a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -226,6 +226,15 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +[[package]] +name = "basic-toml" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "823388e228f614e9558c6804262db37960ec8821856535f5c3f59913140558f8" +dependencies = [ + "serde", +] + [[package]] name = "bitflags" version = "2.6.0" @@ -280,6 +289,7 @@ dependencies = [ "anyhow", "axum", "axum-extra", + "basic-toml", "blah", "clap", "ed25519-dalek", diff --git a/blahd/Cargo.toml b/blahd/Cargo.toml index 2493dc8..f4ab712 100644 --- a/blahd/Cargo.toml +++ b/blahd/Cargo.toml @@ -25,3 +25,4 @@ tracing-subscriber = "0.3" uuid = { version = "1", features = ["v4"] } blah = { path = "..", features = ["rusqlite"] } +basic-toml = "0.1.9" diff --git a/blahd/config.example.toml b/blahd/config.example.toml new file mode 100644 index 0000000..a38667f --- /dev/null +++ b/blahd/config.example.toml @@ -0,0 +1,13 @@ +[database] +# The path to the main SQLite database. +# It will be created and initialized if not exist. +path = "/path/to/db.sqlite" + +[server] + +# The socket address to listen on. +listen = "localhost:8080" + +# The global absolute URL prefix where this service is hosted. +# It is for link generation and must not have trailing slash. +base_url = "http://localhost:8080" diff --git a/blahd/src/config.rs b/blahd/src/config.rs new file mode 100644 index 0000000..bd5f028 --- /dev/null +++ b/blahd/src/config.rs @@ -0,0 +1,46 @@ +use std::path::PathBuf; + +use anyhow::{ensure, Result}; +use serde::Deserialize; + +#[derive(Debug, Clone, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct Config { + pub database: DatabaseConfig, + pub server: ServerConfig, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct DatabaseConfig { + pub path: PathBuf, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct ServerConfig { + pub listen: String, + pub base_url: String, +} + +impl Config { + pub fn validate(&self) -> Result<()> { + ensure!( + !self.server.base_url.ends_with("/"), + "base_url must not have trailing slash", + ); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn example_config() { + let src = std::fs::read_to_string("config.example.toml").unwrap(); + let config = basic_toml::from_str::(&src).unwrap(); + config.validate().unwrap(); + } +} diff --git a/blahd/src/main.rs b/blahd/src/main.rs index c5ad394..20963a1 100644 --- a/blahd/src/main.rs +++ b/blahd/src/main.rs @@ -5,7 +5,7 @@ use std::path::PathBuf; use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime}; -use anyhow::{ensure, Context, Result}; +use anyhow::{Context, Result}; use axum::extract::{Path, Query, State}; use axum::http::{header, StatusCode}; use axum::response::{sse, IntoResponse}; @@ -16,6 +16,7 @@ use blah::types::{ ChatItem, ChatPayload, CreateRoomPayload, MemberPermission, RoomAttrs, ServerPermission, Signee, UserKey, WithSig, }; +use config::Config; use ed25519_dalek::SIGNATURE_LENGTH; use middleware::{ApiError, OptionalAuth, SignedJson}; use rusqlite::{named_params, params, OptionalExtension, Row}; @@ -32,37 +33,54 @@ const TIMESTAMP_TOLERENCE: u64 = 90; #[macro_use] mod middleware; +mod config; mod utils; #[derive(Debug, clap::Parser)] -struct Cli { - /// Address to listen on. - #[arg(long)] - listen: String, +enum Cli { + /// Run the server with given configuration. + Serve { + /// The path to the configuration file. + #[arg(long, short)] + config: PathBuf, + }, - /// Path to the SQLite database. - #[arg(long)] - database: PathBuf, - - /// The global absolute URL prefix where this service is hosted. - /// It is for link generation and must not have trailing slash. - #[arg(long)] - base_url: String, + /// Validate the configuration file and exit. + Validate { + /// The path to the configuration file. + #[arg(long, short)] + config: PathBuf, + }, } fn main() -> Result<()> { tracing_subscriber::fmt::init(); let cli = ::parse(); - let db = rusqlite::Connection::open(&cli.database).context("failed to open database")?; - let st = AppState::init(&*cli.base_url, db).context("failed to initialize state")?; + fn parse_config(path: &std::path::Path) -> Result { + let src = std::fs::read_to_string(path)?; + let config = basic_toml::from_str::(&src)?; + config.validate()?; + Ok(config) + } - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .context("failed to initialize tokio runtime")? - .block_on(main_async(cli, st))?; - Ok(()) + match cli { + Cli::Serve { config } => { + let config = parse_config(&config)?; + let db = rusqlite::Connection::open(&config.database.path) + .context("failed to open database")?; + let st = AppState::init(config, db).context("failed to initialize state")?; + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .context("failed to initialize tokio runtime")? + .block_on(main_async(st)) + } + Cli::Validate { config } => { + parse_config(&config)?; + Ok(()) + } + } } // Locks must be grabbed in the field order. @@ -72,18 +90,15 @@ struct AppState { room_listeners: Mutex>>>, used_nonces: Mutex>, - base_url: Box, + config: Config, } impl AppState { - fn init(base_url: impl Into>, conn: rusqlite::Connection) -> Result { + fn init(config: Config, conn: rusqlite::Connection) -> Result { static INIT_SQL: &str = include_str!("../init.sql"); - let base_url = base_url.into(); - ensure!( - !base_url.ends_with('/'), - "base_url must not has trailing slash", - ); + // Should be validated by `Config`. + assert!(!config.server.base_url.ends_with('/')); conn.execute_batch(INIT_SQL) .context("failed to initialize database")?; @@ -92,7 +107,7 @@ impl AppState { room_listeners: Mutex::new(HashMap::new()), used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(TIMESTAMP_TOLERENCE))), - base_url, + config, }) } @@ -134,24 +149,25 @@ impl AppState { type ArcState = State>; -async fn main_async(opt: Cli, st: AppState) -> Result<()> { +async fn main_async(st: AppState) -> Result<()> { + let st = Arc::new(st); + let app = Router::new() .route("/room/create", post(room_create)) // NB. Sync with `feed_url` and `next_url` generation. .route("/room/:ruuid/feed.json", get(room_get_feed)) .route("/room/:ruuid/event", get(room_event)) .route("/room/:ruuid/item", get(room_get_item).post(room_post_item)) - .with_state(Arc::new(st)) + .with_state(st.clone()) // NB. This comes at last (outmost layer), so inner errors will still be wraped with // correct CORS headers. .layer(tower_http::limit::RequestBodyLimitLayer::new(MAX_BODY_LEN)) .layer(tower_http::cors::CorsLayer::permissive()); - let listener = tokio::net::TcpListener::bind(&opt.listen) + let listener = tokio::net::TcpListener::bind(&st.config.server.listen) .await .context("failed to listen on socket")?; - - tracing::info!("listening on {}", opt.listen); + tracing::info!("listening on {}", st.config.server.listen); let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); axum::serve(listener, app) @@ -296,7 +312,7 @@ async fn room_get_feed( }) .collect::>(); - let base_url = &st.base_url; + let base_url = &st.config.server.base_url; let feed_url = format!("{base_url}/room/{ruuid}/feed.json"); let next_url = (items.len() == PAGE_LEN).then(|| { let last_id = &items.last().expect("page size is not 0").id;