diff --git a/blah-types/src/msg.rs b/blah-types/src/msg.rs index c172639..3e1394c 100644 --- a/blah-types/src/msg.rs +++ b/blah-types/src/msg.rs @@ -1,5 +1,7 @@ //! Core message subtypes. use std::fmt; +use std::num::ParseIntError; +use std::str::FromStr; use bitflags_serde_shim::impl_serde_for_bitflags; use serde::{de, ser, Deserialize, Serialize}; @@ -22,6 +24,14 @@ impl fmt::Display for Id { } } +impl FromStr for Id { + type Err = ParseIntError; + + fn from_str(s: &str) -> Result { + i64::from_str(s).map(Self) + } +} + impl Id { pub const MIN: Self = Id(i64::MIN); pub const MAX: Self = Id(i64::MAX); diff --git a/blahd/src/lib.rs b/blahd/src/lib.rs index f968792..7b22884 100644 --- a/blahd/src/lib.rs +++ b/blahd/src/lib.rs @@ -23,7 +23,7 @@ use blah_types::{get_timestamp, Id, Signed, UserKey}; use database::{Transaction, TransactionOps}; use feed::FeedData; use id::IdExt; -use middleware::{Auth, MaybeAuth, ResultExt as _, SignedJson}; +use middleware::{Auth, ETag, MaybeAuth, ResultExt as _, SignedJson}; use parking_lot::Mutex; use serde::{Deserialize, Deserializer, Serialize}; use serde_inline_default::serde_inline_default; @@ -434,11 +434,11 @@ async fn room_get_metadata( async fn room_get_feed( st: ArcState, + ETag(etag): ETag, R(OriginalUri(req_uri), _): RE, R(Path(rid), _): RE>, R(Query(mut pagination), _): RE>, ) -> Result { - // TODO: If-None-Match. let self_url = st .config .base_url @@ -460,6 +460,12 @@ async fn room_get_feed( Ok((title, msgs, skip_token)) })?; + // Use `Id(0)` as the tag for an empty list. + let ret_etag = msgs.first().map_or(Id(0), |msg| msg.cid); + if etag == Some(ret_etag) { + return Ok(StatusCode::NOT_MODIFIED.into_response()); + } + let next_url = skip_token.map(|skip_token| { let next_params = Pagination { skip_token: Some(skip_token), @@ -478,13 +484,14 @@ async fn room_get_feed( next_url }); - Ok(FT::to_feed_response(FeedData { + let resp = FT::to_feed_response(FeedData { rid, title, msgs, self_url, next_url, - })) + }); + Ok((ETag(Some(ret_etag)), resp).into_response()) } /// Get room messages with pagination parameters, diff --git a/blahd/src/middleware.rs b/blahd/src/middleware.rs index 0faa5e5..9461071 100644 --- a/blahd/src/middleware.rs +++ b/blahd/src/middleware.rs @@ -1,12 +1,13 @@ use std::backtrace::Backtrace; use std::convert::Infallible; use std::fmt; +use std::str::FromStr; use std::sync::Arc; use axum::extract::rejection::{JsonRejection, PathRejection, QueryRejection}; use axum::extract::{FromRef, FromRequest, FromRequestParts, Request}; -use axum::http::{header, request, StatusCode}; -use axum::response::{IntoResponse, Response}; +use axum::http::{header, request, HeaderValue, StatusCode}; +use axum::response::{IntoResponse, IntoResponseParts, Response, ResponseParts}; use axum::{async_trait, Json}; use blah_types::msg::AuthPayload; use blah_types::{Signed, UserKey}; @@ -244,3 +245,42 @@ where Ok(Self(data.signee.user)) } } + +#[derive(Debug, Clone)] +pub struct ETag(pub Option); + +#[async_trait] +impl FromRequestParts for ETag +where + S: Send + Sync, +{ + type Rejection = Infallible; + + async fn from_request_parts( + parts: &mut request::Parts, + _state: &S, + ) -> Result { + let tag = parts + .headers + .get(header::IF_NONE_MATCH) + .and_then(|v| v.to_str().ok()?.strip_prefix('"')?.strip_suffix('"')) + .filter(|s| !s.is_empty()) + .and_then(|s| s.parse::().ok()); + Ok(Self(tag)) + } +} + +impl IntoResponseParts for ETag { + type Error = Infallible; + + fn into_response_parts(self, mut res: ResponseParts) -> Result { + if let Some(tag) = &self.0 { + res.headers_mut().insert( + header::ETAG, + HeaderValue::from_str(&format!("\"{tag}\"")) + .expect("ETag must be a valid header value"), + ); + } + Ok(res) + } +} diff --git a/blahd/tests/webapi.rs b/blahd/tests/webapi.rs index 166d157..c693582 100644 --- a/blahd/tests/webapi.rs +++ b/blahd/tests/webapi.rs @@ -790,15 +790,60 @@ async fn room_feed(server: Server, #[case] typ: &'static str) { .join_room(rid, &BOB, MemberPermission::POST_CHAT) .await .unwrap(); - server.post_chat(rid, &ALICE, "a").await.unwrap(); + let feed_url = server.url(format!("/room/{rid}/feed.{typ}")); + let get_feed = |etag: Option<&str>| { + let mut req = server.client.get(&feed_url); + if let Some(etag) = etag { + req = req.header(header::IF_NONE_MATCH, etag); + } + async move { + let resp = req.send().await.unwrap().error_for_status().unwrap(); + if resp.status() == StatusCode::NOT_MODIFIED { + None + } else { + let etag = resp.headers()[header::ETAG].to_str().unwrap().to_owned(); + Some((etag, resp.text().await.unwrap())) + } + } + }; + + // Empty yet. + let etag_zero = "\"0\""; + assert_eq!(get_feed(None).await.unwrap().0, etag_zero); + // ETag should track from empty -> empty. + assert_eq!(get_feed(Some(etag_zero)).await, None); + + // Post some chats. + let cid1 = server.post_chat(rid, &ALICE, "a").await.unwrap().cid; + // Got some response. + let etag_one = format!("\"{cid1}\""); + { + let resp1 = get_feed(None).await.unwrap(); + // ETag should track from empty -> non-empty. + let resp2 = get_feed(Some(etag_zero)).await.unwrap(); + // Idempotent. + assert_eq!(resp1, resp2); + assert_eq!(resp1.0, etag_one); + } + + // Post more chats. let cid2 = server.post_chat(rid, &BOB, "b1").await.unwrap().cid; - server.post_chat(rid, &BOB, "b2").await.unwrap(); + let cid3 = server.post_chat(rid, &BOB, "b2").await.unwrap().cid; + + let etag_last = format!("\"{cid3}\""); + let resp = { + let resp1 = get_feed(None).await.unwrap(); + // ETag should track from non-empty -> non-empty. + let resp2 = get_feed(Some(&etag_one)).await.unwrap(); + // Idempotent. + assert_eq!(resp1, resp2); + assert_eq!(resp1.0, etag_last); + assert_eq!(get_feed(Some(&etag_last)).await, None); + resp1.1 + }; if typ == "json" { - let feed = server - .get::(&format!("/room/{rid}/feed.json"), None) - .await - .unwrap(); + let feed = serde_json::from_str::(&resp).unwrap(); // TODO: Ideally we should assert on the result, but it contains time and random id currently. assert_eq!(feed["title"].as_str().unwrap(), "public"); assert_eq!(feed["items"].as_array().unwrap().len(), 2); @@ -820,17 +865,6 @@ async fn room_feed(server: Server, #[case] typ: &'static str) { assert_eq!(items.len(), 1); assert_eq!(items[0]["content_html"].as_str().unwrap(), "a"); } else { - let resp = server - .client - .get(server.url(format!("/room/{rid}/feed.atom"))) - .send() - .await - .unwrap() - .error_for_status() - .unwrap() - .text() - .await - .unwrap(); assert!(resp.starts_with(r#""#)); assert_eq!(resp.matches("").count(), 2); }