Librarify blahd for testing

This commit is contained in:
oxalica 2024-09-09 00:30:15 -04:00
parent a92f661003
commit 4f0f1405dc
9 changed files with 288 additions and 101 deletions

53
blahd/src/bin/blahd.rs Normal file
View file

@ -0,0 +1,53 @@
use std::path::PathBuf;
use anyhow::{Context, Result};
use blahd::config::Config;
use blahd::{AppState, Database};
/// Blah Chat Server
#[derive(Debug, clap::Parser)]
#[clap(about, version = option_env!("CFG_RELEASE").unwrap_or(env!("CARGO_PKG_VERSION")))]
enum Cli {
/// Run the server with given configuration.
Serve {
/// The path to the configuration file.
#[arg(long, short)]
config: PathBuf,
},
/// Validate the configuration file and exit.
Validate {
/// The path to the configuration file.
#[arg(long, short)]
config: PathBuf,
},
}
fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let cli = <Cli as clap::Parser>::parse();
fn parse_config(path: &std::path::Path) -> Result<Config> {
let src = std::fs::read_to_string(path)?;
let config = basic_toml::from_str::<Config>(&src)?;
config.validate()?;
Ok(config)
}
match cli {
Cli::Serve { config } => {
let config = parse_config(&config)?;
let db = Database::open(&config.database).context("failed to open database")?;
let st = AppState::new(db, config.server);
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.context("failed to initialize tokio runtime")?
.block_on(st.serve())
}
Cli::Validate { config } => {
parse_config(&config)?;
Ok(())
}
}
}

View file

