diff --git a/Cargo.lock b/Cargo.lock index 0c7b421..b6ecae8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1254,6 +1254,12 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "litrs" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ce301924b7887e9d637144fdade93f9dfff9b60981d4ac161db09720d39aa5" + [[package]] name = "lock_api" version = "0.4.12" @@ -1472,6 +1478,45 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "phf" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +dependencies = [ + "phf_shared", + "rand", +] + +[[package]] +name = "phf_shared" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +dependencies = [ + "siphasher", + "uncased", +] + [[package]] name = "pin-project" version = "1.1.5" @@ -1771,9 +1816,21 @@ dependencies = [ "fallible-streaming-iterator", "hashlink", "libsqlite3-sys", + "rusqlite-macros", "smallvec", ] +[[package]] +name = "rusqlite-macros" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecdc5e5d64f172916dfc8a0b0f7876de19b899e7a5f1d5b2c04c722cc78e0e45" +dependencies = [ + "fallible-iterator", + "litrs", + "sqlite3-parser", +] + [[package]] name = "rustc-demangle" version = "0.1.24" @@ -2089,6 +2146,12 @@ dependencies = [ "rand_core", ] +[[package]] +name = "siphasher" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" + [[package]] name = "slab" version = "0.4.9" @@ -2130,6 +2193,24 @@ dependencies = [ "der", ] +[[package]] +name = "sqlite3-parser" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb5307dad6cb84730ce8bdefde56ff4cf95fe516972d52e2bbdc8a8cd8f2520b" +dependencies = [ + "bitflags", + "cc", + "fallible-iterator", + "indexmap 2.4.0", + "log", + "memchr", + "phf", + "phf_codegen", + "phf_shared", + "uncased", +] + [[package]] name = "strsim" version = "0.11.1" @@ -2540,6 +2621,15 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "uncased" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1b88fcfe09e89d3866a5c11019378088af2d24c3fbd4f0543f96b479ec90697" +dependencies = [ + "version_check", +] + [[package]] name = "unicode-bidi" version = "0.3.15" diff --git a/blah-types/src/lib.rs b/blah-types/src/lib.rs index 297a180..4d3b5db 100644 --- a/blah-types/src/lib.rs +++ b/blah-types/src/lib.rs @@ -490,6 +490,7 @@ bitflags::bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub struct RoomAttrs: u64 { + // NB. Used by schema. const PUBLIC_READABLE = 1 << 0; const PUBLIC_JOINABLE = 1 << 1; diff --git a/blahd/Cargo.toml b/blahd/Cargo.toml index 6521e62..cb4eabf 100644 --- a/blahd/Cargo.toml +++ b/blahd/Cargo.toml @@ -20,7 +20,7 @@ humantime = "2" parking_lot = "0.12" # Maybe no better performance, just that we hate poisoning. ¯\_(ツ)_/¯ rand = "0.8" reqwest = "0.12" -rusqlite = "0.32" +rusqlite = { version = "0.32", features = ["rusqlite-macros"] } rustix = { version = "0.38", features = ["net"] } sd-notify = "0.4" serde = { version = "1", features = ["derive"] } diff --git a/blahd/schema.sql b/blahd/schema.sql index 1dd4098..9562137 100644 --- a/blahd/schema.sql +++ b/blahd/schema.sql @@ -17,12 +17,6 @@ CREATE TABLE IF NOT EXISTS `user_act_key` ( PRIMARY KEY (`uid`, `act_key`) ) STRICT, WITHOUT ROWID; -CREATE VIEW IF NOT EXISTS `valid_user_act_key` AS - SELECT `act_key`, `user`.* - FROM `user_act_key` - JOIN `user` USING (`uid`) - WHERE unixepoch() < `expire_time`; - -- The highest bit of `rid` will be set for peer chat room. -- So simply comparing it against 0 can filter them out. CREATE TABLE IF NOT EXISTS `room` ( @@ -43,6 +37,10 @@ CREATE UNIQUE INDEX IF NOT EXISTS `ix_peer_chat` ON `room` (`peer1`, `peer2`) WHERE `rid` < 0; +-- RoomAttrs::PUBLIC_READABLE +CREATE INDEX IF NOT EXISTS `ix_public_room` ON `room` (`rid`) + WHERE `attrs` & 1 != 0; + CREATE TABLE IF NOT EXISTS `room_member` ( `rid` INTEGER NOT NULL REFERENCES `room` ON DELETE CASCADE, `uid` INTEGER NOT NULL REFERENCES `user` ON DELETE RESTRICT, @@ -68,3 +66,11 @@ CREATE TABLE IF NOT EXISTS `msg` ( ) STRICT; CREATE INDEX IF NOT EXISTS `room_latest_msg` ON `msg` (`rid` ASC, `cid` DESC); + +-- Temporary views. + +CREATE TEMP VIEW `valid_user_act_key` AS + SELECT `act_key`, `user`.* + FROM `user_act_key` + JOIN `user` USING (`uid`) + WHERE unixepoch() < `expire_time`; diff --git a/blahd/src/database.rs b/blahd/src/database.rs index 5354c8c..de511ff 100644 --- a/blahd/src/database.rs +++ b/blahd/src/database.rs @@ -1,21 +1,29 @@ -use std::borrow::Borrow; -use std::ops::DerefMut; use std::path::PathBuf; -use anyhow::{ensure, Context, Result}; +use anyhow::{ensure, Context}; use axum::http::StatusCode; -use blah_types::{ServerPermission, UserKey}; +use blah_types::identity::UserIdentityDesc; +use blah_types::{ + ChatPayload, Id, MemberPermission, PubKey, RoomAttrs, RoomMetadata, ServerPermission, + SignedChatMsg, Signee, UserKey, WithMsgId, +}; use parking_lot::Mutex; -use rusqlite::{params, Connection, OpenFlags, OptionalExtension}; +use rusqlite::{named_params, params, prepare_cached_and_bind, Connection, OpenFlags, Row}; use serde::Deserialize; use serde_inline_default::serde_inline_default; use crate::ApiError; +#[cfg(test)] +mod tests; + const DEFAULT_DATABASE_PATH: &str = "/var/lib/blahd/db.sqlite"; +const STMT_CACHE_CAPACITY: usize = 24; static INIT_SQL: &str = include_str!("../schema.sql"); +type Result<T, E = ApiError> = std::result::Result<T, E>; + // Simple and stupid version check for now. // `echo -n 'blahd-database-0' | sha256sum | head -c5` || version const APPLICATION_ID: i32 = 0xd9e_8405; @@ -44,15 +52,17 @@ pub struct Database { conn: Mutex<Connection>, } +pub struct Transaction<'db>(rusqlite::Transaction<'db>); + impl Database { /// Use an existing database connection and do no initialization or schema checking. /// This should only be used for testing purpose. - pub fn from_raw(conn: Connection) -> Result<Self> { + pub fn from_raw(conn: Connection) -> anyhow::Result<Self> { conn.pragma_update(None, "foreign_keys", "TRUE")?; Ok(Self { conn: conn.into() }) } - pub fn open(config: &Config) -> Result<Self> { + pub fn open(config: &Config) -> anyhow::Result<Self> { let mut conn = if config.in_memory { Connection::open_in_memory().context("failed to open in-memory database")? } else { @@ -63,13 +73,14 @@ impl Database { Connection::open_with_flags(&config.path, flags) .context("failed to connect database")? }; + conn.set_prepared_statement_cache_capacity(STMT_CACHE_CAPACITY); Self::maybe_init(&mut conn)?; Ok(Self { conn: Mutex::new(conn), }) } - pub fn maybe_init(conn: &mut Connection) -> Result<()> { + pub fn maybe_init(conn: &mut Connection) -> anyhow::Result<()> { // Connection-specific pragmas. conn.pragma_update(None, "journal_mode", "WAL")?; conn.pragma_update(None, "foreign_keys", "TRUE")?; @@ -97,49 +108,491 @@ impl Database { Ok(()) } - pub fn get(&self) -> impl DerefMut<Target = Connection> + '_ { - self.conn.lock() + pub fn with_read<T>(&self, f: impl FnOnce(&Transaction<'_>) -> Result<T>) -> Result<T> { + // TODO: Currently no concurrency is implemented. + self.with_write(f) + } + + pub fn with_write<T>(&self, f: impl FnOnce(&Transaction<'_>) -> Result<T>) -> Result<T> { + let mut conn = self.conn.lock(); + let txn = Transaction(conn.transaction()?); + match f(&txn) { + Ok(v) => { + txn.0.commit()?; + Ok(v) + } + Err(e) => Err(e), + } } } -pub trait ConnectionExt: Borrow<Connection> { - fn get_user(&self, user: &UserKey) -> Result<(i64, ServerPermission), ApiError> { - self.borrow() - .query_row( - r" - SELECT `uid`, `permission` - FROM `valid_user_act_key` - WHERE (`id_key`, `act_key`) = (?, ?) - ", - params![user.id_key, user.act_key], - |row| Ok((row.get(0)?, row.get(1)?)), - ) - .optional()? - .ok_or_else(|| { - error_response!( - StatusCode::NOT_FOUND, - "not_found", - "the user does not exist", - ) - }) - } +fn parse_msg(rid: Id, row: &Row<'_>) -> Result<WithMsgId<SignedChatMsg>> { + Ok(WithMsgId { + cid: row.get("cid")?, + msg: SignedChatMsg { + sig: row.get("sig")?, + signee: Signee { + nonce: row.get("nonce")?, + timestamp: row.get("timestamp")?, + user: UserKey { + id_key: row.get("id_key")?, + act_key: row.get("act_key")?, + }, + payload: ChatPayload { + room: rid, + rich_text: row.get("rich_text")?, + }, + }, + }, + }) } -impl ConnectionExt for Connection {} +fn parse_room_metadata(row: &Row<'_>) -> Result<RoomMetadata> { + use rusqlite::types::ValueRef; -#[test] -fn init_sql_valid() { - let conn = Connection::open_in_memory().unwrap(); - conn.execute_batch(INIT_SQL).unwrap(); + let rid = row.get("rid")?; + let last_msg = (matches!(row.get_ref("cid"), Ok(ValueRef::Integer(_)))) + .then(|| parse_msg(rid, row)) + .transpose()?; + Ok(RoomMetadata { + rid, + title: row.get("title")?, + attrs: row.get("attrs")?, + last_msg, + last_seen_cid: row.get("last_seen_cid").ok().filter(|&cid| cid != Id(0)), + unseen_cnt: row.get("unseen_cnt").ok().filter(|&n| n != 0), + member_permission: row.get("member_perm").ok(), + peer_user: row.get("peer_id_key").ok(), + }) +} - // Instantiate view to check syntax and availability of `unixepoch()`. - // It requires sqlite >= 3.38.0 (2022-02-22) which is not available by default on GitHub CI. - let ret = conn - .query_row( - "SELECT COUNT(*) FROM `valid_user_act_key`", - params![], - |row| row.get::<_, i64>(0), +pub trait TransactionOps { + fn conn(&self) -> &Connection; + + fn get_user(&self, UserKey { id_key, act_key }: &UserKey) -> Result<(i64, ServerPermission)> { + prepare_cached_and_bind!( + self.conn(), + r" + SELECT `uid`, `permission` + FROM `valid_user_act_key` + WHERE (`id_key`, `act_key`) = (:id_key, :act_key) + " ) - .unwrap(); - assert_eq!(ret, 0); + .raw_query() + .next()? + .ok_or_else(|| { + error_response!( + StatusCode::NOT_FOUND, + "not_found", + "the user does not exist", + ) + }) + .and_then(|row| Ok((row.get(0)?, row.get(1)?))) + } + + fn get_user_by_id_key(&self, id_key: &PubKey) -> Result<(i64, ServerPermission)> { + prepare_cached_and_bind!( + self.conn(), + r" + SELECT `uid`, `permission` + FROM `user` + WHERE `id_key` = :id_key + " + ) + .raw_query() + .next()? + .ok_or_else(|| { + error_response!( + StatusCode::NOT_FOUND, + "user_not_found", + "the user does not exists", + ) + }) + .and_then(|row| Ok((row.get(0)?, row.get(1)?))) + } + + fn get_room_member( + &self, + rid: Id, + UserKey { id_key, act_key }: &UserKey, + ) -> Result<(i64, MemberPermission, Id)> { + prepare_cached_and_bind!( + self.conn(), + r" + SELECT `uid`, `room_member`.`permission`, `last_seen_cid` + FROM `room_member` + JOIN `valid_user_act_key` USING (`uid`) + WHERE (`rid`, `id_key`, `act_key`) = (:rid, :id_key, :act_key) + " + ) + .raw_query() + .next()? + .ok_or_else(|| { + error_response!( + StatusCode::NOT_FOUND, + "room_not_found", + "the room does not exist or user is not a room member", + ) + }) + .and_then(|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?))) + } + + fn get_room_having(&self, rid: Id, filter: RoomAttrs) -> Result<(RoomAttrs, Option<String>)> { + prepare_cached_and_bind!( + self.conn(), + r" + SELECT `attrs`, `title` + FROM `room` + WHERE `rid` = :rid + " + ) + .raw_query() + .next()? + .map(|row| { + Ok::<_, rusqlite::Error>(( + row.get::<_, RoomAttrs>(0)?, + row.get::<_, Option<String>>(1)?, + )) + }) + .transpose()? + .filter(|(attrs, _)| attrs.contains(filter)) + .ok_or_else(|| { + error_response!( + StatusCode::NOT_FOUND, + "room_not_found", + "the room does not exist" + ) + }) + } + + // FIXME: Eliminate this. + // Currently broadcasting msgs requires traversing over all members. + fn list_room_members(&self, rid: Id) -> Result<Vec<i64>> { + prepare_cached_and_bind!( + self.conn(), + r" + SELECT `uid` + FROM `room_member` + WHERE `rid` = :rid + " + ) + .raw_query() + .mapped(|row| row.get::<_, i64>(0)) + .collect::<rusqlite::Result<Vec<_>>>() + .map_err(Into::into) + } + + fn list_public_rooms(&self, start_rid: Id, page_len: usize) -> Result<Vec<RoomMetadata>> { + // Attribute check must be written in the SQL literal so the query planer + // can successfully pick the conditional index. + const _: () = assert!(RoomAttrs::PUBLIC_READABLE.bits() == 1); + prepare_cached_and_bind!( + self.conn(), + r" + SELECT `rid`, `title`, `attrs`, + MAX(`cid`) AS `cid`, `timestamp`, `nonce`, `sig`, `rich_text`, + `last_author`.`id_key`, `msg`.`act_key` + FROM `room` INDEXED BY `ix_public_room` + LEFT JOIN `msg` USING (`rid`) + LEFT JOIN `user` AS `last_author` USING (`uid`) + WHERE `attrs` & 1 != 0 AND + `rid` > :start_rid + GROUP BY `rid` + ORDER BY `rid` ASC + LIMIT :page_len + " + ) + .raw_query() + .and_then(parse_room_metadata) + .collect() + } + + fn list_joined_rooms( + &self, + uid: i64, + start_rid: Id, + page_len: usize, + ) -> Result<Vec<RoomMetadata>> { + prepare_cached_and_bind!( + self.conn(), + r" + SELECT + `rid`, `title`, `attrs`, `last_seen_cid`, `room_member`.`permission` AS `member_perm`, + MAX(`cid`) AS `cid`, `timestamp`, `nonce`, `sig`, `rich_text`, + `last_author`.`id_key`, `msg`.`act_key`, + `peer_user`.`id_key` AS `peer_id_key` + FROM `room_member` INDEXED BY `ix_member_room` + JOIN `room` USING (`rid`) + LEFT JOIN `msg` USING (`rid`) + LEFT JOIN `user` AS `last_author` ON (`last_author`.`uid` = `msg`.`uid`) + LEFT JOIN `user` AS `peer_user` ON + (`peer_user`.`uid` = `room`.`peer1` + `room`.`peer2` - :uid) + WHERE `room_member`.`uid` = :uid AND + `rid` > :start_rid + GROUP BY `rid` + ORDER BY `rid` ASC + LIMIT :page_len + " + ) + .raw_query() + .and_then(parse_room_metadata) + .collect() + } + + fn list_unseen_rooms( + &self, + uid: i64, + start_rid: Id, + page_len: usize, + ) -> Result<Vec<RoomMetadata>> { + // FIXME: Limit `unseen_cnt` counting. + prepare_cached_and_bind!( + self.conn(), + r" + SELECT + `rid`, `title`, `attrs`, `last_seen_cid`, `room_member`.`permission` AS `member_perm`, + `cid`, `timestamp`, `nonce`, `sig`, `rich_text`, + `last_author`.`id_key`, `msg`.`act_key`, + `peer_user`.`id_key` AS `peer_id_key`, + (SELECT COUNT(*) + FROM `msg` AS `unseen_msg` + WHERE `unseen_msg`.`rid` = `room`.`rid` AND + `last_seen_cid` < `unseen_msg`.`cid`) AS `unseen_cnt` + FROM `room_member` INDEXED BY `ix_member_room` + JOIN `room` USING (`rid`) + LEFT JOIN `msg` USING (`rid`) + LEFT JOIN `user` AS `last_author` ON (`last_author`.`uid` = `msg`.`uid`) + LEFT JOIN `user` AS `peer_user` ON + (`peer_user`.`uid` = `room`.`peer1` + `room`.`peer2` - :uid) + WHERE `room_member`.`uid` = :uid AND + `rid` > :start_rid AND + `cid` > `last_seen_cid` + GROUP BY `rid` HAVING `cid` IS MAX(`cid`) + ORDER BY `rid` ASC + LIMIT :page_len + " + ) + .raw_query() + .and_then(parse_room_metadata) + .collect() + } + + fn list_room_msgs( + &self, + rid: Id, + after_cid: Id, + before_cid: Id, + page_len: usize, + ) -> Result<Vec<WithMsgId<SignedChatMsg>>> { + prepare_cached_and_bind!( + self.conn(), + r" + SELECT `cid`, `timestamp`, `nonce`, `sig`, `id_key`, `act_key`, `sig`, `rich_text` + FROM `msg` + JOIN `user` USING (`uid`) + WHERE `rid` = :rid AND + :after_cid < `cid` AND + `cid` < :before_cid + ORDER BY `cid` DESC + LIMIT :page_len + " + ) + .raw_query() + .and_then(|row| parse_msg(rid, row)) + .collect() + } + + fn create_user( + &self, + id_desc: &UserIdentityDesc, + id_desc_json: &str, + fetch_time: u64, + ) -> Result<i64> { + let conn = self.conn(); + let id_key = &id_desc.id_key; + let uid = prepare_cached_and_bind!( + conn, + r" + INSERT INTO `user` (`id_key`, `last_fetch_time`, `id_desc`) + VALUES (:id_key, :fetch_time, :id_desc_json) + ON CONFLICT (`id_key`) DO UPDATE SET + `last_fetch_time` = excluded.`last_fetch_time`, + `id_desc` = excluded.`id_desc` + WHERE `last_fetch_time` < :fetch_time + RETURNING `uid` + " + ) + .raw_query() + .next()? + .ok_or_else(|| { + error_response!( + StatusCode::CONFLICT, + "conflict", + "racing register, please try again later", + ) + }) + .and_then(|row| Ok(row.get::<_, i64>(0)?))?; + + // Delete existing act_keys. + prepare_cached_and_bind!( + conn, + r" + DELETE FROM `user_act_key` + WHERE `uid` = :uid + " + ) + .raw_execute()?; + + let mut stmt = conn.prepare_cached( + r" + INSERT INTO `user_act_key` (`uid`, `act_key`, `expire_time`) + VALUES (:uid, :act_key, :expire_time) + ", + )?; + for kdesc in &id_desc.act_keys { + stmt.execute(named_params! { + ":uid": uid, + ":act_key": kdesc.signee.payload.act_key, + // FIXME: Other `u64` that will be stored in database should also be range checked. + ":expire_time": kdesc.signee.payload.expire_time.min(i64::MAX as _), + })?; + } + + Ok(uid) + } + + fn create_group(&self, rid: Id, title: &str, attrs: RoomAttrs) -> Result<()> { + prepare_cached_and_bind!( + self.conn(), + r" + INSERT INTO `room` (`rid`, `title`, `attrs`) + VALUES (:rid, :title, :attrs) + " + ) + .raw_execute()?; + Ok(()) + } + + fn create_peer_room_with_members( + &self, + rid: Id, + attrs: RoomAttrs, + src_uid: i64, + tgt_uid: i64, + ) -> Result<()> { + assert!(attrs.contains(RoomAttrs::PEER_CHAT)); + let conn = self.conn(); + let (p1, p2) = if src_uid <= tgt_uid { + (src_uid, tgt_uid) + } else { + (tgt_uid, src_uid) + }; + let updated = prepare_cached_and_bind!( + conn, + r" + INSERT INTO `room` (`rid`, `attrs`, `peer1`, `peer2`) + VALUES (:rid, :attrs, :p1, :p2) + ON CONFLICT (`peer1`, `peer2`) WHERE `rid` < 0 DO NOTHING + " + ) + .raw_execute()?; + if updated == 0 { + return Err(error_response!( + StatusCode::CONFLICT, + "exists", + "room already exists" + )); + } + + // TODO: Limit permission of the src user? + let perm = MemberPermission::MAX_PEER_CHAT; + prepare_cached_and_bind!( + conn, + r" + INSERT INTO `room_member` (`rid`, `uid`, `permission`) + VALUES (:rid, :src_uid, :perm), (:rid, :tgt_uid, :perm) + " + ) + .raw_execute()?; + Ok(()) + } + + fn add_room_member(&self, rid: Id, uid: i64, perm: MemberPermission) -> Result<()> { + let updated = prepare_cached_and_bind!( + self.conn(), + r" + INSERT INTO `room_member` (`rid`, `uid`, `permission`) + VALUES (:rid, :uid, :perm) + ON CONFLICT (`rid`, `uid`) DO NOTHING + " + ) + .raw_execute()?; + if updated != 1 { + return Err(error_response!( + StatusCode::CONFLICT, + "exists", + "the user already joined the room", + )); + } + Ok(()) + } + + fn remove_room_member(&self, rid: Id, uid: i64) -> Result<bool> { + // TODO: Check if it is the last member? + let updated = prepare_cached_and_bind!( + self.conn(), + r" + DELETE FROM `room_member` + WHERE (`rid`, `uid`) = (:rid, :uid) + " + ) + .raw_execute()?; + Ok(updated == 1) + } + + fn add_room_chat_msg(&self, rid: Id, uid: i64, cid: Id, chat: &SignedChatMsg) -> Result<()> { + let conn = self.conn(); + let act_key = &chat.signee.user.act_key; + let timestamp = chat.signee.timestamp; + let nonce = chat.signee.nonce; + let rich_text = &chat.signee.payload.rich_text; + let sig = &chat.sig; + prepare_cached_and_bind!( + conn, + r" + INSERT INTO `msg` (`cid`, `rid`, `uid`, `act_key`, `timestamp`, `nonce`, `sig`, `rich_text`) + VALUES (:cid, :rid, :uid, :act_key, :timestamp, :nonce, :sig, :rich_text) + " + ) + .raw_execute()?; + Ok(()) + } + + fn mark_room_msg_seen(&self, rid: Id, uid: i64, cid: Id) -> Result<()> { + // TODO: Validate `cid`? + let updated = prepare_cached_and_bind!( + self.conn(), + r" + UPDATE `room_member` + SET `last_seen_cid` = MAX(`last_seen_cid`, :cid) + WHERE (`rid`, `uid`) = (:rid, :uid) + " + ) + .raw_execute()?; + if updated != 1 { + return Err(error_response!( + StatusCode::NOT_FOUND, + "room_not_found", + "the room does not exist or the user is not a room member", + )); + } + + Ok(()) + } +} + +impl TransactionOps for Transaction<'_> { + fn conn(&self) -> &Connection { + &self.0 + } } diff --git a/blahd/src/database/tests.rs b/blahd/src/database/tests.rs new file mode 100644 index 0000000..64bc2fc --- /dev/null +++ b/blahd/src/database/tests.rs @@ -0,0 +1,31 @@ +#![expect(clippy::print_stdout, reason = "allowed in tests for debugging")] +use super::*; + +#[test] +fn init_sql_valid() { + let conn = Connection::open_in_memory().unwrap(); + conn.execute_batch(INIT_SQL).unwrap(); + + // Instantiate view to check syntax and availability of `unixepoch()`. + // It requires sqlite >= 3.38.0 (2022-02-22) which is not available by default on GitHub CI. + let ret = conn + .query_row( + "SELECT COUNT(*) FROM `valid_user_act_key`", + params![], + |row| row.get::<_, i64>(0), + ) + .unwrap(); + assert_eq!(ret, 0); +} + +#[test] +fn stmt_cache_capacity() { + let src = std::fs::read_to_string("src/database.rs").unwrap(); + let sql_cnt = src.matches("prepare_cached_and_bind!").count(); + println!("found {sql_cnt} SQLs"); + assert_ne!(sql_cnt, 0); + assert!( + sql_cnt <= STMT_CACHE_CAPACITY, + "stmt cache capacity {STMT_CACHE_CAPACITY} is too small, found {sql_cnt} SQLs", + ); +} diff --git a/blahd/src/event.rs b/blahd/src/event.rs index ce8ca81..f178db8 100644 --- a/blahd/src/event.rs +++ b/blahd/src/event.rs @@ -20,7 +20,7 @@ use tokio::sync::broadcast; use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use tokio_stream::wrappers::BroadcastStream; -use crate::database::ConnectionExt; +use crate::database::TransactionOps; use crate::AppState; #[derive(Debug, Deserialize)] @@ -145,8 +145,7 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib let (uid, _) = st .db - .get() - .get_user(&auth.signee.user) + .with_read(|txn| txn.get_user(&auth.signee.user)) .map_err(|err| anyhow!("{}", err.message))?; // FIXME: Consistency of id's sign. uid as u64 diff --git a/blahd/src/lib.rs b/blahd/src/lib.rs index fa71ccf..f901de0 100644 --- a/blahd/src/lib.rs +++ b/blahd/src/lib.rs @@ -12,15 +12,14 @@ use axum::{Json, Router}; use axum_extra::extract::WithRejection as R; use blah_types::{ ChatPayload, CreateGroup, CreatePeerChat, CreateRoomPayload, Id, MemberPermission, RoomAdminOp, - RoomAdminPayload, RoomAttrs, RoomMetadata, ServerPermission, Signed, SignedChatMsg, Signee, - UserKey, UserRegisterPayload, WithMsgId, X_BLAH_DIFFICULTY, X_BLAH_NONCE, + RoomAdminPayload, RoomAttrs, RoomMetadata, ServerPermission, Signed, SignedChatMsg, UserKey, + UserRegisterPayload, WithMsgId, X_BLAH_DIFFICULTY, X_BLAH_NONCE, }; -use database::ConnectionExt; +use database::{Transaction, TransactionOps}; use ed25519_dalek::SIGNATURE_LENGTH; use id::IdExt; use middleware::{Auth, MaybeAuth, ResultExt as _, SignedJson}; use parking_lot::Mutex; -use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql}; use serde::{Deserialize, Deserializer, Serialize}; use serde_inline_default::serde_inline_default; use url::Url; @@ -188,25 +187,13 @@ async fn user_get( let ret = (|| { match auth.into_optional()? { None => None, - Some(user) => st - .db - .get() - .query_row( - " - SELECT 1 - FROM `valid_user_act_key` - WHERE (`id_key`, `act_key`) = (?, ?) - ", - params![user.id_key, user.act_key], - |_| Ok(()), - ) - .optional()?, + Some(user) => st.db.with_read(|txn| txn.get_user(&user)).ok(), } .ok_or_else(|| error_response!(StatusCode::NOT_FOUND, "not_found", "user does not exist")) })(); match ret { - Ok(()) => Ok(StatusCode::NO_CONTENT), + Ok(_) => Ok(StatusCode::NO_CONTENT), Err(err) => Err((st.register.challenge_headers(), err)), } } @@ -235,7 +222,7 @@ struct ListRoomParams { top: Option<NonZeroUsize>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Copy, Deserialize)] #[serde(rename_all = "snake_case")] enum ListRoomFilter { /// List all public rooms. @@ -243,6 +230,7 @@ enum ListRoomFilter { /// List joined rooms (authentication required). Joined, /// List all joined rooms with unseen messages (authentication required). + // TODO: Is this really useful, given most user keep messages unread forever? Unseen, } @@ -259,143 +247,21 @@ async fn room_list( let page_len = pagination.effective_page_len(&st); let start_rid = pagination.skip_token.unwrap_or(Id::MIN); - let query = |sql: &str, params: &[(&str, &dyn ToSql)]| -> Result<RoomList, ApiError> { - let rooms = st - .db - .get() - .prepare(sql)? - .query_map(params, |row| { - // TODO: Extract this into a function. - let rid = row.get("rid")?; - let last_msg = row - .get::<_, Option<Id>>("cid")? - .map(|cid| { - Ok::<_, rusqlite::Error>(WithMsgId { - cid, - msg: SignedChatMsg { - sig: row.get("sig")?, - signee: Signee { - nonce: row.get("nonce")?, - timestamp: row.get("timestamp")?, - user: UserKey { - act_key: row.get("act_key")?, - id_key: row.get("id_key")?, - }, - payload: ChatPayload { - rich_text: row.get("rich_text")?, - room: rid, - }, - }, - }, - }) - }) - .transpose()?; - Ok(RoomMetadata { - rid, - title: row.get("title")?, - attrs: row.get("attrs")?, - last_msg, - last_seen_cid: Some(row.get::<_, Id>("last_seen_cid")?) - .filter(|cid| cid.0 != 0), - unseen_cnt: row.get("unseen_cnt").ok(), - member_permission: row.get("member_perm").ok(), - peer_user: row.get("peer_id_key").ok(), - }) - })? - .collect::<Result<Vec<_>, _>>()?; - let skip_token = - (rooms.len() == page_len).then(|| rooms.last().expect("page must not be empty").rid); - Ok(RoomList { rooms, skip_token }) - }; - - match params.filter { - ListRoomFilter::Public => query( - r" - SELECT `rid`, `title`, `attrs`, 0 AS `last_seen_cid`, - `cid`, `timestamp`, `nonce`, `sig`, `rich_text`, - `last_author`.`id_key`, `msg`.`act_key` - FROM `room` - LEFT JOIN `msg` USING (`rid`) - LEFT JOIN `user` AS `last_author` USING (`uid`) - WHERE `rid` > :start_rid AND - (`attrs` & :perm) = :perm - GROUP BY `rid` HAVING `cid` IS MAX(`cid`) - ORDER BY `rid` ASC - LIMIT :page_len - ", - named_params! { - ":start_rid": start_rid, - ":page_len": page_len, - ":perm": RoomAttrs::PUBLIC_READABLE, - }, - ), + let rooms = st.db.with_read(|txn| match params.filter { + ListRoomFilter::Public => txn.list_public_rooms(start_rid, page_len), ListRoomFilter::Joined => { - let user = auth?.0; - query( - r" - SELECT - `rid`, `title`, `attrs`, `last_seen_cid`, `room_member`.`permission` AS `member_perm`, - `cid`, `timestamp`, `nonce`, `sig`, `rich_text`, - `last_author`.`id_key`, `msg`.`act_key`, - `peer_user`.`id_key` AS `peer_id_key` - FROM `valid_user_act_key` AS `me` - JOIN `room_member` USING (`uid`) - JOIN `room` USING (`rid`) - LEFT JOIN `msg` USING (`rid`) - LEFT JOIN `user` AS `last_author` ON (`last_author`.`uid` = `msg`.`uid`) - LEFT JOIN `user` AS `peer_user` ON - (`peer_user`.`uid` = `room`.`peer1` + `room`.`peer2` - `me`.`uid`) - WHERE (`me`.`id_key`, `me`.`act_key`) = (:id_key, :act_key) AND - `rid` > :start_rid - GROUP BY `rid` HAVING `cid` IS MAX(`cid`) - ORDER BY `rid` ASC - LIMIT :page_len - ", - named_params! { - ":start_rid": start_rid, - ":page_len": page_len, - ":id_key": user.id_key, - ":act_key": user.act_key, - }, - ) + let (uid, _) = txn.get_user(&auth?.0)?; + txn.list_joined_rooms(uid, start_rid, page_len) } ListRoomFilter::Unseen => { - let user = auth?.0; - query( - r" - SELECT - `rid`, `title`, `attrs`, `last_seen_cid`, `room_member`.`permission` AS `member_perm`, - `cid`, `timestamp`, `nonce`, `sig`, `rich_text`, - `last_author`.`id_key`, `msg`.`act_key`, - `peer_user`.`id_key` AS `peer_id_key`, - (SELECT COUNT(*) - FROM `msg` AS `unseen_msg` - WHERE `unseen_msg`.`rid` = `room`.`rid` AND - `last_seen_cid` < `unseen_msg`.`cid`) AS `unseen_cnt` - FROM `valid_user_act_key` AS `me` - JOIN `room_member` USING (`uid`) - JOIN `room` USING (`rid`) - LEFT JOIN `msg` USING (`rid`) - LEFT JOIN `user` AS `last_author` ON (`last_author`.`uid` = `msg`.`uid`) - LEFT JOIN `user` AS `peer_user` ON - (`peer_user`.`uid` = `room`.`peer1` + `room`.`peer2` - `me`.`uid`) - WHERE (`me`.`id_key`, `me`.`act_key`) = (:id_key, :act_key) AND - `rid` > :start_rid AND - `cid` > `last_seen_cid` - GROUP BY `rid` HAVING `cid` IS MAX(`cid`) - ORDER BY `rid` ASC - LIMIT :page_len - ", - named_params! { - ":start_rid": start_rid, - ":page_len": page_len, - ":id_key": user.id_key, - ":act_key": user.act_key, - }, - ) + let (uid, _) = txn.get_user(&auth?.0)?; + txn.list_unseen_rooms(uid, start_rid, page_len) } - } - .map(Json) + })?; + + let skip_token = + (rooms.len() == page_len).then(|| rooms.last().expect("page must not be empty").rid); + Ok(Json(RoomList { rooms, skip_token })) } async fn room_create( @@ -423,55 +289,20 @@ async fn room_create_group( )); } - let conn = st.db.get(); - let (uid, _perm) = conn - .query_row( - r" - SELECT `uid`, `permission` - FROM `valid_user_act_key` - WHERE (`id_key`, `act_key`) = (?, ?) - ", - params![user.id_key, user.act_key], - |row| { - Ok(( - row.get::<_, i64>("uid")?, - row.get::<_, ServerPermission>("permission")?, - )) - }, - ) - .optional()? - .filter(|(_, perm)| perm.contains(ServerPermission::CREATE_ROOM)) - .ok_or_else(|| { - error_response!( + let rid = st.db.with_write(|conn| { + let (uid, perm) = conn.get_user(user)?; + if !perm.contains(ServerPermission::CREATE_ROOM) { + return Err(error_response!( StatusCode::FORBIDDEN, "permission_denied", - "the user does not exist or does not have permission to create room", - ) - })?; - - let rid = Id::gen(); - conn.execute( - r" - INSERT INTO `room` (`rid`, `title`, `attrs`) - VALUES (:rid, :title, :attrs) - ", - named_params! { - ":rid": rid, - ":title": op.title, - ":attrs": op.attrs, - }, - )?; - conn.execute( - r" - INSERT INTO `room_member` (`rid`, `uid`, `permission`) - VALUES (:rid, :uid, :perm) - ", - named_params! { - ":rid": rid, - ":uid": uid, - ":perm": MemberPermission::ALL, - }, - )?; + "the user does not have permission to create room", + )); + } + let rid = Id::gen(); + conn.create_group(rid, &op.title, op.attrs)?; + conn.add_room_member(rid, uid, MemberPermission::ALL)?; + Ok(rid) + })?; Ok(Json(rid)) } @@ -491,72 +322,24 @@ async fn room_create_peer_chat( } // TODO: Access control and throttling. - - let mut conn = st.db.get(); - let txn = conn.transaction()?; - let (src_uid, _) = txn.get_user(src_user)?; - let (tgt_uid, _) = txn - .query_row( - r" - SELECT `uid`, `permission` - FROM `user` - WHERE `id_key` = ? - ", - params![tgt_user_id_key], - |row| Ok((row.get::<_, i64>(0)?, row.get::<_, ServerPermission>(1)?)), - ) - .optional()? - .filter(|(_, perm)| perm.contains(ServerPermission::ACCEPT_PEER_CHAT)) - .ok_or_else(|| { - error_response!( - StatusCode::NOT_FOUND, - "not_found", - "peer user does not exist or disallows peer chat", - ) - })?; - - let mut peers = [src_uid, tgt_uid]; - peers.sort(); - let rid = Id::gen_peer_chat_rid(); - let updated = txn.execute( - r" - INSERT INTO `room` (`rid`, `attrs`, `peer1`, `peer2`) - VALUES (:rid, :attrs, :peer1, :peer2) - ON CONFLICT (`peer1`, `peer2`) WHERE `rid` < 0 DO NOTHING - ", - named_params! { - ":rid": rid, - ":attrs": RoomAttrs::PEER_CHAT, - ":peer1": peers[0], - ":peer2": peers[1], - }, - )?; - if updated == 0 { - return Err(error_response!( - StatusCode::CONFLICT, - "exists", - "room already exists" - )); - } - - { - let mut stmt = txn.prepare( - r" - INSERT INTO `room_member` (`rid`, `uid`, `permission`) - VALUES (:rid, :uid, :perm) - ", - )?; - // TODO: Limit permission of the src user? - for uid in peers { - stmt.execute(named_params! { - ":rid": rid, - ":uid": uid, - ":perm": MemberPermission::MAX_PEER_CHAT, + let rid = st.db.with_write(|txn| { + let (src_uid, _) = txn.get_user(src_user)?; + let (tgt_uid, _) = txn + .get_user_by_id_key(&tgt_user_id_key) + .ok() + .filter(|(_, perm)| perm.contains(ServerPermission::ACCEPT_PEER_CHAT)) + .ok_or_else(|| { + error_response!( + StatusCode::NOT_FOUND, + "peer_user_not_found", + "peer user does not exist or disallows peer chat", + ) })?; - } - } + let rid = Id::gen_peer_chat_rid(); + txn.create_peer_room_with_members(rid, RoomAttrs::PEER_CHAT, src_uid, tgt_uid)?; + Ok(rid) + })?; - txn.commit()?; Ok(Json(rid)) } @@ -598,11 +381,14 @@ async fn room_msg_list( R(Query(pagination), _): RE<Query<Pagination>>, auth: MaybeAuth, ) -> Result<Json<RoomMsgs>, ApiError> { - let (msgs, skip_token) = { - let conn = st.db.get(); - get_room_if_readable(&conn, rid, auth.into_optional()?.as_ref(), |_row| Ok(()))?; - query_room_msgs(&st, &conn, rid, pagination)? - }; + let (msgs, skip_token) = st.db.with_read(|txn| { + if let Some(user) = auth.into_optional()? { + txn.get_room_member(rid, &user)?; + } else { + txn.get_room_having(rid, RoomAttrs::PUBLIC_READABLE)?; + } + query_room_msgs(&st, txn, rid, pagination) + })?; Ok(Json(RoomMsgs { msgs, skip_token })) } @@ -611,12 +397,16 @@ async fn room_get_metadata( R(Path(rid), _): RE<Path<Id>>, auth: MaybeAuth, ) -> Result<Json<RoomMetadata>, ApiError> { - let conn = st.db.get(); - let (title, attrs) = get_room_if_readable(&conn, rid, auth.into_optional()?.as_ref(), |row| { - Ok(( - row.get::<_, Option<String>>("title")?, - row.get::<_, RoomAttrs>("attrs")?, - )) + let (attrs, title) = st.db.with_read(|txn| { + let filter = if auth + .into_optional()? + .is_some_and(|user| txn.get_room_member(rid, &user).is_ok()) + { + RoomAttrs::empty() + } else { + RoomAttrs::PUBLIC_READABLE + }; + txn.get_room_having(rid, filter) })?; Ok(Json(RoomMetadata { @@ -638,12 +428,14 @@ async fn room_get_feed( R(Path(rid), _): RE<Path<Id>>, R(Query(pagination), _): RE<Query<Pagination>>, ) -> Result<impl IntoResponse, ApiError> { - let title; - let (msgs, skip_token) = { - let conn = st.db.get(); - title = get_room_if_readable(&conn, rid, None, |row| row.get::<_, String>("title"))?; - query_room_msgs(&st, &conn, rid, pagination)? - }; + let (title, msgs, skip_token) = st.db.with_read(|txn| { + let (attrs, title) = txn.get_room_having(rid, RoomAttrs::PUBLIC_READABLE)?; + // Sanity check. + assert!(!attrs.contains(RoomAttrs::PEER_CHAT)); + let title = title.expect("public room must have title"); + let (msgs, skip_token) = query_room_msgs(&st, txn, rid, pagination)?; + Ok((title, msgs, skip_token)) + })?; let items = msgs .into_iter() @@ -736,101 +528,23 @@ struct FeedItemExtra { sig: [u8; SIGNATURE_LENGTH], } -fn get_room_if_readable<T>( - conn: &rusqlite::Connection, - rid: Id, - user: Option<&UserKey>, - f: impl FnOnce(&Row<'_>) -> rusqlite::Result<T>, -) -> Result<T, ApiError> { - let (id_key, act_key) = match user { - Some(keys) => (Some(&keys.id_key), Some(&keys.act_key)), - None => (None, None), - }; - - conn.query_row( - r" - SELECT `title`, `attrs` - FROM `room` - WHERE `rid` = :rid AND - ((`attrs` & :perm) = :perm OR - EXISTS(SELECT 1 - FROM `room_member` - JOIN `valid_user_act_key` USING (`uid`) - WHERE `room_member`.`rid` = `room`.`rid` AND - (`id_key`, `act_key`) = (:id_key, :act_key))) - ", - named_params! { - ":rid": rid, - ":perm": RoomAttrs::PUBLIC_READABLE, - ":id_key": id_key, - ":act_key": act_key, - }, - f, - ) - .optional()? - .ok_or_else(|| { - error_response!( - StatusCode::NOT_FOUND, - "not_found", - "the room does not exist or the user is not a room member", - ) - }) -} - /// Get room messages with pagination parameters, /// return a page of messages and the next `skip_token` if this is not the last page. fn query_room_msgs( st: &AppState, - conn: &Connection, + txn: &Transaction<'_>, rid: Id, pagination: Pagination, ) -> Result<(Vec<WithMsgId<SignedChatMsg>>, Option<Id>), ApiError> { let page_len = pagination.effective_page_len(st); - let mut stmt = conn.prepare( - r" - SELECT `cid`, `timestamp`, `nonce`, `sig`, `id_key`, `act_key`, `sig`, `rich_text` - FROM `msg` - JOIN `user` USING (`uid`) - WHERE `rid` = :rid AND - :after_cid < `cid` AND - `cid` < :before_cid - ORDER BY `cid` DESC - LIMIT :limit - ", + let msgs = txn.list_room_msgs( + rid, + pagination.until_token.unwrap_or(Id::MIN), + pagination.skip_token.unwrap_or(Id::MAX), + page_len, )?; - let msgs = stmt - .query_and_then( - named_params! { - ":rid": rid, - ":after_cid": pagination.until_token.unwrap_or(Id::MIN), - ":before_cid": pagination.skip_token.unwrap_or(Id::MAX), - ":limit": page_len, - }, - |row| { - Ok(WithMsgId { - cid: row.get("cid")?, - msg: SignedChatMsg { - sig: row.get("sig")?, - signee: Signee { - nonce: row.get("nonce")?, - timestamp: row.get("timestamp")?, - user: UserKey { - id_key: row.get("id_key")?, - act_key: row.get("act_key")?, - }, - payload: ChatPayload { - room: rid, - rich_text: row.get("rich_text")?, - }, - }, - }, - }) - }, - )? - .collect::<rusqlite::Result<Vec<_>>>()?; let skip_token = (msgs.len() == page_len).then(|| msgs.last().expect("page must not be empty").cid); - Ok((msgs, skip_token)) } @@ -847,38 +561,8 @@ async fn room_msg_post( )); } - let (cid, txs) = { - let conn = st.db.get(); - let (uid, perm) = conn - .query_row( - r" - SELECT `uid`, `room_member`.`permission` - FROM `room_member` - JOIN `valid_user_act_key` USING (`uid`) - WHERE `rid` = :rid AND - (`id_key`, `act_key`) = (:id_key, :act_key) - ", - named_params! { - ":rid": rid, - ":id_key": &chat.signee.user.id_key, - ":act_key": &chat.signee.user.act_key, - }, - |row| { - Ok(( - row.get::<_, u64>("uid")?, - row.get::<_, MemberPermission>("permission")?, - )) - }, - ) - .optional()? - .ok_or_else(|| { - error_response!( - StatusCode::NOT_FOUND, - "not_found", - "the room does not exist or the user is not a room member", - ) - })?; - + let (cid, members) = st.db.with_write(|txn| { + let (uid, perm, ..) = txn.get_room_member(rid, &chat.signee.user)?; if !perm.contains(MemberPermission::POST_CHAT) { return Err(error_response!( StatusCode::FORBIDDEN, @@ -888,52 +572,26 @@ async fn room_msg_post( } let cid = Id::gen(); - conn.execute( - r" - INSERT INTO `msg` (`cid`, `rid`, `uid`, `act_key`, `timestamp`, `nonce`, `sig`, `rich_text`) - VALUES (:cid, :rid, :uid, :act_key, :timestamp, :nonce, :sig, :rich_text) - ", - named_params! { - ":cid": cid, - ":rid": rid, - ":uid": uid, - ":act_key": chat.signee.user.act_key, - ":timestamp": chat.signee.timestamp, - ":nonce": chat.signee.nonce, - ":rich_text": &chat.signee.payload.rich_text, - ":sig": chat.sig, - }, - )?; + txn.add_room_chat_msg(rid, uid, cid, &chat)?; + let members = txn.list_room_members(rid)?; + Ok((cid, members)) + })?; - // FIXME: Optimize this to not traverses over all members. - let mut stmt = conn.prepare( - r" - SELECT `uid` - FROM `room_member` - WHERE `rid` = :rid - ", - )?; - let listeners = st.event.user_listeners.lock(); - let txs = stmt - .query_map(params![rid], |row| row.get::<_, u64>(0))? - .filter_map(|ret| match ret { - Ok(uid) => listeners.get(&uid).map(|tx| Ok(tx.clone())), - Err(err) => Some(Err(err)), - }) - .collect::<Result<Vec<_>, _>>()?; - - (cid, txs) - }; - - if !txs.is_empty() { - tracing::debug!("broadcasting event to {} clients", txs.len()); - let chat = Arc::new(chat); - for tx in txs { - if let Err(err) = tx.send(chat.clone()) { - tracing::debug!(%err, "failed to broadcast event"); + let chat = Arc::new(chat); + // FIXME: Optimize this to not traverses over all members. + let listeners = st.event.user_listeners.lock(); + let mut cnt = 0usize; + for uid in members { + // FIXME: u64 vs i64. + if let Some(tx) = listeners.get(&(uid as u64)) { + if tx.send(chat.clone()).is_ok() { + cnt += 1; } } } + if cnt != 0 { + tracing::debug!("broadcasted event to {cnt} clients"); + } Ok(Json(cid)) } @@ -997,124 +655,36 @@ async fn room_join( user: &UserKey, permission: MemberPermission, ) -> Result<(), ApiError> { - let mut conn = st.db.get(); - let txn = conn.transaction()?; - let (uid, _) = txn.get_user(user)?; - txn.query_row( - r" - SELECT `attrs` - FROM `room` - WHERE `rid` = ? - ", - params![rid], - |row| row.get::<_, RoomAttrs>(0), - ) - .optional()? - .filter(|attrs| attrs.contains(RoomAttrs::PUBLIC_JOINABLE)) - .ok_or_else(|| { - error_response!( - StatusCode::NOT_FOUND, - "not_found", - "the room does not exist or the user is not allowed to join the room", - ) - })?; - - let updated = txn.execute( - r" - INSERT INTO `room_member` (`rid`, `uid`, `permission`) - SELECT :rid, :uid, :perm - ON CONFLICT (`rid`, `uid`) DO NOTHING - ", - named_params! { - ":rid": rid, - ":uid": uid, - ":perm": permission, - }, - )?; - if updated == 0 { - return Err(error_response!( - StatusCode::CONFLICT, - "exists", - "the user is already in the room", - )); - } - txn.commit()?; - Ok(()) + st.db.with_write(|txn| { + let (uid, _perm) = txn.get_user(user)?; + let (attrs, _) = txn.get_room_having(rid, RoomAttrs::PUBLIC_JOINABLE)?; + // Sanity check. + assert!(!attrs.contains(RoomAttrs::PEER_CHAT)); + txn.add_room_member(rid, uid, permission)?; + Ok(()) + }) } async fn room_leave(st: &AppState, rid: Id, user: &UserKey) -> Result<(), ApiError> { - let mut conn = st.db.get(); - let txn = conn.transaction()?; - - let uid = txn - .query_row( - r" - SELECT `uid` - FROM `room_member` - JOIN `valid_user_act_key` USING (`uid`) - WHERE (`rid`, `id_key`, `act_key`) = (:rid, :id_key, :act_key) - ", - named_params! { - ":rid": rid, - ":id_key": user.id_key, - ":act_key": user.act_key, - }, - |row| row.get::<_, u64>("uid"), - ) - .optional()? - .ok_or_else(|| { - error_response!( - StatusCode::NOT_FOUND, - "not_found", - "the room does not exist or user is not a room member", - ) - })?; - - txn.execute( - r" - DELETE FROM `room_member` - WHERE `rid` = :rid AND - `uid` = :uid - ", - named_params! { - ":rid": rid, - ":uid": uid, - }, - )?; - - txn.commit()?; - Ok(()) + st.db.with_write(|txn| { + // FIXME: Handle peer chat room? + let (uid, ..) = txn.get_room_member(rid, user)?; + txn.remove_room_member(rid, uid)?; + Ok(()) + }) } async fn room_msg_mark_seen( st: ArcState, - R(Path((rid, cid)), _): RE<Path<(Id, u64)>>, + R(Path((rid, cid)), _): RE<Path<(Id, i64)>>, Auth(user): Auth, ) -> Result<StatusCode, ApiError> { - let changed = st.db.get().execute( - r" - UPDATE `room_member` - SET `last_seen_cid` = MAX(`last_seen_cid`, :cid) - WHERE - `rid` = :rid AND - `uid` = (SELECT `uid` - FROM `valid_user_act_key` - WHERE (`id_key`, `act_key`) = (:id_key, :act_key)) - ", - named_params! { - ":cid": cid, - ":rid": rid, - ":id_key": user.id_key, - ":act_key": user.act_key, - }, - )?; - - if changed != 1 { - return Err(error_response!( - StatusCode::NOT_FOUND, - "not_found", - "the room does not exist or the user is not a room member", - )); - } + st.db.with_write(|txn| { + let (uid, _perm, prev_seen_cid) = txn.get_room_member(rid, &user)?; + if cid < prev_seen_cid.0 { + return Ok(()); + } + txn.mark_room_msg_seen(rid, uid, Id(cid as _)) + })?; Ok(StatusCode::NO_CONTENT) } diff --git a/blahd/src/middleware.rs b/blahd/src/middleware.rs index dc5fa13..f3b838e 100644 --- a/blahd/src/middleware.rs +++ b/blahd/src/middleware.rs @@ -17,6 +17,7 @@ use crate::AppState; /// /// Mostly following: <https://learn.microsoft.com/en-us/graph/errors> #[derive(Debug, Serialize, Deserialize)] +#[must_use] pub struct ApiError { #[serde(skip, default)] pub status: StatusCode, diff --git a/blahd/src/register.rs b/blahd/src/register.rs index f7c9969..15a99f6 100644 --- a/blahd/src/register.rs +++ b/blahd/src/register.rs @@ -9,10 +9,10 @@ use http_body_util::BodyExt; use parking_lot::Mutex; use rand::rngs::OsRng; use rand::RngCore; -use rusqlite::{named_params, params, OptionalExtension}; use serde::Deserialize; use sha2::{Digest, Sha256}; +use crate::database::TransactionOps; use crate::{ApiError, AppState}; const USER_AGENT: &str = concat!("blahd/", env!("CARGO_PKG_VERSION")); @@ -260,59 +260,8 @@ pub async fn user_register( // Now the identity is verified. let id_desc_json = serde_jcs::to_string(&id_desc).expect("serialization cannot fail"); - - let mut conn = st.db.get(); - let txn = conn.transaction()?; - let uid = txn - .query_row( - r" - INSERT INTO `user` (`id_key`, `last_fetch_time`, `id_desc`) - VALUES (:id_key, :last_fetch_time, :id_desc) - ON CONFLICT (`id_key`) DO UPDATE SET - `last_fetch_time` = :last_fetch_time, - `id_desc` = :id_desc - WHERE `last_fetch_time` < :last_fetch_time - RETURNING `uid` - ", - named_params! { - ":id_key": reg.id_key, - ":id_desc": id_desc_json, - ":last_fetch_time": fetch_time, - }, - |row| row.get::<_, i64>(0), - ) - .optional()? - .ok_or_else(|| { - error_response!( - StatusCode::CONFLICT, - "conflict", - "racing register, please try again later", - ) - })?; - { - txn.execute( - r" - DELETE FROM `user_act_key` - WHERE `uid` = ? - ", - params![uid], - )?; - let mut stmt = txn.prepare( - r" - INSERT INTO `user_act_key` (`uid`, `act_key`, `expire_time`) - VALUES (:uid, :act_key, :expire_time) - ", - )?; - for kdesc in &id_desc.act_keys { - stmt.execute(named_params! { - ":uid": uid, - ":act_key": kdesc.signee.payload.act_key, - // FIXME: Other `u64` that will be stored in database should also be range checked. - ":expire_time": kdesc.signee.payload.expire_time.min(i64::MAX as _), - })?; - } - } - txn.commit()?; + st.db + .with_write(|txn| txn.create_user(&id_desc, &id_desc_json, fetch_time))?; Ok(StatusCode::NO_CONTENT) } diff --git a/blahd/tests/webapi.rs b/blahd/tests/webapi.rs index 2f15aa4..3fefd8f 100644 --- a/blahd/tests/webapi.rs +++ b/blahd/tests/webapi.rs @@ -408,7 +408,7 @@ async fn room_create_get(server: Server, ref mut rng: impl RngCore, #[case] publ if public { assert_eq!(resp.unwrap(), room_meta); } else { - resp.expect_api_err(StatusCode::NOT_FOUND, "not_found"); + resp.expect_api_err(StatusCode::NOT_FOUND, "room_not_found"); } } @@ -473,11 +473,11 @@ async fn room_join_leave(server: Server, ref mut rng: impl RngCore) { // Not permitted. join(rid_priv, &BOB) .await - .expect_api_err(StatusCode::NOT_FOUND, "not_found"); + .expect_api_err(StatusCode::NOT_FOUND, "room_not_found"); // Not exists. join(Id::INVALID, &BOB) .await - .expect_api_err(StatusCode::NOT_FOUND, "not_found"); + .expect_api_err(StatusCode::NOT_FOUND, "room_not_found"); // Overly high permission. server .join_room(rid_priv, &BOB, MemberPermission::ALL) @@ -502,15 +502,15 @@ async fn room_join_leave(server: Server, ref mut rng: impl RngCore) { // Already left. leave(rid_pub, &BOB) .await - .expect_api_err(StatusCode::NOT_FOUND, "not_found"); + .expect_api_err(StatusCode::NOT_FOUND, "room_not_found"); // Unpermitted and not inside. leave(rid_priv, &BOB) .await - .expect_api_err(StatusCode::NOT_FOUND, "not_found"); + .expect_api_err(StatusCode::NOT_FOUND, "room_not_found"); // Invalid room. leave(Id::INVALID, &BOB) .await - .expect_api_err(StatusCode::NOT_FOUND, "not_found"); + .expect_api_err(StatusCode::NOT_FOUND, "room_not_found"); } #[rstest] @@ -563,7 +563,7 @@ async fn room_chat_post_read(server: Server, ref mut rng: impl RngCore) { // Not a member. post(rid_pub, chat(rid_pub, &BOB, "not a member")) .await - .expect_api_err(StatusCode::NOT_FOUND, "not_found"); + .expect_api_err(StatusCode::NOT_FOUND, "room_not_found"); // Is a member but without permission. server @@ -577,7 +577,7 @@ async fn room_chat_post_read(server: Server, ref mut rng: impl RngCore) { // Room not exists. post(Id::INVALID, chat(Id::INVALID, &ALICE, "not permitted")) .await - .expect_api_err(StatusCode::NOT_FOUND, "not_found"); + .expect_api_err(StatusCode::NOT_FOUND, "room_not_found"); //// Msgs listing //// @@ -636,13 +636,13 @@ async fn room_chat_post_read(server: Server, ref mut rng: impl RngCore) { server .get::<RoomMsgs>(&format!("/room/{rid_priv}/msg"), None) .await - .expect_api_err(StatusCode::NOT_FOUND, "not_found"); + .expect_api_err(StatusCode::NOT_FOUND, "room_not_found"); // Not a member. server .get::<RoomMsgs>(&format!("/room/{rid_priv}/msg"), Some(&auth(&BOB, rng))) .await - .expect_api_err(StatusCode::NOT_FOUND, "not_found"); + .expect_api_err(StatusCode::NOT_FOUND, "room_not_found"); // Ok. let msgs = server @@ -758,7 +758,7 @@ async fn peer_chat(server: Server, ref mut rng: impl RngCore) { // Bob disallows peer chat. create_chat(&ALICE, &BOB) .await - .expect_api_err(StatusCode::NOT_FOUND, "not_found"); + .expect_api_err(StatusCode::NOT_FOUND, "peer_user_not_found"); // Alice accepts bob. let rid = create_chat(&BOB, &ALICE).await.unwrap(); @@ -777,7 +777,7 @@ async fn peer_chat(server: Server, ref mut rng: impl RngCore) { server .get::<RoomMetadata>(&format!("/room/{rid}"), None) .await - .expect_api_err(StatusCode::NOT_FOUND, "not_found"); + .expect_api_err(StatusCode::NOT_FOUND, "room_not_found"); // Both alice and bob are in the room. for (key, peer) in [(&*ALICE, &*BOB), (&*BOB, &*ALICE)] {