Move database logic into submodule and do simple version check

This commit is contained in:
oxalica 2024-09-03 17:52:02 -04:00
parent a37bc3f81e
commit 81a566a097
7 changed files with 112 additions and 36 deletions

View file

@ -57,7 +57,10 @@ enum Command {
}
#[derive(Debug, clap::Subcommand)]
#[allow(clippy::large_enum_variant)]
enum DbCommand {
/// Create and initialize database.
Init,
/// Set user property, possibly adding new users.
SetUser {
#[command(flatten)]
@ -144,8 +147,6 @@ impl User {
}
}
static INIT_SQL: &str = include_str!("../../blahd/init.sql");
fn main() -> Result<()> {
let cli = <Cli as clap::Parser>::parse();
@ -157,9 +158,15 @@ fn main() -> Result<()> {
io::stdout().write_all(pubkey_doc.as_bytes())?;
}
Command::Database { database, command } => {
let conn = Connection::open(database).context("failed to open database")?;
conn.execute_batch(INIT_SQL)
.context("failed to initialize database")?;
use rusqlite::OpenFlags;
let mut flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_NO_MUTEX;
flags.set(
OpenFlags::SQLITE_OPEN_CREATE,
matches!(command, DbCommand::Init),
);
let conn =
Connection::open_with_flags(database, flags).context("failed to open database")?;
main_db(conn, command)?;
}
Command::Api { url, command } => build_rt()?.block_on(main_api(url, command))?,
@ -177,6 +184,7 @@ fn build_rt() -> Result<Runtime> {
fn main_db(conn: Connection, command: DbCommand) -> Result<()> {
match command {
DbCommand::Init => {}
DbCommand::SetUser { user, permission } => {
let userkey = build_rt()?.block_on(user.fetch_key())?;

View file

@ -3,11 +3,16 @@
# the default value.
[database]
# (Required)
# The path to the main SQLite database.
# The file will be created and initialized if not exist, but missing directory
# will not.
path = "/var/lib/blahd/db.sqlite"
# Whether to create and initialize the database if `path` does not exist.
# Note that parent directory will never be created and must already exist.
create = true
[server]
# (Required)

View file

@ -1,5 +1,5 @@
PRAGMA journal_mode=WAL;
PRAGMA foreign_keys=TRUE;
-- TODO: We are still in prototyping phase. Database migration is not
-- implemented and layout can change at any time.
CREATE TABLE IF NOT EXISTS `user` (
`uid` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,

View file

@ -20,6 +20,8 @@ pub struct Config {
pub struct DatabaseConfig {
#[serde_inline_default("/var/lib/blahd/db.sqlite".into())]
pub path: PathBuf,
#[serde_inline_default(true)]
pub create: bool,
}
#[serde_inline_default]

68
blahd/src/database.rs Normal file
View file

@ -0,0 +1,68 @@
use std::ops::DerefMut;
use std::sync::Mutex;
use anyhow::{ensure, Context, Result};
use rusqlite::{params, Connection, OpenFlags};
use crate::config::DatabaseConfig;
static INIT_SQL: &str = include_str!("../schema.sql");
// Simple and stupid version check for now.
// `echo -n 'blahd-database-0' | sha256sum | head -c5` || version
const APPLICATION_ID: i32 = 0xd9e_8400;
#[derive(Debug)]
pub struct Database {
conn: Mutex<Connection>,
}
impl Database {
pub fn open(config: &DatabaseConfig) -> Result<Self> {
let mut flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_NO_MUTEX;
if !config.path.try_exists()? {
flags.set(OpenFlags::SQLITE_OPEN_CREATE, config.create);
}
let mut conn = Connection::open_with_flags(&config.path, flags)
.context("failed to connect database")?;
// Connection-specific pragmas.
conn.pragma_update(None, "journal_mode", "WAL")?;
conn.pragma_update(None, "foreign_keys", "TRUE")?;
if conn.query_row(r"SELECT COUNT(*) FROM sqlite_schema", params![], |row| {
row.get::<_, u64>(0)
})? != 0
{
let cur_app_id =
conn.pragma_query_value(None, "application_id", |row| row.get::<_, i32>(0))?;
ensure!(
cur_app_id == (APPLICATION_ID),
"database is non-empty with a different application_id. \
migration is not implemented yet. \
got: {cur_app_id:#x}, expect: {APPLICATION_ID:#x} \
",
);
}
let txn = conn.transaction()?;
txn.execute_batch(INIT_SQL)
.context("failed to initialize database")?;
txn.pragma_update(None, "application_id", APPLICATION_ID)?;
txn.commit()?;
Ok(Self {
conn: Mutex::new(conn),
})
}
pub fn get(&self) -> impl DerefMut<Target = Connection> + '_ {
self.conn.lock().unwrap()
}
}
#[test]
fn init_sql_valid() {
let conn = Connection::open_in_memory().unwrap();
conn.execute_batch(INIT_SQL).unwrap();
}

View file

@ -115,9 +115,8 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib
let auth = serde_json::from_str::<WithSig<AuthPayload>>(&payload)?;
st.verify_signed_data(&auth)?;
st.conn
.lock()
.unwrap()
st.db
.get()
.query_row(
r"
SELECT `uid`

View file

@ -16,6 +16,7 @@ use blah::types::{
RoomAttrs, ServerPermission, Signee, UserKey, WithSig,
};
use config::Config;
use database::Database;
use ed25519_dalek::SIGNATURE_LENGTH;
use middleware::{ApiError, OptionalAuth, SignedJson};
use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql};
@ -27,6 +28,7 @@ use uuid::Uuid;
#[macro_use]
mod middleware;
mod config;
mod database;
mod event;
mod utils;
@ -63,9 +65,7 @@ fn main() -> Result<()> {
match cli {
Cli::Serve { config } => {
let config = parse_config(&config)?;
let db = rusqlite::Connection::open(&config.database.path)
.context("failed to open database")?;
let st = AppState::init(config, db).context("failed to initialize state")?;
let st = AppState::init(config).context("failed to initialize state")?;
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
@ -82,7 +82,7 @@ fn main() -> Result<()> {
// Locks must be grabbed in the field order.
#[derive(Debug)]
struct AppState {
conn: Mutex<rusqlite::Connection>,
db: Database,
used_nonces: Mutex<ExpiringSet<u32>>,
event: event::State,
@ -90,13 +90,9 @@ struct AppState {
}
impl AppState {
fn init(config: Config, conn: rusqlite::Connection) -> Result<Self> {
static INIT_SQL: &str = include_str!("../init.sql");
conn.execute_batch(INIT_SQL)
.context("failed to initialize database")?;
fn init(config: Config) -> Result<Self> {
Ok(Self {
conn: Mutex::new(conn),
db: Database::open(&config.database).context("failed to open database")?,
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(
config.server.timestamp_tolerance_secs,
))),
@ -236,9 +232,8 @@ async fn room_list(
let query = |sql: &str, params: &[(&str, &dyn ToSql)]| -> Result<RoomList, ApiError> {
let mut last_rid = None;
let rooms = st
.conn
.lock()
.unwrap()
.db
.get()
.prepare(sql)?
.query_map(params, |row| {
// TODO: Extract this into a function.
@ -346,7 +341,7 @@ async fn room_create(
));
}
let mut conn = st.conn.lock().unwrap();
let mut conn = st.db.get();
let Some(true) = conn
.query_row(
r"
@ -450,7 +445,7 @@ async fn room_get_item(
OptionalAuth(user): OptionalAuth,
) -> Result<Json<RoomItems>, ApiError> {
let (items, skip_token) = {
let conn = st.conn.lock().unwrap();
let conn = st.db.get();
get_room_if_readable(&conn, ruuid, user.as_ref(), |_row| Ok(()))?;
query_room_items(&st, &conn, ruuid, pagination)?
};
@ -466,13 +461,12 @@ async fn room_get_metadata(
WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>,
OptionalAuth(user): OptionalAuth,
) -> Result<Json<RoomMetadata>, ApiError> {
let (title, attrs) =
get_room_if_readable(&st.conn.lock().unwrap(), ruuid, user.as_ref(), |row| {
Ok((
row.get::<_, String>("title")?,
row.get::<_, RoomAttrs>("attrs")?,
))
})?;
let (title, attrs) = get_room_if_readable(&st.db.get(), ruuid, user.as_ref(), |row| {
Ok((
row.get::<_, String>("title")?,
row.get::<_, RoomAttrs>("attrs")?,
))
})?;
Ok(Json(RoomMetadata {
ruuid,
@ -489,7 +483,7 @@ async fn room_get_feed(
) -> Result<impl IntoResponse, ApiError> {
let title;
let (items, skip_token) = {
let conn = st.conn.lock().unwrap();
let conn = st.db.get();
title = get_room_if_readable(&conn, ruuid, None, |row| row.get::<_, String>("title"))?;
query_room_items(&st, &conn, ruuid, pagination)?
};
@ -690,7 +684,7 @@ async fn room_post_item(
}
let (cid, txs) = {
let conn = st.conn.lock().unwrap();
let conn = st.db.get();
let Some((rid, uid)) = conn
.query_row(
r"
@ -820,7 +814,7 @@ async fn room_join(
user: UserKey,
permission: MemberPermission,
) -> Result<(), ApiError> {
let mut conn = st.conn.lock().unwrap();
let mut conn = st.db.get();
let txn = conn.transaction()?;
let Some(rid) = txn
.query_row(
@ -872,7 +866,7 @@ async fn room_join(
}
async fn room_leave(st: &AppState, ruuid: Uuid, user: UserKey) -> Result<(), ApiError> {
let mut conn = st.conn.lock().unwrap();
let mut conn = st.db.get();
let txn = conn.transaction()?;
let Some((rid, uid)) = txn