Librarify blahd for testing

This commit is contained in:
oxalica 2024-09-09 00:30:15 -04:00
parent a92f661003
commit 4f0f1405dc
9 changed files with 288 additions and 101 deletions

79
Cargo.lock generated
View file

@ -29,6 +29,15 @@ dependencies = [
"zerocopy",
]
[[package]]
name = "aho-corasick"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
dependencies = [
"memchr",
]
[[package]]
name = "android-tzdata"
version = "0.1.1"
@ -308,6 +317,8 @@ dependencies = [
"hex",
"humantime",
"parking_lot",
"reqwest",
"rstest",
"rusqlite",
"sd-notify",
"serde",
@ -763,6 +774,12 @@ version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd"
[[package]]
name = "glob"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]]
name = "h2"
version = "0.4.6"
@ -1398,6 +1415,41 @@ dependencies = [
"bitflags",
]
[[package]]
name = "regex"
version = "1.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
"regex-syntax",
]
[[package]]
name = "regex-automata"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
[[package]]
name = "relative-path"
version = "1.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2"
[[package]]
name = "reqwest"
version = "0.12.7"
@ -1456,6 +1508,33 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "rstest"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b423f0e62bdd61734b67cd21ff50871dfaeb9cc74f869dcd6af974fbcb19936"
dependencies = [
"rstest_macros",
"rustc_version",
]
[[package]]
name = "rstest_macros"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5e1711e7d14f74b12a58411c542185ef7fb7f2e7f8ee6e2940a883628522b42"
dependencies = [
"cfg-if",
"glob",
"proc-macro2",
"quote",
"regex",
"relative-path",
"rustc_version",
"syn",
"unicode-ident",
]
[[package]]
name = "rusqlite"
version = "0.32.1"

View file

@ -29,5 +29,9 @@ url = { version = "2.5.2", features = ["serde"] }
blah = { path = "..", features = ["rusqlite"] }
[dev-dependencies]
reqwest = { version = "0.12.7", features = ["json"] }
rstest = { version = "0.22.0", default-features = false }
[lints]
workspace = true

View file

@ -3,7 +3,10 @@
# the default value.
[database]
# (Required)
# If enabled, a in-memory non-persistent database is used instead. Options
# `path` and `create` are ignored. This should only be used for testing.
in_memory = false
# The path to the main SQLite database.
# The file will be created and initialized if not exist, but missing directory
# will not.

53
blahd/src/bin/blahd.rs Normal file
View file

@ -0,0 +1,53 @@
use std::path::PathBuf;
use anyhow::{Context, Result};
use blahd::config::Config;
use blahd::{AppState, Database};
/// Blah Chat Server
#[derive(Debug, clap::Parser)]
#[clap(about, version = option_env!("CFG_RELEASE").unwrap_or(env!("CARGO_PKG_VERSION")))]
enum Cli {
/// Run the server with given configuration.
Serve {
/// The path to the configuration file.
#[arg(long, short)]
config: PathBuf,
},
/// Validate the configuration file and exit.
Validate {
/// The path to the configuration file.
#[arg(long, short)]
config: PathBuf,
},
}
fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let cli = <Cli as clap::Parser>::parse();
fn parse_config(path: &std::path::Path) -> Result<Config> {
let src = std::fs::read_to_string(path)?;
let config = basic_toml::from_str::<Config>(&src)?;
config.validate()?;
Ok(config)
}
match cli {
Cli::Serve { config } => {
let config = parse_config(&config)?;
let db = Database::open(&config.database).context("failed to open database")?;
let st = AppState::new(db, config.server);
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.context("failed to initialize tokio runtime")?
.block_on(st.serve())
}
Cli::Validate { config } => {
parse_config(&config)?;
Ok(())
}
}
}

View file

