Enforce sorted fields for signed payloads

This commit is contained in:
oxalica 2024-08-29 18:56:52 -04:00
parent 74bd0d42e2
commit cf5d648315
3 changed files with 45 additions and 6 deletions

1
Cargo.lock generated
View file

@ -239,6 +239,7 @@ dependencies = [
"serde-aux",
"serde-constant",
"serde_json",
"syn",
"tokio",
"tokio-stream",
"tower-http",

View file

@ -28,5 +28,8 @@ tracing = "0.1.40"
tracing-subscriber = "0.3.18"
uuid = { version = "1.10.0", features = ["serde", "v4"] }
[dev-dependencies]
syn = { version = "2.0.76", features = ["full", "visit"] }
[workspace]
members = [ "./blahctl" ]

View file

@ -1,5 +1,8 @@
//! NB. All structs here that are part of signee must be lexically sorted, as RFC8785.
//! This is tested by `canonical_fields_sorted`.
//! See: https://www.rfc-editor.org/rfc/rfc8785
//! FIXME: `typ` is still always the first field because of `serde`'s implementation.
use std::fmt;
// NB. All structs here that are part of signee must be lexically sorted, as RFC8785.
use std::time::SystemTime;
use anyhow::{ensure, Context};
@ -29,7 +32,6 @@ impl fmt::Display for UserKey {
#[derive(Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct WithSig<T> {
// sorted
#[serde(with = "hex::serde")]
pub sig: [u8; SIGNATURE_LENGTH],
pub signee: Signee<T>,
@ -38,7 +40,6 @@ pub struct WithSig<T> {
#[derive(Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Signee<T> {
// sorted
pub nonce: u32,
pub payload: T,
pub timestamp: u64,
@ -82,7 +83,6 @@ impl<T: Serialize> WithSig<T> {
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "typ", rename = "chat")]
pub struct ChatPayload {
// sorted
pub room: Uuid,
pub text: String,
}
@ -92,8 +92,8 @@ pub type ChatItem = WithSig<ChatPayload>;
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "typ", rename = "create_room")]
pub struct CreateRoomPayload {
pub title: String,
pub attrs: RoomAttrs,
pub title: String,
}
/// Proof of room membership for read-access.
@ -107,7 +107,6 @@ pub struct AuthPayload {}
#[serde(deny_unknown_fields, tag = "typ", rename_all = "snake_case")]
pub enum RoomAdminPayload {
AddMember {
// sorted
permission: RoomPermission,
room: Uuid,
user: UserKey,
@ -187,3 +186,39 @@ mod sql_impl {
impl_u64_flag!(ServerPermission, RoomPermission, RoomAttrs);
}
#[cfg(test)]
mod tests {
use std::fmt::Write;
#[derive(Default)]
struct Visitor {
errors: String,
}
impl<'ast> syn::visit::Visit<'ast> for Visitor {
fn visit_fields_named(&mut self, i: &'ast syn::FieldsNamed) {
let fields = i
.named
.iter()
.flat_map(|f| f.ident.clone())
.map(|i| i.to_string())
.collect::<Vec<_>>();
if !fields.windows(2).all(|w| w[0] < w[1]) {
writeln!(self.errors, "unsorted fields: {fields:?}").unwrap();
}
}
}
#[test]
fn canonical_fields_sorted() {
let src = std::fs::read_to_string(file!()).unwrap();
let file = syn::parse_file(&src).unwrap();
let mut v = Visitor::default();
syn::visit::visit_file(&mut v, &file);
if !v.errors.is_empty() {
panic!("{}", v.errors);
}
}
}