@ -18,6 +18,7 @@ pub struct Config {
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct DatabaseConfig {
pub in_memory: bool,
#[serde_inline_default("/var/lib/blahd/db.sqlite".into())]
pub path: PathBuf,
#[serde_inline_default(true)]

View file

@ -18,14 +18,31 @@ pub struct Database {
}
impl Database {
pub fn open(config: &DatabaseConfig) -> Result<Self> {
let mut flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_NO_MUTEX;
if !config.path.try_exists()? {
flags.set(OpenFlags::SQLITE_OPEN_CREATE, config.create);
}
/// Use an existing database connection and do no initialization or schema checking.
/// This should only be used for testing purpose.
pub fn from_raw(conn: Connection) -> Result<Self> {
conn.pragma_update(None, "foreign_keys", "TRUE")?;
Ok(Self { conn: conn.into() })
}
let mut conn = Connection::open_with_flags(&config.path, flags)
.context("failed to connect database")?;
pub fn open(config: &DatabaseConfig) -> Result<Self> {
let mut conn = if config.in_memory {
Connection::open_in_memory().context("failed to open in-memory database")?
} else {
let mut flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_NO_MUTEX;
if !config.path.try_exists()? {
flags.set(OpenFlags::SQLITE_OPEN_CREATE, config.create);
}
Connection::open_with_flags(&config.path, flags)
.context("failed to connect database")?
};
Self::maybe_init(&mut conn)?;
Ok(Self {
conn: Mutex::new(conn),
})
}
pub fn maybe_init(conn: &mut Connection) -> Result<()> {
// Connection-specific pragmas.
conn.pragma_update(None, "journal_mode", "WAL")?;
conn.pragma_update(None, "foreign_keys", "TRUE")?;
@ -50,10 +67,7 @@ impl Database {
.context("failed to initialize database")?;
txn.pragma_update(None, "application_id", APPLICATION_ID)?;
txn.commit()?;
Ok(Self {
conn: Mutex::new(conn),
})
Ok(())
}
pub fn get(&self) -> impl DerefMut<Target = Connection> + '_ {

View file

@ -19,6 +19,7 @@ use tokio::sync::broadcast;
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tokio_stream::wrappers::BroadcastStream;
use crate::config::ServerConfig;
use crate::AppState;
#[derive(Debug, Deserialize)]
@ -52,14 +53,14 @@ impl std::error::Error for StreamEnded {}
struct WsSenderWrapper<'ws, 'c> {
inner: SplitSink<&'ws mut WebSocket, Message>,
config: &'c crate::config::Config,
config: &'c ServerConfig,
}
impl WsSenderWrapper<'_, '_> {
async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> {
let data = serde_json::to_string(&msg).expect("serialization cannot fail");
let fut = tokio::time::timeout(
self.config.server.ws_send_timeout_sec,
self.config.ws_send_timeout_sec,
self.inner.send(Message::Text(data)),
);
match fut.await {
@ -108,7 +109,7 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib
};
let uid = {
let payload = tokio::time::timeout(st.config.server.ws_auth_timeout_sec, ws_rx.next())
let payload = tokio::time::timeout(st.config.ws_auth_timeout_sec, ws_rx.next())
.await
.context("authentication timeout")?
.ok_or(StreamEnded)??;
@ -136,7 +137,7 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib
let rx = match st.event.user_listeners.lock().entry(uid) {
Entry::Occupied(ent) => ent.get().subscribe(),
Entry::Vacant(ent) => {
let (tx, rx) = broadcast::channel(st.config.server.ws_event_queue_len);
let (tx, rx) = broadcast::channel(st.config.ws_event_queue_len);
ent.insert(tx);
rx
}

View file

@ -1,5 +1,4 @@
use std::num::NonZeroUsize;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
@ -15,8 +14,7 @@ use blah::types::{
ChatItem, ChatPayload, CreateRoomPayload, Id, MemberPermission, RoomAdminOp, RoomAdminPayload,
RoomAttrs, ServerPermission, Signee, UserKey, WithItemId, WithSig,
};
use config::Config;
use database::Database;
use config::ServerConfig;
use ed25519_dalek::SIGNATURE_LENGTH;
use id::IdExt;
use middleware::{ApiError, Auth, MaybeAuth, ResultExt as _, SignedJson};
@ -28,80 +26,39 @@ use utils::ExpiringSet;
#[macro_use]
mod middleware;
mod config;
pub mod config;
mod database;
mod event;
mod id;
mod utils;
/// Blah Chat Server
#[derive(Debug, clap::Parser)]
#[clap(about, version = option_env!("CFG_RELEASE").unwrap_or(env!("CARGO_PKG_VERSION")))]
enum Cli {
/// Run the server with given configuration.
Serve {
/// The path to the configuration file.
#[arg(long, short)]
config: PathBuf,
},
/// Validate the configuration file and exit.
Validate {
/// The path to the configuration file.
#[arg(long, short)]
config: PathBuf,
},
}
fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let cli = <Cli as clap::Parser>::parse();
fn parse_config(path: &std::path::Path) -> Result<Config> {
let src = std::fs::read_to_string(path)?;
let config = basic_toml::from_str::<Config>(&src)?;
config.validate()?;
Ok(config)
}
match cli {
Cli::Serve { config } => {
let config = parse_config(&config)?;
let st = AppState::init(config).context("failed to initialize state")?;
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.context("failed to initialize tokio runtime")?
.block_on(main_async(st))
}
Cli::Validate { config } => {
parse_config(&config)?;
Ok(())
}
}
}
pub use database::Database;
// Locks must be grabbed in the field order.
#[derive(Debug)]
struct AppState {
pub struct AppState {
db: Database,
used_nonces: Mutex<ExpiringSet<u32>>,
event: event::State,
config: Config,
config: ServerConfig,
}
impl AppState {
fn init(config: Config) -> Result<Self> {
Ok(Self {
db: Database::open(&config.database).context("failed to open database")?,
pub fn new(db: Database, config: ServerConfig) -> Self {
Self {
db,
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(
config.server.timestamp_tolerance_secs,
config.timestamp_tolerance_secs,
))),
event: event::State::default(),
config,
})
}
}
pub async fn serve(self) -> Result<()> {
serve(Arc::new(self)).await
}
fn verify_signed_data<T: Serialize>(&self, data: &WithSig<T>) -> Result<(), ApiError> {
@ -117,7 +74,7 @@ impl AppState {
.expect("after UNIX epoch")
.as_secs()
.abs_diff(data.signee.timestamp);
if timestamp_diff > self.config.server.timestamp_tolerance_secs {
if timestamp_diff > self.config.timestamp_tolerance_secs {
return Err(error_response!(
StatusCode::BAD_REQUEST,
"invalid_timestamp",
@ -137,10 +94,22 @@ impl AppState {
type ArcState = State<Arc<AppState>>;
async fn main_async(st: AppState) -> Result<()> {
let st = Arc::new(st);
async fn serve(st: Arc<AppState>) -> Result<()> {
let listener = tokio::net::TcpListener::bind(&st.config.listen)
.await
.context("failed to listen on socket")?;
tracing::info!("listening on {}", st.config.listen);
let router = router(st.clone());
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
let app = Router::new()
axum::serve(listener, router)
.await
.context("failed to serve")?;
Ok(())
}
pub fn router(st: Arc<AppState>) -> Router {
Router::new()
.route("/ws", get(handle_ws))
.route("/room", get(room_list))
.route("/room/create", post(room_create))
@ -150,27 +119,16 @@ async fn main_async(st: AppState) -> Result<()> {
.route("/room/:rid/item", get(room_item_list).post(room_item_post))
.route("/room/:rid/item/:cid/seen", post(room_item_mark_seen))
.route("/room/:rid/admin", post(room_admin))
.with_state(st.clone())
.layer(tower_http::limit::RequestBodyLimitLayer::new(
st.config.server.max_request_len,
st.config.max_request_len,
))
// NB. This comes at last (outmost layer), so inner errors will still be wrapped with
// correct CORS headers. Also `Authorization` must be explicitly included besides `*`.
.layer(
tower_http::cors::CorsLayer::permissive()
.allow_headers([header::HeaderName::from_static("*"), header::AUTHORIZATION]),
);
let listener = tokio::net::TcpListener::bind(&st.config.server.listen)
.await
.context("failed to listen on socket")?;
tracing::info!("listening on {}", st.config.server.listen);
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
axum::serve(listener, app)
.await
.context("failed to serve")?;
Ok(())
)
.with_state(st)
}
type RE<T> = R<T, ApiError>;
@ -193,11 +151,11 @@ async fn handle_ws(State(st): ArcState, ws: WebSocketUpgrade) -> Response {
})
}
#[derive(Debug, Serialize)]
struct RoomList {
rooms: Vec<RoomMetadata>,
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RoomList {
pub rooms: Vec<RoomMetadata>,
#[serde(skip_serializing_if = "Option::is_none")]
skip_token: Option<Id>,
pub skip_token: Option<Id>,
}
#[derive(Debug, Deserialize)]
@ -463,16 +421,16 @@ impl Pagination {
fn effective_page_len(&self, st: &AppState) -> usize {
self.top
.unwrap_or(usize::MAX.try_into().expect("not zero"))
.min(st.config.server.max_page_len)
.min(st.config.max_page_len)
.get()
}
}
#[derive(Debug, Serialize)]
struct RoomItems {
items: Vec<WithItemId<ChatItem>>,
pub struct RoomItems {
pub items: Vec<WithItemId<ChatItem>>,
#[serde(skip_serializing_if = "Option::is_none")]
skip_token: Option<Id>,
pub skip_token: Option<Id>,
}
async fn room_item_list(
@ -548,7 +506,6 @@ async fn room_get_feed(
let feed_url = st
.config
.server
.base_url
.join(&format!("/room/{rid}/feed.json"))
.expect("base_url must be valid");
@ -616,7 +573,7 @@ struct FeedItemExtra {
sig: [u8; SIGNATURE_LENGTH],
}
#[derive(Debug, Serialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RoomMetadata {
pub rid: Id,
pub title: String,