feat(blahd): impl optional prometheus metrics

This commit is contained in:
oxalica 2024-10-14 18:59:04 -04:00
parent 9baf47963d
commit fe587f057f
9 changed files with 284 additions and 43 deletions

View file

@ -4,8 +4,13 @@ version = "0.0.0"
edition = "2021"
[features]
default = []
default = ["prometheus"]
unsafe_use_mock_instant_for_testing = ["dep:mock_instant", "blah-types/unsafe_use_mock_instant_for_testing"]
prometheus = [
"dep:axum-metrics",
"dep:metrics",
"dep:metrics-exporter-prometheus",
]
[dependencies]
anyhow = "1"
@ -44,6 +49,10 @@ url = { version = "2", features = ["serde"] }
blah-types = { path = "../blah-types", features = ["rusqlite"] }
axum-metrics = { version = "0.1", optional = true }
metrics = { version = "0.24.0", optional = true }
metrics-exporter-prometheus = { version = "0.16", optional = true, default-features = false }
[build-dependencies]
url = "2"

View file

@ -94,3 +94,18 @@ difficulty = 16
# The challenge nonce rotation period in seconds.
nonce_rotate_secs = 60
# If this section appears, the prometheus metrics tracking is enabled.
# The server must be built with "prometheus" feature enabled to support this.
#[metric.prometheus]
# The listen address for the metrics server, in the same type as `[listen]`.
# GET on route `/metrics` responds in Prometheus Exposition Format.
# See more: https://prometheus.io/docs/instrumenting/exposition_formats/#text-based-format
#listen.systemd = true
# Upkeep interval in seconds.
#upkeep_period_secs = 5
# The bucket width in seconds for histograms.
#bucket_duration_secs = 20

View file

