diff --git a/Cargo.lock b/Cargo.lock index 2aae64b..2a00fb6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -29,6 +29,15 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + [[package]] name = "android-tzdata" version = "0.1.1" @@ -308,6 +317,8 @@ dependencies = [ "hex", "humantime", "parking_lot", + "reqwest", + "rstest", "rusqlite", "sd-notify", "serde", @@ -763,6 +774,12 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "h2" version = "0.4.6" @@ -1398,6 +1415,41 @@ dependencies = [ "bitflags", ] +[[package]] +name = "regex" +version = "1.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" + +[[package]] +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" + [[package]] name = "reqwest" version = "0.12.7" @@ -1456,6 +1508,33 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rstest" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b423f0e62bdd61734b67cd21ff50871dfaeb9cc74f869dcd6af974fbcb19936" +dependencies = [ + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5e1711e7d14f74b12a58411c542185ef7fb7f2e7f8ee6e2940a883628522b42" +dependencies = [ + "cfg-if", + "glob", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn", + "unicode-ident", +] + [[package]] name = "rusqlite" version = "0.32.1" diff --git a/blahd/Cargo.toml b/blahd/Cargo.toml index 6276a9b..6de844e 100644 --- a/blahd/Cargo.toml +++ b/blahd/Cargo.toml @@ -29,5 +29,9 @@ url = { version = "2.5.2", features = ["serde"] } blah = { path = "..", features = ["rusqlite"] } +[dev-dependencies] +reqwest = { version = "0.12.7", features = ["json"] } +rstest = { version = "0.22.0", default-features = false } + [lints] workspace = true diff --git a/blahd/config.example.toml b/blahd/config.example.toml index 516a29b..594cfe6 100644 --- a/blahd/config.example.toml +++ b/blahd/config.example.toml @@ -3,7 +3,10 @@ # the default value. [database] -# (Required) +# If enabled, a in-memory non-persistent database is used instead. Options +# `path` and `create` are ignored. This should only be used for testing. +in_memory = false + # The path to the main SQLite database. # The file will be created and initialized if not exist, but missing directory # will not. diff --git a/blahd/src/bin/blahd.rs b/blahd/src/bin/blahd.rs new file mode 100644 index 0000000..2731e80 --- /dev/null +++ b/blahd/src/bin/blahd.rs @@ -0,0 +1,53 @@ +use std::path::PathBuf; + +use anyhow::{Context, Result}; +use blahd::config::Config; +use blahd::{AppState, Database}; + +/// Blah Chat Server +#[derive(Debug, clap::Parser)] +#[clap(about, version = option_env!("CFG_RELEASE").unwrap_or(env!("CARGO_PKG_VERSION")))] +enum Cli { + /// Run the server with given configuration. + Serve { + /// The path to the configuration file. + #[arg(long, short)] + config: PathBuf, + }, + + /// Validate the configuration file and exit. + Validate { + /// The path to the configuration file. + #[arg(long, short)] + config: PathBuf, + }, +} + +fn main() -> Result<()> { + tracing_subscriber::fmt::init(); + let cli = ::parse(); + + fn parse_config(path: &std::path::Path) -> Result { + let src = std::fs::read_to_string(path)?; + let config = basic_toml::from_str::(&src)?; + config.validate()?; + Ok(config) + } + + match cli { + Cli::Serve { config } => { + let config = parse_config(&config)?; + let db = Database::open(&config.database).context("failed to open database")?; + let st = AppState::new(db, config.server); + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .context("failed to initialize tokio runtime")? + .block_on(st.serve()) + } + Cli::Validate { config } => { + parse_config(&config)?; + Ok(()) + } + } +} diff --git a/blahd/src/config.rs b/blahd/src/config.rs index 6d1dec9..088f530 100644 --- a/blahd/src/config.rs +++ b/blahd/src/config.rs @@ -18,6 +18,7 @@ pub struct Config { #[derive(Debug, Clone, Deserialize)] #[serde(deny_unknown_fields)] pub struct DatabaseConfig { + pub in_memory: bool, #[serde_inline_default("/var/lib/blahd/db.sqlite".into())] pub path: PathBuf, #[serde_inline_default(true)] diff --git a/blahd/src/database.rs b/blahd/src/database.rs index 2ac5143..99323c2 100644 --- a/blahd/src/database.rs +++ b/blahd/src/database.rs @@ -18,14 +18,31 @@ pub struct Database { } impl Database { - pub fn open(config: &DatabaseConfig) -> Result { - let mut flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_NO_MUTEX; - if !config.path.try_exists()? { - flags.set(OpenFlags::SQLITE_OPEN_CREATE, config.create); - } + /// Use an existing database connection and do no initialization or schema checking. + /// This should only be used for testing purpose. + pub fn from_raw(conn: Connection) -> Result { + conn.pragma_update(None, "foreign_keys", "TRUE")?; + Ok(Self { conn: conn.into() }) + } - let mut conn = Connection::open_with_flags(&config.path, flags) - .context("failed to connect database")?; + pub fn open(config: &DatabaseConfig) -> Result { + let mut conn = if config.in_memory { + Connection::open_in_memory().context("failed to open in-memory database")? + } else { + let mut flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_NO_MUTEX; + if !config.path.try_exists()? { + flags.set(OpenFlags::SQLITE_OPEN_CREATE, config.create); + } + Connection::open_with_flags(&config.path, flags) + .context("failed to connect database")? + }; + Self::maybe_init(&mut conn)?; + Ok(Self { + conn: Mutex::new(conn), + }) + } + + pub fn maybe_init(conn: &mut Connection) -> Result<()> { // Connection-specific pragmas. conn.pragma_update(None, "journal_mode", "WAL")?; conn.pragma_update(None, "foreign_keys", "TRUE")?; @@ -50,10 +67,7 @@ impl Database { .context("failed to initialize database")?; txn.pragma_update(None, "application_id", APPLICATION_ID)?; txn.commit()?; - - Ok(Self { - conn: Mutex::new(conn), - }) + Ok(()) } pub fn get(&self) -> impl DerefMut + '_ { diff --git a/blahd/src/event.rs b/blahd/src/event.rs index b1be384..b9cb849 100644 --- a/blahd/src/event.rs +++ b/blahd/src/event.rs @@ -19,6 +19,7 @@ use tokio::sync::broadcast; use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use tokio_stream::wrappers::BroadcastStream; +use crate::config::ServerConfig; use crate::AppState; #[derive(Debug, Deserialize)] @@ -52,14 +53,14 @@ impl std::error::Error for StreamEnded {} struct WsSenderWrapper<'ws, 'c> { inner: SplitSink<&'ws mut WebSocket, Message>, - config: &'c crate::config::Config, + config: &'c ServerConfig, } impl WsSenderWrapper<'_, '_> { async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> { let data = serde_json::to_string(&msg).expect("serialization cannot fail"); let fut = tokio::time::timeout( - self.config.server.ws_send_timeout_sec, + self.config.ws_send_timeout_sec, self.inner.send(Message::Text(data)), ); match fut.await { @@ -108,7 +109,7 @@ pub async fn handle_ws(st: Arc, ws: &mut WebSocket) -> Result, ws: &mut WebSocket) -> Result ent.get().subscribe(), Entry::Vacant(ent) => { - let (tx, rx) = broadcast::channel(st.config.server.ws_event_queue_len); + let (tx, rx) = broadcast::channel(st.config.ws_event_queue_len); ent.insert(tx); rx } diff --git a/blahd/src/main.rs b/blahd/src/lib.rs similarity index 92% rename from blahd/src/main.rs rename to blahd/src/lib.rs index 8469cd9..a6fccf5 100644 --- a/blahd/src/main.rs +++ b/blahd/src/lib.rs @@ -1,5 +1,4 @@ use std::num::NonZeroUsize; -use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, SystemTime}; @@ -15,8 +14,7 @@ use blah::types::{ ChatItem, ChatPayload, CreateRoomPayload, Id, MemberPermission, RoomAdminOp, RoomAdminPayload, RoomAttrs, ServerPermission, Signee, UserKey, WithItemId, WithSig, }; -use config::Config; -use database::Database; +use config::ServerConfig; use ed25519_dalek::SIGNATURE_LENGTH; use id::IdExt; use middleware::{ApiError, Auth, MaybeAuth, ResultExt as _, SignedJson}; @@ -28,80 +26,39 @@ use utils::ExpiringSet; #[macro_use] mod middleware; -mod config; +pub mod config; mod database; mod event; mod id; mod utils; -/// Blah Chat Server -#[derive(Debug, clap::Parser)] -#[clap(about, version = option_env!("CFG_RELEASE").unwrap_or(env!("CARGO_PKG_VERSION")))] -enum Cli { - /// Run the server with given configuration. - Serve { - /// The path to the configuration file. - #[arg(long, short)] - config: PathBuf, - }, - - /// Validate the configuration file and exit. - Validate { - /// The path to the configuration file. - #[arg(long, short)] - config: PathBuf, - }, -} - -fn main() -> Result<()> { - tracing_subscriber::fmt::init(); - let cli = ::parse(); - - fn parse_config(path: &std::path::Path) -> Result { - let src = std::fs::read_to_string(path)?; - let config = basic_toml::from_str::(&src)?; - config.validate()?; - Ok(config) - } - - match cli { - Cli::Serve { config } => { - let config = parse_config(&config)?; - let st = AppState::init(config).context("failed to initialize state")?; - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .context("failed to initialize tokio runtime")? - .block_on(main_async(st)) - } - Cli::Validate { config } => { - parse_config(&config)?; - Ok(()) - } - } -} +pub use database::Database; // Locks must be grabbed in the field order. #[derive(Debug)] -struct AppState { +pub struct AppState { db: Database, used_nonces: Mutex>, event: event::State, - config: Config, + config: ServerConfig, } impl AppState { - fn init(config: Config) -> Result { - Ok(Self { - db: Database::open(&config.database).context("failed to open database")?, + pub fn new(db: Database, config: ServerConfig) -> Self { + Self { + db, used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs( - config.server.timestamp_tolerance_secs, + config.timestamp_tolerance_secs, ))), event: event::State::default(), config, - }) + } + } + + pub async fn serve(self) -> Result<()> { + serve(Arc::new(self)).await } fn verify_signed_data(&self, data: &WithSig) -> Result<(), ApiError> { @@ -117,7 +74,7 @@ impl AppState { .expect("after UNIX epoch") .as_secs() .abs_diff(data.signee.timestamp); - if timestamp_diff > self.config.server.timestamp_tolerance_secs { + if timestamp_diff > self.config.timestamp_tolerance_secs { return Err(error_response!( StatusCode::BAD_REQUEST, "invalid_timestamp", @@ -137,10 +94,22 @@ impl AppState { type ArcState = State>; -async fn main_async(st: AppState) -> Result<()> { - let st = Arc::new(st); +async fn serve(st: Arc) -> Result<()> { + let listener = tokio::net::TcpListener::bind(&st.config.listen) + .await + .context("failed to listen on socket")?; + tracing::info!("listening on {}", st.config.listen); + let router = router(st.clone()); + let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); - let app = Router::new() + axum::serve(listener, router) + .await + .context("failed to serve")?; + Ok(()) +} + +pub fn router(st: Arc) -> Router { + Router::new() .route("/ws", get(handle_ws)) .route("/room", get(room_list)) .route("/room/create", post(room_create)) @@ -150,27 +119,16 @@ async fn main_async(st: AppState) -> Result<()> { .route("/room/:rid/item", get(room_item_list).post(room_item_post)) .route("/room/:rid/item/:cid/seen", post(room_item_mark_seen)) .route("/room/:rid/admin", post(room_admin)) - .with_state(st.clone()) .layer(tower_http::limit::RequestBodyLimitLayer::new( - st.config.server.max_request_len, + st.config.max_request_len, )) // NB. This comes at last (outmost layer), so inner errors will still be wrapped with // correct CORS headers. Also `Authorization` must be explicitly included besides `*`. .layer( tower_http::cors::CorsLayer::permissive() .allow_headers([header::HeaderName::from_static("*"), header::AUTHORIZATION]), - ); - - let listener = tokio::net::TcpListener::bind(&st.config.server.listen) - .await - .context("failed to listen on socket")?; - tracing::info!("listening on {}", st.config.server.listen); - let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); - - axum::serve(listener, app) - .await - .context("failed to serve")?; - Ok(()) + ) + .with_state(st) } type RE = R; @@ -193,11 +151,11 @@ async fn handle_ws(State(st): ArcState, ws: WebSocketUpgrade) -> Response { }) } -#[derive(Debug, Serialize)] -struct RoomList { - rooms: Vec, +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct RoomList { + pub rooms: Vec, #[serde(skip_serializing_if = "Option::is_none")] - skip_token: Option, + pub skip_token: Option, } #[derive(Debug, Deserialize)] @@ -463,16 +421,16 @@ impl Pagination { fn effective_page_len(&self, st: &AppState) -> usize { self.top .unwrap_or(usize::MAX.try_into().expect("not zero")) - .min(st.config.server.max_page_len) + .min(st.config.max_page_len) .get() } } #[derive(Debug, Serialize)] -struct RoomItems { - items: Vec>, +pub struct RoomItems { + pub items: Vec>, #[serde(skip_serializing_if = "Option::is_none")] - skip_token: Option, + pub skip_token: Option, } async fn room_item_list( @@ -548,7 +506,6 @@ async fn room_get_feed( let feed_url = st .config - .server .base_url .join(&format!("/room/{rid}/feed.json")) .expect("base_url must be valid"); @@ -616,7 +573,7 @@ struct FeedItemExtra { sig: [u8; SIGNATURE_LENGTH], } -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct RoomMetadata { pub rid: Id, pub title: String, diff --git a/blahd/tests/basic.rs b/blahd/tests/basic.rs new file mode 100644 index 0000000..f6017e1 --- /dev/null +++ b/blahd/tests/basic.rs @@ -0,0 +1,75 @@ +// FIXME: False positive? +#![allow(clippy::unwrap_used)] +use std::fmt; +use std::future::IntoFuture; +use std::sync::Arc; + +use blahd::{AppState, Database, RoomList}; +use reqwest::Client; +use rstest::{fixture, rstest}; +use rusqlite::Connection; +use tokio::net::TcpListener; + +// Avoid name resolution. +const LOCALHOST: &str = "127.0.0.1"; + +#[derive(Debug)] +struct Server { + port: u16, +} + +impl Server { + fn url(&self, rhs: impl fmt::Display) -> String { + format!("http://{}:{}{}", LOCALHOST, self.port, rhs) + } +} + +#[fixture] +fn server() -> Server { + let mut conn = Connection::open_in_memory().unwrap(); + Database::maybe_init(&mut conn).unwrap(); + let db = Database::from_raw(conn).unwrap(); + + // Use std's to avoid async, since we need no name resolution. + let listener = std::net::TcpListener::bind(format!("{LOCALHOST}:0")).unwrap(); + listener.set_nonblocking(true).unwrap(); + let port = listener.local_addr().unwrap().port(); + let listener = TcpListener::from_std(listener).unwrap(); + + // TODO: Testing config is hard to build because it does have a `Default` impl. + let config = basic_toml::from_str(&format!( + r#" +listen = "" # TODO: unused +base_url = "http://{LOCALHOST}:{port}" + "# + )) + .unwrap(); + let st = AppState::new(db, config); + let router = blahd::router(Arc::new(st)); + + tokio::spawn(axum::serve(listener, router).into_future()); + Server { port } +} + +#[fixture] +fn client() -> Client { + Client::new() +} + +#[rstest] +#[tokio::test] +async fn smoke(client: Client, server: Server) { + let got = client + .get(server.url("/room?filter=public")) + .send() + .await + .unwrap() + .json::() + .await + .unwrap(); + let exp = RoomList { + rooms: Vec::new(), + skip_token: None, + }; + assert_eq!(got, exp); +}