From 98b2df2fdf32225ca50b874c67bf2069bc40d648 Mon Sep 17 00:00:00 2001 From: oxalica Date: Wed, 11 Sep 2024 13:14:06 -0400 Subject: [PATCH] feat(blahd): impl socket activation integration This also allows arbitrary listening fd include UNIX domain sockets, via environment variables as sd_listen_fds(3). --- Cargo.lock | 34 ++++++++- blahd/Cargo.toml | 3 + blahd/config.example.toml | 11 ++- blahd/src/bin/blahd.rs | 34 +++++++-- blahd/src/config.rs | 3 +- blahd/tests/socket_activate.rs | 122 +++++++++++++++++++++++++++++++++ contrib/blahd.example.socket | 9 +++ flake.nix | 3 + 8 files changed, 206 insertions(+), 13 deletions(-) create mode 100644 blahd/tests/socket_activate.rs create mode 100644 contrib/blahd.example.socket diff --git a/Cargo.lock b/Cargo.lock index b8ab707..3138013 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -306,13 +306,16 @@ dependencies = [ "futures-util", "hex", "humantime", + "nix", "parking_lot", "rand", "reqwest", "rstest", "rusqlite", + "scopeguard", "sd-notify", "serde", + "serde-constant", "serde-inline-default", "serde_json", "serde_urlencoded", @@ -367,6 +370,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.38" @@ -1159,6 +1168,18 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags", + "cfg-if", + "cfg_aliases", + "libc", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -1558,9 +1579,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.34" +version = "0.38.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +checksum = "3f55e80d50763938498dd5ebb18647174e0c76dc38c5505294bb224624f30f36" dependencies = [ "bitflags", "errno", @@ -1686,6 +1707,15 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-constant" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "301d0de95fe56444b1ea1ee62363452b756fe715fb4eab505be16b906599b47e" +dependencies = [ + "serde", +] + [[package]] name = "serde-inline-default" version = "0.2.0" diff --git a/blahd/Cargo.toml b/blahd/Cargo.toml index aef1968..6f03720 100644 --- a/blahd/Cargo.toml +++ b/blahd/Cargo.toml @@ -16,6 +16,7 @@ parking_lot = "0.12" # Maybe no better performance, just that we hate poisoning. rusqlite = "0.32" sd-notify = "0.4" serde = { version = "1", features = ["derive"] } +serde-constant = "0.1" serde-inline-default = "0.2.0" serde_json = "1" serde_urlencoded = "0.7.1" @@ -30,9 +31,11 @@ url = { version = "2.5.2", features = ["serde"] } blah-types = { path = "../blah-types", features = ["rusqlite"] } [dev-dependencies] +nix = { version = "0.29.0", features = ["fs", "process", "signal"] } rand = "0.8.5" reqwest = { version = "0.12.7", features = ["json"] } rstest = { version = "0.22.0", default-features = false } +scopeguard = "1.2.0" [lints] workspace = true diff --git a/blahd/config.example.toml b/blahd/config.example.toml index 78ef411..2242e67 100644 --- a/blahd/config.example.toml +++ b/blahd/config.example.toml @@ -1,6 +1,6 @@ # The example configuration file, required options are documented as # `(Required)`, other options are optional and the example value given here is -# the default value. +# the default value, or example values for commented lines. [database] # If enabled, a in-memory non-persistent database is used instead. Options @@ -16,11 +16,16 @@ path = "/var/lib/blahd/db.sqlite" # Note that parent directory will never be created and must already exist. create = true +# Listener socket configuration. (Required) +# There must be exact one option under this section being set. [listen] -# (Required) -# The local address to listen on. +# Listen on an address. address = "localhost:8080" +# Use systemd socket activation mechanism to get listener fd from envvars. +# See also sd_listen_fds(3) and systemd.socket(5). +#systemd = true + [server] # (Required) # The global absolute URL prefix where this service is hosted. diff --git a/blahd/src/bin/blahd.rs b/blahd/src/bin/blahd.rs index 46a9bb0..9a5e1a8 100644 --- a/blahd/src/bin/blahd.rs +++ b/blahd/src/bin/blahd.rs @@ -1,7 +1,9 @@ +use std::net::TcpListener; +use std::os::fd::FromRawFd; use std::path::PathBuf; use std::sync::Arc; -use anyhow::{Context, Result}; +use anyhow::{anyhow, Context, Result}; use blahd::config::{Config, ListenConfig}; use blahd::{AppState, Database}; @@ -53,14 +55,32 @@ fn main() -> Result<()> { } async fn main_serve(db: Database, config: Config) -> Result<()> { - let listener = match &config.listen { - ListenConfig::Address(addr) => tokio::net::TcpListener::bind(addr) - .await - .context("failed to listen on socket")?, - }; let st = AppState::new(db, config.server); - tracing::info!("listening on {:?}", config.listen); + let listener = match &config.listen { + ListenConfig::Address(addr) => { + tracing::info!("listening on {addr:?}"); + tokio::net::TcpListener::bind(addr) + .await + .context("failed to listen on socket")? + } + ListenConfig::Systemd(_) => { + tracing::info!("listening on fd from environment"); + let [fd] = sd_notify::listen_fds() + .context("failed to get fds from sd_listen_fds(3)")? + .collect::>() + .try_into() + .map_err(|_| anyhow!("more than one fds available from sd_listen_fds(3)"))?; + // SAFETY: `fd` is valid by sd_listen_fds(3) protocol. + let listener = unsafe { TcpListener::from_raw_fd(fd) }; + listener + .set_nonblocking(true) + .context("failed to set socket non-blocking")?; + tokio::net::TcpListener::from_std(listener) + .context("failed to register async socket")? + } + }; + let router = blahd::router(Arc::new(st)); let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); diff --git a/blahd/src/config.rs b/blahd/src/config.rs index 2e4786f..0c6aef4 100644 --- a/blahd/src/config.rs +++ b/blahd/src/config.rs @@ -4,6 +4,7 @@ use std::time::Duration; use anyhow::{ensure, Result}; use serde::{Deserialize, Deserializer, Serialize}; +use serde_constant::ConstBool; use serde_inline_default::serde_inline_default; use url::Url; @@ -31,7 +32,7 @@ pub struct DatabaseConfig { #[serde(rename_all = "snake_case")] pub enum ListenConfig { Address(String), - // TODO: Unix socket. + Systemd(ConstBool), } #[serde_inline_default] diff --git a/blahd/tests/socket_activate.rs b/blahd/tests/socket_activate.rs new file mode 100644 index 0000000..ce843bb --- /dev/null +++ b/blahd/tests/socket_activate.rs @@ -0,0 +1,122 @@ +use std::ffi::CString; +use std::fs::File; +use std::io::{Seek, Write}; +use std::net::TcpListener; +use std::os::fd::{AsFd, AsRawFd}; +use std::process::abort; +use std::ptr::null; +use std::time::Duration; + +use nix::fcntl::{fcntl, FcntlArg, FdFlag}; +use nix::libc::execve; +use nix::sys::memfd::{memfd_create, MemFdCreateFlag}; +use nix::sys::signal::{kill, Signal}; +use nix::sys::wait::{waitpid, WaitStatus}; +use nix::unistd::{alarm, dup2, fork, getpid, ForkResult}; +use tokio::io::stderr; + +const TIMEOUT_SEC: u32 = 1; + +const SERVER_EXE_PATH: &str = env!("CARGO_BIN_EXE_blahd"); + +const CONFIG: &str = r#" +[database] +in_memory = true +[listen] +systemd = true +[server] +base_url = "http://example.com" +"#; + +#[test] +fn socket_activate() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let local_port = listener.local_addr().unwrap().port(); + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + // Remove `FD_CLOEXEC` since we want to send it to the child. + fn remove_cloexec(fd: &impl AsFd) { + let mut flags = + FdFlag::from_bits_retain(fcntl(fd.as_fd().as_raw_fd(), FcntlArg::F_GETFD).unwrap()); + flags -= FdFlag::FD_CLOEXEC; + fcntl(fd.as_fd().as_raw_fd(), FcntlArg::F_SETFD(flags)).unwrap(); + } + remove_cloexec(&listener); + remove_cloexec(&stderr()); + + let server_exe_c = CString::new(SERVER_EXE_PATH).unwrap(); + + // Intentionally no FD_CLOEXEC. + let mut memfd = File::from(memfd_create(c"test-config", MemFdCreateFlag::empty()).unwrap()); + memfd.write_all(CONFIG.as_bytes()).unwrap(); + memfd.rewind().unwrap(); + + // Unfortunately we have to deal with raw `fork(2)` here, because no library supports passing + // child PID in environment variables for child. + // SAFETY: Between `fork()` and `execve()`, all syscalls are async-signal-safe: + // no memory allocation, no panic unwinding (always abort). + match unsafe { fork().unwrap() } { + ForkResult::Child => { + // Ideally, we want `std::panic::always_abort()`, which is unstable yet. + // WAIT: https://github.com/rust-lang/rust/issues/84438 + scopeguard::defer!(abort()); + + // Don't leave an orphan process if something does wrong. + alarm::set(TIMEOUT_SEC); + + // Ignore all errors here to stay safe, and lazy. + let _ = dup2(2, 1); // stdout <- stderr + let _ = dup2(memfd.as_raw_fd(), 0); // stdin <- config memfd + let _ = dup2(listener.as_raw_fd(), 3); // listener fd + let args = [ + c"blahd".as_ptr(), + c"serve".as_ptr(), + c"-c".as_ptr(), + c"/proc/self/fd/0".as_ptr(), + null(), + ]; + let mut buf = [0u8; 64]; + let _ = write!(&mut buf[..], "LISTEN_PID={}\0", getpid().as_raw()); + let envs = [c"LISTEN_FDS=1".as_ptr(), buf.as_ptr().cast(), null()]; + // NB. Do raw libc call, not the wrapper fn that does allocation inside. + // SAFETY: Valid NULL-terminated array of NULL-terminated strings. + unsafe { + execve(server_exe_c.as_ptr(), args.as_ptr(), envs.as_ptr()); + // If exec fail, the drop guard will abort the process anyway. Do nothing. + } + } + ForkResult::Parent { child } => { + { + scopeguard::defer! { + let _ = kill(child, Signal::SIGKILL); + } + + let resp = rt.block_on(async { + let fut = async { + reqwest::get(format!("http://127.0.0.1:{local_port}/room?filter=public")) + .await + .unwrap() + .error_for_status() + .unwrap() + .text() + .await + .unwrap() + }; + tokio::time::timeout(Duration::from_secs(TIMEOUT_SEC.into()), fut) + .await + .unwrap() + }); + assert_eq!(resp, r#"{"rooms":[]}"#); + } + + let st = waitpid(child, None).unwrap(); + assert!( + matches!(st, WaitStatus::Signaled(_, Signal::SIGKILL, _)), + "unexpected exit status {st:?}", + ); + } + } +} diff --git a/contrib/blahd.example.socket b/contrib/blahd.example.socket new file mode 100644 index 0000000..ea99170 --- /dev/null +++ b/contrib/blahd.example.socket @@ -0,0 +1,9 @@ +[Unit] +Description=Blah Chat Server Socket + +[Socket] +ListenStream=[::]:8080 +BindIPv6Only=both + +[Install] +WantedBy=sockets.target diff --git a/flake.nix b/flake.nix index 678f28e..2f9dcf3 100644 --- a/flake.nix +++ b/flake.nix @@ -65,6 +65,9 @@ rec { "--package=blahctl" ]; + # Intentionally omit the socket unit. It is trivial but + # highly configuration-specific. Users who want to use it almost + # always need customization. postInstall = '' mkdir -p $out/etc/systemd/system substitute ./contrib/blahd.example.service $out/etc/systemd/system/blahd.service \