mirror of
https://github.com/Blah-IM/blahrs.git
synced 2025-07-09 15:45:33 +00:00
Librarify blahd for testing
This commit is contained in:
parent
a92f661003
commit
4f0f1405dc
9 changed files with 288 additions and 101 deletions
53
blahd/src/bin/blahd.rs
Normal file
53
blahd/src/bin/blahd.rs
Normal file
|
@ -0,0 +1,53 @@
|
|||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use blahd::config::Config;
|
||||
use blahd::{AppState, Database};
|
||||
|
||||
/// Blah Chat Server
|
||||
#[derive(Debug, clap::Parser)]
|
||||
#[clap(about, version = option_env!("CFG_RELEASE").unwrap_or(env!("CARGO_PKG_VERSION")))]
|
||||
enum Cli {
|
||||
/// Run the server with given configuration.
|
||||
Serve {
|
||||
/// The path to the configuration file.
|
||||
#[arg(long, short)]
|
||||
config: PathBuf,
|
||||
},
|
||||
|
||||
/// Validate the configuration file and exit.
|
||||
Validate {
|
||||
/// The path to the configuration file.
|
||||
#[arg(long, short)]
|
||||
config: PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
let cli = <Cli as clap::Parser>::parse();
|
||||
|
||||
fn parse_config(path: &std::path::Path) -> Result<Config> {
|
||||
let src = std::fs::read_to_string(path)?;
|
||||
let config = basic_toml::from_str::<Config>(&src)?;
|
||||
config.validate()?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
match cli {
|
||||
Cli::Serve { config } => {
|
||||
let config = parse_config(&config)?;
|
||||
let db = Database::open(&config.database).context("failed to open database")?;
|
||||
let st = AppState::new(db, config.server);
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.context("failed to initialize tokio runtime")?
|
||||
.block_on(st.serve())
|
||||
}
|
||||
Cli::Validate { config } => {
|
||||
parse_config(&config)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
|
@ -18,6 +18,7 @@ pub struct Config {
|
|||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct DatabaseConfig {
|
||||
pub in_memory: bool,
|
||||
#[serde_inline_default("/var/lib/blahd/db.sqlite".into())]
|
||||
pub path: PathBuf,
|
||||
#[serde_inline_default(true)]
|
||||
|
|
|
@ -18,14 +18,31 @@ pub struct Database {
|
|||
}
|
||||
|
||||
impl Database {
|
||||
pub fn open(config: &DatabaseConfig) -> Result<Self> {
|
||||
let mut flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_NO_MUTEX;
|
||||
if !config.path.try_exists()? {
|
||||
flags.set(OpenFlags::SQLITE_OPEN_CREATE, config.create);
|
||||
}
|
||||
/// Use an existing database connection and do no initialization or schema checking.
|
||||
/// This should only be used for testing purpose.
|
||||
pub fn from_raw(conn: Connection) -> Result<Self> {
|
||||
conn.pragma_update(None, "foreign_keys", "TRUE")?;
|
||||
Ok(Self { conn: conn.into() })
|
||||
}
|
||||
|
||||
let mut conn = Connection::open_with_flags(&config.path, flags)
|
||||
.context("failed to connect database")?;
|
||||
pub fn open(config: &DatabaseConfig) -> Result<Self> {
|
||||
let mut conn = if config.in_memory {
|
||||
Connection::open_in_memory().context("failed to open in-memory database")?
|
||||
} else {
|
||||
let mut flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_NO_MUTEX;
|
||||
if !config.path.try_exists()? {
|
||||
flags.set(OpenFlags::SQLITE_OPEN_CREATE, config.create);
|
||||
}
|
||||
Connection::open_with_flags(&config.path, flags)
|
||||
.context("failed to connect database")?
|
||||
};
|
||||
Self::maybe_init(&mut conn)?;
|
||||
Ok(Self {
|
||||
conn: Mutex::new(conn),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn maybe_init(conn: &mut Connection) -> Result<()> {
|
||||
// Connection-specific pragmas.
|
||||
conn.pragma_update(None, "journal_mode", "WAL")?;
|
||||
conn.pragma_update(None, "foreign_keys", "TRUE")?;
|
||||
|
@ -50,10 +67,7 @@ impl Database {
|
|||
.context("failed to initialize database")?;
|
||||
txn.pragma_update(None, "application_id", APPLICATION_ID)?;
|
||||
txn.commit()?;
|
||||
|
||||
Ok(Self {
|
||||
conn: Mutex::new(conn),
|
||||
})
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get(&self) -> impl DerefMut<Target = Connection> + '_ {
|
||||
|
|
|
@ -19,6 +19,7 @@ use tokio::sync::broadcast;
|
|||
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
|
||||
use crate::config::ServerConfig;
|
||||
use crate::AppState;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
@ -52,14 +53,14 @@ impl std::error::Error for StreamEnded {}
|
|||
|
||||
struct WsSenderWrapper<'ws, 'c> {
|
||||
inner: SplitSink<&'ws mut WebSocket, Message>,
|
||||
config: &'c crate::config::Config,
|
||||
config: &'c ServerConfig,
|
||||
}
|
||||
|
||||
impl WsSenderWrapper<'_, '_> {
|
||||
async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> {
|
||||
let data = serde_json::to_string(&msg).expect("serialization cannot fail");
|
||||
let fut = tokio::time::timeout(
|
||||
self.config.server.ws_send_timeout_sec,
|
||||
self.config.ws_send_timeout_sec,
|
||||
self.inner.send(Message::Text(data)),
|
||||
);
|
||||
match fut.await {
|
||||
|
@ -108,7 +109,7 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib
|
|||
};
|
||||
|
||||
let uid = {
|
||||
let payload = tokio::time::timeout(st.config.server.ws_auth_timeout_sec, ws_rx.next())
|
||||
let payload = tokio::time::timeout(st.config.ws_auth_timeout_sec, ws_rx.next())
|
||||
.await
|
||||
.context("authentication timeout")?
|
||||
.ok_or(StreamEnded)??;
|
||||
|
@ -136,7 +137,7 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib
|
|||
let rx = match st.event.user_listeners.lock().entry(uid) {
|
||||
Entry::Occupied(ent) => ent.get().subscribe(),
|
||||
Entry::Vacant(ent) => {
|
||||
let (tx, rx) = broadcast::channel(st.config.server.ws_event_queue_len);
|
||||
let (tx, rx) = broadcast::channel(st.config.ws_event_queue_len);
|
||||
ent.insert(tx);
|
||||
rx
|
||||
}
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use std::num::NonZeroUsize;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
|
@ -15,8 +14,7 @@ use blah::types::{
|
|||
ChatItem, ChatPayload, CreateRoomPayload, Id, MemberPermission, RoomAdminOp, RoomAdminPayload,
|
||||
RoomAttrs, ServerPermission, Signee, UserKey, WithItemId, WithSig,
|
||||
};
|
||||
use config::Config;
|
||||
use database::Database;
|
||||
use config::ServerConfig;
|
||||
use ed25519_dalek::SIGNATURE_LENGTH;
|
||||
use id::IdExt;
|
||||
use middleware::{ApiError, Auth, MaybeAuth, ResultExt as _, SignedJson};
|
||||
|
@ -28,80 +26,39 @@ use utils::ExpiringSet;
|
|||
|
||||
#[macro_use]
|
||||
mod middleware;
|
||||
mod config;
|
||||
pub mod config;
|
||||
mod database;
|
||||
mod event;
|
||||
mod id;
|
||||
mod utils;
|
||||
|
||||
/// Blah Chat Server
|
||||
#[derive(Debug, clap::Parser)]
|
||||
#[clap(about, version = option_env!("CFG_RELEASE").unwrap_or(env!("CARGO_PKG_VERSION")))]
|
||||
enum Cli {
|
||||
/// Run the server with given configuration.
|
||||
Serve {
|
||||
/// The path to the configuration file.
|
||||
#[arg(long, short)]
|
||||
config: PathBuf,
|
||||
},
|
||||
|
||||
/// Validate the configuration file and exit.
|
||||
Validate {
|
||||
/// The path to the configuration file.
|
||||
#[arg(long, short)]
|
||||
config: PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
let cli = <Cli as clap::Parser>::parse();
|
||||
|
||||
fn parse_config(path: &std::path::Path) -> Result<Config> {
|
||||
let src = std::fs::read_to_string(path)?;
|
||||
let config = basic_toml::from_str::<Config>(&src)?;
|
||||
config.validate()?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
match cli {
|
||||
Cli::Serve { config } => {
|
||||
let config = parse_config(&config)?;
|
||||
let st = AppState::init(config).context("failed to initialize state")?;
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.context("failed to initialize tokio runtime")?
|
||||
.block_on(main_async(st))
|
||||
}
|
||||
Cli::Validate { config } => {
|
||||
parse_config(&config)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
pub use database::Database;
|
||||
|
||||
// Locks must be grabbed in the field order.
|
||||
#[derive(Debug)]
|
||||
struct AppState {
|
||||
pub struct AppState {
|
||||
db: Database,
|
||||
used_nonces: Mutex<ExpiringSet<u32>>,
|
||||
event: event::State,
|
||||
|
||||
config: Config,
|
||||
config: ServerConfig,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
fn init(config: Config) -> Result<Self> {
|
||||
Ok(Self {
|
||||
db: Database::open(&config.database).context("failed to open database")?,
|
||||
pub fn new(db: Database, config: ServerConfig) -> Self {
|
||||
Self {
|
||||
db,
|
||||
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(
|
||||
config.server.timestamp_tolerance_secs,
|
||||
config.timestamp_tolerance_secs,
|
||||
))),
|
||||
event: event::State::default(),
|
||||
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve(self) -> Result<()> {
|
||||
serve(Arc::new(self)).await
|
||||
}
|
||||
|
||||
fn verify_signed_data<T: Serialize>(&self, data: &WithSig<T>) -> Result<(), ApiError> {
|
||||
|
@ -117,7 +74,7 @@ impl AppState {
|
|||
.expect("after UNIX epoch")
|
||||
.as_secs()
|
||||
.abs_diff(data.signee.timestamp);
|
||||
if timestamp_diff > self.config.server.timestamp_tolerance_secs {
|
||||
if timestamp_diff > self.config.timestamp_tolerance_secs {
|
||||
return Err(error_response!(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"invalid_timestamp",
|
||||
|
@ -137,10 +94,22 @@ impl AppState {
|
|||
|
||||
type ArcState = State<Arc<AppState>>;
|
||||
|
||||
async fn main_async(st: AppState) -> Result<()> {
|
||||
let st = Arc::new(st);
|
||||
async fn serve(st: Arc<AppState>) -> Result<()> {
|
||||
let listener = tokio::net::TcpListener::bind(&st.config.listen)
|
||||
.await
|
||||
.context("failed to listen on socket")?;
|
||||
tracing::info!("listening on {}", st.config.listen);
|
||||
let router = router(st.clone());
|
||||
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
|
||||
|
||||
let app = Router::new()
|
||||
axum::serve(listener, router)
|
||||
.await
|
||||
.context("failed to serve")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn router(st: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/ws", get(handle_ws))
|
||||
.route("/room", get(room_list))
|
||||
.route("/room/create", post(room_create))
|
||||
|
@ -150,27 +119,16 @@ async fn main_async(st: AppState) -> Result<()> {
|
|||
.route("/room/:rid/item", get(room_item_list).post(room_item_post))
|
||||
.route("/room/:rid/item/:cid/seen", post(room_item_mark_seen))
|
||||
.route("/room/:rid/admin", post(room_admin))
|
||||
.with_state(st.clone())
|
||||
.layer(tower_http::limit::RequestBodyLimitLayer::new(
|
||||
st.config.server.max_request_len,
|
||||
st.config.max_request_len,
|
||||
))
|
||||
// NB. This comes at last (outmost layer), so inner errors will still be wrapped with
|
||||
// correct CORS headers. Also `Authorization` must be explicitly included besides `*`.
|
||||
.layer(
|
||||
tower_http::cors::CorsLayer::permissive()
|
||||
.allow_headers([header::HeaderName::from_static("*"), header::AUTHORIZATION]),
|
||||
);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(&st.config.server.listen)
|
||||
.await
|
||||
.context("failed to listen on socket")?;
|
||||
tracing::info!("listening on {}", st.config.server.listen);
|
||||
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
|
||||
|
||||
axum::serve(listener, app)
|
||||
.await
|
||||
.context("failed to serve")?;
|
||||
Ok(())
|
||||
)
|
||||
.with_state(st)
|
||||
}
|
||||
|
||||
type RE<T> = R<T, ApiError>;
|
||||
|
@ -193,11 +151,11 @@ async fn handle_ws(State(st): ArcState, ws: WebSocketUpgrade) -> Response {
|
|||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct RoomList {
|
||||
rooms: Vec<RoomMetadata>,
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct RoomList {
|
||||
pub rooms: Vec<RoomMetadata>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
skip_token: Option<Id>,
|
||||
pub skip_token: Option<Id>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
@ -463,16 +421,16 @@ impl Pagination {
|
|||
fn effective_page_len(&self, st: &AppState) -> usize {
|
||||
self.top
|
||||
.unwrap_or(usize::MAX.try_into().expect("not zero"))
|
||||
.min(st.config.server.max_page_len)
|
||||
.min(st.config.max_page_len)
|
||||
.get()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct RoomItems {
|
||||
items: Vec<WithItemId<ChatItem>>,
|
||||
pub struct RoomItems {
|
||||
pub items: Vec<WithItemId<ChatItem>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
skip_token: Option<Id>,
|
||||
pub skip_token: Option<Id>,
|
||||
}
|
||||
|
||||
async fn room_item_list(
|
||||
|
@ -548,7 +506,6 @@ async fn room_get_feed(
|
|||
|
||||
let feed_url = st
|
||||
.config
|
||||
.server
|
||||
.base_url
|
||||
.join(&format!("/room/{rid}/feed.json"))
|
||||
.expect("base_url must be valid");
|
||||
|
@ -616,7 +573,7 @@ struct FeedItemExtra {
|
|||
sig: [u8; SIGNATURE_LENGTH],
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct RoomMetadata {
|
||||
pub rid: Id,
|
||||
pub title: String,
|
Loading…
Add table
Add a link
Reference in a new issue