refactor(database)!: decouple SQLs from backend logic and cache stmts

This decouples SQLs from handler logic, makes it easier for auditing and
caching. It also enables the possibility to switch or support multiple
database backends.
This commit is contained in:
oxalica 2024-09-21 07:28:41 -04:00
parent b955d32099
commit fafd2de2e3
11 changed files with 769 additions and 669 deletions

90
Cargo.lock generated
View file

@ -1254,6 +1254,12 @@ version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89"
[[package]]
name = "litrs"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ce301924b7887e9d637144fdade93f9dfff9b60981d4ac161db09720d39aa5"
[[package]]
name = "lock_api"
version = "0.4.12"
@ -1472,6 +1478,45 @@ version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "phf"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc"
dependencies = [
"phf_shared",
]
[[package]]
name = "phf_codegen"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a"
dependencies = [
"phf_generator",
"phf_shared",
]
[[package]]
name = "phf_generator"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0"
dependencies = [
"phf_shared",
"rand",
]
[[package]]
name = "phf_shared"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b"
dependencies = [
"siphasher",
"uncased",
]
[[package]]
name = "pin-project"
version = "1.1.5"
@ -1771,9 +1816,21 @@ dependencies = [
"fallible-streaming-iterator",
"hashlink",
"libsqlite3-sys",
"rusqlite-macros",
"smallvec",
]
[[package]]
name = "rusqlite-macros"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ecdc5e5d64f172916dfc8a0b0f7876de19b899e7a5f1d5b2c04c722cc78e0e45"
dependencies = [
"fallible-iterator",
"litrs",
"sqlite3-parser",
]
[[package]]
name = "rustc-demangle"
version = "0.1.24"
@ -2089,6 +2146,12 @@ dependencies = [
"rand_core",
]
[[package]]
name = "siphasher"
version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d"
[[package]]
name = "slab"
version = "0.4.9"
@ -2130,6 +2193,24 @@ dependencies = [
"der",
]
[[package]]
name = "sqlite3-parser"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb5307dad6cb84730ce8bdefde56ff4cf95fe516972d52e2bbdc8a8cd8f2520b"
dependencies = [
"bitflags",
"cc",
"fallible-iterator",
"indexmap 2.4.0",
"log",
"memchr",
"phf",
"phf_codegen",
"phf_shared",
"uncased",
]
[[package]]
name = "strsim"
version = "0.11.1"
@ -2540,6 +2621,15 @@ version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]]
name = "uncased"
version = "0.9.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1b88fcfe09e89d3866a5c11019378088af2d24c3fbd4f0543f96b479ec90697"
dependencies = [
"version_check",
]
[[package]]
name = "unicode-bidi"
version = "0.3.15"

View file

