fix(blahd): reject UNIX domain socket for now

It's too complex to bother with for the current `axum` API. Let's wait
for axum 0.8 release.

Ref: https://github.com/tokio-rs/axum/pull/2479
This commit is contained in:
oxalica 2024-09-19 09:04:50 -04:00
parent ec7f428519
commit b955d32099
5 changed files with 68 additions and 29 deletions

2
Cargo.lock generated
View file

@ -324,6 +324,7 @@ dependencies = [
"reqwest", "reqwest",
"rstest", "rstest",
"rusqlite", "rusqlite",
"rustix",
"scopeguard", "scopeguard",
"sd-notify", "sd-notify",
"serde", "serde",
@ -333,6 +334,7 @@ dependencies = [
"serde_json", "serde_json",
"serde_urlencoded", "serde_urlencoded",
"sha2", "sha2",
"tempfile",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"toml", "toml",

View file

@ -21,6 +21,7 @@ parking_lot = "0.12" # Maybe no better performance, just that we hate poisoning.
rand = "0.8" rand = "0.8"
reqwest = "0.12" reqwest = "0.12"
rusqlite = "0.32" rusqlite = "0.32"
rustix = { version = "0.38", features = ["net"] }
sd-notify = "0.4" sd-notify = "0.4"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde-constant = "0.1" serde-constant = "0.1"
@ -45,6 +46,7 @@ nix = { version = "0.29.0", features = ["fs", "process", "signal"] }
reqwest = { version = "0.12.7", features = ["json"] } reqwest = { version = "0.12.7", features = ["json"] }
rstest = { version = "0.22.0", default-features = false } rstest = { version = "0.22.0", default-features = false }
scopeguard = "1.2.0" scopeguard = "1.2.0"
tempfile = "3.12.0"
[lints] [lints]
workspace = true workspace = true

View file

@ -24,6 +24,7 @@ address = "localhost:8080"
# Use systemd socket activation mechanism to get listener fd from envvars. # Use systemd socket activation mechanism to get listener fd from envvars.
# See also sd_listen_fds(3) and systemd.socket(5). # See also sd_listen_fds(3) and systemd.socket(5).
# NB. Currently only TCP sockets are supported. UNIX domain socket is TODO.
#systemd = true #systemd = true
[server] [server]

View file

@ -1,9 +1,8 @@
use std::net::TcpListener; use std::os::fd::{FromRawFd, OwnedFd};
use std::os::fd::FromRawFd;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, bail, Context, Result};
use blahd::config::{Config, ListenConfig}; use blahd::config::{Config, ListenConfig};
use blahd::{AppState, Database}; use blahd::{AppState, Database};
@ -56,30 +55,44 @@ fn main() -> Result<()> {
async fn main_serve(db: Database, config: Config) -> Result<()> { async fn main_serve(db: Database, config: Config) -> Result<()> {
let st = AppState::new(db, config.server); let st = AppState::new(db, config.server);
let listener = match &config.listen { let (listener_display, listener) = match &config.listen {
ListenConfig::Address(addr) => { ListenConfig::Address(addr) => (
tracing::info!("listening on {addr:?}"); format!("address {addr:?}"),
tokio::net::TcpListener::bind(addr) tokio::net::TcpListener::bind(addr)
.await .await
.context("failed to listen on socket")? .context("failed to listen on socket")?,
} ),
ListenConfig::Systemd(_) => { ListenConfig::Systemd(_) => {
tracing::info!("listening on fd from environment"); use rustix::net::{getsockname, SocketAddrAny};
let [fd] = sd_notify::listen_fds() let [fd] = sd_notify::listen_fds()
.context("failed to get fds from sd_listen_fds(3)")? .context("failed to get fds from sd_listen_fds(3)")?
.collect::<Vec<_>>() .collect::<Vec<_>>()
.try_into() .try_into()
.map_err(|_| anyhow!("more than one fds available from sd_listen_fds(3)"))?; .map_err(|_| anyhow!("expecting exactly one fd from LISTEN_FDS"))?;
// SAFETY: `fd` is valid by sd_listen_fds(3) protocol. // SAFETY: `fd` is valid by sd_listen_fds(3) protocol.
let listener = unsafe { TcpListener::from_raw_fd(fd) }; let listener = unsafe { OwnedFd::from_raw_fd(fd) };
let addr = getsockname(&listener).context("failed to getsockname")?;
match addr {
SocketAddrAny::V4(_) | SocketAddrAny::V6(_) => {
let listener = std::net::TcpListener::from(listener);
listener listener
.set_nonblocking(true) .set_nonblocking(true)
.context("failed to set socket non-blocking")?; .context("failed to set socket non-blocking")?;
tokio::net::TcpListener::from_std(listener) let listener = tokio::net::TcpListener::from_std(listener)
.context("failed to register async socket")? .context("failed to register async socket")?;
(format!("tcp socket {addr:?} from LISTEN_FDS"), listener)
}
// Unix socket support for axum is currently overly complex.
// WAIT: https://github.com/tokio-rs/axum/pull/2479
_ => bail!("unsupported socket type from LISTEN_FDS: {addr:?}"),
}
} }
}; };
tracing::info!("listening on {listener_display}");
let router = blahd::router(Arc::new(st)); let router = blahd::router(Arc::new(st));
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);

View file

@ -2,7 +2,8 @@ use std::ffi::CString;
use std::fs::File; use std::fs::File;
use std::io::{Seek, Write}; use std::io::{Seek, Write};
use std::net::TcpListener; use std::net::TcpListener;
use std::os::fd::{AsFd, AsRawFd}; use std::os::fd::{AsFd, AsRawFd, OwnedFd};
use std::os::unix::net::UnixListener;
use std::process::abort; use std::process::abort;
use std::ptr::null; use std::ptr::null;
use std::time::Duration; use std::time::Duration;
@ -13,6 +14,7 @@ use nix::sys::memfd::{memfd_create, MemFdCreateFlag};
use nix::sys::signal::{kill, Signal}; use nix::sys::signal::{kill, Signal};
use nix::sys::wait::{waitpid, WaitStatus}; use nix::sys::wait::{waitpid, WaitStatus};
use nix::unistd::{alarm, dup2, fork, getpid, ForkResult}; use nix::unistd::{alarm, dup2, fork, getpid, ForkResult};
use rstest::rstest;
use tokio::io::stderr; use tokio::io::stderr;
const TIMEOUT_SEC: u32 = 1; const TIMEOUT_SEC: u32 = 1;
@ -28,10 +30,22 @@ systemd = true
base_url = "http://example.com" base_url = "http://example.com"
"#; "#;
#[test] #[rstest]
fn socket_activate() { #[case::tcp(false)]
#[case::unix(true)]
fn socket_activate(#[case] unix_socket: bool) {
let socket_dir;
let (local_port, listener) = if unix_socket {
socket_dir = tempfile::tempdir().unwrap();
let listener = UnixListener::bind(socket_dir.path().join("socket")).unwrap();
// Port is unused.
(0, OwnedFd::from(listener))
} else {
let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let local_port = listener.local_addr().unwrap().port(); let local_port = listener.local_addr().unwrap().port();
(local_port, OwnedFd::from(listener))
};
let rt = tokio::runtime::Builder::new_current_thread() let rt = tokio::runtime::Builder::new_current_thread()
.enable_all() .enable_all()
.build() .build()
@ -89,11 +103,11 @@ fn socket_activate() {
} }
} }
ForkResult::Parent { child } => { ForkResult::Parent { child } => {
{ let guard = scopeguard::guard((), |()| {
scopeguard::defer! {
let _ = kill(child, Signal::SIGKILL); let _ = kill(child, Signal::SIGKILL);
} });
if !unix_socket {
let resp = rt.block_on(async { let resp = rt.block_on(async {
let url = format!("http://127.0.0.1:{local_port}/_blah/room?filter=public"); let url = format!("http://127.0.0.1:{local_port}/_blah/room?filter=public");
let fut = async { let fut = async {
@ -111,9 +125,15 @@ fn socket_activate() {
.unwrap() .unwrap()
}); });
assert_eq!(resp, r#"{"rooms":[]}"#); assert_eq!(resp, r#"{"rooms":[]}"#);
// Trigger the killer.
drop(guard);
} }
let st = waitpid(child, None).unwrap(); let st = waitpid(child, None).unwrap();
if unix_socket {
// Fail with unsupported error.
assert!(matches!(st, WaitStatus::Exited(_, 1)));
} else {
assert!( assert!(
matches!(st, WaitStatus::Signaled(_, Signal::SIGKILL, _)), matches!(st, WaitStatus::Signaled(_, Signal::SIGKILL, _)),
"unexpected exit status {st:?}", "unexpected exit status {st:?}",
@ -121,3 +141,4 @@ fn socket_activate() {
} }
} }
} }
}