mirror of
https://github.com/Blah-IM/blahrs.git
synced 2025-05-01 00:31:09 +00:00
Move database logic into submodule and do simple version check
This commit is contained in:
parent
a37bc3f81e
commit
81a566a097
7 changed files with 112 additions and 36 deletions
|
@ -57,7 +57,10 @@ enum Command {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, clap::Subcommand)]
|
#[derive(Debug, clap::Subcommand)]
|
||||||
|
#[allow(clippy::large_enum_variant)]
|
||||||
enum DbCommand {
|
enum DbCommand {
|
||||||
|
/// Create and initialize database.
|
||||||
|
Init,
|
||||||
/// Set user property, possibly adding new users.
|
/// Set user property, possibly adding new users.
|
||||||
SetUser {
|
SetUser {
|
||||||
#[command(flatten)]
|
#[command(flatten)]
|
||||||
|
@ -144,8 +147,6 @@ impl User {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static INIT_SQL: &str = include_str!("../../blahd/init.sql");
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let cli = <Cli as clap::Parser>::parse();
|
let cli = <Cli as clap::Parser>::parse();
|
||||||
|
|
||||||
|
@ -157,9 +158,15 @@ fn main() -> Result<()> {
|
||||||
io::stdout().write_all(pubkey_doc.as_bytes())?;
|
io::stdout().write_all(pubkey_doc.as_bytes())?;
|
||||||
}
|
}
|
||||||
Command::Database { database, command } => {
|
Command::Database { database, command } => {
|
||||||
let conn = Connection::open(database).context("failed to open database")?;
|
use rusqlite::OpenFlags;
|
||||||
conn.execute_batch(INIT_SQL)
|
|
||||||
.context("failed to initialize database")?;
|
let mut flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_NO_MUTEX;
|
||||||
|
flags.set(
|
||||||
|
OpenFlags::SQLITE_OPEN_CREATE,
|
||||||
|
matches!(command, DbCommand::Init),
|
||||||
|
);
|
||||||
|
let conn =
|
||||||
|
Connection::open_with_flags(database, flags).context("failed to open database")?;
|
||||||
main_db(conn, command)?;
|
main_db(conn, command)?;
|
||||||
}
|
}
|
||||||
Command::Api { url, command } => build_rt()?.block_on(main_api(url, command))?,
|
Command::Api { url, command } => build_rt()?.block_on(main_api(url, command))?,
|
||||||
|
@ -177,6 +184,7 @@ fn build_rt() -> Result<Runtime> {
|
||||||
|
|
||||||
fn main_db(conn: Connection, command: DbCommand) -> Result<()> {
|
fn main_db(conn: Connection, command: DbCommand) -> Result<()> {
|
||||||
match command {
|
match command {
|
||||||
|
DbCommand::Init => {}
|
||||||
DbCommand::SetUser { user, permission } => {
|
DbCommand::SetUser { user, permission } => {
|
||||||
let userkey = build_rt()?.block_on(user.fetch_key())?;
|
let userkey = build_rt()?.block_on(user.fetch_key())?;
|
||||||
|
|
||||||
|
|
|
@ -3,11 +3,16 @@
|
||||||
# the default value.
|
# the default value.
|
||||||
|
|
||||||
[database]
|
[database]
|
||||||
|
# (Required)
|
||||||
# The path to the main SQLite database.
|
# The path to the main SQLite database.
|
||||||
# The file will be created and initialized if not exist, but missing directory
|
# The file will be created and initialized if not exist, but missing directory
|
||||||
# will not.
|
# will not.
|
||||||
path = "/var/lib/blahd/db.sqlite"
|
path = "/var/lib/blahd/db.sqlite"
|
||||||
|
|
||||||
|
# Whether to create and initialize the database if `path` does not exist.
|
||||||
|
# Note that parent directory will never be created and must already exist.
|
||||||
|
create = true
|
||||||
|
|
||||||
[server]
|
[server]
|
||||||
|
|
||||||
# (Required)
|
# (Required)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
PRAGMA journal_mode=WAL;
|
-- TODO: We are still in prototyping phase. Database migration is not
|
||||||
PRAGMA foreign_keys=TRUE;
|
-- implemented and layout can change at any time.
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS `user` (
|
CREATE TABLE IF NOT EXISTS `user` (
|
||||||
`uid` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
`uid` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
|
@ -20,6 +20,8 @@ pub struct Config {
|
||||||
pub struct DatabaseConfig {
|
pub struct DatabaseConfig {
|
||||||
#[serde_inline_default("/var/lib/blahd/db.sqlite".into())]
|
#[serde_inline_default("/var/lib/blahd/db.sqlite".into())]
|
||||||
pub path: PathBuf,
|
pub path: PathBuf,
|
||||||
|
#[serde_inline_default(true)]
|
||||||
|
pub create: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[serde_inline_default]
|
#[serde_inline_default]
|
||||||
|
|
68
blahd/src/database.rs
Normal file
68
blahd/src/database.rs
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
use std::ops::DerefMut;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
use anyhow::{ensure, Context, Result};
|
||||||
|
use rusqlite::{params, Connection, OpenFlags};
|
||||||
|
|
||||||
|
use crate::config::DatabaseConfig;
|
||||||
|
|
||||||
|
static INIT_SQL: &str = include_str!("../schema.sql");
|
||||||
|
|
||||||
|
// Simple and stupid version check for now.
|
||||||
|
// `echo -n 'blahd-database-0' | sha256sum | head -c5` || version
|
||||||
|
const APPLICATION_ID: i32 = 0xd9e_8400;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Database {
|
||||||
|
conn: Mutex<Connection>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Database {
|
||||||
|
pub fn open(config: &DatabaseConfig) -> Result<Self> {
|
||||||
|
let mut flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_NO_MUTEX;
|
||||||
|
if !config.path.try_exists()? {
|
||||||
|
flags.set(OpenFlags::SQLITE_OPEN_CREATE, config.create);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut conn = Connection::open_with_flags(&config.path, flags)
|
||||||
|
.context("failed to connect database")?;
|
||||||
|
// Connection-specific pragmas.
|
||||||
|
conn.pragma_update(None, "journal_mode", "WAL")?;
|
||||||
|
conn.pragma_update(None, "foreign_keys", "TRUE")?;
|
||||||
|
|
||||||
|
if conn.query_row(r"SELECT COUNT(*) FROM sqlite_schema", params![], |row| {
|
||||||
|
row.get::<_, u64>(0)
|
||||||
|
})? != 0
|
||||||
|
{
|
||||||
|
let cur_app_id =
|
||||||
|
conn.pragma_query_value(None, "application_id", |row| row.get::<_, i32>(0))?;
|
||||||
|
ensure!(
|
||||||
|
cur_app_id == (APPLICATION_ID),
|
||||||
|
"database is non-empty with a different application_id. \
|
||||||
|
migration is not implemented yet. \
|
||||||
|
got: {cur_app_id:#x}, expect: {APPLICATION_ID:#x} \
|
||||||
|
",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let txn = conn.transaction()?;
|
||||||
|
txn.execute_batch(INIT_SQL)
|
||||||
|
.context("failed to initialize database")?;
|
||||||
|
txn.pragma_update(None, "application_id", APPLICATION_ID)?;
|
||||||
|
txn.commit()?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
conn: Mutex::new(conn),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self) -> impl DerefMut<Target = Connection> + '_ {
|
||||||
|
self.conn.lock().unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn init_sql_valid() {
|
||||||
|
let conn = Connection::open_in_memory().unwrap();
|
||||||
|
conn.execute_batch(INIT_SQL).unwrap();
|
||||||
|
}
|
|
@ -115,9 +115,8 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib
|
||||||
let auth = serde_json::from_str::<WithSig<AuthPayload>>(&payload)?;
|
let auth = serde_json::from_str::<WithSig<AuthPayload>>(&payload)?;
|
||||||
st.verify_signed_data(&auth)?;
|
st.verify_signed_data(&auth)?;
|
||||||
|
|
||||||
st.conn
|
st.db
|
||||||
.lock()
|
.get()
|
||||||
.unwrap()
|
|
||||||
.query_row(
|
.query_row(
|
||||||
r"
|
r"
|
||||||
SELECT `uid`
|
SELECT `uid`
|
||||||
|
|
|
@ -16,6 +16,7 @@ use blah::types::{
|
||||||
RoomAttrs, ServerPermission, Signee, UserKey, WithSig,
|
RoomAttrs, ServerPermission, Signee, UserKey, WithSig,
|
||||||
};
|
};
|
||||||
use config::Config;
|
use config::Config;
|
||||||
|
use database::Database;
|
||||||
use ed25519_dalek::SIGNATURE_LENGTH;
|
use ed25519_dalek::SIGNATURE_LENGTH;
|
||||||
use middleware::{ApiError, OptionalAuth, SignedJson};
|
use middleware::{ApiError, OptionalAuth, SignedJson};
|
||||||
use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql};
|
use rusqlite::{named_params, params, Connection, OptionalExtension, Row, ToSql};
|
||||||
|
@ -27,6 +28,7 @@ use uuid::Uuid;
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
mod middleware;
|
mod middleware;
|
||||||
mod config;
|
mod config;
|
||||||
|
mod database;
|
||||||
mod event;
|
mod event;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
|
@ -63,9 +65,7 @@ fn main() -> Result<()> {
|
||||||
match cli {
|
match cli {
|
||||||
Cli::Serve { config } => {
|
Cli::Serve { config } => {
|
||||||
let config = parse_config(&config)?;
|
let config = parse_config(&config)?;
|
||||||
let db = rusqlite::Connection::open(&config.database.path)
|
let st = AppState::init(config).context("failed to initialize state")?;
|
||||||
.context("failed to open database")?;
|
|
||||||
let st = AppState::init(config, db).context("failed to initialize state")?;
|
|
||||||
tokio::runtime::Builder::new_multi_thread()
|
tokio::runtime::Builder::new_multi_thread()
|
||||||
.enable_all()
|
.enable_all()
|
||||||
.build()
|
.build()
|
||||||
|
@ -82,7 +82,7 @@ fn main() -> Result<()> {
|
||||||
// Locks must be grabbed in the field order.
|
// Locks must be grabbed in the field order.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct AppState {
|
struct AppState {
|
||||||
conn: Mutex<rusqlite::Connection>,
|
db: Database,
|
||||||
used_nonces: Mutex<ExpiringSet<u32>>,
|
used_nonces: Mutex<ExpiringSet<u32>>,
|
||||||
event: event::State,
|
event: event::State,
|
||||||
|
|
||||||
|
@ -90,13 +90,9 @@ struct AppState {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppState {
|
impl AppState {
|
||||||
fn init(config: Config, conn: rusqlite::Connection) -> Result<Self> {
|
fn init(config: Config) -> Result<Self> {
|
||||||
static INIT_SQL: &str = include_str!("../init.sql");
|
|
||||||
|
|
||||||
conn.execute_batch(INIT_SQL)
|
|
||||||
.context("failed to initialize database")?;
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
conn: Mutex::new(conn),
|
db: Database::open(&config.database).context("failed to open database")?,
|
||||||
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(
|
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(
|
||||||
config.server.timestamp_tolerance_secs,
|
config.server.timestamp_tolerance_secs,
|
||||||
))),
|
))),
|
||||||
|
@ -236,9 +232,8 @@ async fn room_list(
|
||||||
let query = |sql: &str, params: &[(&str, &dyn ToSql)]| -> Result<RoomList, ApiError> {
|
let query = |sql: &str, params: &[(&str, &dyn ToSql)]| -> Result<RoomList, ApiError> {
|
||||||
let mut last_rid = None;
|
let mut last_rid = None;
|
||||||
let rooms = st
|
let rooms = st
|
||||||
.conn
|
.db
|
||||||
.lock()
|
.get()
|
||||||
.unwrap()
|
|
||||||
.prepare(sql)?
|
.prepare(sql)?
|
||||||
.query_map(params, |row| {
|
.query_map(params, |row| {
|
||||||
// TODO: Extract this into a function.
|
// TODO: Extract this into a function.
|
||||||
|
@ -346,7 +341,7 @@ async fn room_create(
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut conn = st.conn.lock().unwrap();
|
let mut conn = st.db.get();
|
||||||
let Some(true) = conn
|
let Some(true) = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
r"
|
r"
|
||||||
|
@ -450,7 +445,7 @@ async fn room_get_item(
|
||||||
OptionalAuth(user): OptionalAuth,
|
OptionalAuth(user): OptionalAuth,
|
||||||
) -> Result<Json<RoomItems>, ApiError> {
|
) -> Result<Json<RoomItems>, ApiError> {
|
||||||
let (items, skip_token) = {
|
let (items, skip_token) = {
|
||||||
let conn = st.conn.lock().unwrap();
|
let conn = st.db.get();
|
||||||
get_room_if_readable(&conn, ruuid, user.as_ref(), |_row| Ok(()))?;
|
get_room_if_readable(&conn, ruuid, user.as_ref(), |_row| Ok(()))?;
|
||||||
query_room_items(&st, &conn, ruuid, pagination)?
|
query_room_items(&st, &conn, ruuid, pagination)?
|
||||||
};
|
};
|
||||||
|
@ -466,13 +461,12 @@ async fn room_get_metadata(
|
||||||
WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>,
|
WithRejection(Path(ruuid), _): WithRejection<Path<Uuid>, ApiError>,
|
||||||
OptionalAuth(user): OptionalAuth,
|
OptionalAuth(user): OptionalAuth,
|
||||||
) -> Result<Json<RoomMetadata>, ApiError> {
|
) -> Result<Json<RoomMetadata>, ApiError> {
|
||||||
let (title, attrs) =
|
let (title, attrs) = get_room_if_readable(&st.db.get(), ruuid, user.as_ref(), |row| {
|
||||||
get_room_if_readable(&st.conn.lock().unwrap(), ruuid, user.as_ref(), |row| {
|
Ok((
|
||||||
Ok((
|
row.get::<_, String>("title")?,
|
||||||
row.get::<_, String>("title")?,
|
row.get::<_, RoomAttrs>("attrs")?,
|
||||||
row.get::<_, RoomAttrs>("attrs")?,
|
))
|
||||||
))
|
})?;
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(Json(RoomMetadata {
|
Ok(Json(RoomMetadata {
|
||||||
ruuid,
|
ruuid,
|
||||||
|
@ -489,7 +483,7 @@ async fn room_get_feed(
|
||||||
) -> Result<impl IntoResponse, ApiError> {
|
) -> Result<impl IntoResponse, ApiError> {
|
||||||
let title;
|
let title;
|
||||||
let (items, skip_token) = {
|
let (items, skip_token) = {
|
||||||
let conn = st.conn.lock().unwrap();
|
let conn = st.db.get();
|
||||||
title = get_room_if_readable(&conn, ruuid, None, |row| row.get::<_, String>("title"))?;
|
title = get_room_if_readable(&conn, ruuid, None, |row| row.get::<_, String>("title"))?;
|
||||||
query_room_items(&st, &conn, ruuid, pagination)?
|
query_room_items(&st, &conn, ruuid, pagination)?
|
||||||
};
|
};
|
||||||
|
@ -690,7 +684,7 @@ async fn room_post_item(
|
||||||
}
|
}
|
||||||
|
|
||||||
let (cid, txs) = {
|
let (cid, txs) = {
|
||||||
let conn = st.conn.lock().unwrap();
|
let conn = st.db.get();
|
||||||
let Some((rid, uid)) = conn
|
let Some((rid, uid)) = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
r"
|
r"
|
||||||
|
@ -820,7 +814,7 @@ async fn room_join(
|
||||||
user: UserKey,
|
user: UserKey,
|
||||||
permission: MemberPermission,
|
permission: MemberPermission,
|
||||||
) -> Result<(), ApiError> {
|
) -> Result<(), ApiError> {
|
||||||
let mut conn = st.conn.lock().unwrap();
|
let mut conn = st.db.get();
|
||||||
let txn = conn.transaction()?;
|
let txn = conn.transaction()?;
|
||||||
let Some(rid) = txn
|
let Some(rid) = txn
|
||||||
.query_row(
|
.query_row(
|
||||||
|
@ -872,7 +866,7 @@ async fn room_join(
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn room_leave(st: &AppState, ruuid: Uuid, user: UserKey) -> Result<(), ApiError> {
|
async fn room_leave(st: &AppState, ruuid: Uuid, user: UserKey) -> Result<(), ApiError> {
|
||||||
let mut conn = st.conn.lock().unwrap();
|
let mut conn = st.db.get();
|
||||||
let txn = conn.transaction()?;
|
let txn = conn.transaction()?;
|
||||||
|
|
||||||
let Some((rid, uid)) = txn
|
let Some((rid, uid)) = txn
|
||||||
|
|
Loading…
Add table
Reference in a new issue