blahrs/blahd/src/bin/blahd.rs

137 lines
5.1 KiB
Rust

use std::future::IntoFuture;
use std::os::fd::{FromRawFd, OwnedFd};
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::{bail, Context, Result};
use blahd::config::{Config, ListenConfig};
use blahd::{AppState, Database};
use tokio::net::TcpListener;
use tokio::signal::unix::{signal, SignalKind};
/// Blah Chat Server
#[derive(Debug, clap::Parser)]
#[clap(about, version = env!("CFG_RELEASE"))]
enum Cli {
/// Run the server with given configuration.
Serve {
/// The path to the configuration file.
#[arg(long, short)]
config: PathBuf,
},
/// Validate the configuration file and exit.
Validate {
/// The path to the configuration file.
#[arg(long, short)]
config: PathBuf,
},
}
fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let cli = <Cli as clap::Parser>::parse();
fn parse_config(path: &std::path::Path) -> Result<Config> {
let src = std::fs::read_to_string(path)?;
let config = toml::from_str::<Config>(&src)?;
Ok(config)
}
match cli {
Cli::Serve { config } => {
let config = parse_config(&config)?;
let db = Database::open(&config.database).context("failed to open database")?;
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.context("failed to initialize tokio runtime")?
.block_on(main_serve(db, config))
}
Cli::Validate { config } => {
parse_config(&config)?;
Ok(())
}
}
}
async fn main_serve(db: Database, config: Config) -> Result<()> {
let st = AppState::new(db, config.server);
let mut listen_fds = sd_notify::listen_fds()
.context("failed to get fds from sd_listen_fds(3)")?
// SAFETY: `fd` is valid by sd_listen_fds(3) protocol.
.map(|fd| unsafe { Some(OwnedFd::from_raw_fd(fd)) })
.collect::<Vec<_>>();
let mut get_listener = move |name: &'static str, fd_idx: usize, listen: ListenConfig| {
let fd = listen_fds.get_mut(0).and_then(|opt| opt.take());
async move {
match listen {
ListenConfig::Address(addr) => {
tracing::info!("serve {name} on address {addr:?}");
TcpListener::bind(addr)
.await
.context("failed to listen on socket")
}
ListenConfig::Systemd(_) => {
use rustix::net::{getsockname, SocketAddrAny};
let fd = fd.with_context(|| format!("missing LISTEN_FDS[{fd_idx}]"))?;
let addr = getsockname(&fd).context("failed to getsockname")?;
match addr {
SocketAddrAny::V4(_) | SocketAddrAny::V6(_) => {
let listener = std::net::TcpListener::from(fd);
listener
.set_nonblocking(true)
.context("failed to set socket non-blocking")?;
let listener = TcpListener::from_std(listener)
.context("failed to register async socket")?;
tracing::info!(
"serve {name} on tcp socket {addr:?} from LISTEN_FDS[{fd_idx}"
);
Ok(listener)
}
// Unix socket support for axum is currently overly complex.
// WAIT: https://github.com/tokio-rs/axum/pull/2479
_ => bail!("unsupported socket type from LISTEN_FDS[{fd_idx}]: {addr:?}"),
}
}
}
}
};
let router = blahd::router(Arc::new(st));
#[cfg(not(feature = "prometheus"))]
anyhow::ensure!(
config.metric.is_none(),
"metric support is disabled at compile time",
);
#[cfg(feature = "prometheus")]
let router = if let Some(metric_config) = &config.metric {
let blahd::config::MetricConfig::Prometheus(prom_config) = metric_config;
let metric_listener = get_listener("metrics", 1, prom_config.listen.clone()).await?;
let (metric_router, recorder, upkeeper) = blahd::metrics_router(metric_config);
metrics::set_global_recorder(recorder).expect("only set once");
tokio::spawn(upkeeper);
tokio::spawn(axum::serve(metric_listener, metric_router).into_future());
router.layer(axum_metrics::MetricLayer::default())
} else {
router
};
let mut sigterm = signal(SignalKind::terminate()).context("failed to listen on SIGTERM")?;
let listener = get_listener("main", 0, config.listen.clone()).await?;
let service = axum::serve(listener, router)
.with_graceful_shutdown(async move {
sigterm.recv().await;
tracing::info!("received SIGTERM, shutting down gracefully");
})
.into_future();
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
service.await.context("failed to serve")?;
Ok(())
}