Use configuration file to simplify CLI

This commit is contained in:
oxalica 2024-08-31 00:53:53 -04:00
parent 4937502d4c
commit abdc32b51f
6 changed files with 123 additions and 35 deletions

2
.gitignore vendored
View file

@ -1,3 +1,5 @@
/target /target
*.sqlite* *.sqlite*
*.key *.key
config.toml

10
Cargo.lock generated
View file

@ -226,6 +226,15 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "basic-toml"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "823388e228f614e9558c6804262db37960ec8821856535f5c3f59913140558f8"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "2.6.0" version = "2.6.0"
@ -280,6 +289,7 @@ dependencies = [
"anyhow", "anyhow",
"axum", "axum",
"axum-extra", "axum-extra",
"basic-toml",
"blah", "blah",
"clap", "clap",
"ed25519-dalek", "ed25519-dalek",

View file

@ -25,3 +25,4 @@ tracing-subscriber = "0.3"
uuid = { version = "1", features = ["v4"] } uuid = { version = "1", features = ["v4"] }
blah = { path = "..", features = ["rusqlite"] } blah = { path = "..", features = ["rusqlite"] }
basic-toml = "0.1.9"

13
blahd/config.example.toml Normal file
View file

@ -0,0 +1,13 @@
[database]
# The path to the main SQLite database.
# It will be created and initialized if not exist.
path = "/path/to/db.sqlite"
[server]
# The socket address to listen on.
listen = "localhost:8080"
# The global absolute URL prefix where this service is hosted.
# It is for link generation and must not have trailing slash.
base_url = "http://localhost:8080"

46
blahd/src/config.rs Normal file
View file

@ -0,0 +1,46 @@
use std::path::PathBuf;
use anyhow::{ensure, Result};
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Config {
pub database: DatabaseConfig,
pub server: ServerConfig,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct DatabaseConfig {
pub path: PathBuf,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ServerConfig {
pub listen: String,
pub base_url: String,
}
impl Config {
pub fn validate(&self) -> Result<()> {
ensure!(
!self.server.base_url.ends_with("/"),
"base_url must not have trailing slash",
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn example_config() {
let src = std::fs::read_to_string("config.example.toml").unwrap();
let config = basic_toml::from_str::<Config>(&src).unwrap();
config.validate().unwrap();
}
}

View file

@ -5,7 +5,7 @@ use std::path::PathBuf;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime}; use std::time::{Duration, SystemTime};
use anyhow::{ensure, Context, Result}; use anyhow::{Context, Result};
use axum::extract::{Path, Query, State}; use axum::extract::{Path, Query, State};
use axum::http::{header, StatusCode}; use axum::http::{header, StatusCode};
use axum::response::{sse, IntoResponse}; use axum::response::{sse, IntoResponse};
@ -16,6 +16,7 @@ use blah::types::{
ChatItem, ChatPayload, CreateRoomPayload, MemberPermission, RoomAttrs, ServerPermission, ChatItem, ChatPayload, CreateRoomPayload, MemberPermission, RoomAttrs, ServerPermission,
Signee, UserKey, WithSig, Signee, UserKey, WithSig,
}; };
use config::Config;
use ed25519_dalek::SIGNATURE_LENGTH; use ed25519_dalek::SIGNATURE_LENGTH;
use middleware::{ApiError, OptionalAuth, SignedJson}; use middleware::{ApiError, OptionalAuth, SignedJson};
use rusqlite::{named_params, params, OptionalExtension, Row}; use rusqlite::{named_params, params, OptionalExtension, Row};
@ -32,37 +33,54 @@ const TIMESTAMP_TOLERENCE: u64 = 90;
#[macro_use] #[macro_use]
mod middleware; mod middleware;
mod config;
mod utils; mod utils;
#[derive(Debug, clap::Parser)] #[derive(Debug, clap::Parser)]
struct Cli { enum Cli {
/// Address to listen on. /// Run the server with given configuration.
#[arg(long)] Serve {
listen: String, /// The path to the configuration file.
#[arg(long, short)]
config: PathBuf,
},
/// Path to the SQLite database. /// Validate the configuration file and exit.
#[arg(long)] Validate {
database: PathBuf, /// The path to the configuration file.
#[arg(long, short)]
/// The global absolute URL prefix where this service is hosted. config: PathBuf,
/// It is for link generation and must not have trailing slash. },
#[arg(long)]
base_url: String,
} }
fn main() -> Result<()> { fn main() -> Result<()> {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let cli = <Cli as clap::Parser>::parse(); let cli = <Cli as clap::Parser>::parse();
let db = rusqlite::Connection::open(&cli.database).context("failed to open database")?; fn parse_config(path: &std::path::Path) -> Result<Config> {
let st = AppState::init(&*cli.base_url, db).context("failed to initialize state")?; let src = std::fs::read_to_string(path)?;
let config = basic_toml::from_str::<Config>(&src)?;
config.validate()?;
Ok(config)
}
tokio::runtime::Builder::new_multi_thread() match cli {
.enable_all() Cli::Serve { config } => {
.build() let config = parse_config(&config)?;
.context("failed to initialize tokio runtime")? let db = rusqlite::Connection::open(&config.database.path)
.block_on(main_async(cli, st))?; .context("failed to open database")?;
Ok(()) let st = AppState::init(config, db).context("failed to initialize state")?;
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.context("failed to initialize tokio runtime")?
.block_on(main_async(st))
}
Cli::Validate { config } => {
parse_config(&config)?;
Ok(())
}
}
} }
// Locks must be grabbed in the field order. // Locks must be grabbed in the field order.
@ -72,18 +90,15 @@ struct AppState {
room_listeners: Mutex<HashMap<u64, broadcast::Sender<Arc<ChatItem>>>>, room_listeners: Mutex<HashMap<u64, broadcast::Sender<Arc<ChatItem>>>>,
used_nonces: Mutex<ExpiringSet<u32>>, used_nonces: Mutex<ExpiringSet<u32>>,
base_url: Box<str>, config: Config,
} }
impl AppState { impl AppState {
fn init(base_url: impl Into<Box<str>>, conn: rusqlite::Connection) -> Result<Self> { fn init(config: Config, conn: rusqlite::Connection) -> Result<Self> {
static INIT_SQL: &str = include_str!("../init.sql"); static INIT_SQL: &str = include_str!("../init.sql");
let base_url = base_url.into(); // Should be validated by `Config`.
ensure!( assert!(!config.server.base_url.ends_with('/'));
!base_url.ends_with('/'),
"base_url must not has trailing slash",
);
conn.execute_batch(INIT_SQL) conn.execute_batch(INIT_SQL)
.context("failed to initialize database")?; .context("failed to initialize database")?;
@ -92,7 +107,7 @@ impl AppState {
room_listeners: Mutex::new(HashMap::new()), room_listeners: Mutex::new(HashMap::new()),
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(TIMESTAMP_TOLERENCE))), used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(TIMESTAMP_TOLERENCE))),
base_url, config,
}) })
} }
@ -134,24 +149,25 @@ impl AppState {
type ArcState = State<Arc<AppState>>; type ArcState = State<Arc<AppState>>;
async fn main_async(opt: Cli, st: AppState) -> Result<()> { async fn main_async(st: AppState) -> Result<()> {
let st = Arc::new(st);
let app = Router::new() let app = Router::new()
.route("/room/create", post(room_create)) .route("/room/create", post(room_create))
// NB. Sync with `feed_url` and `next_url` generation. // NB. Sync with `feed_url` and `next_url` generation.
.route("/room/:ruuid/feed.json", get(room_get_feed)) .route("/room/:ruuid/feed.json", get(room_get_feed))
.route("/room/:ruuid/event", get(room_event)) .route("/room/:ruuid/event", get(room_event))
.route("/room/:ruuid/item", get(room_get_item).post(room_post_item)) .route("/room/:ruuid/item", get(room_get_item).post(room_post_item))
.with_state(Arc::new(st)) .with_state(st.clone())
// NB. This comes at last (outmost layer), so inner errors will still be wraped with // NB. This comes at last (outmost layer), so inner errors will still be wraped with
// correct CORS headers. // correct CORS headers.
.layer(tower_http::limit::RequestBodyLimitLayer::new(MAX_BODY_LEN)) .layer(tower_http::limit::RequestBodyLimitLayer::new(MAX_BODY_LEN))
.layer(tower_http::cors::CorsLayer::permissive()); .layer(tower_http::cors::CorsLayer::permissive());
let listener = tokio::net::TcpListener::bind(&opt.listen) let listener = tokio::net::TcpListener::bind(&st.config.server.listen)
.await .await
.context("failed to listen on socket")?; .context("failed to listen on socket")?;
tracing::info!("listening on {}", st.config.server.listen);
tracing::info!("listening on {}", opt.listen);
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
axum::serve(listener, app) axum::serve(listener, app)
@ -296,7 +312,7 @@ async fn room_get_feed(
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let base_url = &st.base_url; let base_url = &st.config.server.base_url;
let feed_url = format!("{base_url}/room/{ruuid}/feed.json"); let feed_url = format!("{base_url}/room/{ruuid}/feed.json");
let next_url = (items.len() == PAGE_LEN).then(|| { let next_url = (items.len() == PAGE_LEN).then(|| {
let last_id = &items.last().expect("page size is not 0").id; let last_id = &items.last().expect("page size is not 0").id;