diff --git a/blahctl/src/main.rs b/blahctl/src/main.rs index 5c1cf1c..3907291 100644 --- a/blahctl/src/main.rs +++ b/blahctl/src/main.rs @@ -57,7 +57,10 @@ enum Command { } #[derive(Debug, clap::Subcommand)] +#[allow(clippy::large_enum_variant)] enum DbCommand { + /// Create and initialize database. + Init, /// Set user property, possibly adding new users. SetUser { #[command(flatten)] @@ -144,8 +147,6 @@ impl User { } } -static INIT_SQL: &str = include_str!("../../blahd/init.sql"); - fn main() -> Result<()> { let cli = ::parse(); @@ -157,9 +158,15 @@ fn main() -> Result<()> { io::stdout().write_all(pubkey_doc.as_bytes())?; } Command::Database { database, command } => { - let conn = Connection::open(database).context("failed to open database")?; - conn.execute_batch(INIT_SQL) - .context("failed to initialize database")?; + use rusqlite::OpenFlags; + + let mut flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_NO_MUTEX; + flags.set( + OpenFlags::SQLITE_OPEN_CREATE, + matches!(command, DbCommand::Init), + ); + let conn = + Connection::open_with_flags(database, flags).context("failed to open database")?; main_db(conn, command)?; } Command::Api { url, command } => build_rt()?.block_on(main_api(url, command))?, @@ -177,6 +184,7 @@ fn build_rt() -> Result { fn main_db(conn: Connection, command: DbCommand) -> Result<()> { match command { + DbCommand::Init => {} DbCommand::SetUser { user, permission } => { let userkey = build_rt()?.block_on(user.fetch_key())?; diff --git a/blahd/config.example.toml b/blahd/config.example.toml index fff3b72..516a29b 100644 --- a/blahd/config.example.toml +++ b/blahd/config.example.toml @@ -3,11 +3,16 @@ # the default value. [database] +# (Required) # The path to the main SQLite database. # The file will be created and initialized if not exist, but missing directory # will not. path = "/var/lib/blahd/db.sqlite" +# Whether to create and initialize the database if `path` does not exist. +# Note that parent directory will never be created and must already exist. +create = true + [server] # (Required) diff --git a/blahd/init.sql b/blahd/schema.sql similarity index 91% rename from blahd/init.sql rename to blahd/schema.sql index 2c6f92e..15c4058 100644 --- a/blahd/init.sql +++ b/blahd/schema.sql @@ -1,5 +1,5 @@ -PRAGMA journal_mode=WAL; -PRAGMA foreign_keys=TRUE; +-- TODO: We are still in prototyping phase. Database migration is not +-- implemented and layout can change at any time. CREATE TABLE IF NOT EXISTS `user` ( `uid` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, diff --git a/blahd/src/config.rs b/blahd/src/config.rs index 64ee923..0758688 100644 --- a/blahd/src/config.rs +++ b/blahd/src/config.rs @@ -20,6 +20,8 @@ pub struct Config { pub struct DatabaseConfig { #[serde_inline_default("/var/lib/blahd/db.sqlite".into())] pub path: PathBuf, + #[serde_inline_default(true)] + pub create: bool, } #[serde_inline_default] diff --git a/blahd/src/database.rs b/blahd/src/database.rs new file mode 100644 index 0000000..2ae8a5a --- /dev/null +++ b/blahd/src/database.rs @@ -0,0 +1,68 @@ +use std::ops::DerefMut; +use std::sync::Mutex; + +use anyhow::{ensure, Context, Result}; +use rusqlite::{params, Connection, OpenFlags}; + +use crate::config::DatabaseConfig; + +static INIT_SQL: &str = include_str!("../schema.sql"); + +// Simple and stupid version check for now. +// `echo -n 'blahd-database-0' | sha256sum | head -c5` || version +const APPLICATION_ID: i32 = 0xd9e_8400; + +#[derive(Debug)] +pub struct Database { + conn: Mutex, +} + +impl Database { + pub fn open(config: &DatabaseConfig) -> Result { + let mut flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_NO_MUTEX; + if !config.path.try_exists()? { + flags.set(OpenFlags::SQLITE_OPEN_CREATE, config.create); + } + + let mut conn = Connection::open_with_flags(&config.path, flags) + .context("failed to connect database")?; + // Connection-specific pragmas. + conn.pragma_update(None, "journal_mode", "WAL")?; + conn.pragma_update(None, "foreign_keys", "TRUE")?; + + if conn.query_row(r"SELECT COUNT(*) FROM sqlite_schema", params![], |row| { + row.get::<_, u64>(0) + })? != 0 + { + let cur_app_id = + conn.pragma_query_value(None, "application_id", |row| row.get::<_, i32>(0))?; + ensure!( + cur_app_id == (APPLICATION_ID), + "database is non-empty with a different application_id. \ + migration is not implemented yet. \ + got: {cur_app_id:#x}, expect: {APPLICATION_ID:#x} \ + ", + ); + } + + let txn = conn.transaction()?; + txn.execute_batch(INIT_SQL) + .context("failed to initialize database")?; + txn.pragma_update(None, "application_id", APPLICATION_ID)?; + txn.commit()?; + + Ok(Self { + conn: Mutex::new(conn), + }) + } + + pub fn get(&self) -> impl DerefMut + '_ { + self.conn.lock().unwrap() + } +} + +#[test] +fn init_sql_valid() { + let conn = Connection::open_in_memory().unwrap(); + conn.execute_batch(INIT_SQL).unwrap(); +} diff --git a/blahd/src/event.rs b/blahd/src/event.rs index 4911a5f..a0ac9b1 100644 --- a/blahd/src/event.rs +++ b/blahd/src/event.rs @@ -115,9 +115,8 @@ pub async fn handle_ws(st: Arc, ws: &mut WebSocket) -> Result>(&payload)?; st.verify_signed_data(&auth)?; - st.conn - .lock() - .unwrap() + st.db + .get() .query_row( r" SELECT `uid` diff --git a/blahd/src/main.rs b/blahd/src/main.rs index 5f412f5..666bdb2 100644 --- a/blahd/src/main.rs +++ b/blahd/src/main.rs @@ -16,6 +16,7 @@ use blah::types::{ RoomAttrs, ServerPermission, Signee, UserKey, WithSig, }; use config::Config; +use database::Database; use ed25519_dalek::SIGNATURE_LENGTH; use middleware::{ApiError, OptionalAuth, SignedJson}; use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql}; @@ -27,6 +28,7 @@ use uuid::Uuid; #[macro_use] mod middleware; mod config; +mod database; mod event; mod utils; @@ -63,9 +65,7 @@ fn main() -> Result<()> { match cli { Cli::Serve { config } => { let config = parse_config(&config)?; - let db = rusqlite::Connection::open(&config.database.path) - .context("failed to open database")?; - let st = AppState::init(config, db).context("failed to initialize state")?; + let st = AppState::init(config).context("failed to initialize state")?; tokio::runtime::Builder::new_multi_thread() .enable_all() .build() @@ -82,7 +82,7 @@ fn main() -> Result<()> { // Locks must be grabbed in the field order. #[derive(Debug)] struct AppState { - conn: Mutex, + db: Database, used_nonces: Mutex>, event: event::State, @@ -90,13 +90,9 @@ struct AppState { } impl AppState { - fn init(config: Config, conn: rusqlite::Connection) -> Result { - static INIT_SQL: &str = include_str!("../init.sql"); - - conn.execute_batch(INIT_SQL) - .context("failed to initialize database")?; + fn init(config: Config) -> Result { Ok(Self { - conn: Mutex::new(conn), + db: Database::open(&config.database).context("failed to open database")?, used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs( config.server.timestamp_tolerance_secs, ))), @@ -236,9 +232,8 @@ async fn room_list( let query = |sql: &str, params: &[(&str, &dyn ToSql)]| -> Result { let mut last_rid = None; let rooms = st - .conn - .lock() - .unwrap() + .db + .get() .prepare(sql)? .query_map(params, |row| { // TODO: Extract this into a function. @@ -346,7 +341,7 @@ async fn room_create( )); } - let mut conn = st.conn.lock().unwrap(); + let mut conn = st.db.get(); let Some(true) = conn .query_row( r" @@ -450,7 +445,7 @@ async fn room_get_item( OptionalAuth(user): OptionalAuth, ) -> Result, ApiError> { let (items, skip_token) = { - let conn = st.conn.lock().unwrap(); + let conn = st.db.get(); get_room_if_readable(&conn, ruuid, user.as_ref(), |_row| Ok(()))?; query_room_items(&st, &conn, ruuid, pagination)? }; @@ -466,13 +461,12 @@ async fn room_get_metadata( WithRejection(Path(ruuid), _): WithRejection, ApiError>, OptionalAuth(user): OptionalAuth, ) -> Result, ApiError> { - let (title, attrs) = - get_room_if_readable(&st.conn.lock().unwrap(), ruuid, user.as_ref(), |row| { - Ok(( - row.get::<_, String>("title")?, - row.get::<_, RoomAttrs>("attrs")?, - )) - })?; + let (title, attrs) = get_room_if_readable(&st.db.get(), ruuid, user.as_ref(), |row| { + Ok(( + row.get::<_, String>("title")?, + row.get::<_, RoomAttrs>("attrs")?, + )) + })?; Ok(Json(RoomMetadata { ruuid, @@ -489,7 +483,7 @@ async fn room_get_feed( ) -> Result { let title; let (items, skip_token) = { - let conn = st.conn.lock().unwrap(); + let conn = st.db.get(); title = get_room_if_readable(&conn, ruuid, None, |row| row.get::<_, String>("title"))?; query_room_items(&st, &conn, ruuid, pagination)? }; @@ -690,7 +684,7 @@ async fn room_post_item( } let (cid, txs) = { - let conn = st.conn.lock().unwrap(); + let conn = st.db.get(); let Some((rid, uid)) = conn .query_row( r" @@ -820,7 +814,7 @@ async fn room_join( user: UserKey, permission: MemberPermission, ) -> Result<(), ApiError> { - let mut conn = st.conn.lock().unwrap(); + let mut conn = st.db.get(); let txn = conn.transaction()?; let Some(rid) = txn .query_row( @@ -872,7 +866,7 @@ async fn room_join( } async fn room_leave(st: &AppState, ruuid: Uuid, user: UserKey) -> Result<(), ApiError> { - let mut conn = st.conn.lock().unwrap(); + let mut conn = st.db.get(); let txn = conn.transaction()?; let Some((rid, uid)) = txn