feat(blahd): support unix domain socket and rewrite tests

This commit is contained in:
oxalica 2025-03-19 01:57:37 -04:00
parent 2e0a878d56
commit eb8c56e688
4 changed files with 215 additions and 180 deletions

View file

@ -6,6 +6,7 @@ use std::sync::Arc;
use anyhow::{Context, Result, anyhow, bail};
use blahd::config::{Config, ListenConfig};
use blahd::{AppState, Database};
use futures_util::future::Either;
use tokio::signal::unix::{SignalKind, signal};
/// Blah Chat Server
@ -57,15 +58,15 @@ fn main() -> Result<()> {
async fn main_serve(db: Database, config: Config) -> Result<()> {
let st = AppState::new(db, config.server);
let (listener_display, listener) = match &config.listen {
ListenConfig::Address(addr) => (
format!("address {addr:?}"),
tokio::net::TcpListener::bind(addr)
let (listener, addr_display) = match &config.listen {
ListenConfig::Address(addr) => {
let tcp = tokio::net::TcpListener::bind(addr)
.await
.context("failed to listen on socket")?,
),
.context("failed to listen on socket")?;
(Either::Left(tcp), format!("tcp address {addr:?}"))
}
ListenConfig::Systemd(_) => {
use rustix::net::{SocketAddr, getsockname};
use rustix::net::{AddressFamily, getsockname};
let [fd] = sd_notify::listen_fds()
.context("failed to get fds from sd_listen_fds(3)")?
@ -73,37 +74,50 @@ async fn main_serve(db: Database, config: Config) -> Result<()> {
.try_into()
.map_err(|_| anyhow!("expecting exactly one fd from LISTEN_FDS"))?;
// SAFETY: `fd` is valid by sd_listen_fds(3) protocol.
let listener = unsafe { OwnedFd::from_raw_fd(fd) };
let fd = unsafe { OwnedFd::from_raw_fd(fd) };
let addr = getsockname(&listener).context("failed to getsockname")?;
if let Ok(addr) = SocketAddr::try_from(addr.clone()) {
let listener = std::net::TcpListener::from(listener);
listener
.set_nonblocking(true)
.context("failed to set socket non-blocking")?;
let listener = tokio::net::TcpListener::from_std(listener)
.context("failed to register async socket")?;
(format!("tcp socket {addr:?} from LISTEN_FDS"), listener)
} else {
// Unix socket support for axum is currently overly complex.
// WAIT: https://github.com/tokio-rs/axum/pull/2479
bail!("unsupported socket type from LISTEN_FDS: {addr:?}");
}
let addr = getsockname(&fd).context("failed to getsockname")?;
let listener = match addr.address_family() {
AddressFamily::INET | AddressFamily::INET6 => {
let listener = std::net::TcpListener::from(fd);
listener
.set_nonblocking(true)
.context("failed to set socket non-blocking")?;
let listener = tokio::net::TcpListener::from_std(listener)
.context("failed to register async socket")?;
Either::Left(listener)
}
AddressFamily::UNIX => {
let uds = std::os::unix::net::UnixListener::from(fd);
uds.set_nonblocking(true)
.context("failed to set socket non-blocking")?;
let uds = tokio::net::UnixListener::from_std(uds)
.context("failed to register async socket")?;
Either::Right(uds)
}
_ => bail!("unsupported socket type from LISTEN_FDS: {addr:?}"),
};
(listener, format!("socket {addr:?} from LISTEN_FDS"))
}
};
tracing::info!("listening on {listener_display}");
let router = blahd::router(Arc::new(st));
let mut sigterm = signal(SignalKind::terminate()).context("failed to listen on SIGTERM")?;
let service = axum::serve(listener, router)
.with_graceful_shutdown(async move {
sigterm.recv().await;
tracing::info!("received SIGTERM, shutting down gracefully");
})
.into_future();
let shutdown = async move {
sigterm.recv().await;
tracing::info!("received SIGTERM, shutting down gracefully");
};
let service = match listener {
Either::Left(tcp) => axum::serve(tcp, router)
.with_graceful_shutdown(shutdown)
.into_future(),
Either::Right(uds) => axum::serve(uds, router)
.with_graceful_shutdown(shutdown)
.into_future(),
};
tracing::info!("serving on {addr_display}");
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
service.await.context("failed to serve")?;
Ok(())