diff --git a/blahd/Cargo.toml b/blahd/Cargo.toml index 2dc4ec7..85d0725 100644 --- a/blahd/Cargo.toml +++ b/blahd/Cargo.toml @@ -32,7 +32,7 @@ sha2 = "0.10" tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal", "sync", "time"] } tokio-stream = { version = "0.1", features = ["sync"] } toml = "0.8" -tower-http = { version = "0.6", features = ["cors", "limit"] } +tower-http = { version = "0.6", features = ["cors", "limit", "set-header"] } tracing = "0.1" tracing-subscriber = "0.3" url = { version = "2", features = ["serde"] } diff --git a/blahd/src/lib.rs b/blahd/src/lib.rs index 01ffccb..0c47d39 100644 --- a/blahd/src/lib.rs +++ b/blahd/src/lib.rs @@ -50,6 +50,12 @@ pub use middleware::ApiError; pub(crate) const SERVER_AND_VERSION: &str = concat!("blahd/", env!("CARGO_PKG_VERSION")); const SERVER_SRC_URL: Option<&str> = option_env!("CFG_SRC_URL"); +const HEADER_PUBLIC_NO_CACHE: (HeaderName, HeaderValue) = ( + header::CACHE_CONTROL, + HeaderValue::from_static("public, no-cache"), +); +const DEFAULT_CACHE_CONTROL: HeaderValue = HeaderValue::from_static("private, no-cache"); + #[serde_inline_default] #[derive(Debug, Clone, Deserialize)] #[serde(deny_unknown_fields)] @@ -158,6 +164,12 @@ pub fn router(st: Arc) -> Router { .layer(tower_http::limit::RequestBodyLimitLayer::new( st.config.max_request_len, )) + .layer( + tower_http::set_header::SetResponseHeaderLayer::if_not_present( + header::CACHE_CONTROL, + DEFAULT_CACHE_CONTROL, + ), + ) // NB. This comes at last (outmost layer), so inner errors will still be wrapped with // correct CORS headers. Also `Authorization` must be explicitly included besides `*`. .layer( @@ -176,15 +188,14 @@ type RE = R; async fn handle_server_metadata(State(st): ArcState) -> Response { let (json, etag) = st.server_metadata.clone(); - ( - [( + let headers = [ + ( header::CONTENT_TYPE, const { HeaderValue::from_static("application/json") }, - )], - etag, - json, - ) - .into_response() + ), + HEADER_PUBLIC_NO_CACHE, + ]; + (headers, etag, json).into_response() } async fn handle_ws(State(st): ArcState, ws: WebSocketUpgrade) -> Response { @@ -495,7 +506,7 @@ async fn room_get_feed( self_url, next_url, }); - Ok((ETag(Some(ret_etag)), resp).into_response()) + Ok(([HEADER_PUBLIC_NO_CACHE], ETag(Some(ret_etag)), resp).into_response()) } /// Get room messages with pagination parameters,