Split out timestamp check and actually check nonce

This commit is contained in:
oxalica 2024-08-30 20:32:48 -04:00
parent 668b873b07
commit 4ceffe3f31
5 changed files with 104 additions and 21 deletions

1
Cargo.lock generated
View file

@ -223,7 +223,6 @@ dependencies = [
name = "blah" name = "blah"
version = "0.0.0" version = "0.0.0"
dependencies = [ dependencies = [
"anyhow",
"bitflags", "bitflags",
"bitflags_serde_shim", "bitflags_serde_shim",
"ed25519-dalek", "ed25519-dalek",

View file

@ -13,7 +13,6 @@ version = "0.0.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
anyhow = "1"
bitflags = "2" bitflags = "2"
bitflags_serde_shim = "0.2" bitflags_serde_shim = "0.2"
ed25519-dalek = "2.1" ed25519-dalek = "2.1"

View file

@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime}; use std::time::{Duration, SystemTime};
use anyhow::{ensure, Context, Result}; use anyhow::{ensure, Context, Result};
use axum::extract::{FromRequest, FromRequestParts, Path, Query, Request, State}; use axum::extract::{FromRef, FromRequest, FromRequestParts, Path, Query, Request, State};
use axum::http::{header, request, StatusCode}; use axum::http::{header, request, StatusCode};
use axum::response::{sse, IntoResponse, Response}; use axum::response::{sse, IntoResponse, Response};
use axum::routing::{get, post}; use axum::routing::{get, post};
@ -21,11 +21,15 @@ use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::broadcast; use tokio::sync::broadcast;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use utils::ExpiringSet;
use uuid::Uuid; use uuid::Uuid;
const PAGE_LEN: usize = 64; const PAGE_LEN: usize = 64;
const EVENT_QUEUE_LEN: usize = 1024; const EVENT_QUEUE_LEN: usize = 1024;
const MAX_BODY_LEN: usize = 4 << 10; // 4KiB const MAX_BODY_LEN: usize = 4 << 10; // 4KiB
const TIMESTAMP_TOLERENCE: u64 = 90;
mod utils;
#[derive(Debug, clap::Parser)] #[derive(Debug, clap::Parser)]
struct Cli { struct Cli {
@ -58,10 +62,12 @@ fn main() -> Result<()> {
Ok(()) Ok(())
} }
// Locks must be grabbed in the field order.
#[derive(Debug)] #[derive(Debug)]
struct AppState { struct AppState {
conn: Mutex<rusqlite::Connection>, conn: Mutex<rusqlite::Connection>,
room_listeners: Mutex<HashMap<u64, broadcast::Sender<Arc<ChatItem>>>>, room_listeners: Mutex<HashMap<u64, broadcast::Sender<Arc<ChatItem>>>>,
used_nonces: Mutex<ExpiringSet<u32>>,
base_url: Box<str>, base_url: Box<str>,
} }
@ -81,9 +87,32 @@ impl AppState {
Ok(Self { Ok(Self {
conn: Mutex::new(conn), conn: Mutex::new(conn),
room_listeners: Mutex::new(HashMap::new()), room_listeners: Mutex::new(HashMap::new()),
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(TIMESTAMP_TOLERENCE))),
base_url, base_url,
}) })
} }
fn verify_signed_data<T: Serialize>(&self, data: &WithSig<T>) -> Result<()> {
data.verify().context("unsigned payload")?;
let timestamp_diff = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("after UNIX epoch")
.as_secs()
.abs_diff(data.signee.timestamp);
ensure!(
timestamp_diff <= TIMESTAMP_TOLERENCE,
"invalid timestamp, off by {timestamp_diff}s"
);
ensure!(
self.used_nonces
.lock()
.unwrap()
.try_insert(data.signee.nonce),
"duplicated nonce",
);
Ok(())
}
} }
type ArcState = State<Arc<AppState>>; type ArcState = State<Arc<AppState>>;
@ -406,14 +435,20 @@ fn query_room_items(
struct SignedJson<T>(WithSig<T>); struct SignedJson<T>(WithSig<T>);
#[async_trait] #[async_trait]
impl<S: Send + Sync, T: Serialize + DeserializeOwned> FromRequest<S> for SignedJson<T> { impl<S, T> FromRequest<S> for SignedJson<T>
where
S: Send + Sync,
T: Serialize + DeserializeOwned,
Arc<AppState>: FromRef<S>,
{
type Rejection = Response; type Rejection = Response;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> { async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let Json(data) = <Json<WithSig<T>> as FromRequest<S>>::from_request(req, state) let Json(data) = <Json<WithSig<T>> as FromRequest<S>>::from_request(req, state)
.await .await
.map_err(|err| err.into_response())?; .map_err(|err| err.into_response())?;
data.verify().map_err(|err| { let st = <Arc<AppState>>::from_ref(state);
st.verify_signed_data(&data).map_err(|err| {
tracing::debug!(%err, "unsigned payload"); tracing::debug!(%err, "unsigned payload");
StatusCode::BAD_REQUEST.into_response() StatusCode::BAD_REQUEST.into_response()
})?; })?;
@ -426,21 +461,26 @@ impl<S: Send + Sync, T: Serialize + DeserializeOwned> FromRequest<S> for SignedJ
struct OptionalAuth(Option<UserKey>); struct OptionalAuth(Option<UserKey>);
#[async_trait] #[async_trait]
impl<S: Send + Sync> FromRequestParts<S> for OptionalAuth { impl<S> FromRequestParts<S> for OptionalAuth
where
S: Send + Sync,
Arc<AppState>: FromRef<S>,
{
type Rejection = StatusCode; type Rejection = StatusCode;
async fn from_request_parts( async fn from_request_parts(
parts: &mut request::Parts, parts: &mut request::Parts,
_state: &S, state: &S,
) -> Result<Self, Self::Rejection> { ) -> Result<Self, Self::Rejection> {
let Some(auth) = parts.headers.get(header::AUTHORIZATION) else { let Some(auth) = parts.headers.get(header::AUTHORIZATION) else {
return Ok(Self(None)); return Ok(Self(None));
}; };
let st = <Arc<AppState>>::from_ref(state);
let ret = serde_json::from_slice::<WithSig<AuthPayload>>(auth.as_bytes()) let ret = serde_json::from_slice::<WithSig<AuthPayload>>(auth.as_bytes())
.context("invalid JSON") .context("invalid JSON")
.and_then(|data| { .and_then(|data| {
data.verify()?; st.verify_signed_data(&data)?;
Ok(data.signee.user) Ok(data.signee.user)
}); });
match ret { match ret {

44
blahd/src/utils.rs Normal file
View file

@ -0,0 +1,44 @@
use std::collections::{HashSet, VecDeque};
use std::hash::Hash;
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct ExpiringSet<T> {
set: HashSet<T>,
expire_queue: VecDeque<(Instant, T)>,
expire_delay: Duration,
}
impl<T: Hash + Eq + Clone + Copy> ExpiringSet<T> {
pub fn new(expire_delay: Duration) -> Self {
Self {
set: HashSet::new(),
expire_queue: VecDeque::new(),
expire_delay,
}
}
fn maintain(&mut self, now: Instant) {
while let Some((_, x)) = self
.expire_queue
.front()
.filter(|(deadline, _)| *deadline < now)
{
self.set.remove(x);
self.expire_queue.pop_front();
}
// TODO: Reclaim space after instant heavy load.
}
pub fn try_insert(&mut self, v: T) -> bool {
let now = Instant::now();
self.maintain(now);
if self.set.insert(v) {
self.expire_queue.push_back((now + self.expire_delay, v));
true
} else {
false
}
}
}

View file

@ -5,18 +5,16 @@
use std::fmt; use std::fmt;
use std::time::SystemTime; use std::time::SystemTime;
use anyhow::{ensure, Context};
use bitflags::bitflags; use bitflags::bitflags;
use bitflags_serde_shim::impl_serde_for_bitflags; use bitflags_serde_shim::impl_serde_for_bitflags;
use ed25519_dalek::{ use ed25519_dalek::{
Signature, Signer, SigningKey, VerifyingKey, PUBLIC_KEY_LENGTH, SIGNATURE_LENGTH, Signature, SignatureError, Signer, SigningKey, VerifyingKey, PUBLIC_KEY_LENGTH,
SIGNATURE_LENGTH,
}; };
use rand_core::RngCore; use rand_core::RngCore;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
use uuid::Uuid; use uuid::Uuid;
const TIMESTAMP_TOLERENCE: u64 = 90;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)] #[serde(transparent)]
pub struct UserKey(#[serde(with = "hex::serde")] pub [u8; PUBLIC_KEY_LENGTH]); pub struct UserKey(#[serde(with = "hex::serde")] pub [u8; PUBLIC_KEY_LENGTH]);
@ -54,25 +52,28 @@ fn get_timestamp() -> u64 {
} }
impl<T: Serialize> WithSig<T> { impl<T: Serialize> WithSig<T> {
pub fn sign(key: &SigningKey, rng: &mut impl RngCore, payload: T) -> anyhow::Result<Self> { /// Sign the payload with the given `key`.
pub fn sign(
key: &SigningKey,
rng: &mut impl RngCore,
payload: T,
) -> Result<Self, SignatureError> {
let signee = Signee { let signee = Signee {
nonce: rng.next_u32(), nonce: rng.next_u32(),
payload, payload,
timestamp: get_timestamp(), timestamp: get_timestamp(),
user: UserKey(key.verifying_key().to_bytes()), user: UserKey(key.verifying_key().to_bytes()),
}; };
let canonical_signee = serde_json::to_vec(&signee).context("failed to serialize")?; let canonical_signee = serde_json::to_vec(&signee).expect("serialization cannot fail");
let sig = key.try_sign(&canonical_signee)?.to_bytes(); let sig = key.try_sign(&canonical_signee)?.to_bytes();
Ok(Self { sig, signee }) Ok(Self { sig, signee })
} }
pub fn verify(&self) -> anyhow::Result<()> { /// Verify `sig` is valid for `signee`.
ensure!( ///
self.signee.timestamp.abs_diff(get_timestamp()) < TIMESTAMP_TOLERENCE, /// Note that this does nott check validity of timestamp and other data.
"invalid timestamp" pub fn verify(&self) -> Result<(), SignatureError> {
); let canonical_signee = serde_json::to_vec(&self.signee).expect("serialization cannot fail");
let canonical_signee = serde_json::to_vec(&self.signee).context("failed to serialize")?;
let sig = Signature::from_bytes(&self.sig); let sig = Signature::from_bytes(&self.sig);
VerifyingKey::from_bytes(&self.signee.user.0)?.verify_strict(&canonical_signee, &sig)?; VerifyingKey::from_bytes(&self.signee.user.0)?.verify_strict(&canonical_signee, &sig)?;
Ok(()) Ok(())