use std::future::IntoFuture; use std::os::fd::{FromRawFd, OwnedFd}; use std::path::PathBuf; use std::sync::Arc; use anyhow::{bail, Context, Result}; use blahd::config::{Config, ListenConfig}; use blahd::{AppState, Database}; use tokio::net::TcpListener; use tokio::signal::unix::{signal, SignalKind}; /// Blah Chat Server #[derive(Debug, clap::Parser)] #[clap(about, version = env!("CFG_RELEASE"))] enum Cli { /// Run the server with given configuration. Serve { /// The path to the configuration file. #[arg(long, short)] config: PathBuf, }, /// Validate the configuration file and exit. Validate { /// The path to the configuration file. #[arg(long, short)] config: PathBuf, }, } fn main() -> Result<()> { tracing_subscriber::fmt::init(); let cli = ::parse(); fn parse_config(path: &std::path::Path) -> Result { let src = std::fs::read_to_string(path)?; let config = toml::from_str::(&src)?; Ok(config) } match cli { Cli::Serve { config } => { let config = parse_config(&config)?; let db = Database::open(&config.database).context("failed to open database")?; tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .context("failed to initialize tokio runtime")? .block_on(main_serve(db, config)) } Cli::Validate { config } => { parse_config(&config)?; Ok(()) } } } async fn main_serve(db: Database, config: Config) -> Result<()> { let st = AppState::new(db, config.server); let mut listen_fds = sd_notify::listen_fds() .context("failed to get fds from sd_listen_fds(3)")? // SAFETY: `fd` is valid by sd_listen_fds(3) protocol. .map(|fd| unsafe { Some(OwnedFd::from_raw_fd(fd)) }) .collect::>(); let mut get_listener = move |name: &'static str, fd_idx: usize, listen: ListenConfig| { let fd = listen_fds.get_mut(0).and_then(|opt| opt.take()); async move { match listen { ListenConfig::Address(addr) => { tracing::info!("serve {name} on address {addr:?}"); TcpListener::bind(addr) .await .context("failed to listen on socket") } ListenConfig::Systemd(_) => { use rustix::net::{getsockname, SocketAddrAny}; let fd = fd.with_context(|| format!("missing LISTEN_FDS[{fd_idx}]"))?; let addr = getsockname(&fd).context("failed to getsockname")?; match addr { SocketAddrAny::V4(_) | SocketAddrAny::V6(_) => { let listener = std::net::TcpListener::from(fd); listener .set_nonblocking(true) .context("failed to set socket non-blocking")?; let listener = TcpListener::from_std(listener) .context("failed to register async socket")?; tracing::info!( "serve {name} on tcp socket {addr:?} from LISTEN_FDS[{fd_idx}" ); Ok(listener) } // Unix socket support for axum is currently overly complex. // WAIT: https://github.com/tokio-rs/axum/pull/2479 _ => bail!("unsupported socket type from LISTEN_FDS[{fd_idx}]: {addr:?}"), } } } } }; let router = blahd::router(Arc::new(st)); #[cfg(not(feature = "prometheus"))] anyhow::ensure!( config.metric.is_none(), "metric support is disabled at compile time", ); #[cfg(feature = "prometheus")] let router = if let Some(metric_config) = &config.metric { let blahd::config::MetricConfig::Prometheus(prom_config) = metric_config; let metric_listener = get_listener("metrics", 1, prom_config.listen.clone()).await?; let (metric_router, recorder, upkeeper) = blahd::metrics_router(metric_config); metrics::set_global_recorder(recorder).expect("only set once"); tokio::spawn(upkeeper); tokio::spawn(axum::serve(metric_listener, metric_router).into_future()); router.layer(axum_metrics::MetricLayer::default()) } else { router }; let mut sigterm = signal(SignalKind::terminate()).context("failed to listen on SIGTERM")?; let listener = get_listener("main", 0, config.listen.clone()).await?; let service = axum::serve(listener, router) .with_graceful_shutdown(async move { sigterm.recv().await; tracing::info!("received SIGTERM, shutting down gracefully"); }) .into_future(); let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); service.await.context("failed to serve")?; Ok(()) }