@ -490,6 +490,7 @@ bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct RoomAttrs: u64 {
// NB. Used by schema.
const PUBLIC_READABLE = 1 << 0;
const PUBLIC_JOINABLE = 1 << 1;

View file

@ -20,7 +20,7 @@ humantime = "2"
parking_lot = "0.12" # Maybe no better performance, just that we hate poisoning. ¯\_(ツ)_/¯
rand = "0.8"
reqwest = "0.12"
rusqlite = "0.32"
rusqlite = { version = "0.32", features = ["rusqlite-macros"] }
rustix = { version = "0.38", features = ["net"] }
sd-notify = "0.4"
serde = { version = "1", features = ["derive"] }

View file

@ -17,12 +17,6 @@ CREATE TABLE IF NOT EXISTS `user_act_key` (
PRIMARY KEY (`uid`, `act_key`)
) STRICT, WITHOUT ROWID;
CREATE VIEW IF NOT EXISTS `valid_user_act_key` AS
SELECT `act_key`, `user`.*
FROM `user_act_key`
JOIN `user` USING (`uid`)
WHERE unixepoch() < `expire_time`;
-- The highest bit of `rid` will be set for peer chat room.
-- So simply comparing it against 0 can filter them out.
CREATE TABLE IF NOT EXISTS `room` (
@ -43,6 +37,10 @@ CREATE UNIQUE INDEX IF NOT EXISTS `ix_peer_chat` ON `room`
(`peer1`, `peer2`)
WHERE `rid` < 0;
-- RoomAttrs::PUBLIC_READABLE
CREATE INDEX IF NOT EXISTS `ix_public_room` ON `room` (`rid`)
WHERE `attrs` & 1 != 0;
CREATE TABLE IF NOT EXISTS `room_member` (
`rid` INTEGER NOT NULL REFERENCES `room` ON DELETE CASCADE,
`uid` INTEGER NOT NULL REFERENCES `user` ON DELETE RESTRICT,
@ -68,3 +66,11 @@ CREATE TABLE IF NOT EXISTS `msg` (
) STRICT;
CREATE INDEX IF NOT EXISTS `room_latest_msg` ON `msg` (`rid` ASC, `cid` DESC);
-- Temporary views.
CREATE TEMP VIEW `valid_user_act_key` AS
SELECT `act_key`, `user`.*
FROM `user_act_key`
JOIN `user` USING (`uid`)
WHERE unixepoch() < `expire_time`;

View file

@ -1,21 +1,29 @@
use std::borrow::Borrow;
use std::ops::DerefMut;
use std::path::PathBuf;
use anyhow::{ensure, Context, Result};
use anyhow::{ensure, Context};
use axum::http::StatusCode;
use blah_types::{ServerPermission, UserKey};
use blah_types::identity::UserIdentityDesc;
use blah_types::{
ChatPayload, Id, MemberPermission, PubKey, RoomAttrs, RoomMetadata, ServerPermission,
SignedChatMsg, Signee, UserKey, WithMsgId,
};
use parking_lot::Mutex;
use rusqlite::{params, Connection, OpenFlags, OptionalExtension};
use rusqlite::{named_params, params, prepare_cached_and_bind, Connection, OpenFlags, Row};
use serde::Deserialize;
use serde_inline_default::serde_inline_default;
use crate::ApiError;
#[cfg(test)]
mod tests;
const DEFAULT_DATABASE_PATH: &str = "/var/lib/blahd/db.sqlite";
const STMT_CACHE_CAPACITY: usize = 24;
static INIT_SQL: &str = include_str!("../schema.sql");
type Result<T, E = ApiError> = std::result::Result<T, E>;
// Simple and stupid version check for now.
// `echo -n 'blahd-database-0' | sha256sum | head -c5` || version
const APPLICATION_ID: i32 = 0xd9e_8405;
@ -44,15 +52,17 @@ pub struct Database {
conn: Mutex<Connection>,
}
pub struct Transaction<'db>(rusqlite::Transaction<'db>);
impl Database {
/// Use an existing database connection and do no initialization or schema checking.
/// This should only be used for testing purpose.
pub fn from_raw(conn: Connection) -> Result<Self> {
pub fn from_raw(conn: Connection) -> anyhow::Result<Self> {
conn.pragma_update(None, "foreign_keys", "TRUE")?;
Ok(Self { conn: conn.into() })
}
pub fn open(config: &Config) -> Result<Self> {
pub fn open(config: &Config) -> anyhow::Result<Self> {
let mut conn = if config.in_memory {
Connection::open_in_memory().context("failed to open in-memory database")?
} else {
@ -63,13 +73,14 @@ impl Database {
Connection::open_with_flags(&config.path, flags)
.context("failed to connect database")?
};
conn.set_prepared_statement_cache_capacity(STMT_CACHE_CAPACITY);
Self::maybe_init(&mut conn)?;
Ok(Self {
conn: Mutex::new(conn),
})
}
pub fn maybe_init(conn: &mut Connection) -> Result<()> {
pub fn maybe_init(conn: &mut Connection) -> anyhow::Result<()> {
// Connection-specific pragmas.
conn.pragma_update(None, "journal_mode", "WAL")?;
conn.pragma_update(None, "foreign_keys", "TRUE")?;
@ -97,49 +108,491 @@ impl Database {
Ok(())
}
pub fn get(&self) -> impl DerefMut<Target = Connection> + '_ {
self.conn.lock()
pub fn with_read<T>(&self, f: impl FnOnce(&Transaction<'_>) -> Result<T>) -> Result<T> {
// TODO: Currently no concurrency is implemented.
self.with_write(f)
}
pub fn with_write<T>(&self, f: impl FnOnce(&Transaction<'_>) -> Result<T>) -> Result<T> {
let mut conn = self.conn.lock();
let txn = Transaction(conn.transaction()?);
match f(&txn) {
Ok(v) => {
txn.0.commit()?;
Ok(v)
}
Err(e) => Err(e),
}
}
}
pub trait ConnectionExt: Borrow<Connection> {
fn get_user(&self, user: &UserKey) -> Result<(i64, ServerPermission), ApiError> {
self.borrow()
.query_row(
r"
SELECT `uid`, `permission`
FROM `valid_user_act_key`
WHERE (`id_key`, `act_key`) = (?, ?)
",
params![user.id_key, user.act_key],
|row| Ok((row.get(0)?, row.get(1)?)),
)
.optional()?
.ok_or_else(|| {
error_response!(
StatusCode::NOT_FOUND,
"not_found",
"the user does not exist",
)
})
}
fn parse_msg(rid: Id, row: &Row<'_>) -> Result<WithMsgId<SignedChatMsg>> {
Ok(WithMsgId {
cid: row.get("cid")?,
msg: SignedChatMsg {
sig: row.get("sig")?,
signee: Signee {
nonce: row.get("nonce")?,
timestamp: row.get("timestamp")?,
user: UserKey {
id_key: row.get("id_key")?,
act_key: row.get("act_key")?,
},
payload: ChatPayload {
room: rid,
rich_text: row.get("rich_text")?,
},
},
},
})
}
impl ConnectionExt for Connection {}
fn parse_room_metadata(row: &Row<'_>) -> Result<RoomMetadata> {
use rusqlite::types::ValueRef;
#[test]
fn init_sql_valid() {
let conn = Connection::open_in_memory().unwrap();
conn.execute_batch(INIT_SQL).unwrap();
let rid = row.get("rid")?;
let last_msg = (matches!(row.get_ref("cid"), Ok(ValueRef::Integer(_))))
.then(|| parse_msg(rid, row))
.transpose()?;
Ok(RoomMetadata {
rid,
title: row.get("title")?,
attrs: row.get("attrs")?,
last_msg,
last_seen_cid: row.get("last_seen_cid").ok().filter(|&cid| cid != Id(0)),
unseen_cnt: row.get("unseen_cnt").ok().filter(|&n| n != 0),
member_permission: row.get("member_perm").ok(),
peer_user: row.get("peer_id_key").ok(),
})
}
// Instantiate view to check syntax and availability of `unixepoch()`.
// It requires sqlite >= 3.38.0 (2022-02-22) which is not available by default on GitHub CI.
let ret = conn
.query_row(
"SELECT COUNT(*) FROM `valid_user_act_key`",
params![],
|row| row.get::<_, i64>(0),
pub trait TransactionOps {
fn conn(&self) -> &Connection;
fn get_user(&self, UserKey { id_key, act_key }: &UserKey) -> Result<(i64, ServerPermission)> {
prepare_cached_and_bind!(
self.conn(),
r"
SELECT `uid`, `permission`
FROM `valid_user_act_key`
WHERE (`id_key`, `act_key`) = (:id_key, :act_key)
"
)
.unwrap();
assert_eq!(ret, 0);
.raw_query()
.next()?
.ok_or_else(|| {
error_response!(
StatusCode::NOT_FOUND,
"not_found",
"the user does not exist",
)
})
.and_then(|row| Ok((row.get(0)?, row.get(1)?)))
}
fn get_user_by_id_key(&self, id_key: &PubKey) -> Result<(i64, ServerPermission)> {
prepare_cached_and_bind!(
self.conn(),
r"
SELECT `uid`, `permission`
FROM `user`
WHERE `id_key` = :id_key
"
)
.raw_query()
.next()?
.ok_or_else(|| {
error_response!(
StatusCode::NOT_FOUND,
"user_not_found",
"the user does not exists",
)
})
.and_then(|row| Ok((row.get(0)?, row.get(1)?)))
}
fn get_room_member(
&self,
rid: Id,
UserKey { id_key, act_key }: &UserKey,
) -> Result<(i64, MemberPermission, Id)> {
prepare_cached_and_bind!(
self.conn(),
r"
SELECT `uid`, `room_member`.`permission`, `last_seen_cid`
FROM `room_member`
JOIN `valid_user_act_key` USING (`uid`)
WHERE (`rid`, `id_key`, `act_key`) = (:rid, :id_key, :act_key)
"
)
.raw_query()
.next()?
.ok_or_else(|| {
error_response!(
StatusCode::NOT_FOUND,
"room_not_found",
"the room does not exist or user is not a room member",
)
})
.and_then(|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))
}
fn get_room_having(&self, rid: Id, filter: RoomAttrs) -> Result<(RoomAttrs, Option<String>)> {
prepare_cached_and_bind!(
self.conn(),
r"
SELECT `attrs`, `title`
FROM `room`
WHERE `rid` = :rid
"
)
.raw_query()
.next()?
.map(|row| {
Ok::<_, rusqlite::Error>((
row.get::<_, RoomAttrs>(0)?,
row.get::<_, Option<String>>(1)?,
))
})
.transpose()?
.filter(|(attrs, _)| attrs.contains(filter))
.ok_or_else(|| {
error_response!(
StatusCode::NOT_FOUND,
"room_not_found",
"the room does not exist"
)
})
}
// FIXME: Eliminate this.
// Currently broadcasting msgs requires traversing over all members.
fn list_room_members(&self, rid: Id) -> Result<Vec<i64>> {
prepare_cached_and_bind!(
self.conn(),
r"
SELECT `uid`
FROM `room_member`
WHERE `rid` = :rid
"
)
.raw_query()
.mapped(|row| row.get::<_, i64>(0))
.collect::<rusqlite::Result<Vec<_>>>()
.map_err(Into::into)
}
fn list_public_rooms(&self, start_rid: Id, page_len: usize) -> Result<Vec<RoomMetadata>> {
// Attribute check must be written in the SQL literal so the query planer
// can successfully pick the conditional index.
const _: () = assert!(RoomAttrs::PUBLIC_READABLE.bits() == 1);
prepare_cached_and_bind!(
self.conn(),
r"
SELECT `rid`, `title`, `attrs`,
MAX(`cid`) AS `cid`, `timestamp`, `nonce`, `sig`, `rich_text`,
`last_author`.`id_key`, `msg`.`act_key`
FROM `room` INDEXED BY `ix_public_room`
LEFT JOIN `msg` USING (`rid`)
LEFT JOIN `user` AS `last_author` USING (`uid`)
WHERE `attrs` & 1 != 0 AND
`rid` > :start_rid
GROUP BY `rid`
ORDER BY `rid` ASC
LIMIT :page_len
"
)
.raw_query()
.and_then(parse_room_metadata)
.collect()
}
fn list_joined_rooms(
&self,
uid: i64,
start_rid: Id,
page_len: usize,
) -> Result<Vec<RoomMetadata>> {
prepare_cached_and_bind!(
self.conn(),
r"
SELECT
`rid`, `title`, `attrs`, `last_seen_cid`, `room_member`.`permission` AS `member_perm`,
MAX(`cid`) AS `cid`, `timestamp`, `nonce`, `sig`, `rich_text`,
`last_author`.`id_key`, `msg`.`act_key`,
`peer_user`.`id_key` AS `peer_id_key`
FROM `room_member` INDEXED BY `ix_member_room`
JOIN `room` USING (`rid`)
LEFT JOIN `msg` USING (`rid`)
LEFT JOIN `user` AS `last_author` ON (`last_author`.`uid` = `msg`.`uid`)
LEFT JOIN `user` AS `peer_user` ON
(`peer_user`.`uid` = `room`.`peer1` + `room`.`peer2` - :uid)
WHERE `room_member`.`uid` = :uid AND
`rid` > :start_rid
GROUP BY `rid`
ORDER BY `rid` ASC
LIMIT :page_len
"
)
.raw_query()
.and_then(parse_room_metadata)
.collect()
}
fn list_unseen_rooms(
&self,
uid: i64,
start_rid: Id,
page_len: usize,
) -> Result<Vec<RoomMetadata>> {
// FIXME: Limit `unseen_cnt` counting.
prepare_cached_and_bind!(
self.conn(),
r"
SELECT
`rid`, `title`, `attrs`, `last_seen_cid`, `room_member`.`permission` AS `member_perm`,
`cid`, `timestamp`, `nonce`, `sig`, `rich_text`,
`last_author`.`id_key`, `msg`.`act_key`,
`peer_user`.`id_key` AS `peer_id_key`,
(SELECT COUNT(*)
FROM `msg` AS `unseen_msg`
WHERE `unseen_msg`.`rid` = `room`.`rid` AND
`last_seen_cid` < `unseen_msg`.`cid`) AS `unseen_cnt`
FROM `room_member` INDEXED BY `ix_member_room`
JOIN `room` USING (`rid`)
LEFT JOIN `msg` USING (`rid`)
LEFT JOIN `user` AS `last_author` ON (`last_author`.`uid` = `msg`.`uid`)
LEFT JOIN `user` AS `peer_user` ON
(`peer_user`.`uid` = `room`.`peer1` + `room`.`peer2` - :uid)
WHERE `room_member`.`uid` = :uid AND
`rid` > :start_rid AND
`cid` > `last_seen_cid`
GROUP BY `rid` HAVING `cid` IS MAX(`cid`)
ORDER BY `rid` ASC
LIMIT :page_len
"
)
.raw_query()
.and_then(parse_room_metadata)
.collect()
}
fn list_room_msgs(
&self,
rid: Id,
after_cid: Id,
before_cid: Id,
page_len: usize,
) -> Result<Vec<WithMsgId<SignedChatMsg>>> {
prepare_cached_and_bind!(
self.conn(),
r"
SELECT `cid`, `timestamp`, `nonce`, `sig`, `id_key`, `act_key`, `sig`, `rich_text`
FROM `msg`
JOIN `user` USING (`uid`)
WHERE `rid` = :rid AND
:after_cid < `cid` AND
`cid` < :before_cid
ORDER BY `cid` DESC
LIMIT :page_len
"
)
.raw_query()
.and_then(|row| parse_msg(rid, row))
.collect()
}
fn create_user(
&self,
id_desc: &UserIdentityDesc,
id_desc_json: &str,
fetch_time: u64,
) -> Result<i64> {
let conn = self.conn();
let id_key = &id_desc.id_key;
let uid = prepare_cached_and_bind!(
conn,
r"
INSERT INTO `user` (`id_key`, `last_fetch_time`, `id_desc`)
VALUES (:id_key, :fetch_time, :id_desc_json)
ON CONFLICT (`id_key`) DO UPDATE SET
`last_fetch_time` = excluded.`last_fetch_time`,
`id_desc` = excluded.`id_desc`
WHERE `last_fetch_time` < :fetch_time
RETURNING `uid`
"
)
.raw_query()
.next()?
.ok_or_else(|| {
error_response!(
StatusCode::CONFLICT,
"conflict",
"racing register, please try again later",
)
})
.and_then(|row| Ok(row.get::<_, i64>(0)?))?;
// Delete existing act_keys.
prepare_cached_and_bind!(
conn,
r"
DELETE FROM `user_act_key`
WHERE `uid` = :uid
"
)
.raw_execute()?;
let mut stmt = conn.prepare_cached(
r"
INSERT INTO `user_act_key` (`uid`, `act_key`, `expire_time`)
VALUES (:uid, :act_key, :expire_time)
",
)?;
for kdesc in &id_desc.act_keys {
stmt.execute(named_params! {
":uid": uid,
":act_key": kdesc.signee.payload.act_key,
// FIXME: Other `u64` that will be stored in database should also be range checked.
":expire_time": kdesc.signee.payload.expire_time.min(i64::MAX as _),
})?;
}
Ok(uid)
}
fn create_group(&self, rid: Id, title: &str, attrs: RoomAttrs) -> Result<()> {
prepare_cached_and_bind!(
self.conn(),
r"
INSERT INTO `room` (`rid`, `title`, `attrs`)
VALUES (:rid, :title, :attrs)
"
)
.raw_execute()?;
Ok(())
}
fn create_peer_room_with_members(
&self,
rid: Id,
attrs: RoomAttrs,
src_uid: i64,
tgt_uid: i64,
) -> Result<()> {
assert!(attrs.contains(RoomAttrs::PEER_CHAT));
let conn = self.conn();
let (p1, p2) = if src_uid <= tgt_uid {
(src_uid, tgt_uid)
} else {
(tgt_uid, src_uid)
};
let updated = prepare_cached_and_bind!(
conn,
r"
INSERT INTO `room` (`rid`, `attrs`, `peer1`, `peer2`)
VALUES (:rid, :attrs, :p1, :p2)
ON CONFLICT (`peer1`, `peer2`) WHERE `rid` < 0 DO NOTHING
"
)
.raw_execute()?;
if updated == 0 {
return Err(error_response!(
StatusCode::CONFLICT,
"exists",
"room already exists"
));
}
// TODO: Limit permission of the src user?
let perm = MemberPermission::MAX_PEER_CHAT;
prepare_cached_and_bind!(
conn,
r"
INSERT INTO `room_member` (`rid`, `uid`, `permission`)
VALUES (:rid, :src_uid, :perm), (:rid, :tgt_uid, :perm)
"
)
.raw_execute()?;
Ok(())
}
fn add_room_member(&self, rid: Id, uid: i64, perm: MemberPermission) -> Result<()> {
let updated = prepare_cached_and_bind!(
self.conn(),
r"
INSERT INTO `room_member` (`rid`, `uid`, `permission`)
VALUES (:rid, :uid, :perm)
ON CONFLICT (`rid`, `uid`) DO NOTHING
"
)
.raw_execute()?;
if updated != 1 {
return Err(error_response!(
StatusCode::CONFLICT,
"exists",
"the user already joined the room",
));
}
Ok(())
}
fn remove_room_member(&self, rid: Id, uid: i64) -> Result<bool> {
// TODO: Check if it is the last member?
let updated = prepare_cached_and_bind!(
self.conn(),
r"
DELETE FROM `room_member`
WHERE (`rid`, `uid`) = (:rid, :uid)
"
)
.raw_execute()?;
Ok(updated == 1)
}
fn add_room_chat_msg(&self, rid: Id, uid: i64, cid: Id, chat: &SignedChatMsg) -> Result<()> {
let conn = self.conn();
let act_key = &chat.signee.user.act_key;
let timestamp = chat.signee.timestamp;
let nonce = chat.signee.nonce;
let rich_text = &chat.signee.payload.rich_text;
let sig = &chat.sig;
prepare_cached_and_bind!(
conn,
r"
INSERT INTO `msg` (`cid`, `rid`, `uid`, `act_key`, `timestamp`, `nonce`, `sig`, `rich_text`)
VALUES (:cid, :rid, :uid, :act_key, :timestamp, :nonce, :sig, :rich_text)
"
)
.raw_execute()?;
Ok(())
}
fn mark_room_msg_seen(&self, rid: Id, uid: i64, cid: Id) -> Result<()> {
// TODO: Validate `cid`?
let updated = prepare_cached_and_bind!(
self.conn(),
r"
UPDATE `room_member`
SET `last_seen_cid` = MAX(`last_seen_cid`, :cid)
WHERE (`rid`, `uid`) = (:rid, :uid)
"
)
.raw_execute()?;
if updated != 1 {
return Err(error_response!(
StatusCode::NOT_FOUND,
"room_not_found",
"the room does not exist or the user is not a room member",
));
}
Ok(())
}
}
impl TransactionOps for Transaction<'_> {
fn conn(&self) -> &Connection {
&self.0
}
}

View file

@ -0,0 +1,31 @@
#![expect(clippy::print_stdout, reason = "allowed in tests for debugging")]
use super::*;
#[test]
fn init_sql_valid() {
let conn = Connection::open_in_memory().unwrap();
conn.execute_batch(INIT_SQL).unwrap();
// Instantiate view to check syntax and availability of `unixepoch()`.
// It requires sqlite >= 3.38.0 (2022-02-22) which is not available by default on GitHub CI.
let ret = conn
.query_row(
"SELECT COUNT(*) FROM `valid_user_act_key`",
params![],
|row| row.get::<_, i64>(0),
)
.unwrap();
assert_eq!(ret, 0);
}
#[test]
fn stmt_cache_capacity() {
let src = std::fs::read_to_string("src/database.rs").unwrap();
let sql_cnt = src.matches("prepare_cached_and_bind!").count();
println!("found {sql_cnt} SQLs");
assert_ne!(sql_cnt, 0);
assert!(
sql_cnt <= STMT_CACHE_CAPACITY,
"stmt cache capacity {STMT_CACHE_CAPACITY} is too small, found {sql_cnt} SQLs",
);
}

View file

@ -20,7 +20,7 @@ use tokio::sync::broadcast;
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tokio_stream::wrappers::BroadcastStream;
use crate::database::ConnectionExt;
use crate::database::TransactionOps;
use crate::AppState;
#[derive(Debug, Deserialize)]
@ -145,8 +145,7 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib
let (uid, _) = st
.db
.get()
.get_user(&auth.signee.user)
.with_read(|txn| txn.get_user(&auth.signee.user))
.map_err(|err| anyhow!("{}", err.message))?;
// FIXME: Consistency of id's sign.
uid as u64

View file

@ -12,15 +12,14 @@ use axum::{Json, Router};
use axum_extra::extract::WithRejection as R;
use blah_types::{
ChatPayload, CreateGroup, CreatePeerChat, CreateRoomPayload, Id, MemberPermission, RoomAdminOp,
RoomAdminPayload, RoomAttrs, RoomMetadata, ServerPermission, Signed, SignedChatMsg, Signee,
UserKey, UserRegisterPayload, WithMsgId, X_BLAH_DIFFICULTY, X_BLAH_NONCE,
RoomAdminPayload, RoomAttrs, RoomMetadata, ServerPermission, Signed, SignedChatMsg, UserKey,
UserRegisterPayload, WithMsgId, X_BLAH_DIFFICULTY, X_BLAH_NONCE,
};
use database::ConnectionExt;
use database::{Transaction, TransactionOps};
use ed25519_dalek::SIGNATURE_LENGTH;
use id::IdExt;
use middleware::{Auth, MaybeAuth, ResultExt as _, SignedJson};
use parking_lot::Mutex;
use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql};
use serde::{Deserialize, Deserializer, Serialize};
use serde_inline_default::serde_inline_default;
use url::Url;
@ -188,25 +187,13 @@ async fn user_get(
let ret = (|| {
match auth.into_optional()? {
None => None,
Some(user) => st
.db
.get()
.query_row(
"
SELECT 1
FROM `valid_user_act_key`
WHERE (`id_key`, `act_key`) = (?, ?)
",
params![user.id_key, user.act_key],
|_| Ok(()),
)
.optional()?,
Some(user) => st.db.with_read(|txn| txn.get_user(&user)).ok(),
}
.ok_or_else(|| error_response!(StatusCode::NOT_FOUND, "not_found", "user does not exist"))
})();
match ret {
Ok(()) => Ok(StatusCode::NO_CONTENT),
Ok(_) => Ok(StatusCode::NO_CONTENT),
Err(err) => Err((st.register.challenge_headers(), err)),
}
}
@ -235,7 +222,7 @@ struct ListRoomParams {
top: Option<NonZeroUsize>,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Clone, Copy, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ListRoomFilter {
/// List all public rooms.
@ -243,6 +230,7 @@ enum ListRoomFilter {
/// List joined rooms (authentication required).
Joined,
/// List all joined rooms with unseen messages (authentication required).
// TODO: Is this really useful, given most user keep messages unread forever?
Unseen,
}
@ -259,143 +247,21 @@ async fn room_list(
let page_len = pagination.effective_page_len(&st);
let start_rid = pagination.skip_token.unwrap_or(Id::MIN);
let query = |sql: &str, params: &[(&str, &dyn ToSql)]| -> Result<RoomList, ApiError> {
let rooms = st
.db
.get()
.prepare(sql)?
.query_map(params, |row| {
// TODO: Extract this into a function.
let rid = row.get("rid")?;
let last_msg = row
.get::<_, Option<Id>>("cid")?
.map(|cid| {
Ok::<_, rusqlite::Error>(WithMsgId {
cid,
msg: SignedChatMsg {
sig: row.get("sig")?,
signee: Signee {
nonce: row.get("nonce")?,
timestamp: row.get("timestamp")?,
user: UserKey {
act_key: row.get("act_key")?,
id_key: row.get("id_key")?,
},
payload: ChatPayload {
rich_text: row.get("rich_text")?,
room: rid,
},
},
},
})
})
.transpose()?;
Ok(RoomMetadata {
rid,
title: row.get("title")?,
attrs: row.get("attrs")?,
last_msg,
last_seen_cid: Some(row.get::<_, Id>("last_seen_cid")?)
.filter(|cid| cid.0 != 0),
unseen_cnt: row.get("unseen_cnt").ok(),
member_permission: row.get("member_perm").ok(),
peer_user: row.get("peer_id_key").ok(),
})
})?
.collect::<Result<Vec<_>, _>>()?;
let skip_token =
(rooms.len() == page_len).then(|| rooms.last().expect("page must not be empty").rid);
Ok(RoomList { rooms, skip_token })
};
match params.filter {
ListRoomFilter::Public => query(
r"
SELECT `rid`, `title`, `attrs`, 0 AS `last_seen_cid`,
`cid`, `timestamp`, `nonce`, `sig`, `rich_text`,
`last_author`.`id_key`, `msg`.`act_key`
FROM `room`
LEFT JOIN `msg` USING (`rid`)
LEFT JOIN `user` AS `last_author` USING (`uid`)
WHERE `rid` > :start_rid AND
(`attrs` & :perm) = :perm
GROUP BY `rid` HAVING `cid` IS MAX(`cid`)
ORDER BY `rid` ASC
LIMIT :page_len
",
named_params! {
":start_rid": start_rid,
":page_len": page_len,
":perm": RoomAttrs::PUBLIC_READABLE,
},
),
let rooms = st.db.with_read(|txn| match params.filter {
ListRoomFilter::Public => txn.list_public_rooms(start_rid, page_len),
ListRoomFilter::Joined => {
let user = auth?.0;
query(
r"
SELECT
`rid`, `title`, `attrs`, `last_seen_cid`, `room_member`.`permission` AS `member_perm`,
`cid`, `timestamp`, `nonce`, `sig`, `rich_text`,
`last_author`.`id_key`, `msg`.`act_key`,
`peer_user`.`id_key` AS `peer_id_key`
FROM `valid_user_act_key` AS `me`
JOIN `room_member` USING (`uid`)
JOIN `room` USING (`rid`)
LEFT JOIN `msg` USING (`rid`)
LEFT JOIN `user` AS `last_author` ON (`last_author`.`uid` = `msg`.`uid`)
LEFT JOIN `user` AS `peer_user` ON
(`peer_user`.`uid` = `room`.`peer1` + `room`.`peer2` - `me`.`uid`)
WHERE (`me`.`id_key`, `me`.`act_key`) = (:id_key, :act_key) AND
`rid` > :start_rid
GROUP BY `rid` HAVING `cid` IS MAX(`cid`)
ORDER BY `rid` ASC
LIMIT :page_len
",
named_params! {
":start_rid": start_rid,
":page_len": page_len,
":id_key": user.id_key,
":act_key": user.act_key,
},
)
let (uid, _) = txn.get_user(&auth?.0)?;
txn.list_joined_rooms(uid, start_rid, page_len)
}
ListRoomFilter::Unseen => {
let user = auth?.0;
query(
r"
SELECT
`rid`, `title`, `attrs`, `last_seen_cid`, `room_member`.`permission` AS `member_perm`,
`cid`, `timestamp`, `nonce`, `sig`, `rich_text`,
`last_author`.`id_key`, `msg`.`act_key`,
`peer_user`.`id_key` AS `peer_id_key`,
(SELECT COUNT(*)
FROM `msg` AS `unseen_msg`
WHERE `unseen_msg`.`rid` = `room`.`rid` AND
`last_seen_cid` < `unseen_msg`.`cid`) AS `unseen_cnt`
FROM `valid_user_act_key` AS `me`
JOIN `room_member` USING (`uid`)
JOIN `room` USING (`rid`)
LEFT JOIN `msg` USING (`rid`)
LEFT JOIN `user` AS `last_author` ON (`last_author`.`uid` = `msg`.`uid`)
LEFT JOIN `user` AS `peer_user` ON
(`peer_user`.`uid` = `room`.`peer1` + `room`.`peer2` - `me`.`uid`)
WHERE (`me`.`id_key`, `me`.`act_key`) = (:id_key, :act_key) AND
`rid` > :start_rid AND
`cid` > `last_seen_cid`
GROUP BY `rid` HAVING `cid` IS MAX(`cid`)
ORDER BY `rid` ASC
LIMIT :page_len
",
named_params! {
":start_rid": start_rid,
":page_len": page_len,
":id_key": user.id_key,
":act_key": user.act_key,
},
)
let (uid, _) = txn.get_user(&auth?.0)?;
txn.list_unseen_rooms(uid, start_rid, page_len)
}
}
.map(Json)
})?;
let skip_token =
(rooms.len() == page_len).then(|| rooms.last().expect("page must not be empty").rid);
Ok(Json(RoomList { rooms, skip_token }))
}
async fn room_create(
@ -423,55 +289,20 @@ async fn room_create_group(
));
}
let conn = st.db.get();
let (uid, _perm) = conn
.query_row(
r"
SELECT `uid`, `permission`
FROM `valid_user_act_key`
WHERE (`id_key`, `act_key`) = (?, ?)
",
params![user.id_key, user.act_key],
|row| {
Ok((
row.get::<_, i64>("uid")?,
row.get::<_, ServerPermission>("permission")?,
))
},
)
.optional()?
.filter(|(_, perm)| perm.contains(ServerPermission::CREATE_ROOM))
.ok_or_else(|| {
error_response!(
let rid = st.db.with_write(|conn| {
let (uid, perm) = conn.get_user(user)?;
if !perm.contains(ServerPermission::CREATE_ROOM) {
return Err(error_response!(
StatusCode::FORBIDDEN,
"permission_denied",
"the user does not exist or does not have permission to create room",
)
})?;
let rid = Id::gen();
conn.execute(
r"
INSERT INTO `room` (`rid`, `title`, `attrs`)
VALUES (:rid, :title, :attrs)
",
named_params! {
":rid": rid,
":title": op.title,
":attrs": op.attrs,
},
)?;
conn.execute(
r"
INSERT INTO `room_member` (`rid`, `uid`, `permission`)
VALUES (:rid, :uid, :perm)
",
named_params! {
":rid": rid,
":uid": uid,
":perm": MemberPermission::ALL,
},
)?;
"the user does not have permission to create room",
));
}
let rid = Id::gen();
conn.create_group(rid, &op.title, op.attrs)?;
conn.add_room_member(rid, uid, MemberPermission::ALL)?;
Ok(rid)
})?;
Ok(Json(rid))
}
@ -491,72 +322,24 @@ async fn room_create_peer_chat(
}
// TODO: Access control and throttling.
let mut conn = st.db.get();
let txn = conn.transaction()?;
let (src_uid, _) = txn.get_user(src_user)?;
let (tgt_uid, _) = txn
.query_row(
r"
SELECT `uid`, `permission`
FROM `user`
WHERE `id_key` = ?
",
params![tgt_user_id_key],
|row| Ok((row.get::<_, i64>(0)?, row.get::<_, ServerPermission>(1)?)),
)
.optional()?
.filter(|(_, perm)| perm.contains(ServerPermission::ACCEPT_PEER_CHAT))
.ok_or_else(|| {
error_response!(
StatusCode::NOT_FOUND,
"not_found",
"peer user does not exist or disallows peer chat",
)
})?;
let mut peers = [src_uid, tgt_uid];
peers.sort();
let rid = Id::gen_peer_chat_rid();
let updated = txn.execute(
r"
INSERT INTO `room` (`rid`, `attrs`, `peer1`, `peer2`)
VALUES (:rid, :attrs, :peer1, :peer2)
ON CONFLICT (`peer1`, `peer2`) WHERE `rid` < 0 DO NOTHING
",
named_params! {
":rid": rid,
":attrs": RoomAttrs::PEER_CHAT,
":peer1": peers[0],
":peer2": peers[1],
},
)?;
if updated == 0 {
return Err(error_response!(
StatusCode::CONFLICT,
"exists",
"room already exists"
));
}
{
let mut stmt = txn.prepare(
r"
INSERT INTO `room_member` (`rid`, `uid`, `permission`)
VALUES (:rid, :uid, :perm)
",
)?;
// TODO: Limit permission of the src user?
for uid in peers {
stmt.execute(named_params! {
":rid": rid,
":uid": uid,
":perm": MemberPermission::MAX_PEER_CHAT,
let rid = st.db.with_write(|txn| {
let (src_uid, _) = txn.get_user(src_user)?;
let (tgt_uid, _) = txn
.get_user_by_id_key(&tgt_user_id_key)
.ok()
.filter(|(_, perm)| perm.contains(ServerPermission::ACCEPT_PEER_CHAT))
.ok_or_else(|| {
error_response!(
StatusCode::NOT_FOUND,
"peer_user_not_found",
"peer user does not exist or disallows peer chat",
)
})?;
}
}
let rid = Id::gen_peer_chat_rid();
txn.create_peer_room_with_members(rid, RoomAttrs::PEER_CHAT, src_uid, tgt_uid)?;
Ok(rid)
})?;
txn.commit()?;
Ok(Json(rid))
}
@ -598,11 +381,14 @@ async fn room_msg_list(
R(Query(pagination), _): RE<Query<Pagination>>,
auth: MaybeAuth,
) -> Result<Json<RoomMsgs>, ApiError> {
let (msgs, skip_token) = {
let conn = st.db.get();
get_room_if_readable(&conn, rid, auth.into_optional()?.as_ref(), |_row| Ok(()))?;
query_room_msgs(&st, &conn, rid, pagination)?
};
let (msgs, skip_token) = st.db.with_read(|txn| {
if let Some(user) = auth.into_optional()? {
txn.get_room_member(rid, &user)?;
} else {
txn.get_room_having(rid, RoomAttrs::PUBLIC_READABLE)?;
}
query_room_msgs(&st, txn, rid, pagination)
})?;
Ok(Json(RoomMsgs { msgs, skip_token }))
}
@ -611,12 +397,16 @@ async fn room_get_metadata(
R(Path(rid), _): RE<Path<Id>>,
auth: MaybeAuth,
) -> Result<Json<RoomMetadata>, ApiError> {
let conn = st.db.get();
let (title, attrs) = get_room_if_readable(&conn, rid, auth.into_optional()?.as_ref(), |row| {
Ok((
row.get::<_, Option<String>>("title")?,
row.get::<_, RoomAttrs>("attrs")?,
))
let (attrs, title) = st.db.with_read(|txn| {
let filter = if auth
.into_optional()?
.is_some_and(|user| txn.get_room_member(rid, &user).is_ok())
{
RoomAttrs::empty()
} else {
RoomAttrs::PUBLIC_READABLE
};
txn.get_room_having(rid, filter)
})?;
Ok(Json(RoomMetadata {
@ -638,12 +428,14 @@ async fn room_get_feed(
R(Path(rid), _): RE<Path<Id>>,
R(Query(pagination), _): RE<Query<Pagination>>,
) -> Result<impl IntoResponse, ApiError> {
let title;
let (msgs, skip_token) = {
let conn = st.db.get();
title = get_room_if_readable(&conn, rid, None, |row| row.get::<_, String>("title"))?;
query_room_msgs(&st, &conn, rid, pagination)?
};
let (title, msgs, skip_token) = st.db.with_read(|txn| {
let (attrs, title) = txn.get_room_having(rid, RoomAttrs::PUBLIC_READABLE)?;
// Sanity check.
assert!(!attrs.contains(RoomAttrs::PEER_CHAT));
let title = title.expect("public room must have title");
let (msgs, skip_token) = query_room_msgs(&st, txn, rid, pagination)?;
Ok((title, msgs, skip_token))
})?;
let items = msgs
.into_iter()
@ -736,101 +528,23 @@ struct FeedItemExtra {
sig: [u8; SIGNATURE_LENGTH],
}
fn get_room_if_readable<T>(
conn: &rusqlite::Connection,
rid: Id,
user: Option<&UserKey>,
f: impl FnOnce(&Row<'_>) -> rusqlite::Result<T>,
) -> Result<T, ApiError> {
let (id_key, act_key) = match user {
Some(keys) => (Some(&keys.id_key), Some(&keys.act_key)),
None => (None, None),
};
conn.query_row(
r"
SELECT `title`, `attrs`
FROM `room`
WHERE `rid` = :rid AND
((`attrs` & :perm) = :perm OR
EXISTS(SELECT 1
FROM `room_member`
JOIN `valid_user_act_key` USING (`uid`)
WHERE `room_member`.`rid` = `room`.`rid` AND
(`id_key`, `act_key`) = (:id_key, :act_key)))
",
named_params! {
":rid": rid,
":perm": RoomAttrs::PUBLIC_READABLE,
":id_key": id_key,
":act_key": act_key,
},
f,
)
.optional()?
.ok_or_else(|| {
error_response!(
StatusCode::NOT_FOUND,
"not_found",
"the room does not exist or the user is not a room member",
)
})
}
/// Get room messages with pagination parameters,
/// return a page of messages and the next `skip_token` if this is not the last page.
fn query_room_msgs(
st: &AppState,
conn: &Connection,
txn: &Transaction<'_>,
rid: Id,
pagination: Pagination,
) -> Result<(Vec<WithMsgId<SignedChatMsg>>, Option<Id>), ApiError> {
let page_len = pagination.effective_page_len(st);
let mut stmt = conn.prepare(
r"
SELECT `cid`, `timestamp`, `nonce`, `sig`, `id_key`, `act_key`, `sig`, `rich_text`
FROM `msg`
JOIN `user` USING (`uid`)
WHERE `rid` = :rid AND
:after_cid < `cid` AND
`cid` < :before_cid
ORDER BY `cid` DESC
LIMIT :limit
",
let msgs = txn.list_room_msgs(
rid,
pagination.until_token.unwrap_or(Id::MIN),
pagination.skip_token.unwrap_or(Id::MAX),
page_len,
)?;
let msgs = stmt
.query_and_then(
named_params! {
":rid": rid,
":after_cid": pagination.until_token.unwrap_or(Id::MIN),
":before_cid": pagination.skip_token.unwrap_or(Id::MAX),
":limit": page_len,
},
|row| {
Ok(WithMsgId {
cid: row.get("cid")?,
msg: SignedChatMsg {
sig: row.get("sig")?,
signee: Signee {
nonce: row.get("nonce")?,
timestamp: row.get("timestamp")?,
user: UserKey {
id_key: row.get("id_key")?,
act_key: row.get("act_key")?,
},
payload: ChatPayload {
room: rid,
rich_text: row.get("rich_text")?,
},
},
},
})
},
)?
.collect::<rusqlite::Result<Vec<_>>>()?;
let skip_token =
(msgs.len() == page_len).then(|| msgs.last().expect("page must not be empty").cid);
Ok((msgs, skip_token))
}
@ -847,38 +561,8 @@ async fn room_msg_post(
));
}
let (cid, txs) = {
let conn = st.db.get();
let (uid, perm) = conn
.query_row(
r"
SELECT `uid`, `room_member`.`permission`
FROM `room_member`
JOIN `valid_user_act_key` USING (`uid`)
WHERE `rid` = :rid AND
(`id_key`, `act_key`) = (:id_key, :act_key)
",
named_params! {
":rid": rid,
":id_key": &chat.signee.user.id_key,
":act_key": &chat.signee.user.act_key,
},
|row| {
Ok((
row.get::<_, u64>("uid")?,
row.get::<_, MemberPermission>("permission")?,
))
},
)
.optional()?
.ok_or_else(|| {
error_response!(
StatusCode::NOT_FOUND,
"not_found",
"the room does not exist or the user is not a room member",
)
})?;
let (cid, members) = st.db.with_write(|txn| {
let (uid, perm, ..) = txn.get_room_member(rid, &chat.signee.user)?;
if !perm.contains(MemberPermission::POST_CHAT) {
return Err(error_response!(
StatusCode::FORBIDDEN,
@ -888,52 +572,26 @@ async fn room_msg_post(
}
let cid = Id::gen();
conn.execute(
r"
INSERT INTO `msg` (`cid`, `rid`, `uid`, `act_key`, `timestamp`, `nonce`, `sig`, `rich_text`)
VALUES (:cid, :rid, :uid, :act_key, :timestamp, :nonce, :sig, :rich_text)
",
named_params! {
":cid": cid,
":rid": rid,
":uid": uid,
":act_key": chat.signee.user.act_key,
":timestamp": chat.signee.timestamp,
":nonce": chat.signee.nonce,
":rich_text": &chat.signee.payload.rich_text,
":sig": chat.sig,
},
)?;
txn.add_room_chat_msg(rid, uid, cid, &chat)?;
let members = txn.list_room_members(rid)?;
Ok((cid, members))
})?;
// FIXME: Optimize this to not traverses over all members.
let mut stmt = conn.prepare(
r"
SELECT `uid`
FROM `room_member`
WHERE `rid` = :rid
",
)?;
let listeners = st.event.user_listeners.lock();
let txs = stmt
.query_map(params![rid], |row| row.get::<_, u64>(0))?
.filter_map(|ret| match ret {
Ok(uid) => listeners.get(&uid).map(|tx| Ok(tx.clone())),
Err(err) => Some(Err(err)),
})
.collect::<Result<Vec<_>, _>>()?;
(cid, txs)
};
if !txs.is_empty() {
tracing::debug!("broadcasting event to {} clients", txs.len());
let chat = Arc::new(chat);
for tx in txs {
if let Err(err) = tx.send(chat.clone()) {
tracing::debug!(%err, "failed to broadcast event");
let chat = Arc::new(chat);
// FIXME: Optimize this to not traverses over all members.
let listeners = st.event.user_listeners.lock();
let mut cnt = 0usize;
for uid in members {
// FIXME: u64 vs i64.
if let Some(tx) = listeners.get(&(uid as u64)) {
if tx.send(chat.clone()).is_ok() {
cnt += 1;
}
}
}
if cnt != 0 {
tracing::debug!("broadcasted event to {cnt} clients");
}
Ok(Json(cid))
}
@ -997,124 +655,36 @@ async fn room_join(
user: &UserKey,
permission: MemberPermission,
) -> Result<(), ApiError> {
let mut conn = st.db.get();
let txn = conn.transaction()?;
let (uid, _) = txn.get_user(user)?;
txn.query_row(
r"
SELECT `attrs`
FROM `room`
WHERE `rid` = ?
",
params![rid],
|row| row.get::<_, RoomAttrs>(0),
)
.optional()?
.filter(|attrs| attrs.contains(RoomAttrs::PUBLIC_JOINABLE))
.ok_or_else(|| {
error_response!(
StatusCode::NOT_FOUND,
"not_found",
"the room does not exist or the user is not allowed to join the room",
)
})?;
let updated = txn.execute(
r"
INSERT INTO `room_member` (`rid`, `uid`, `permission`)
SELECT :rid, :uid, :perm
ON CONFLICT (`rid`, `uid`) DO NOTHING
",
named_params! {
":rid": rid,
":uid": uid,
":perm": permission,
},
)?;
if updated == 0 {
return Err(error_response!(
StatusCode::CONFLICT,
"exists",
"the user is already in the room",
));
}
txn.commit()?;
Ok(())
st.db.with_write(|txn| {
let (uid, _perm) = txn.get_user(user)?;
let (attrs, _) = txn.get_room_having(rid, RoomAttrs::PUBLIC_JOINABLE)?;
// Sanity check.
assert!(!attrs.contains(RoomAttrs::PEER_CHAT));
txn.add_room_member(rid, uid, permission)?;
Ok(())
})
}
async fn room_leave(st: &AppState, rid: Id, user: &UserKey) -> Result<(), ApiError> {
let mut conn = st.db.get();
let txn = conn.transaction()?;
let uid = txn
.query_row(
r"
SELECT `uid`
FROM `room_member`
JOIN `valid_user_act_key` USING (`uid`)
WHERE (`rid`, `id_key`, `act_key`) = (:rid, :id_key, :act_key)
",
named_params! {
":rid": rid,
":id_key": user.id_key,
":act_key": user.act_key,
},
|row| row.get::<_, u64>("uid"),
)
.optional()?
.ok_or_else(|| {
error_response!(
StatusCode::NOT_FOUND,
"not_found",
"the room does not exist or user is not a room member",
)
})?;
txn.execute(
r"
DELETE FROM `room_member`
WHERE `rid` = :rid AND
`uid` = :uid
",
named_params! {
":rid": rid,
":uid": uid,
},
)?;
txn.commit()?;
Ok(())
st.db.with_write(|txn| {
// FIXME: Handle peer chat room?
let (uid, ..) = txn.get_room_member(rid, user)?;
txn.remove_room_member(rid, uid)?;
Ok(())
})
}
async fn room_msg_mark_seen(
st: ArcState,
R(Path((rid, cid)), _): RE<Path<(Id, u64)>>,
R(Path((rid, cid)), _): RE<Path<(Id, i64)>>,
Auth(user): Auth,
) -> Result<StatusCode, ApiError> {
let changed = st.db.get().execute(
r"
UPDATE `room_member`
SET `last_seen_cid` = MAX(`last_seen_cid`, :cid)
WHERE
`rid` = :rid AND
`uid` = (SELECT `uid`
FROM `valid_user_act_key`
WHERE (`id_key`, `act_key`) = (:id_key, :act_key))
",
named_params! {
":cid": cid,
":rid": rid,
":id_key": user.id_key,
":act_key": user.act_key,
},
)?;
if changed != 1 {
return Err(error_response!(
StatusCode::NOT_FOUND,
"not_found",
"the room does not exist or the user is not a room member",
));
}
st.db.with_write(|txn| {
let (uid, _perm, prev_seen_cid) = txn.get_room_member(rid, &user)?;
if cid < prev_seen_cid.0 {
return Ok(());
}
txn.mark_room_msg_seen(rid, uid, Id(cid as _))
})?;
Ok(StatusCode::NO_CONTENT)
}

View file

@ -17,6 +17,7 @@ use crate::AppState;
///
/// Mostly following: <https://learn.microsoft.com/en-us/graph/errors>
#[derive(Debug, Serialize, Deserialize)]
#[must_use]
pub struct ApiError {
#[serde(skip, default)]
pub status: StatusCode,

View file

@ -9,10 +9,10 @@ use http_body_util::BodyExt;
use parking_lot::Mutex;
use rand::rngs::OsRng;
use rand::RngCore;
use rusqlite::{named_params, params, OptionalExtension};
use serde::Deserialize;
use sha2::{Digest, Sha256};
use crate::database::TransactionOps;
use crate::{ApiError, AppState};
const USER_AGENT: &str = concat!("blahd/", env!("CARGO_PKG_VERSION"));
@ -260,59 +260,8 @@ pub async fn user_register(
// Now the identity is verified.
let id_desc_json = serde_jcs::to_string(&id_desc).expect("serialization cannot fail");
let mut conn = st.db.get();
let txn = conn.transaction()?;
let uid = txn
.query_row(
r"
INSERT INTO `user` (`id_key`, `last_fetch_time`, `id_desc`)
VALUES (:id_key, :last_fetch_time, :id_desc)
ON CONFLICT (`id_key`) DO UPDATE SET
`last_fetch_time` = :last_fetch_time,
`id_desc` = :id_desc
WHERE `last_fetch_time` < :last_fetch_time
RETURNING `uid`
",
named_params! {
":id_key": reg.id_key,
":id_desc": id_desc_json,
":last_fetch_time": fetch_time,
},
|row| row.get::<_, i64>(0),
)
.optional()?
.ok_or_else(|| {
error_response!(
StatusCode::CONFLICT,
"conflict",
"racing register, please try again later",
)
})?;
{
txn.execute(
r"
DELETE FROM `user_act_key`
WHERE `uid` = ?
",
params![uid],
)?;
let mut stmt = txn.prepare(
r"
INSERT INTO `user_act_key` (`uid`, `act_key`, `expire_time`)
VALUES (:uid, :act_key, :expire_time)
",
)?;
for kdesc in &id_desc.act_keys {
stmt.execute(named_params! {
":uid": uid,
":act_key": kdesc.signee.payload.act_key,
// FIXME: Other `u64` that will be stored in database should also be range checked.
":expire_time": kdesc.signee.payload.expire_time.min(i64::MAX as _),
})?;
}
}
txn.commit()?;
st.db
.with_write(|txn| txn.create_user(&id_desc, &id_desc_json, fetch_time))?;
Ok(StatusCode::NO_CONTENT)
}

View file

@ -408,7 +408,7 @@ async fn room_create_get(server: Server, ref mut rng: impl RngCore, #[case] publ
if public {
assert_eq!(resp.unwrap(), room_meta);
} else {
resp.expect_api_err(StatusCode::NOT_FOUND, "not_found");
resp.expect_api_err(StatusCode::NOT_FOUND, "room_not_found");
}
}
@ -473,11 +473,11 @@ async fn room_join_leave(server: Server, ref mut rng: impl RngCore) {
// Not permitted.
join(rid_priv, &BOB)
.await
.expect_api_err(StatusCode::NOT_FOUND, "not_found");
.expect_api_err(StatusCode::NOT_FOUND, "room_not_found");
// Not exists.
join(Id::INVALID, &BOB)
.await
.expect_api_err(StatusCode::NOT_FOUND, "not_found");
.expect_api_err(StatusCode::NOT_FOUND, "room_not_found");
// Overly high permission.
server
.join_room(rid_priv, &BOB, MemberPermission::ALL)
@ -502,15 +502,15 @@ async fn room_join_leave(server: Server, ref mut rng: impl RngCore) {
// Already left.
leave(rid_pub, &BOB)
.await
.expect_api_err(StatusCode::NOT_FOUND, "not_found");
.expect_api_err(StatusCode::NOT_FOUND, "room_not_found");
// Unpermitted and not inside.
leave(rid_priv, &BOB)
.await
.expect_api_err(StatusCode::NOT_FOUND, "not_found");
.expect_api_err(StatusCode::NOT_FOUND, "room_not_found");
// Invalid room.
leave(Id::INVALID, &BOB)
.await
.expect_api_err(StatusCode::NOT_FOUND, "not_found");
.expect_api_err(StatusCode::NOT_FOUND, "room_not_found");
}
#[rstest]
@ -563,7 +563,7 @@ async fn room_chat_post_read(server: Server, ref mut rng: impl RngCore) {
// Not a member.
post(rid_pub, chat(rid_pub, &BOB, "not a member"))
.await
.expect_api_err(StatusCode::NOT_FOUND, "not_found");
.expect_api_err(StatusCode::NOT_FOUND, "room_not_found");
// Is a member but without permission.
server
@ -577,7 +577,7 @@ async fn room_chat_post_read(server: Server, ref mut rng: impl RngCore) {
// Room not exists.
post(Id::INVALID, chat(Id::INVALID, &ALICE, "not permitted"))
.await
.expect_api_err(StatusCode::NOT_FOUND, "not_found");
.expect_api_err(StatusCode::NOT_FOUND, "room_not_found");
//// Msgs listing ////
@ -636,13 +636,13 @@ async fn room_chat_post_read(server: Server, ref mut rng: impl RngCore) {
server
.get::<RoomMsgs>(&format!("/room/{rid_priv}/msg"), None)
.await
.expect_api_err(StatusCode::NOT_FOUND, "not_found");
.expect_api_err(StatusCode::NOT_FOUND, "room_not_found");
// Not a member.
server
.get::<RoomMsgs>(&format!("/room/{rid_priv}/msg"), Some(&auth(&BOB, rng)))
.await
.expect_api_err(StatusCode::NOT_FOUND, "not_found");
.expect_api_err(StatusCode::NOT_FOUND, "room_not_found");
// Ok.
let msgs = server
@ -758,7 +758,7 @@ async fn peer_chat(server: Server, ref mut rng: impl RngCore) {
// Bob disallows peer chat.
create_chat(&ALICE, &BOB)
.await
.expect_api_err(StatusCode::NOT_FOUND, "not_found");
.expect_api_err(StatusCode::NOT_FOUND, "peer_user_not_found");
// Alice accepts bob.
let rid = create_chat(&BOB, &ALICE).await.unwrap();
@ -777,7 +777,7 @@ async fn peer_chat(server: Server, ref mut rng: impl RngCore) {
server
.get::<RoomMetadata>(&format!("/room/{rid}"), None)
.await
.expect_api_err(StatusCode::NOT_FOUND, "not_found");
.expect_api_err(StatusCode::NOT_FOUND, "room_not_found");
// Both alice and bob are in the room.
for (key, peer) in [(&*ALICE, &*BOB), (&*BOB, &*ALICE)] {