feat(blahd): support unix domain socket and rewrite tests

This commit is contained in:
oxalica 2025-03-19 01:57:37 -04:00
parent 2e0a878d56
commit eb8c56e688
4 changed files with 215 additions and 180 deletions

View file

@ -1,24 +1,22 @@
use std::ffi::CString;
use std::fs::File;
use std::io::{Seek, Write};
#![expect(clippy::unwrap_used, reason = "allowed in tests")]
use std::env;
use std::mem::ManuallyDrop;
use std::net::TcpListener;
use std::os::fd::{AsFd, AsRawFd, OwnedFd};
use std::os::unix::net::UnixListener;
use std::process::abort;
use std::ptr::null;
use std::os::fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd};
use std::os::unix::{net::UnixListener, process::CommandExt};
use std::process::{Command, ExitCode, Stdio};
use std::time::Duration;
use nix::fcntl::{FcntlArg, FdFlag, fcntl};
use nix::libc::execve;
use nix::sys::memfd::{MemFdCreateFlag, memfd_create};
use nix::sys::signal::{Signal, kill};
use nix::sys::wait::{WaitStatus, waitpid};
use nix::unistd::{ForkResult, alarm, dup2, fork, getpid};
use rstest::rstest;
use tokio::io::stderr;
const TIMEOUT_SEC: u32 = 1;
use futures_util::future::Either;
use http_body_util::BodyExt;
use hyper::StatusCode;
use libtest_mimic::{Arguments, Trial};
use rustix::io::{FdFlags, fcntl_getfd, fcntl_setfd};
use rustix::process::{Pid, Signal};
use tokio::io::{AsyncRead, AsyncWrite};
const EXEC_HELPER_SENTINEL: &str = "--__exec_helper";
const WAIT_TIMEOUT: Duration = Duration::from_secs(3);
const SERVER_EXE_PATH: &str = env!("CARGO_BIN_EXE_blahd");
const CONFIG: &str = r#"
@ -30,130 +28,144 @@ systemd = true
base_url = "http://example.com"
"#;
#[rstest]
#[case::tcp(false)]
#[case::unix(true)]
fn socket_activate(#[case] unix_socket: bool) {
let socket_dir;
fn main() -> ExitCode {
if env::args()
.nth(1)
.is_some_and(|s| s == EXEC_HELPER_SENTINEL)
{
exec_helper();
}
let args = Arguments::from_args();
let tests = vec![
Trial::test("tcp", || test_socket_activate(false)),
Trial::test("unix", || test_socket_activate(true)),
];
libtest_mimic::run(&args, tests).exit_code()
}
fn exec_helper() -> ! {
// Don't leave an orphan process if something goes wrong.
unsafe { libc::alarm(WAIT_TIMEOUT.as_secs() as u32 + 1) };
let pid = rustix::process::getpid().as_raw_nonzero();
let err = Command::new(SERVER_EXE_PATH)
.args(env::args().skip(2))
.stdin(Stdio::inherit())
.stdout(Stdio::inherit())
.stderr(Stdio::inherit())
.env("LISTEN_PID", pid.to_string())
.env("LISTEN_FDS", "1")
.exec();
panic!("failed to exec: {err}");
}
fn test_socket_activate(unix_socket: bool) -> Result<(), libtest_mimic::Failed> {
let temp_dir = tempfile::tempdir().unwrap();
let config_path = temp_dir.path().join("config.toml");
std::fs::write(&config_path, CONFIG).unwrap();
let socket_path = temp_dir.path().join("socket");
let (local_port, listener) = if unix_socket {
socket_dir = tempfile::tempdir().unwrap();
let listener = UnixListener::bind(socket_dir.path().join("socket")).unwrap();
let listener = UnixListener::bind(&socket_path).unwrap();
// Port is unused.
(0, OwnedFd::from(listener))
(0, Either::Left(listener))
} else {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let local_port = listener.local_addr().unwrap().port();
(local_port, OwnedFd::from(listener))
(local_port, Either::Right(listener))
};
let listener_fd = match &listener {
Either::Left(s) => s.as_fd(),
Either::Right(s) => s.as_fd(),
};
// Remove CLOEXEC.
{
let flag = fcntl_getfd(listener_fd).unwrap();
fcntl_setfd(listener_fd, flag - FdFlags::CLOEXEC).unwrap();
}
let mut cmd = Command::new(env::current_exe().unwrap());
cmd.arg(EXEC_HELPER_SENTINEL)
.args(["serve", "-c"])
.arg(&config_path)
.stdin(Stdio::null())
.stdout(Stdio::inherit())
.stderr(Stdio::inherit());
// SAFETY: dup2 is a syscall thus is async-signal safe.
// `listener_fd` is alive during the pre_exec hook.
// Fd 3 is created and not closed.
unsafe {
let listener_fd = std::mem::transmute::<BorrowedFd<'_>, BorrowedFd<'static>>(listener_fd);
cmd.pre_exec(move || {
let mut tgt_fd = ManuallyDrop::new(OwnedFd::from_raw_fd(3));
rustix::io::dup2(listener_fd, &mut tgt_fd)?;
Ok(())
})
};
let mut child = cmd.spawn().unwrap();
let uri = "/_blah/room?filter=public".to_owned();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
// Remove `FD_CLOEXEC` since we want to send it to the child.
fn remove_cloexec(fd: &impl AsFd) {
let mut flags =
FdFlag::from_bits_retain(fcntl(fd.as_fd().as_raw_fd(), FcntlArg::F_GETFD).unwrap());
flags -= FdFlag::FD_CLOEXEC;
fcntl(fd.as_fd().as_raw_fd(), FcntlArg::F_SETFD(flags)).unwrap();
}
remove_cloexec(&listener);
remove_cloexec(&stderr());
let server_exe_c = CString::new(SERVER_EXE_PATH).unwrap();
// Intentionally no FD_CLOEXEC.
let mut memfd = File::from(memfd_create(c"test-config", MemFdCreateFlag::empty()).unwrap());
memfd.write_all(CONFIG.as_bytes()).unwrap();
memfd.rewind().unwrap();
// Inherit environment variables.
let envs = std::env::vars()
.filter(|(name, _)| !name.starts_with("LISTEN_"))
.map(|(name, value)| CString::new(format!("{name}={value}")).unwrap())
.collect::<Vec<_>>();
let mut env_ptrs = envs
.iter()
.map(|s| s.as_ptr())
.chain([c"LISTEN_FDS=1".as_ptr(), null(), null()])
.collect::<Vec<_>>();
// Unfortunately we have to deal with raw `fork(2)` here, because no library supports passing
// child PID in environment variables for child.
// SAFETY: Between `fork()` and `execve()`, all syscalls are async-signal-safe:
// no memory allocation, no panic unwinding (always abort).
match unsafe { fork().unwrap() } {
ForkResult::Child => {
// Ideally, we want `std::panic::always_abort()`, which is unstable yet.
// WAIT: https://github.com/rust-lang/rust/issues/84438
scopeguard::defer!(abort());
// Don't leave an orphan process if something does wrong.
alarm::set(TIMEOUT_SEC);
// Ignore all errors here to stay safe, and lazy.
let _ = dup2(2, 1); // stdout <- stderr
let _ = dup2(memfd.as_raw_fd(), 0); // stdin <- config memfd
let _ = dup2(listener.as_raw_fd(), 3); // listener fd
let args = [
c"blahd".as_ptr(),
c"serve".as_ptr(),
c"-c".as_ptr(),
c"/proc/self/fd/0".as_ptr(),
null(),
];
let mut buf = [0u8; 64];
let _ = write!(&mut buf[..], "LISTEN_PID={}\0", getpid().as_raw());
let pos = env_ptrs.len() - 2;
env_ptrs[pos] = buf.as_ptr().cast();
// NB. Do raw libc call, not the wrapper fn that does allocation inside.
// SAFETY: Valid NULL-terminated array of NULL-terminated strings.
unsafe {
execve(server_exe_c.as_ptr(), args.as_ptr(), env_ptrs.as_ptr());
// If exec fail, the drop guard will abort the process anyway. Do nothing.
}
}
ForkResult::Parent { child } => {
let guard = scopeguard::guard((), |()| {
let _ = kill(child, Signal::SIGTERM);
});
if !unix_socket {
let resp = rt.block_on(async {
let url = format!("http://127.0.0.1:{local_port}/_blah/room?filter=public");
let fut = async {
reqwest::get(url)
.await
.unwrap()
.error_for_status()
.unwrap()
.text()
.await
.unwrap()
};
tokio::time::timeout(Duration::from_secs(TIMEOUT_SEC.into()), fut)
let (st, resp) = rt.block_on(async {
let fut = async move {
match &listener {
Either::Left(_) => {
let sock = tokio::net::UnixStream::connect(&socket_path).await.unwrap();
send_get_request(sock, uri).await
}
Either::Right(_) => {
let sock = tokio::net::TcpStream::connect(("127.0.0.1", local_port))
.await
.unwrap()
});
assert_eq!(resp, r#"{"rooms":[]}"#);
// Trigger the killer.
drop(guard);
.unwrap();
send_get_request(sock, uri).await
}
}
};
tokio::time::timeout(WAIT_TIMEOUT, fut)
.await
.unwrap()
.unwrap()
});
assert_eq!(st, StatusCode::OK);
assert_eq!(resp, r#"{"rooms":[]}"#);
let st = waitpid(child, None).unwrap();
let expect_exit_code = if unix_socket {
// Fail with unsupported error.
1
} else {
// Graceful shutdown.
0
};
assert!(
matches!(st, WaitStatus::Exited(_, code) if code == expect_exit_code),
"unexpected exit status {st:?}",
);
}
}
rustix::process::kill_process(Pid::from_child(&child), Signal::TERM).unwrap();
let st = child.wait().unwrap();
assert!(st.success(), "unexpected exit status: {st:?}");
Ok(())
}
// Ref: <https://github.com/seanmonstar/reqwest/issues/39#issuecomment-778716774>
async fn send_get_request(
stream: impl AsyncRead + AsyncWrite + Unpin + Send + 'static,
uri: String,
) -> anyhow::Result<(StatusCode, String)> {
let stream = hyper_util::rt::TokioIo::new(stream);
let (mut request_sender, connection) = hyper::client::conn::http1::Builder::new()
.handshake(stream)
.await?;
tokio::task::spawn(connection);
let request = hyper::Request::builder()
.method("GET")
.uri(uri)
.header("Host", "example.com")
.body(String::new())?;
let response = request_sender.send_request(request).await?;
let status = response.status();
let body = response.into_body().collect().await?.to_bytes();
let body = String::from_utf8(body.to_vec())?;
Ok((status, body))
}