mirror of
https://github.com/Blah-IM/blahrs.git
synced 2025-05-01 08:41:09 +00:00
fix(blahd): reject UNIX domain socket for now
It's too complex to bother with for the current `axum` API. Let's wait for axum 0.8 release. Ref: https://github.com/tokio-rs/axum/pull/2479
This commit is contained in:
parent
ec7f428519
commit
b955d32099
5 changed files with 68 additions and 29 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -324,6 +324,7 @@ dependencies = [
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"rstest",
|
"rstest",
|
||||||
"rusqlite",
|
"rusqlite",
|
||||||
|
"rustix",
|
||||||
"scopeguard",
|
"scopeguard",
|
||||||
"sd-notify",
|
"sd-notify",
|
||||||
"serde",
|
"serde",
|
||||||
|
@ -333,6 +334,7 @@ dependencies = [
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
"sha2",
|
"sha2",
|
||||||
|
"tempfile",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"toml",
|
"toml",
|
||||||
|
|
|
@ -21,6 +21,7 @@ parking_lot = "0.12" # Maybe no better performance, just that we hate poisoning.
|
||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
reqwest = "0.12"
|
reqwest = "0.12"
|
||||||
rusqlite = "0.32"
|
rusqlite = "0.32"
|
||||||
|
rustix = { version = "0.38", features = ["net"] }
|
||||||
sd-notify = "0.4"
|
sd-notify = "0.4"
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde-constant = "0.1"
|
serde-constant = "0.1"
|
||||||
|
@ -45,6 +46,7 @@ nix = { version = "0.29.0", features = ["fs", "process", "signal"] }
|
||||||
reqwest = { version = "0.12.7", features = ["json"] }
|
reqwest = { version = "0.12.7", features = ["json"] }
|
||||||
rstest = { version = "0.22.0", default-features = false }
|
rstest = { version = "0.22.0", default-features = false }
|
||||||
scopeguard = "1.2.0"
|
scopeguard = "1.2.0"
|
||||||
|
tempfile = "3.12.0"
|
||||||
|
|
||||||
[lints]
|
[lints]
|
||||||
workspace = true
|
workspace = true
|
||||||
|
|
|
@ -24,6 +24,7 @@ address = "localhost:8080"
|
||||||
|
|
||||||
# Use systemd socket activation mechanism to get listener fd from envvars.
|
# Use systemd socket activation mechanism to get listener fd from envvars.
|
||||||
# See also sd_listen_fds(3) and systemd.socket(5).
|
# See also sd_listen_fds(3) and systemd.socket(5).
|
||||||
|
# NB. Currently only TCP sockets are supported. UNIX domain socket is TODO.
|
||||||
#systemd = true
|
#systemd = true
|
||||||
|
|
||||||
[server]
|
[server]
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
use std::net::TcpListener;
|
use std::os::fd::{FromRawFd, OwnedFd};
|
||||||
use std::os::fd::FromRawFd;
|
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::{anyhow, Context, Result};
|
use anyhow::{anyhow, bail, Context, Result};
|
||||||
use blahd::config::{Config, ListenConfig};
|
use blahd::config::{Config, ListenConfig};
|
||||||
use blahd::{AppState, Database};
|
use blahd::{AppState, Database};
|
||||||
|
|
||||||
|
@ -56,30 +55,44 @@ fn main() -> Result<()> {
|
||||||
async fn main_serve(db: Database, config: Config) -> Result<()> {
|
async fn main_serve(db: Database, config: Config) -> Result<()> {
|
||||||
let st = AppState::new(db, config.server);
|
let st = AppState::new(db, config.server);
|
||||||
|
|
||||||
let listener = match &config.listen {
|
let (listener_display, listener) = match &config.listen {
|
||||||
ListenConfig::Address(addr) => {
|
ListenConfig::Address(addr) => (
|
||||||
tracing::info!("listening on {addr:?}");
|
format!("address {addr:?}"),
|
||||||
tokio::net::TcpListener::bind(addr)
|
tokio::net::TcpListener::bind(addr)
|
||||||
.await
|
.await
|
||||||
.context("failed to listen on socket")?
|
.context("failed to listen on socket")?,
|
||||||
}
|
),
|
||||||
ListenConfig::Systemd(_) => {
|
ListenConfig::Systemd(_) => {
|
||||||
tracing::info!("listening on fd from environment");
|
use rustix::net::{getsockname, SocketAddrAny};
|
||||||
|
|
||||||
let [fd] = sd_notify::listen_fds()
|
let [fd] = sd_notify::listen_fds()
|
||||||
.context("failed to get fds from sd_listen_fds(3)")?
|
.context("failed to get fds from sd_listen_fds(3)")?
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.try_into()
|
.try_into()
|
||||||
.map_err(|_| anyhow!("more than one fds available from sd_listen_fds(3)"))?;
|
.map_err(|_| anyhow!("expecting exactly one fd from LISTEN_FDS"))?;
|
||||||
// SAFETY: `fd` is valid by sd_listen_fds(3) protocol.
|
// SAFETY: `fd` is valid by sd_listen_fds(3) protocol.
|
||||||
let listener = unsafe { TcpListener::from_raw_fd(fd) };
|
let listener = unsafe { OwnedFd::from_raw_fd(fd) };
|
||||||
|
|
||||||
|
let addr = getsockname(&listener).context("failed to getsockname")?;
|
||||||
|
match addr {
|
||||||
|
SocketAddrAny::V4(_) | SocketAddrAny::V6(_) => {
|
||||||
|
let listener = std::net::TcpListener::from(listener);
|
||||||
listener
|
listener
|
||||||
.set_nonblocking(true)
|
.set_nonblocking(true)
|
||||||
.context("failed to set socket non-blocking")?;
|
.context("failed to set socket non-blocking")?;
|
||||||
tokio::net::TcpListener::from_std(listener)
|
let listener = tokio::net::TcpListener::from_std(listener)
|
||||||
.context("failed to register async socket")?
|
.context("failed to register async socket")?;
|
||||||
|
(format!("tcp socket {addr:?} from LISTEN_FDS"), listener)
|
||||||
|
}
|
||||||
|
// Unix socket support for axum is currently overly complex.
|
||||||
|
// WAIT: https://github.com/tokio-rs/axum/pull/2479
|
||||||
|
_ => bail!("unsupported socket type from LISTEN_FDS: {addr:?}"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
tracing::info!("listening on {listener_display}");
|
||||||
|
|
||||||
let router = blahd::router(Arc::new(st));
|
let router = blahd::router(Arc::new(st));
|
||||||
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
|
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,8 @@ use std::ffi::CString;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{Seek, Write};
|
use std::io::{Seek, Write};
|
||||||
use std::net::TcpListener;
|
use std::net::TcpListener;
|
||||||
use std::os::fd::{AsFd, AsRawFd};
|
use std::os::fd::{AsFd, AsRawFd, OwnedFd};
|
||||||
|
use std::os::unix::net::UnixListener;
|
||||||
use std::process::abort;
|
use std::process::abort;
|
||||||
use std::ptr::null;
|
use std::ptr::null;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
@ -13,6 +14,7 @@ use nix::sys::memfd::{memfd_create, MemFdCreateFlag};
|
||||||
use nix::sys::signal::{kill, Signal};
|
use nix::sys::signal::{kill, Signal};
|
||||||
use nix::sys::wait::{waitpid, WaitStatus};
|
use nix::sys::wait::{waitpid, WaitStatus};
|
||||||
use nix::unistd::{alarm, dup2, fork, getpid, ForkResult};
|
use nix::unistd::{alarm, dup2, fork, getpid, ForkResult};
|
||||||
|
use rstest::rstest;
|
||||||
use tokio::io::stderr;
|
use tokio::io::stderr;
|
||||||
|
|
||||||
const TIMEOUT_SEC: u32 = 1;
|
const TIMEOUT_SEC: u32 = 1;
|
||||||
|
@ -28,10 +30,22 @@ systemd = true
|
||||||
base_url = "http://example.com"
|
base_url = "http://example.com"
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
#[test]
|
#[rstest]
|
||||||
fn socket_activate() {
|
#[case::tcp(false)]
|
||||||
|
#[case::unix(true)]
|
||||||
|
fn socket_activate(#[case] unix_socket: bool) {
|
||||||
|
let socket_dir;
|
||||||
|
let (local_port, listener) = if unix_socket {
|
||||||
|
socket_dir = tempfile::tempdir().unwrap();
|
||||||
|
let listener = UnixListener::bind(socket_dir.path().join("socket")).unwrap();
|
||||||
|
// Port is unused.
|
||||||
|
(0, OwnedFd::from(listener))
|
||||||
|
} else {
|
||||||
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
|
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
|
||||||
let local_port = listener.local_addr().unwrap().port();
|
let local_port = listener.local_addr().unwrap().port();
|
||||||
|
(local_port, OwnedFd::from(listener))
|
||||||
|
};
|
||||||
|
|
||||||
let rt = tokio::runtime::Builder::new_current_thread()
|
let rt = tokio::runtime::Builder::new_current_thread()
|
||||||
.enable_all()
|
.enable_all()
|
||||||
.build()
|
.build()
|
||||||
|
@ -89,11 +103,11 @@ fn socket_activate() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ForkResult::Parent { child } => {
|
ForkResult::Parent { child } => {
|
||||||
{
|
let guard = scopeguard::guard((), |()| {
|
||||||
scopeguard::defer! {
|
|
||||||
let _ = kill(child, Signal::SIGKILL);
|
let _ = kill(child, Signal::SIGKILL);
|
||||||
}
|
});
|
||||||
|
|
||||||
|
if !unix_socket {
|
||||||
let resp = rt.block_on(async {
|
let resp = rt.block_on(async {
|
||||||
let url = format!("http://127.0.0.1:{local_port}/_blah/room?filter=public");
|
let url = format!("http://127.0.0.1:{local_port}/_blah/room?filter=public");
|
||||||
let fut = async {
|
let fut = async {
|
||||||
|
@ -111,13 +125,20 @@ fn socket_activate() {
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
assert_eq!(resp, r#"{"rooms":[]}"#);
|
assert_eq!(resp, r#"{"rooms":[]}"#);
|
||||||
|
// Trigger the killer.
|
||||||
|
drop(guard);
|
||||||
}
|
}
|
||||||
|
|
||||||
let st = waitpid(child, None).unwrap();
|
let st = waitpid(child, None).unwrap();
|
||||||
|
if unix_socket {
|
||||||
|
// Fail with unsupported error.
|
||||||
|
assert!(matches!(st, WaitStatus::Exited(_, 1)));
|
||||||
|
} else {
|
||||||
assert!(
|
assert!(
|
||||||
matches!(st, WaitStatus::Signaled(_, Signal::SIGKILL, _)),
|
matches!(st, WaitStatus::Signaled(_, Signal::SIGKILL, _)),
|
||||||
"unexpected exit status {st:?}",
|
"unexpected exit status {st:?}",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue