Use configuration file to simplify CLI

This commit is contained in:
oxalica 2024-08-31 00:53:53 -04:00
parent 4937502d4c
commit abdc32b51f
6 changed files with 123 additions and 35 deletions

2
.gitignore vendored
View file

@ -1,3 +1,5 @@
/target
*.sqlite*
*.key
config.toml

10
Cargo.lock generated
View file

@ -226,6 +226,15 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "basic-toml"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "823388e228f614e9558c6804262db37960ec8821856535f5c3f59913140558f8"
dependencies = [
"serde",
]
[[package]]
name = "bitflags"
version = "2.6.0"
@ -280,6 +289,7 @@ dependencies = [
"anyhow",
"axum",
"axum-extra",
"basic-toml",
"blah",
"clap",
"ed25519-dalek",

View file

@ -25,3 +25,4 @@ tracing-subscriber = "0.3"
uuid = { version = "1", features = ["v4"] }
blah = { path = "..", features = ["rusqlite"] }
basic-toml = "0.1.9"

13
blahd/config.example.toml Normal file
View file

@ -0,0 +1,13 @@
[database]
# The path to the main SQLite database.
# It will be created and initialized if not exist.
path = "/path/to/db.sqlite"
[server]
# The socket address to listen on.
listen = "localhost:8080"
# The global absolute URL prefix where this service is hosted.
# It is for link generation and must not have trailing slash.
base_url = "http://localhost:8080"

46
blahd/src/config.rs Normal file
View file

@ -0,0 +1,46 @@
use std::path::PathBuf;
use anyhow::{ensure, Result};
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Config {
pub database: DatabaseConfig,
pub server: ServerConfig,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct DatabaseConfig {
pub path: PathBuf,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ServerConfig {
pub listen: String,
pub base_url: String,
}
impl Config {
pub fn validate(&self) -> Result<()> {
ensure!(
!self.server.base_url.ends_with("/"),
"base_url must not have trailing slash",
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn example_config() {
let src = std::fs::read_to_string("config.example.toml").unwrap();
let config = basic_toml::from_str::<Config>(&src).unwrap();
config.validate().unwrap();
}
}

View file

@ -5,7 +5,7 @@ use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};
use anyhow::{ensure, Context, Result};
use anyhow::{Context, Result};
use axum::extract::{Path, Query, State};
use axum::http::{header, StatusCode};
use axum::response::{sse, IntoResponse};
@ -16,6 +16,7 @@ use blah::types::{
ChatItem, ChatPayload, CreateRoomPayload, MemberPermission, RoomAttrs, ServerPermission,
Signee, UserKey, WithSig,
};
use config::Config;
use ed25519_dalek::SIGNATURE_LENGTH;
use middleware::{ApiError, OptionalAuth, SignedJson};
use rusqlite::{named_params, params, OptionalExtension, Row};
@ -32,38 +33,55 @@ const TIMESTAMP_TOLERENCE: u64 = 90;
#[macro_use]
mod middleware;
mod config;
mod utils;
#[derive(Debug, clap::Parser)]
struct Cli {
/// Address to listen on.
#[arg(long)]
listen: String,
enum Cli {
/// Run the server with given configuration.
Serve {
/// The path to the configuration file.
#[arg(long, short)]
config: PathBuf,
},
/// Path to the SQLite database.
#[arg(long)]
database: PathBuf,
/// The global absolute URL prefix where this service is hosted.
/// It is for link generation and must not have trailing slash.
#[arg(long)]
base_url: String,
/// Validate the configuration file and exit.
Validate {
/// The path to the configuration file.
#[arg(long, short)]
config: PathBuf,
},
}
fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let cli = <Cli as clap::Parser>::parse();
let db = rusqlite::Connection::open(&cli.database).context("failed to open database")?;
let st = AppState::init(&*cli.base_url, db).context("failed to initialize state")?;
fn parse_config(path: &std::path::Path) -> Result<Config> {
let src = std::fs::read_to_string(path)?;
let config = basic_toml::from_str::<Config>(&src)?;
config.validate()?;
Ok(config)
}
match cli {
Cli::Serve { config } => {
let config = parse_config(&config)?;
let db = rusqlite::Connection::open(&config.database.path)
.context("failed to open database")?;
let st = AppState::init(config, db).context("failed to initialize state")?;
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.context("failed to initialize tokio runtime")?
.block_on(main_async(cli, st))?;
.block_on(main_async(st))
}
Cli::Validate { config } => {
parse_config(&config)?;
Ok(())
}
}
}
// Locks must be grabbed in the field order.
#[derive(Debug)]
@ -72,18 +90,15 @@ struct AppState {
room_listeners: Mutex<HashMap<u64, broadcast::Sender<Arc<ChatItem>>>>,
used_nonces: Mutex<ExpiringSet<u32>>,
base_url: Box<str>,
config: Config,
}
impl AppState {
fn init(base_url: impl Into<Box<str>>, conn: rusqlite::Connection) -> Result<Self> {
fn init(config: Config, conn: rusqlite::Connection) -> Result<Self> {
static INIT_SQL: &str = include_str!("../init.sql");
let base_url = base_url.into();
ensure!(
!base_url.ends_with('/'),
"base_url must not has trailing slash",
);
// Should be validated by `Config`.
assert!(!config.server.base_url.ends_with('/'));
conn.execute_batch(INIT_SQL)
.context("failed to initialize database")?;
@ -92,7 +107,7 @@ impl AppState {
room_listeners: Mutex::new(HashMap::new()),
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(TIMESTAMP_TOLERENCE))),
base_url,
config,
})
}
@ -134,24 +149,25 @@ impl AppState {
type ArcState = State<Arc<AppState>>;
async fn main_async(opt: Cli, st: AppState) -> Result<()> {
async fn main_async(st: AppState) -> Result<()> {
let st = Arc::new(st);
let app = Router::new()
.route("/room/create", post(room_create))
// NB. Sync with `feed_url` and `next_url` generation.
.route("/room/:ruuid/feed.json", get(room_get_feed))
.route("/room/:ruuid/event", get(room_event))
.route("/room/:ruuid/item", get(room_get_item).post(room_post_item))
.with_state(Arc::new(st))
.with_state(st.clone())
// NB. This comes at last (outmost layer), so inner errors will still be wraped with
// correct CORS headers.
.layer(tower_http::limit::RequestBodyLimitLayer::new(MAX_BODY_LEN))
.layer(tower_http::cors::CorsLayer::permissive());
let listener = tokio::net::TcpListener::bind(&opt.listen)
let listener = tokio::net::TcpListener::bind(&st.config.server.listen)
.await
.context("failed to listen on socket")?;
tracing::info!("listening on {}", opt.listen);
tracing::info!("listening on {}", st.config.server.listen);
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
axum::serve(listener, app)
@ -296,7 +312,7 @@ async fn room_get_feed(
})
.collect::<Vec<_>>();
let base_url = &st.base_url;
let base_url = &st.config.server.base_url;
let feed_url = format!("{base_url}/room/{ruuid}/feed.json");
let next_url = (items.len() == PAGE_LEN).then(|| {
let last_id = &items.last().expect("page size is not 0").id;