feat(blahd): impl socket activation integration

This also allows arbitrary listening fd include UNIX domain sockets, via
environment variables as sd_listen_fds(3).
This commit is contained in:
oxalica 2024-09-11 13:14:06 -04:00
parent 87c8acd5b2
commit 98b2df2fdf
8 changed files with 206 additions and 13 deletions

34
Cargo.lock generated
View file

@ -306,13 +306,16 @@ dependencies = [
"futures-util",
"hex",
"humantime",
"nix",
"parking_lot",
"rand",
"reqwest",
"rstest",
"rusqlite",
"scopeguard",
"sd-notify",
"serde",
"serde-constant",
"serde-inline-default",
"serde_json",
"serde_urlencoded",
@ -367,6 +370,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "chrono"
version = "0.4.38"
@ -1159,6 +1168,18 @@ dependencies = [
"tempfile",
]
[[package]]
name = "nix"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46"
dependencies = [
"bitflags",
"cfg-if",
"cfg_aliases",
"libc",
]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
@ -1558,9 +1579,9 @@ dependencies = [
[[package]]
name = "rustix"
version = "0.38.34"
version = "0.38.36"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f"
checksum = "3f55e80d50763938498dd5ebb18647174e0c76dc38c5505294bb224624f30f36"
dependencies = [
"bitflags",
"errno",
@ -1686,6 +1707,15 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "serde-constant"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "301d0de95fe56444b1ea1ee62363452b756fe715fb4eab505be16b906599b47e"
dependencies = [
"serde",
]
[[package]]
name = "serde-inline-default"
version = "0.2.0"

View file

@ -16,6 +16,7 @@ parking_lot = "0.12" # Maybe no better performance, just that we hate poisoning.
rusqlite = "0.32"
sd-notify = "0.4"
serde = { version = "1", features = ["derive"] }
serde-constant = "0.1"
serde-inline-default = "0.2.0"
serde_json = "1"
serde_urlencoded = "0.7.1"
@ -30,9 +31,11 @@ url = { version = "2.5.2", features = ["serde"] }
blah-types = { path = "../blah-types", features = ["rusqlite"] }
[dev-dependencies]
nix = { version = "0.29.0", features = ["fs", "process", "signal"] }
rand = "0.8.5"
reqwest = { version = "0.12.7", features = ["json"] }
rstest = { version = "0.22.0", default-features = false }
scopeguard = "1.2.0"
[lints]
workspace = true

View file

@ -1,6 +1,6 @@
# The example configuration file, required options are documented as
# `(Required)`, other options are optional and the example value given here is
# the default value.
# the default value, or example values for commented lines.
[database]
# If enabled, a in-memory non-persistent database is used instead. Options
@ -16,11 +16,16 @@ path = "/var/lib/blahd/db.sqlite"
# Note that parent directory will never be created and must already exist.
create = true
# Listener socket configuration. (Required)
# There must be exact one option under this section being set.
[listen]
# (Required)
# The local address to listen on.
# Listen on an address.
address = "localhost:8080"
# Use systemd socket activation mechanism to get listener fd from envvars.
# See also sd_listen_fds(3) and systemd.socket(5).
#systemd = true
[server]
# (Required)
# The global absolute URL prefix where this service is hosted.

View file

@ -1,7 +1,9 @@
use std::net::TcpListener;
use std::os::fd::FromRawFd;
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::{Context, Result};
use anyhow::{anyhow, Context, Result};
use blahd::config::{Config, ListenConfig};
use blahd::{AppState, Database};
@ -53,14 +55,32 @@ fn main() -> Result<()> {
}
async fn main_serve(db: Database, config: Config) -> Result<()> {
let listener = match &config.listen {
ListenConfig::Address(addr) => tokio::net::TcpListener::bind(addr)
.await
.context("failed to listen on socket")?,
};
let st = AppState::new(db, config.server);
tracing::info!("listening on {:?}", config.listen);
let listener = match &config.listen {
ListenConfig::Address(addr) => {
tracing::info!("listening on {addr:?}");
tokio::net::TcpListener::bind(addr)
.await
.context("failed to listen on socket")?
}
ListenConfig::Systemd(_) => {
tracing::info!("listening on fd from environment");
let [fd] = sd_notify::listen_fds()
.context("failed to get fds from sd_listen_fds(3)")?
.collect::<Vec<_>>()
.try_into()
.map_err(|_| anyhow!("more than one fds available from sd_listen_fds(3)"))?;
// SAFETY: `fd` is valid by sd_listen_fds(3) protocol.
let listener = unsafe { TcpListener::from_raw_fd(fd) };
listener
.set_nonblocking(true)
.context("failed to set socket non-blocking")?;
tokio::net::TcpListener::from_std(listener)
.context("failed to register async socket")?
}
};
let router = blahd::router(Arc::new(st));
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);

View file

@ -4,6 +4,7 @@ use std::time::Duration;
use anyhow::{ensure, Result};
use serde::{Deserialize, Deserializer, Serialize};
use serde_constant::ConstBool;
use serde_inline_default::serde_inline_default;
use url::Url;
@ -31,7 +32,7 @@ pub struct DatabaseConfig {
#[serde(rename_all = "snake_case")]
pub enum ListenConfig {
Address(String),
// TODO: Unix socket.
Systemd(ConstBool<true>),
}
#[serde_inline_default]

View file

@ -0,0 +1,122 @@
use std::ffi::CString;
use std::fs::File;
use std::io::{Seek, Write};
use std::net::TcpListener;
use std::os::fd::{AsFd, AsRawFd};
use std::process::abort;
use std::ptr::null;
use std::time::Duration;
use nix::fcntl::{fcntl, FcntlArg, FdFlag};
use nix::libc::execve;
use nix::sys::memfd::{memfd_create, MemFdCreateFlag};
use nix::sys::signal::{kill, Signal};
use nix::sys::wait::{waitpid, WaitStatus};
use nix::unistd::{alarm, dup2, fork, getpid, ForkResult};
use tokio::io::stderr;
const TIMEOUT_SEC: u32 = 1;
const SERVER_EXE_PATH: &str = env!("CARGO_BIN_EXE_blahd");
const CONFIG: &str = r#"
[database]
in_memory = true
[listen]
systemd = true
[server]
base_url = "http://example.com"
"#;
#[test]
fn socket_activate() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let local_port = listener.local_addr().unwrap().port();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
// Remove `FD_CLOEXEC` since we want to send it to the child.
fn remove_cloexec(fd: &impl AsFd) {
let mut flags =
FdFlag::from_bits_retain(fcntl(fd.as_fd().as_raw_fd(), FcntlArg::F_GETFD).unwrap());
flags -= FdFlag::FD_CLOEXEC;
fcntl(fd.as_fd().as_raw_fd(), FcntlArg::F_SETFD(flags)).unwrap();
}
remove_cloexec(&listener);
remove_cloexec(&stderr());
let server_exe_c = CString::new(SERVER_EXE_PATH).unwrap();
// Intentionally no FD_CLOEXEC.
let mut memfd = File::from(memfd_create(c"test-config", MemFdCreateFlag::empty()).unwrap());
memfd.write_all(CONFIG.as_bytes()).unwrap();
memfd.rewind().unwrap();
// Unfortunately we have to deal with raw `fork(2)` here, because no library supports passing
// child PID in environment variables for child.
// SAFETY: Between `fork()` and `execve()`, all syscalls are async-signal-safe:
// no memory allocation, no panic unwinding (always abort).
match unsafe { fork().unwrap() } {
ForkResult::Child => {
// Ideally, we want `std::panic::always_abort()`, which is unstable yet.
// WAIT: https://github.com/rust-lang/rust/issues/84438
scopeguard::defer!(abort());
// Don't leave an orphan process if something does wrong.
alarm::set(TIMEOUT_SEC);
// Ignore all errors here to stay safe, and lazy.
let _ = dup2(2, 1); // stdout <- stderr
let _ = dup2(memfd.as_raw_fd(), 0); // stdin <- config memfd
let _ = dup2(listener.as_raw_fd(), 3); // listener fd
let args = [
c"blahd".as_ptr(),
c"serve".as_ptr(),
c"-c".as_ptr(),
c"/proc/self/fd/0".as_ptr(),
null(),
];
let mut buf = [0u8; 64];
let _ = write!(&mut buf[..], "LISTEN_PID={}\0", getpid().as_raw());
let envs = [c"LISTEN_FDS=1".as_ptr(), buf.as_ptr().cast(), null()];
// NB. Do raw libc call, not the wrapper fn that does allocation inside.
// SAFETY: Valid NULL-terminated array of NULL-terminated strings.
unsafe {
execve(server_exe_c.as_ptr(), args.as_ptr(), envs.as_ptr());
// If exec fail, the drop guard will abort the process anyway. Do nothing.
}
}
ForkResult::Parent { child } => {
{
scopeguard::defer! {
let _ = kill(child, Signal::SIGKILL);
}
let resp = rt.block_on(async {
let fut = async {
reqwest::get(format!("http://127.0.0.1:{local_port}/room?filter=public"))
.await
.unwrap()
.error_for_status()
.unwrap()
.text()
.await
.unwrap()
};
tokio::time::timeout(Duration::from_secs(TIMEOUT_SEC.into()), fut)
.await
.unwrap()
});
assert_eq!(resp, r#"{"rooms":[]}"#);
}
let st = waitpid(child, None).unwrap();
assert!(
matches!(st, WaitStatus::Signaled(_, Signal::SIGKILL, _)),
"unexpected exit status {st:?}",
);
}
}
}

View file

@ -0,0 +1,9 @@
[Unit]
Description=Blah Chat Server Socket
[Socket]
ListenStream=[::]:8080
BindIPv6Only=both
[Install]
WantedBy=sockets.target

View file

@ -65,6 +65,9 @@ rec {
"--package=blahctl"
];
# Intentionally omit the socket unit. It is trivial but
# highly configuration-specific. Users who want to use it almost
# always need customization.
postInstall = ''
mkdir -p $out/etc/systemd/system
substitute ./contrib/blahd.example.service $out/etc/systemd/system/blahd.service \