@ -3,9 +3,10 @@ use std::os::fd::{FromRawFd, OwnedFd};
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::{anyhow, bail, Context, Result};
use anyhow::{bail, Context, Result};
use blahd::config::{Config, ListenConfig};
use blahd::{AppState, Database};
use tokio::net::TcpListener;
use tokio::signal::unix::{signal, SignalKind};
/// Blah Chat Server
@ -57,47 +58,72 @@ fn main() -> Result<()> {
async fn main_serve(db: Database, config: Config) -> Result<()> {
let st = AppState::new(db, config.server);
let (listener_display, listener) = match &config.listen {
ListenConfig::Address(addr) => (
format!("address {addr:?}"),
tokio::net::TcpListener::bind(addr)
.await
.context("failed to listen on socket")?,
),
ListenConfig::Systemd(_) => {
use rustix::net::{getsockname, SocketAddrAny};
let mut listen_fds = sd_notify::listen_fds()
.context("failed to get fds from sd_listen_fds(3)")?
// SAFETY: `fd` is valid by sd_listen_fds(3) protocol.
.map(|fd| unsafe { Some(OwnedFd::from_raw_fd(fd)) })
.collect::<Vec<_>>();
let [fd] = sd_notify::listen_fds()
.context("failed to get fds from sd_listen_fds(3)")?
.collect::<Vec<_>>()
.try_into()
.map_err(|_| anyhow!("expecting exactly one fd from LISTEN_FDS"))?;
// SAFETY: `fd` is valid by sd_listen_fds(3) protocol.
let listener = unsafe { OwnedFd::from_raw_fd(fd) };
let addr = getsockname(&listener).context("failed to getsockname")?;
match addr {
SocketAddrAny::V4(_) | SocketAddrAny::V6(_) => {
let listener = std::net::TcpListener::from(listener);
listener
.set_nonblocking(true)
.context("failed to set socket non-blocking")?;
let listener = tokio::net::TcpListener::from_std(listener)
.context("failed to register async socket")?;
(format!("tcp socket {addr:?} from LISTEN_FDS"), listener)
let mut get_listener = move |name: &'static str, fd_idx: usize, listen: ListenConfig| {
let fd = listen_fds.get_mut(0).and_then(|opt| opt.take());
async move {
match listen {
ListenConfig::Address(addr) => {
tracing::info!("serve {name} on address {addr:?}");
TcpListener::bind(addr)
.await
.context("failed to listen on socket")
}
ListenConfig::Systemd(_) => {
use rustix::net::{getsockname, SocketAddrAny};
let fd = fd.with_context(|| format!("missing LISTEN_FDS[{fd_idx}]"))?;
let addr = getsockname(&fd).context("failed to getsockname")?;
match addr {
SocketAddrAny::V4(_) | SocketAddrAny::V6(_) => {
let listener = std::net::TcpListener::from(fd);
listener
.set_nonblocking(true)
.context("failed to set socket non-blocking")?;
let listener = TcpListener::from_std(listener)
.context("failed to register async socket")?;
tracing::info!(
"serve {name} on tcp socket {addr:?} from LISTEN_FDS[{fd_idx}"
);
Ok(listener)
}
// Unix socket support for axum is currently overly complex.
// WAIT: https://github.com/tokio-rs/axum/pull/2479
_ => bail!("unsupported socket type from LISTEN_FDS[{fd_idx}]: {addr:?}"),
}
}
// Unix socket support for axum is currently overly complex.
// WAIT: https://github.com/tokio-rs/axum/pull/2479
_ => bail!("unsupported socket type from LISTEN_FDS: {addr:?}"),
}
}
};
tracing::info!("listening on {listener_display}");
let router = blahd::router(Arc::new(st));
#[cfg(not(feature = "prometheus"))]
anyhow::ensure!(
config.metric.is_none(),
"metric support is disabled at compile time",
);
#[cfg(feature = "prometheus")]
let router = if let Some(metric_config) = &config.metric {
let blahd::config::MetricConfig::Prometheus(prom_config) = metric_config;
let metric_listener = get_listener("metrics", 1, prom_config.listen.clone()).await?;
let (metric_router, recorder, upkeeper) = blahd::metrics_router(metric_config);
metrics::set_global_recorder(recorder).expect("only set once");
tokio::spawn(upkeeper);
tokio::spawn(axum::serve(metric_listener, metric_router).into_future());
router.layer(axum_metrics::MetricLayer::default())
} else {
router
};
let mut sigterm = signal(SignalKind::terminate()).context("failed to listen on SIGTERM")?;
let listener = get_listener("main", 0, config.listen.clone()).await?;
let service = axum::serve(listener, router)
.with_graceful_shutdown(async move {
sigterm.recv().await;

View file

@ -1,5 +1,8 @@
use serde::{Deserialize, Serialize};
use std::num::NonZero;
use serde::Deserialize;
use serde_constant::ConstBool;
use serde_inline_default::serde_inline_default;
use crate::{database, ServerConfig};
@ -10,15 +13,34 @@ pub struct Config {
pub database: database::Config,
pub listen: ListenConfig,
pub server: ServerConfig,
#[serde(default)]
pub metric: Option<MetricConfig>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ListenConfig {
Address(String),
Systemd(ConstBool<true>),
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MetricConfig {
Prometheus(#[serde(default)] PrometheusConfig),
}
#[serde_inline_default]
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct PrometheusConfig {
pub listen: ListenConfig,
#[serde_inline_default(5.try_into().expect("not zero"))]
pub upkeep_period_secs: NonZero<u32>,
#[serde_inline_default(20.try_into().expect("not zero"))]
pub bucket_duration_secs: NonZero<u32>,
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -138,6 +138,20 @@ impl Stream for UserEventReceiver {
// TODO: Authenticate via HTTP query?
pub async fn get_ws(st: ArcState, ws: WebSocketUpgrade) -> Response {
ws.on_upgrade(move |mut socket| async move {
#[cfg(feature = "prometheus")]
let _guard = {
struct DecOnDrop(metrics::Gauge);
impl Drop for DecOnDrop {
fn drop(&mut self) {
self.0.decrement(1);
}
}
let gauge = metrics::gauge!("ws_connections_in_flight");
gauge.increment(1);
DecOnDrop(gauge)
};
match handle_ws(st.0, &mut socket).await {
#[allow(
unreachable_patterns,

View file

@ -44,6 +44,11 @@ mod id;
mod register;
mod utils;
#[cfg(feature = "prometheus")]
mod metric;
#[cfg(feature = "prometheus")]
pub use metric::metrics_router;
pub use database::{Config as DatabaseConfig, Database};
pub use middleware::ApiError;

42
blahd/src/metric.rs Normal file
View file

@ -0,0 +1,42 @@
use std::time::Duration;
use axum::Router;
use futures_util::future::BoxFuture;
use metrics::Recorder;
use metrics_exporter_prometheus::PrometheusBuilder;
use tokio::time::{interval, MissedTickBehavior};
use crate::config::MetricConfig;
type DynRecorder = Box<dyn Recorder + Send + Sync>;
type UpkeeperTask = BoxFuture<'static, ()>;
pub fn metrics_router(config: &MetricConfig) -> (Router, DynRecorder, UpkeeperTask) {
let MetricConfig::Prometheus(config) = config;
let recorder = PrometheusBuilder::new()
.set_bucket_duration(Duration::from_secs(
config.bucket_duration_secs.get().into(),
))
.expect("not zero")
.build_recorder();
let handle_render = recorder.handle();
let get_metrics = || async move { handle_render.render() };
let upkeeper = Box::pin({
let handle_upkeep = recorder.handle();
let upkeep_period = Duration::from_secs(config.upkeep_period_secs.get().into());
async move {
let mut interval = interval(upkeep_period);
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
loop {
interval.tick().await;
handle_upkeep.run_upkeep();
}
}
}) as _;
let router = Router::new().route("/metrics", axum::routing::get(get_metrics));
(router, Box::new(recorder) as _, upkeeper)
}