diff --git a/Cargo.lock b/Cargo.lock index b22c22a..8b31446 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2138,6 +2138,15 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + [[package]] name = "signature" version = "2.2.0" @@ -2381,6 +2390,7 @@ dependencies = [ "libc", "mio", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.52.0", diff --git a/blahd/Cargo.toml b/blahd/Cargo.toml index 19a6b30..5d62147 100644 --- a/blahd/Cargo.toml +++ b/blahd/Cargo.toml @@ -30,7 +30,7 @@ serde_jcs = "0.1" serde_json = "1" serde_urlencoded = "0.7.1" sha2 = "0.10" -tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync", "time"] } +tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal", "sync", "time"] } tokio-stream = { version = "0.1", features = ["sync"] } toml = "0.8" tower-http = { version = "0.5", features = ["cors", "limit"] } diff --git a/blahd/src/bin/blahd.rs b/blahd/src/bin/blahd.rs index fe4d06c..c84dde2 100644 --- a/blahd/src/bin/blahd.rs +++ b/blahd/src/bin/blahd.rs @@ -1,3 +1,4 @@ +use std::future::IntoFuture; use std::os::fd::{FromRawFd, OwnedFd}; use std::path::PathBuf; use std::sync::Arc; @@ -5,6 +6,7 @@ use std::sync::Arc; use anyhow::{anyhow, bail, Context, Result}; use blahd::config::{Config, ListenConfig}; use blahd::{AppState, Database}; +use tokio::signal::unix::{signal, SignalKind}; /// Blah Chat Server #[derive(Debug, clap::Parser)] @@ -94,10 +96,16 @@ async fn main_serve(db: Database, config: Config) -> Result<()> { tracing::info!("listening on {listener_display}"); let router = blahd::router(Arc::new(st)); - let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); - axum::serve(listener, router) - .await - .context("failed to serve")?; + let mut sigterm = signal(SignalKind::terminate()).context("failed to listen on SIGTERM")?; + let service = axum::serve(listener, router) + .with_graceful_shutdown(async move { + sigterm.recv().await; + tracing::info!("received SIGTERM, shutting down gracefully"); + }) + .into_future(); + + let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); + service.await.context("failed to serve")?; Ok(()) } diff --git a/blahd/tests/socket_activate.rs b/blahd/tests/socket_activate.rs index fab971b..2307d10 100644 --- a/blahd/tests/socket_activate.rs +++ b/blahd/tests/socket_activate.rs @@ -104,7 +104,7 @@ fn socket_activate(#[case] unix_socket: bool) { } ForkResult::Parent { child } => { let guard = scopeguard::guard((), |()| { - let _ = kill(child, Signal::SIGKILL); + let _ = kill(child, Signal::SIGTERM); }); if !unix_socket { @@ -130,15 +130,17 @@ fn socket_activate(#[case] unix_socket: bool) { } let st = waitpid(child, None).unwrap(); - if unix_socket { + let expect_exit_code = if unix_socket { // Fail with unsupported error. - assert!(matches!(st, WaitStatus::Exited(_, 1))); + 1 } else { - assert!( - matches!(st, WaitStatus::Signaled(_, Signal::SIGKILL, _)), - "unexpected exit status {st:?}", - ); - } + // Graceful shutdown. + 0 + }; + assert!( + matches!(st, WaitStatus::Exited(_, code) if code == expect_exit_code), + "unexpected exit status {st:?}", + ); } } }