diff --git a/Cargo.lock b/Cargo.lock index 4241f7b..a5c077e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -266,8 +266,11 @@ dependencies = [ "html-escape", "http-body-util", "humantime", + "hyper", + "hyper-util", + "libc", + "libtest-mimic", "mock_instant", - "nix", "parking_lot", "paste", "rand 0.8.5", @@ -275,7 +278,6 @@ dependencies = [ "rstest", "rusqlite", "rustix", - "scopeguard", "sd-notify", "serde", "serde-constant", @@ -337,12 +339,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" -[[package]] -name = "cfg_aliases" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" - [[package]] name = "ciborium" version = "0.2.2" @@ -657,6 +653,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "escape8259" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5692dd7b5a1978a5aeb0ce83b7655c58ca8efdcb79d21036ea249da95afec2c6" + [[package]] name = "expect-test" version = "1.5.1" @@ -1246,6 +1248,18 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "libtest-mimic" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5297962ef19edda4ce33aaa484386e0a5b3d7f2f4e037cbeee00503ef6b29d33" +dependencies = [ + "anstream", + "anstyle", + "clap", + "escape8259", +] + [[package]] name = "linux-raw-sys" version = "0.9.3" @@ -1341,18 +1355,6 @@ dependencies = [ "tempfile", ] -[[package]] -name = "nix" -version = "0.29.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" -dependencies = [ - "bitflags", - "cfg-if", - "cfg_aliases", - "libc", -] - [[package]] name = "nu-ansi-term" version = "0.46.0" diff --git a/blahd/Cargo.toml b/blahd/Cargo.toml index 70cf154..79c083a 100644 --- a/blahd/Cargo.toml +++ b/blahd/Cargo.toml @@ -52,12 +52,19 @@ url = "2" [dev-dependencies] expect-test = "1" -nix = { version = "0.29", features = ["fs", "process", "signal"] } +hyper = { version = "1", features = ["client", "http1"] } +hyper-util = { version = "0.1", features = ["tokio"] } +libc = "0.2" +libtest-mimic = "0.8" reqwest = { version = "0.12", features = ["json"] } rstest = { version = "0.24", default-features = false } -scopeguard = "1" +rustix = { version = "1", features = ["process"] } tempfile = "3" tokio-tungstenite = "0.26" +[[test]] +name = "socket_activate" +harness = false + [lints] workspace = true diff --git a/blahd/src/bin/blahd.rs b/blahd/src/bin/blahd.rs index dfdb285..2be16c9 100644 --- a/blahd/src/bin/blahd.rs +++ b/blahd/src/bin/blahd.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use anyhow::{Context, Result, anyhow, bail}; use blahd::config::{Config, ListenConfig}; use blahd::{AppState, Database}; +use futures_util::future::Either; use tokio::signal::unix::{SignalKind, signal}; /// Blah Chat Server @@ -57,15 +58,15 @@ fn main() -> Result<()> { async fn main_serve(db: Database, config: Config) -> Result<()> { let st = AppState::new(db, config.server); - let (listener_display, listener) = match &config.listen { - ListenConfig::Address(addr) => ( - format!("address {addr:?}"), - tokio::net::TcpListener::bind(addr) + let (listener, addr_display) = match &config.listen { + ListenConfig::Address(addr) => { + let tcp = tokio::net::TcpListener::bind(addr) .await - .context("failed to listen on socket")?, - ), + .context("failed to listen on socket")?; + (Either::Left(tcp), format!("tcp address {addr:?}")) + } ListenConfig::Systemd(_) => { - use rustix::net::{SocketAddr, getsockname}; + use rustix::net::{AddressFamily, getsockname}; let [fd] = sd_notify::listen_fds() .context("failed to get fds from sd_listen_fds(3)")? @@ -73,37 +74,50 @@ async fn main_serve(db: Database, config: Config) -> Result<()> { .try_into() .map_err(|_| anyhow!("expecting exactly one fd from LISTEN_FDS"))?; // SAFETY: `fd` is valid by sd_listen_fds(3) protocol. - let listener = unsafe { OwnedFd::from_raw_fd(fd) }; + let fd = unsafe { OwnedFd::from_raw_fd(fd) }; - let addr = getsockname(&listener).context("failed to getsockname")?; - if let Ok(addr) = SocketAddr::try_from(addr.clone()) { - let listener = std::net::TcpListener::from(listener); - listener - .set_nonblocking(true) - .context("failed to set socket non-blocking")?; - let listener = tokio::net::TcpListener::from_std(listener) - .context("failed to register async socket")?; - (format!("tcp socket {addr:?} from LISTEN_FDS"), listener) - } else { - // Unix socket support for axum is currently overly complex. - // WAIT: https://github.com/tokio-rs/axum/pull/2479 - bail!("unsupported socket type from LISTEN_FDS: {addr:?}"); - } + let addr = getsockname(&fd).context("failed to getsockname")?; + let listener = match addr.address_family() { + AddressFamily::INET | AddressFamily::INET6 => { + let listener = std::net::TcpListener::from(fd); + listener + .set_nonblocking(true) + .context("failed to set socket non-blocking")?; + let listener = tokio::net::TcpListener::from_std(listener) + .context("failed to register async socket")?; + Either::Left(listener) + } + AddressFamily::UNIX => { + let uds = std::os::unix::net::UnixListener::from(fd); + uds.set_nonblocking(true) + .context("failed to set socket non-blocking")?; + let uds = tokio::net::UnixListener::from_std(uds) + .context("failed to register async socket")?; + Either::Right(uds) + } + _ => bail!("unsupported socket type from LISTEN_FDS: {addr:?}"), + }; + (listener, format!("socket {addr:?} from LISTEN_FDS")) } }; - tracing::info!("listening on {listener_display}"); - let router = blahd::router(Arc::new(st)); let mut sigterm = signal(SignalKind::terminate()).context("failed to listen on SIGTERM")?; - let service = axum::serve(listener, router) - .with_graceful_shutdown(async move { - sigterm.recv().await; - tracing::info!("received SIGTERM, shutting down gracefully"); - }) - .into_future(); + let shutdown = async move { + sigterm.recv().await; + tracing::info!("received SIGTERM, shutting down gracefully"); + }; + let service = match listener { + Either::Left(tcp) => axum::serve(tcp, router) + .with_graceful_shutdown(shutdown) + .into_future(), + Either::Right(uds) => axum::serve(uds, router) + .with_graceful_shutdown(shutdown) + .into_future(), + }; + tracing::info!("serving on {addr_display}"); let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); service.await.context("failed to serve")?; Ok(()) diff --git a/blahd/tests/socket_activate.rs b/blahd/tests/socket_activate.rs index 2ff665b..691ab61 100644 --- a/blahd/tests/socket_activate.rs +++ b/blahd/tests/socket_activate.rs @@ -1,24 +1,22 @@ -use std::ffi::CString; -use std::fs::File; -use std::io::{Seek, Write}; +#![expect(clippy::unwrap_used, reason = "allowed in tests")] +use std::env; +use std::mem::ManuallyDrop; use std::net::TcpListener; -use std::os::fd::{AsFd, AsRawFd, OwnedFd}; -use std::os::unix::net::UnixListener; -use std::process::abort; -use std::ptr::null; +use std::os::fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd}; +use std::os::unix::{net::UnixListener, process::CommandExt}; +use std::process::{Command, ExitCode, Stdio}; use std::time::Duration; -use nix::fcntl::{FcntlArg, FdFlag, fcntl}; -use nix::libc::execve; -use nix::sys::memfd::{MemFdCreateFlag, memfd_create}; -use nix::sys::signal::{Signal, kill}; -use nix::sys::wait::{WaitStatus, waitpid}; -use nix::unistd::{ForkResult, alarm, dup2, fork, getpid}; -use rstest::rstest; -use tokio::io::stderr; - -const TIMEOUT_SEC: u32 = 1; +use futures_util::future::Either; +use http_body_util::BodyExt; +use hyper::StatusCode; +use libtest_mimic::{Arguments, Trial}; +use rustix::io::{FdFlags, fcntl_getfd, fcntl_setfd}; +use rustix::process::{Pid, Signal}; +use tokio::io::{AsyncRead, AsyncWrite}; +const EXEC_HELPER_SENTINEL: &str = "--__exec_helper"; +const WAIT_TIMEOUT: Duration = Duration::from_secs(3); const SERVER_EXE_PATH: &str = env!("CARGO_BIN_EXE_blahd"); const CONFIG: &str = r#" @@ -30,130 +28,144 @@ systemd = true base_url = "http://example.com" "#; -#[rstest] -#[case::tcp(false)] -#[case::unix(true)] -fn socket_activate(#[case] unix_socket: bool) { - let socket_dir; +fn main() -> ExitCode { + if env::args() + .nth(1) + .is_some_and(|s| s == EXEC_HELPER_SENTINEL) + { + exec_helper(); + } + + let args = Arguments::from_args(); + + let tests = vec![ + Trial::test("tcp", || test_socket_activate(false)), + Trial::test("unix", || test_socket_activate(true)), + ]; + + libtest_mimic::run(&args, tests).exit_code() +} + +fn exec_helper() -> ! { + // Don't leave an orphan process if something goes wrong. + unsafe { libc::alarm(WAIT_TIMEOUT.as_secs() as u32 + 1) }; + + let pid = rustix::process::getpid().as_raw_nonzero(); + let err = Command::new(SERVER_EXE_PATH) + .args(env::args().skip(2)) + .stdin(Stdio::inherit()) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()) + .env("LISTEN_PID", pid.to_string()) + .env("LISTEN_FDS", "1") + .exec(); + panic!("failed to exec: {err}"); +} + +fn test_socket_activate(unix_socket: bool) -> Result<(), libtest_mimic::Failed> { + let temp_dir = tempfile::tempdir().unwrap(); + + let config_path = temp_dir.path().join("config.toml"); + std::fs::write(&config_path, CONFIG).unwrap(); + + let socket_path = temp_dir.path().join("socket"); let (local_port, listener) = if unix_socket { - socket_dir = tempfile::tempdir().unwrap(); - let listener = UnixListener::bind(socket_dir.path().join("socket")).unwrap(); + let listener = UnixListener::bind(&socket_path).unwrap(); // Port is unused. - (0, OwnedFd::from(listener)) + (0, Either::Left(listener)) } else { let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let local_port = listener.local_addr().unwrap().port(); - (local_port, OwnedFd::from(listener)) + (local_port, Either::Right(listener)) }; + let listener_fd = match &listener { + Either::Left(s) => s.as_fd(), + Either::Right(s) => s.as_fd(), + }; + + // Remove CLOEXEC. + { + let flag = fcntl_getfd(listener_fd).unwrap(); + fcntl_setfd(listener_fd, flag - FdFlags::CLOEXEC).unwrap(); + } + + let mut cmd = Command::new(env::current_exe().unwrap()); + cmd.arg(EXEC_HELPER_SENTINEL) + .args(["serve", "-c"]) + .arg(&config_path) + .stdin(Stdio::null()) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()); + // SAFETY: dup2 is a syscall thus is async-signal safe. + // `listener_fd` is alive during the pre_exec hook. + // Fd 3 is created and not closed. + unsafe { + let listener_fd = std::mem::transmute::, BorrowedFd<'static>>(listener_fd); + cmd.pre_exec(move || { + let mut tgt_fd = ManuallyDrop::new(OwnedFd::from_raw_fd(3)); + rustix::io::dup2(listener_fd, &mut tgt_fd)?; + Ok(()) + }) + }; + let mut child = cmd.spawn().unwrap(); + + let uri = "/_blah/room?filter=public".to_owned(); let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build() .unwrap(); - // Remove `FD_CLOEXEC` since we want to send it to the child. - fn remove_cloexec(fd: &impl AsFd) { - let mut flags = - FdFlag::from_bits_retain(fcntl(fd.as_fd().as_raw_fd(), FcntlArg::F_GETFD).unwrap()); - flags -= FdFlag::FD_CLOEXEC; - fcntl(fd.as_fd().as_raw_fd(), FcntlArg::F_SETFD(flags)).unwrap(); - } - remove_cloexec(&listener); - remove_cloexec(&stderr()); - - let server_exe_c = CString::new(SERVER_EXE_PATH).unwrap(); - - // Intentionally no FD_CLOEXEC. - let mut memfd = File::from(memfd_create(c"test-config", MemFdCreateFlag::empty()).unwrap()); - memfd.write_all(CONFIG.as_bytes()).unwrap(); - memfd.rewind().unwrap(); - - // Inherit environment variables. - let envs = std::env::vars() - .filter(|(name, _)| !name.starts_with("LISTEN_")) - .map(|(name, value)| CString::new(format!("{name}={value}")).unwrap()) - .collect::>(); - let mut env_ptrs = envs - .iter() - .map(|s| s.as_ptr()) - .chain([c"LISTEN_FDS=1".as_ptr(), null(), null()]) - .collect::>(); - - // Unfortunately we have to deal with raw `fork(2)` here, because no library supports passing - // child PID in environment variables for child. - // SAFETY: Between `fork()` and `execve()`, all syscalls are async-signal-safe: - // no memory allocation, no panic unwinding (always abort). - match unsafe { fork().unwrap() } { - ForkResult::Child => { - // Ideally, we want `std::panic::always_abort()`, which is unstable yet. - // WAIT: https://github.com/rust-lang/rust/issues/84438 - scopeguard::defer!(abort()); - - // Don't leave an orphan process if something does wrong. - alarm::set(TIMEOUT_SEC); - - // Ignore all errors here to stay safe, and lazy. - let _ = dup2(2, 1); // stdout <- stderr - let _ = dup2(memfd.as_raw_fd(), 0); // stdin <- config memfd - let _ = dup2(listener.as_raw_fd(), 3); // listener fd - let args = [ - c"blahd".as_ptr(), - c"serve".as_ptr(), - c"-c".as_ptr(), - c"/proc/self/fd/0".as_ptr(), - null(), - ]; - let mut buf = [0u8; 64]; - let _ = write!(&mut buf[..], "LISTEN_PID={}\0", getpid().as_raw()); - let pos = env_ptrs.len() - 2; - env_ptrs[pos] = buf.as_ptr().cast(); - - // NB. Do raw libc call, not the wrapper fn that does allocation inside. - // SAFETY: Valid NULL-terminated array of NULL-terminated strings. - unsafe { - execve(server_exe_c.as_ptr(), args.as_ptr(), env_ptrs.as_ptr()); - // If exec fail, the drop guard will abort the process anyway. Do nothing. - } - } - ForkResult::Parent { child } => { - let guard = scopeguard::guard((), |()| { - let _ = kill(child, Signal::SIGTERM); - }); - - if !unix_socket { - let resp = rt.block_on(async { - let url = format!("http://127.0.0.1:{local_port}/_blah/room?filter=public"); - let fut = async { - reqwest::get(url) - .await - .unwrap() - .error_for_status() - .unwrap() - .text() - .await - .unwrap() - }; - tokio::time::timeout(Duration::from_secs(TIMEOUT_SEC.into()), fut) + let (st, resp) = rt.block_on(async { + let fut = async move { + match &listener { + Either::Left(_) => { + let sock = tokio::net::UnixStream::connect(&socket_path).await.unwrap(); + send_get_request(sock, uri).await + } + Either::Right(_) => { + let sock = tokio::net::TcpStream::connect(("127.0.0.1", local_port)) .await - .unwrap() - }); - assert_eq!(resp, r#"{"rooms":[]}"#); - // Trigger the killer. - drop(guard); + .unwrap(); + send_get_request(sock, uri).await + } } + }; + tokio::time::timeout(WAIT_TIMEOUT, fut) + .await + .unwrap() + .unwrap() + }); + assert_eq!(st, StatusCode::OK); + assert_eq!(resp, r#"{"rooms":[]}"#); - let st = waitpid(child, None).unwrap(); - let expect_exit_code = if unix_socket { - // Fail with unsupported error. - 1 - } else { - // Graceful shutdown. - 0 - }; - assert!( - matches!(st, WaitStatus::Exited(_, code) if code == expect_exit_code), - "unexpected exit status {st:?}", - ); - } - } + rustix::process::kill_process(Pid::from_child(&child), Signal::TERM).unwrap(); + let st = child.wait().unwrap(); + assert!(st.success(), "unexpected exit status: {st:?}"); + + Ok(()) +} + +// Ref: +async fn send_get_request( + stream: impl AsyncRead + AsyncWrite + Unpin + Send + 'static, + uri: String, +) -> anyhow::Result<(StatusCode, String)> { + let stream = hyper_util::rt::TokioIo::new(stream); + let (mut request_sender, connection) = hyper::client::conn::http1::Builder::new() + .handshake(stream) + .await?; + tokio::task::spawn(connection); + + let request = hyper::Request::builder() + .method("GET") + .uri(uri) + .header("Host", "example.com") + .body(String::new())?; + + let response = request_sender.send_request(request).await?; + let status = response.status(); + let body = response.into_body().collect().await?.to_bytes(); + let body = String::from_utf8(body.to_vec())?; + Ok((status, body)) }