@ -18,6 +18,7 @@ pub struct Config {
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct DatabaseConfig {
pub in_memory: bool,
#[serde_inline_default("/var/lib/blahd/db.sqlite".into())]
pub path: PathBuf,
#[serde_inline_default(true)]

View file

@ -18,14 +18,31 @@ pub struct Database {
}
impl Database {
pub fn open(config: &DatabaseConfig) -> Result<Self> {
let mut flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_NO_MUTEX;
if !config.path.try_exists()? {
flags.set(OpenFlags::SQLITE_OPEN_CREATE, config.create);
}
/// Use an existing database connection and do no initialization or schema checking.
/// This should only be used for testing purpose.
pub fn from_raw(conn: Connection) -> Result<Self> {
conn.pragma_update(None, "foreign_keys", "TRUE")?;
Ok(Self { conn: conn.into() })
}
let mut conn = Connection::open_with_flags(&config.path, flags)
.context("failed to connect database")?;
pub fn open(config: &DatabaseConfig) -> Result<Self> {
let mut conn = if config.in_memory {
Connection::open_in_memory().context("failed to open in-memory database")?
} else {
let mut flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_NO_MUTEX;
if !config.path.try_exists()? {
flags.set(OpenFlags::SQLITE_OPEN_CREATE, config.create);
}
Connection::open_with_flags(&config.path, flags)
.context("failed to connect database")?
};
Self::maybe_init(&mut conn)?;
Ok(Self {
conn: Mutex::new(conn),
})
}
pub fn maybe_init(conn: &mut Connection) -> Result<()> {
// Connection-specific pragmas.
conn.pragma_update(None, "journal_mode", "WAL")?;
conn.pragma_update(None, "foreign_keys", "TRUE")?;
@ -50,10 +67,7 @@ impl Database {
.context("failed to initialize database")?;
txn.pragma_update(None, "application_id", APPLICATION_ID)?;
txn.commit()?;
Ok(Self {
conn: Mutex::new(conn),
})
Ok(())
}
pub fn get(&self) -> impl DerefMut<Target = Connection> + '_ {

View file

@ -19,6 +19,7 @@ use tokio::sync::broadcast;
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tokio_stream::wrappers::BroadcastStream;
use crate::config::ServerConfig;
use crate::AppState;
#[derive(Debug, Deserialize)]
@ -52,14 +53,14 @@ impl std::error::Error for StreamEnded {}
struct WsSenderWrapper<'ws, 'c> {
inner: SplitSink<&'ws mut WebSocket, Message>,
config: &'c crate::config::Config,
config: &'c ServerConfig,
}
impl WsSenderWrapper<'_, '_> {
async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> {
let data = serde_json::to_string(&msg).expect("serialization cannot fail");
let fut = tokio::time::timeout(
self.config.server.ws_send_timeout_sec,
self.config.ws_send_timeout_sec,
self.inner.send(Message::Text(data)),
);
match fut.await {
@ -108,7 +109,7 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib
};
let uid = {
let payload = tokio::time::timeout(st.config.server.ws_auth_timeout_sec, ws_rx.next())
let payload = tokio::time::timeout(st.config.ws_auth_timeout_sec, ws_rx.next())
.await
.context("authentication timeout")?
.ok_or(StreamEnded)??;
@ -136,7 +137,7 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib
let rx = match st.event.user_listeners.lock().entry(uid) {
Entry::Occupied(ent) => ent.get().subscribe(),
Entry::Vacant(ent) => {
let (tx, rx) = broadcast::channel(st.config.server.ws_event_queue_len);
let (tx, rx) = broadcast::channel(st.config.ws_event_queue_len);
ent.insert(tx);
rx
}

View file

@ -1,5 +1,4 @@
use std::num::NonZeroUsize;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
@ -15,8 +14,7 @@ use blah::types::{
ChatItem, ChatPayload, CreateRoomPayload, Id, MemberPermission, RoomAdminOp, RoomAdminPayload,
RoomAttrs, ServerPermission, Signee, UserKey, WithItemId, WithSig,
};
use config::Config;
use database::Database;
use config::ServerConfig;
use ed25519_dalek::SIGNATURE_LENGTH;
use id::IdExt;
use middleware::{ApiError, Auth, MaybeAuth, ResultExt as _, SignedJson};
@ -28,80 +26,39 @@ use utils::ExpiringSet;
#[macro_use]
mod middleware;
mod config;
pub mod config;
mod database;
mod event;
mod id;
mod utils;
/// Blah Chat Server
#[derive(Debug, clap::Parser)]
#[clap(about, version = option_env!("CFG_RELEASE").unwrap_or(env!("CARGO_PKG_VERSION")))]
enum Cli {
/// Run the server with given configuration.
Serve {
/// The path to the configuration file.
#[arg(long, short)]
config: PathBuf,
},
/// Validate the configuration file and exit.
Validate {
/// The path to the configuration file.
#[arg(long, short)]
config: PathBuf,
},
}
fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let cli = <Cli as clap::Parser>::parse();
fn parse_config(path: &std::path::Path) -> Result<Config> {
let src = std::fs::read_to_string(path)?;
let config = basic_toml::from_str::<Config>(&src)?;
config.validate()?;
Ok(config)
}
match cli {
Cli::Serve { config } => {
let config = parse_config(&config)?;
let st = AppState::init(config).context("failed to initialize state")?;
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.context("failed to initialize tokio runtime")?
.block_on(main_async(st))
}
Cli::Validate { config } => {
parse_config(&config)?;
Ok(())
}
}
}
pub use database::Database;
// Locks must be grabbed in the field order.
#[derive(Debug)]
struct AppState {
pub struct AppState {
db: Database,
used_nonces: Mutex<ExpiringSet<u32>>,
event: event::State,
config: Config,
config: ServerConfig,
}
impl AppState {
fn init(config: Config) -> Result<Self> {
Ok(Self {
db: Database::open(&config.database).context("failed to open database")?,
pub fn new(db: Database, config: ServerConfig) -> Self {
Self {
db,
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(
config.server.timestamp_tolerance_secs,
config.timestamp_tolerance_secs,
))),
event: event::State::default(),
config,
})
}
}
pub async fn serve(self) -> Result<()> {
serve(Arc::new(self)).await
}
fn verify_signed_data<T: Serialize>(&self, data: &WithSig<T>) -> Result<(), ApiError> {
@ -117,7 +74,7 @@ impl AppState {
.expect("after UNIX epoch")
.as_secs()
.abs_diff(data.signee.timestamp);
if timestamp_diff > self.config.server.timestamp_tolerance_secs {
if timestamp_diff > self.config.timestamp_tolerance_secs {
return Err(error_response!(
StatusCode::BAD_REQUEST,
"invalid_timestamp",
@ -137,10 +94,22 @@ impl AppState {
type ArcState = State<Arc<AppState>>;
async fn main_async(st: AppState) -> Result<()> {
let st = Arc::new(st);
async fn serve(st: Arc<AppState>) -> Result<()> {
let listener = tokio::net::TcpListener::bind(&st.config.listen)
.await
.context("failed to listen on socket")?;
tracing::info!("listening on {}", st.config.listen);
let router = router(st.clone());
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
let app = Router::new()
axum::serve(listener, router)
.await
.context("failed to serve")?;
Ok(())
}
pub fn router(st: Arc<AppState>) -> Router {
Router::new()
.route("/ws", get(handle_ws))
.route("/room", get(room_list))
.route("/room/create", post(room_create))
@ -150,27 +119,16 @@ async fn main_async(st: AppState) -> Result<()> {
.route("/room/:rid/item", get(room_item_list).post(room_item_post))
.route("/room/:rid/item/:cid/seen", post(room_item_mark_seen))
.route("/room/:rid/admin", post(room_admin))
.with_state(st.clone())
.layer(tower_http::limit::RequestBodyLimitLayer::new(
st.config.server.max_request_len,
st.config.max_request_len,
))
// NB. This comes at last (outmost layer), so inner errors will still be wrapped with
// correct CORS headers. Also `Authorization` must be explicitly included besides `*`.
.layer(
tower_http::cors::CorsLayer::permissive()
.allow_headers([header::HeaderName::from_static("*"), header::AUTHORIZATION]),
);
let listener = tokio::net::TcpListener::bind(&st.config.server.listen)
.await
.context("failed to listen on socket")?;
tracing::info!("listening on {}", st.config.server.listen);
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
axum::serve(listener, app)
.await
.context("failed to serve")?;
Ok(())
)
.with_state(st)
}
type RE<T> = R<T, ApiError>;
@ -193,11 +151,11 @@ async fn handle_ws(State(st): ArcState, ws: WebSocketUpgrade) -> Response {
})
}
#[derive(Debug, Serialize)]
struct RoomList {
rooms: Vec<RoomMetadata>,
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RoomList {
pub rooms: Vec<RoomMetadata>,
#[serde(skip_serializing_if = "Option::is_none")]
skip_token: Option<Id>,
pub skip_token: Option<Id>,
}
#[derive(Debug, Deserialize)]
@ -463,16 +421,16 @@ impl Pagination {
fn effective_page_len(&self, st: &AppState) -> usize {
self.top
.unwrap_or(usize::MAX.try_into().expect("not zero"))
.min(st.config.server.max_page_len)
.min(st.config.max_page_len)
.get()
}
}
#[derive(Debug, Serialize)]
struct RoomItems {
items: Vec<WithItemId<ChatItem>>,
pub struct RoomItems {
pub items: Vec<WithItemId<ChatItem>>,
#[serde(skip_serializing_if = "Option::is_none")]
skip_token: Option<Id>,
pub skip_token: Option<Id>,
}
async fn room_item_list(
@ -548,7 +506,6 @@ async fn room_get_feed(
let feed_url = st
.config
.server
.base_url
.join(&format!("/room/{rid}/feed.json"))
.expect("base_url must be valid");
@ -616,7 +573,7 @@ struct FeedItemExtra {
sig: [u8; SIGNATURE_LENGTH],
}
#[derive(Debug, Serialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RoomMetadata {
pub rid: Id,
pub title: String,

75
blahd/tests/basic.rs Normal file
View file

@ -0,0 +1,75 @@
// FIXME: False positive?
#![allow(clippy::unwrap_used)]
use std::fmt;
use std::future::IntoFuture;
use std::sync::Arc;
use blahd::{AppState, Database, RoomList};
use reqwest::Client;
use rstest::{fixture, rstest};
use rusqlite::Connection;
use tokio::net::TcpListener;
// Avoid name resolution.
const LOCALHOST: &str = "127.0.0.1";
#[derive(Debug)]
struct Server {
port: u16,
}
impl Server {
fn url(&self, rhs: impl fmt::Display) -> String {
format!("http://{}:{}{}", LOCALHOST, self.port, rhs)
}
}
#[fixture]
fn server() -> Server {
let mut conn = Connection::open_in_memory().unwrap();
Database::maybe_init(&mut conn).unwrap();
let db = Database::from_raw(conn).unwrap();
// Use std's to avoid async, since we need no name resolution.
let listener = std::net::TcpListener::bind(format!("{LOCALHOST}:0")).unwrap();
listener.set_nonblocking(true).unwrap();
let port = listener.local_addr().unwrap().port();
let listener = TcpListener::from_std(listener).unwrap();
// TODO: Testing config is hard to build because it does have a `Default` impl.
let config = basic_toml::from_str(&format!(
r#"
listen = "" # TODO: unused
base_url = "http://{LOCALHOST}:{port}"
"#
))
.unwrap();
let st = AppState::new(db, config);
let router = blahd::router(Arc::new(st));
tokio::spawn(axum::serve(listener, router).into_future());
Server { port }
}
#[fixture]
fn client() -> Client {
Client::new()
}
#[rstest]
#[tokio::test]
async fn smoke(client: Client, server: Server) {
let got = client
.get(server.url("/room?filter=public"))
.send()
.await
.unwrap()
.json::<RoomList>()
.await
.unwrap();
let exp = RoomList {
rooms: Vec::new(),
skip_token: None,
};
assert_eq!(got, exp);
}