fix(blahd): reject UNIX domain socket for now

It's too complex to bother with for the current `axum` API. Let's wait
for axum 0.8 release.

Ref: https://github.com/tokio-rs/axum/pull/2479
This commit is contained in:
oxalica 2024-09-19 09:04:50 -04:00
parent ec7f428519
commit b955d32099
5 changed files with 68 additions and 29 deletions

View file

@ -1,9 +1,8 @@
use std::net::TcpListener;
use std::os::fd::FromRawFd;
use std::os::fd::{FromRawFd, OwnedFd};
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::{anyhow, Context, Result};
use anyhow::{anyhow, bail, Context, Result};
use blahd::config::{Config, ListenConfig};
use blahd::{AppState, Database};
@ -56,30 +55,44 @@ fn main() -> Result<()> {
async fn main_serve(db: Database, config: Config) -> Result<()> {
let st = AppState::new(db, config.server);
let listener = match &config.listen {
ListenConfig::Address(addr) => {
tracing::info!("listening on {addr:?}");
let (listener_display, listener) = match &config.listen {
ListenConfig::Address(addr) => (
format!("address {addr:?}"),
tokio::net::TcpListener::bind(addr)
.await
.context("failed to listen on socket")?
}
.context("failed to listen on socket")?,
),
ListenConfig::Systemd(_) => {
tracing::info!("listening on fd from environment");
use rustix::net::{getsockname, SocketAddrAny};
let [fd] = sd_notify::listen_fds()
.context("failed to get fds from sd_listen_fds(3)")?
.collect::<Vec<_>>()
.try_into()
.map_err(|_| anyhow!("more than one fds available from sd_listen_fds(3)"))?;
.map_err(|_| anyhow!("expecting exactly one fd from LISTEN_FDS"))?;
// SAFETY: `fd` is valid by sd_listen_fds(3) protocol.
let listener = unsafe { TcpListener::from_raw_fd(fd) };
listener
.set_nonblocking(true)
.context("failed to set socket non-blocking")?;
tokio::net::TcpListener::from_std(listener)
.context("failed to register async socket")?
let listener = unsafe { OwnedFd::from_raw_fd(fd) };
let addr = getsockname(&listener).context("failed to getsockname")?;
match addr {
SocketAddrAny::V4(_) | SocketAddrAny::V6(_) => {
let listener = std::net::TcpListener::from(listener);
listener
.set_nonblocking(true)
.context("failed to set socket non-blocking")?;
let listener = tokio::net::TcpListener::from_std(listener)
.context("failed to register async socket")?;
(format!("tcp socket {addr:?} from LISTEN_FDS"), listener)
}
// Unix socket support for axum is currently overly complex.
// WAIT: https://github.com/tokio-rs/axum/pull/2479
_ => bail!("unsupported socket type from LISTEN_FDS: {addr:?}"),
}
}
};
tracing::info!("listening on {listener_display}");
let router = blahd::router(Arc::new(st));
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);