diff --git a/Cargo.lock b/Cargo.lock index efa3175..4225066 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -223,7 +223,6 @@ dependencies = [ name = "blah" version = "0.0.0" dependencies = [ - "anyhow", "bitflags", "bitflags_serde_shim", "ed25519-dalek", diff --git a/Cargo.toml b/Cargo.toml index 4e0eed7..7695945 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,6 @@ version = "0.0.0" edition = "2021" [dependencies] -anyhow = "1" bitflags = "2" bitflags_serde_shim = "0.2" ed25519-dalek = "2.1" diff --git a/blahd/src/main.rs b/blahd/src/main.rs index 9b81cd3..819e5ea 100644 --- a/blahd/src/main.rs +++ b/blahd/src/main.rs @@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime}; use anyhow::{ensure, Context, Result}; -use axum::extract::{FromRequest, FromRequestParts, Path, Query, Request, State}; +use axum::extract::{FromRef, FromRequest, FromRequestParts, Path, Query, Request, State}; use axum::http::{header, request, StatusCode}; use axum::response::{sse, IntoResponse, Response}; use axum::routing::{get, post}; @@ -21,11 +21,15 @@ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use tokio::sync::broadcast; use tokio_stream::StreamExt; +use utils::ExpiringSet; use uuid::Uuid; const PAGE_LEN: usize = 64; const EVENT_QUEUE_LEN: usize = 1024; const MAX_BODY_LEN: usize = 4 << 10; // 4KiB +const TIMESTAMP_TOLERENCE: u64 = 90; + +mod utils; #[derive(Debug, clap::Parser)] struct Cli { @@ -58,10 +62,12 @@ fn main() -> Result<()> { Ok(()) } +// Locks must be grabbed in the field order. #[derive(Debug)] struct AppState { conn: Mutex, room_listeners: Mutex>>>, + used_nonces: Mutex>, base_url: Box, } @@ -81,9 +87,32 @@ impl AppState { Ok(Self { conn: Mutex::new(conn), room_listeners: Mutex::new(HashMap::new()), + used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(TIMESTAMP_TOLERENCE))), + base_url, }) } + + fn verify_signed_data(&self, data: &WithSig) -> Result<()> { + data.verify().context("unsigned payload")?; + let timestamp_diff = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("after UNIX epoch") + .as_secs() + .abs_diff(data.signee.timestamp); + ensure!( + timestamp_diff <= TIMESTAMP_TOLERENCE, + "invalid timestamp, off by {timestamp_diff}s" + ); + ensure!( + self.used_nonces + .lock() + .unwrap() + .try_insert(data.signee.nonce), + "duplicated nonce", + ); + Ok(()) + } } type ArcState = State>; @@ -406,14 +435,20 @@ fn query_room_items( struct SignedJson(WithSig); #[async_trait] -impl FromRequest for SignedJson { +impl FromRequest for SignedJson +where + S: Send + Sync, + T: Serialize + DeserializeOwned, + Arc: FromRef, +{ type Rejection = Response; async fn from_request(req: Request, state: &S) -> Result { let Json(data) = > as FromRequest>::from_request(req, state) .await .map_err(|err| err.into_response())?; - data.verify().map_err(|err| { + let st = >::from_ref(state); + st.verify_signed_data(&data).map_err(|err| { tracing::debug!(%err, "unsigned payload"); StatusCode::BAD_REQUEST.into_response() })?; @@ -426,21 +461,26 @@ impl FromRequest for SignedJ struct OptionalAuth(Option); #[async_trait] -impl FromRequestParts for OptionalAuth { +impl FromRequestParts for OptionalAuth +where + S: Send + Sync, + Arc: FromRef, +{ type Rejection = StatusCode; async fn from_request_parts( parts: &mut request::Parts, - _state: &S, + state: &S, ) -> Result { let Some(auth) = parts.headers.get(header::AUTHORIZATION) else { return Ok(Self(None)); }; + let st = >::from_ref(state); let ret = serde_json::from_slice::>(auth.as_bytes()) .context("invalid JSON") .and_then(|data| { - data.verify()?; + st.verify_signed_data(&data)?; Ok(data.signee.user) }); match ret { diff --git a/blahd/src/utils.rs b/blahd/src/utils.rs new file mode 100644 index 0000000..f3a1c0a --- /dev/null +++ b/blahd/src/utils.rs @@ -0,0 +1,44 @@ +use std::collections::{HashSet, VecDeque}; +use std::hash::Hash; +use std::time::{Duration, Instant}; + +#[derive(Debug)] +pub struct ExpiringSet { + set: HashSet, + expire_queue: VecDeque<(Instant, T)>, + expire_delay: Duration, +} + +impl ExpiringSet { + pub fn new(expire_delay: Duration) -> Self { + Self { + set: HashSet::new(), + expire_queue: VecDeque::new(), + expire_delay, + } + } + + fn maintain(&mut self, now: Instant) { + while let Some((_, x)) = self + .expire_queue + .front() + .filter(|(deadline, _)| *deadline < now) + { + self.set.remove(x); + self.expire_queue.pop_front(); + } + + // TODO: Reclaim space after instant heavy load. + } + + pub fn try_insert(&mut self, v: T) -> bool { + let now = Instant::now(); + self.maintain(now); + if self.set.insert(v) { + self.expire_queue.push_back((now + self.expire_delay, v)); + true + } else { + false + } + } +} diff --git a/src/types.rs b/src/types.rs index d551767..1b1dfb6 100644 --- a/src/types.rs +++ b/src/types.rs @@ -5,18 +5,16 @@ use std::fmt; use std::time::SystemTime; -use anyhow::{ensure, Context}; use bitflags::bitflags; use bitflags_serde_shim::impl_serde_for_bitflags; use ed25519_dalek::{ - Signature, Signer, SigningKey, VerifyingKey, PUBLIC_KEY_LENGTH, SIGNATURE_LENGTH, + Signature, SignatureError, Signer, SigningKey, VerifyingKey, PUBLIC_KEY_LENGTH, + SIGNATURE_LENGTH, }; use rand_core::RngCore; use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; use uuid::Uuid; -const TIMESTAMP_TOLERENCE: u64 = 90; - #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(transparent)] pub struct UserKey(#[serde(with = "hex::serde")] pub [u8; PUBLIC_KEY_LENGTH]); @@ -54,25 +52,28 @@ fn get_timestamp() -> u64 { } impl WithSig { - pub fn sign(key: &SigningKey, rng: &mut impl RngCore, payload: T) -> anyhow::Result { + /// Sign the payload with the given `key`. + pub fn sign( + key: &SigningKey, + rng: &mut impl RngCore, + payload: T, + ) -> Result { let signee = Signee { nonce: rng.next_u32(), payload, timestamp: get_timestamp(), user: UserKey(key.verifying_key().to_bytes()), }; - let canonical_signee = serde_json::to_vec(&signee).context("failed to serialize")?; + let canonical_signee = serde_json::to_vec(&signee).expect("serialization cannot fail"); let sig = key.try_sign(&canonical_signee)?.to_bytes(); Ok(Self { sig, signee }) } - pub fn verify(&self) -> anyhow::Result<()> { - ensure!( - self.signee.timestamp.abs_diff(get_timestamp()) < TIMESTAMP_TOLERENCE, - "invalid timestamp" - ); - - let canonical_signee = serde_json::to_vec(&self.signee).context("failed to serialize")?; + /// Verify `sig` is valid for `signee`. + /// + /// Note that this does nott check validity of timestamp and other data. + pub fn verify(&self) -> Result<(), SignatureError> { + let canonical_signee = serde_json::to_vec(&self.signee).expect("serialization cannot fail"); let sig = Signature::from_bytes(&self.sig); VerifyingKey::from_bytes(&self.signee.user.0)?.verify_strict(&canonical_signee, &sig)?; Ok(())