Split out timestamp check and actually check nonce

This commit is contained in:
oxalica 2024-08-30 20:32:48 -04:00
parent 668b873b07
commit 4ceffe3f31
5 changed files with 104 additions and 21 deletions

1
Cargo.lock generated
View file

@ -223,7 +223,6 @@ dependencies = [
name = "blah"
version = "0.0.0"
dependencies = [
"anyhow",
"bitflags",
"bitflags_serde_shim",
"ed25519-dalek",

View file

@ -13,7 +13,6 @@ version = "0.0.0"
edition = "2021"
[dependencies]
anyhow = "1"
bitflags = "2"
bitflags_serde_shim = "0.2"
ed25519-dalek = "2.1"

View file

@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};
use anyhow::{ensure, Context, Result};
use axum::extract::{FromRequest, FromRequestParts, Path, Query, Request, State};
use axum::extract::{FromRef, FromRequest, FromRequestParts, Path, Query, Request, State};
use axum::http::{header, request, StatusCode};
use axum::response::{sse, IntoResponse, Response};
use axum::routing::{get, post};
@ -21,11 +21,15 @@ use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast;
use tokio_stream::StreamExt;
use utils::ExpiringSet;
use uuid::Uuid;
const PAGE_LEN: usize = 64;
const EVENT_QUEUE_LEN: usize = 1024;
const MAX_BODY_LEN: usize = 4 << 10; // 4KiB
const TIMESTAMP_TOLERENCE: u64 = 90;
mod utils;
#[derive(Debug, clap::Parser)]
struct Cli {
@ -58,10 +62,12 @@ fn main() -> Result<()> {
Ok(())
}
// Locks must be grabbed in the field order.
#[derive(Debug)]
struct AppState {
conn: Mutex<rusqlite::Connection>,
room_listeners: Mutex<HashMap<u64, broadcast::Sender<Arc<ChatItem>>>>,
used_nonces: Mutex<ExpiringSet<u32>>,
base_url: Box<str>,
}
@ -81,9 +87,32 @@ impl AppState {
Ok(Self {
conn: Mutex::new(conn),
room_listeners: Mutex::new(HashMap::new()),
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(TIMESTAMP_TOLERENCE))),
base_url,
})
}
fn verify_signed_data<T: Serialize>(&self, data: &WithSig<T>) -> Result<()> {
data.verify().context("unsigned payload")?;
let timestamp_diff = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("after UNIX epoch")
.as_secs()
.abs_diff(data.signee.timestamp);
ensure!(
timestamp_diff <= TIMESTAMP_TOLERENCE,
"invalid timestamp, off by {timestamp_diff}s"
);
ensure!(
self.used_nonces
.lock()
.unwrap()
.try_insert(data.signee.nonce),
"duplicated nonce",
);
Ok(())
}
}
type ArcState = State<Arc<AppState>>;
@ -406,14 +435,20 @@ fn query_room_items(
struct SignedJson<T>(WithSig<T>);
#[async_trait]
impl<S: Send + Sync, T: Serialize + DeserializeOwned> FromRequest<S> for SignedJson<T> {
impl<S, T> FromRequest<S> for SignedJson<T>
where
S: Send + Sync,
T: Serialize + DeserializeOwned,
Arc<AppState>: FromRef<S>,
{
type Rejection = Response;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let Json(data) = <Json<WithSig<T>> as FromRequest<S>>::from_request(req, state)
.await
.map_err(|err| err.into_response())?;
data.verify().map_err(|err| {
let st = <Arc<AppState>>::from_ref(state);
st.verify_signed_data(&data).map_err(|err| {
tracing::debug!(%err, "unsigned payload");
StatusCode::BAD_REQUEST.into_response()
})?;
@ -426,21 +461,26 @@ impl<S: Send + Sync, T: Serialize + DeserializeOwned> FromRequest<S> for SignedJ
struct OptionalAuth(Option<UserKey>);
#[async_trait]
impl<S: Send + Sync> FromRequestParts<S> for OptionalAuth {
impl<S> FromRequestParts<S> for OptionalAuth
where
S: Send + Sync,
Arc<AppState>: FromRef<S>,
{
type Rejection = StatusCode;
async fn from_request_parts(
parts: &mut request::Parts,
_state: &S,
state: &S,
) -> Result<Self, Self::Rejection> {
let Some(auth) = parts.headers.get(header::AUTHORIZATION) else {
return Ok(Self(None));
};
let st = <Arc<AppState>>::from_ref(state);
let ret = serde_json::from_slice::<WithSig<AuthPayload>>(auth.as_bytes())
.context("invalid JSON")
.and_then(|data| {
data.verify()?;
st.verify_signed_data(&data)?;
Ok(data.signee.user)
});
match ret {

44
blahd/src/utils.rs Normal file
View file

@ -0,0 +1,44 @@
use std::collections::{HashSet, VecDeque};
use std::hash::Hash;
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct ExpiringSet<T> {
set: HashSet<T>,
expire_queue: VecDeque<(Instant, T)>,
expire_delay: Duration,
}
impl<T: Hash + Eq + Clone + Copy> ExpiringSet<T> {
pub fn new(expire_delay: Duration) -> Self {
Self {
set: HashSet::new(),
expire_queue: VecDeque::new(),
expire_delay,
}
}
fn maintain(&mut self, now: Instant) {
while let Some((_, x)) = self
.expire_queue
.front()
.filter(|(deadline, _)| *deadline < now)
{
self.set.remove(x);
self.expire_queue.pop_front();
}
// TODO: Reclaim space after instant heavy load.
}
pub fn try_insert(&mut self, v: T) -> bool {
let now = Instant::now();
self.maintain(now);
if self.set.insert(v) {
self.expire_queue.push_back((now + self.expire_delay, v));
true
} else {
false
}
}
}

View file

@ -5,18 +5,16 @@
use std::fmt;
use std::time::SystemTime;
use anyhow::{ensure, Context};
use bitflags::bitflags;
use bitflags_serde_shim::impl_serde_for_bitflags;
use ed25519_dalek::{
Signature, Signer, SigningKey, VerifyingKey, PUBLIC_KEY_LENGTH, SIGNATURE_LENGTH,
Signature, SignatureError, Signer, SigningKey, VerifyingKey, PUBLIC_KEY_LENGTH,
SIGNATURE_LENGTH,
};
use rand_core::RngCore;
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
use uuid::Uuid;
const TIMESTAMP_TOLERENCE: u64 = 90;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct UserKey(#[serde(with = "hex::serde")] pub [u8; PUBLIC_KEY_LENGTH]);
@ -54,25 +52,28 @@ fn get_timestamp() -> u64 {
}
impl<T: Serialize> WithSig<T> {
pub fn sign(key: &SigningKey, rng: &mut impl RngCore, payload: T) -> anyhow::Result<Self> {
/// Sign the payload with the given `key`.
pub fn sign(
key: &SigningKey,
rng: &mut impl RngCore,
payload: T,
) -> Result<Self, SignatureError> {
let signee = Signee {
nonce: rng.next_u32(),
payload,
timestamp: get_timestamp(),
user: UserKey(key.verifying_key().to_bytes()),
};
let canonical_signee = serde_json::to_vec(&signee).context("failed to serialize")?;
let canonical_signee = serde_json::to_vec(&signee).expect("serialization cannot fail");
let sig = key.try_sign(&canonical_signee)?.to_bytes();
Ok(Self { sig, signee })
}
pub fn verify(&self) -> anyhow::Result<()> {
ensure!(
self.signee.timestamp.abs_diff(get_timestamp()) < TIMESTAMP_TOLERENCE,
"invalid timestamp"
);
let canonical_signee = serde_json::to_vec(&self.signee).context("failed to serialize")?;
/// Verify `sig` is valid for `signee`.
///
/// Note that this does nott check validity of timestamp and other data.
pub fn verify(&self) -> Result<(), SignatureError> {
let canonical_signee = serde_json::to_vec(&self.signee).expect("serialization cannot fail");
let sig = Signature::from_bytes(&self.sig);
VerifyingKey::from_bytes(&self.signee.user.0)?.verify_strict(&canonical_signee, &sig)?;
Ok(())