mirror of
https://github.com/Blah-IM/blahrs.git
synced 2025-05-01 00:31:09 +00:00
Librarify blahd for testing
This commit is contained in:
parent
a92f661003
commit
4f0f1405dc
9 changed files with 288 additions and 101 deletions
79
Cargo.lock
generated
79
Cargo.lock
generated
|
@ -29,6 +29,15 @@ dependencies = [
|
||||||
"zerocopy",
|
"zerocopy",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aho-corasick"
|
||||||
|
version = "1.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "android-tzdata"
|
name = "android-tzdata"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
|
@ -308,6 +317,8 @@ dependencies = [
|
||||||
"hex",
|
"hex",
|
||||||
"humantime",
|
"humantime",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
|
"reqwest",
|
||||||
|
"rstest",
|
||||||
"rusqlite",
|
"rusqlite",
|
||||||
"sd-notify",
|
"sd-notify",
|
||||||
"serde",
|
"serde",
|
||||||
|
@ -763,6 +774,12 @@ version = "0.29.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd"
|
checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "glob"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "h2"
|
name = "h2"
|
||||||
version = "0.4.6"
|
version = "0.4.6"
|
||||||
|
@ -1398,6 +1415,41 @@ dependencies = [
|
||||||
"bitflags",
|
"bitflags",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex"
|
||||||
|
version = "1.10.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"memchr",
|
||||||
|
"regex-automata",
|
||||||
|
"regex-syntax",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex-automata"
|
||||||
|
version = "0.4.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"memchr",
|
||||||
|
"regex-syntax",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex-syntax"
|
||||||
|
version = "0.8.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "relative-path"
|
||||||
|
version = "1.9.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "reqwest"
|
name = "reqwest"
|
||||||
version = "0.12.7"
|
version = "0.12.7"
|
||||||
|
@ -1456,6 +1508,33 @@ dependencies = [
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rstest"
|
||||||
|
version = "0.22.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7b423f0e62bdd61734b67cd21ff50871dfaeb9cc74f869dcd6af974fbcb19936"
|
||||||
|
dependencies = [
|
||||||
|
"rstest_macros",
|
||||||
|
"rustc_version",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rstest_macros"
|
||||||
|
version = "0.22.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c5e1711e7d14f74b12a58411c542185ef7fb7f2e7f8ee6e2940a883628522b42"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"glob",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"regex",
|
||||||
|
"relative-path",
|
||||||
|
"rustc_version",
|
||||||
|
"syn",
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rusqlite"
|
name = "rusqlite"
|
||||||
version = "0.32.1"
|
version = "0.32.1"
|
||||||
|
|
|
@ -29,5 +29,9 @@ url = { version = "2.5.2", features = ["serde"] }
|
||||||
|
|
||||||
blah = { path = "..", features = ["rusqlite"] }
|
blah = { path = "..", features = ["rusqlite"] }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
reqwest = { version = "0.12.7", features = ["json"] }
|
||||||
|
rstest = { version = "0.22.0", default-features = false }
|
||||||
|
|
||||||
[lints]
|
[lints]
|
||||||
workspace = true
|
workspace = true
|
||||||
|
|
|
@ -3,7 +3,10 @@
|
||||||
# the default value.
|
# the default value.
|
||||||
|
|
||||||
[database]
|
[database]
|
||||||
# (Required)
|
# If enabled, a in-memory non-persistent database is used instead. Options
|
||||||
|
# `path` and `create` are ignored. This should only be used for testing.
|
||||||
|
in_memory = false
|
||||||
|
|
||||||
# The path to the main SQLite database.
|
# The path to the main SQLite database.
|
||||||
# The file will be created and initialized if not exist, but missing directory
|
# The file will be created and initialized if not exist, but missing directory
|
||||||
# will not.
|
# will not.
|
||||||
|
|
53
blahd/src/bin/blahd.rs
Normal file
53
blahd/src/bin/blahd.rs
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use blahd::config::Config;
|
||||||
|
use blahd::{AppState, Database};
|
||||||
|
|
||||||
|
/// Blah Chat Server
|
||||||
|
#[derive(Debug, clap::Parser)]
|
||||||
|
#[clap(about, version = option_env!("CFG_RELEASE").unwrap_or(env!("CARGO_PKG_VERSION")))]
|
||||||
|
enum Cli {
|
||||||
|
/// Run the server with given configuration.
|
||||||
|
Serve {
|
||||||
|
/// The path to the configuration file.
|
||||||
|
#[arg(long, short)]
|
||||||
|
config: PathBuf,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Validate the configuration file and exit.
|
||||||
|
Validate {
|
||||||
|
/// The path to the configuration file.
|
||||||
|
#[arg(long, short)]
|
||||||
|
config: PathBuf,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
tracing_subscriber::fmt::init();
|
||||||
|
let cli = <Cli as clap::Parser>::parse();
|
||||||
|
|
||||||
|
fn parse_config(path: &std::path::Path) -> Result<Config> {
|
||||||
|
let src = std::fs::read_to_string(path)?;
|
||||||
|
let config = basic_toml::from_str::<Config>(&src)?;
|
||||||
|
config.validate()?;
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
match cli {
|
||||||
|
Cli::Serve { config } => {
|
||||||
|
let config = parse_config(&config)?;
|
||||||
|
let db = Database::open(&config.database).context("failed to open database")?;
|
||||||
|
let st = AppState::new(db, config.server);
|
||||||
|
tokio::runtime::Builder::new_multi_thread()
|
||||||
|
.enable_all()
|
||||||
|
.build()
|
||||||
|
.context("failed to initialize tokio runtime")?
|
||||||
|
.block_on(st.serve())
|
||||||
|
}
|
||||||
|
Cli::Validate { config } => {
|
||||||
|
parse_config(&config)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -18,6 +18,7 @@ pub struct Config {
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
#[serde(deny_unknown_fields)]
|
#[serde(deny_unknown_fields)]
|
||||||
pub struct DatabaseConfig {
|
pub struct DatabaseConfig {
|
||||||
|
pub in_memory: bool,
|
||||||
#[serde_inline_default("/var/lib/blahd/db.sqlite".into())]
|
#[serde_inline_default("/var/lib/blahd/db.sqlite".into())]
|
||||||
pub path: PathBuf,
|
pub path: PathBuf,
|
||||||
#[serde_inline_default(true)]
|
#[serde_inline_default(true)]
|
||||||
|
|
|
@ -18,14 +18,31 @@ pub struct Database {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Database {
|
impl Database {
|
||||||
pub fn open(config: &DatabaseConfig) -> Result<Self> {
|
/// Use an existing database connection and do no initialization or schema checking.
|
||||||
let mut flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_NO_MUTEX;
|
/// This should only be used for testing purpose.
|
||||||
if !config.path.try_exists()? {
|
pub fn from_raw(conn: Connection) -> Result<Self> {
|
||||||
flags.set(OpenFlags::SQLITE_OPEN_CREATE, config.create);
|
conn.pragma_update(None, "foreign_keys", "TRUE")?;
|
||||||
}
|
Ok(Self { conn: conn.into() })
|
||||||
|
}
|
||||||
|
|
||||||
let mut conn = Connection::open_with_flags(&config.path, flags)
|
pub fn open(config: &DatabaseConfig) -> Result<Self> {
|
||||||
.context("failed to connect database")?;
|
let mut conn = if config.in_memory {
|
||||||
|
Connection::open_in_memory().context("failed to open in-memory database")?
|
||||||
|
} else {
|
||||||
|
let mut flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_NO_MUTEX;
|
||||||
|
if !config.path.try_exists()? {
|
||||||
|
flags.set(OpenFlags::SQLITE_OPEN_CREATE, config.create);
|
||||||
|
}
|
||||||
|
Connection::open_with_flags(&config.path, flags)
|
||||||
|
.context("failed to connect database")?
|
||||||
|
};
|
||||||
|
Self::maybe_init(&mut conn)?;
|
||||||
|
Ok(Self {
|
||||||
|
conn: Mutex::new(conn),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn maybe_init(conn: &mut Connection) -> Result<()> {
|
||||||
// Connection-specific pragmas.
|
// Connection-specific pragmas.
|
||||||
conn.pragma_update(None, "journal_mode", "WAL")?;
|
conn.pragma_update(None, "journal_mode", "WAL")?;
|
||||||
conn.pragma_update(None, "foreign_keys", "TRUE")?;
|
conn.pragma_update(None, "foreign_keys", "TRUE")?;
|
||||||
|
@ -50,10 +67,7 @@ impl Database {
|
||||||
.context("failed to initialize database")?;
|
.context("failed to initialize database")?;
|
||||||
txn.pragma_update(None, "application_id", APPLICATION_ID)?;
|
txn.pragma_update(None, "application_id", APPLICATION_ID)?;
|
||||||
txn.commit()?;
|
txn.commit()?;
|
||||||
|
Ok(())
|
||||||
Ok(Self {
|
|
||||||
conn: Mutex::new(conn),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self) -> impl DerefMut<Target = Connection> + '_ {
|
pub fn get(&self) -> impl DerefMut<Target = Connection> + '_ {
|
||||||
|
|
|
@ -19,6 +19,7 @@ use tokio::sync::broadcast;
|
||||||
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
||||||
use tokio_stream::wrappers::BroadcastStream;
|
use tokio_stream::wrappers::BroadcastStream;
|
||||||
|
|
||||||
|
use crate::config::ServerConfig;
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
@ -52,14 +53,14 @@ impl std::error::Error for StreamEnded {}
|
||||||
|
|
||||||
struct WsSenderWrapper<'ws, 'c> {
|
struct WsSenderWrapper<'ws, 'c> {
|
||||||
inner: SplitSink<&'ws mut WebSocket, Message>,
|
inner: SplitSink<&'ws mut WebSocket, Message>,
|
||||||
config: &'c crate::config::Config,
|
config: &'c ServerConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WsSenderWrapper<'_, '_> {
|
impl WsSenderWrapper<'_, '_> {
|
||||||
async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> {
|
async fn send(&mut self, msg: &Outgoing<'_>) -> Result<()> {
|
||||||
let data = serde_json::to_string(&msg).expect("serialization cannot fail");
|
let data = serde_json::to_string(&msg).expect("serialization cannot fail");
|
||||||
let fut = tokio::time::timeout(
|
let fut = tokio::time::timeout(
|
||||||
self.config.server.ws_send_timeout_sec,
|
self.config.ws_send_timeout_sec,
|
||||||
self.inner.send(Message::Text(data)),
|
self.inner.send(Message::Text(data)),
|
||||||
);
|
);
|
||||||
match fut.await {
|
match fut.await {
|
||||||
|
@ -108,7 +109,7 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib
|
||||||
};
|
};
|
||||||
|
|
||||||
let uid = {
|
let uid = {
|
||||||
let payload = tokio::time::timeout(st.config.server.ws_auth_timeout_sec, ws_rx.next())
|
let payload = tokio::time::timeout(st.config.ws_auth_timeout_sec, ws_rx.next())
|
||||||
.await
|
.await
|
||||||
.context("authentication timeout")?
|
.context("authentication timeout")?
|
||||||
.ok_or(StreamEnded)??;
|
.ok_or(StreamEnded)??;
|
||||||
|
@ -136,7 +137,7 @@ pub async fn handle_ws(st: Arc<AppState>, ws: &mut WebSocket) -> Result<Infallib
|
||||||
let rx = match st.event.user_listeners.lock().entry(uid) {
|
let rx = match st.event.user_listeners.lock().entry(uid) {
|
||||||
Entry::Occupied(ent) => ent.get().subscribe(),
|
Entry::Occupied(ent) => ent.get().subscribe(),
|
||||||
Entry::Vacant(ent) => {
|
Entry::Vacant(ent) => {
|
||||||
let (tx, rx) = broadcast::channel(st.config.server.ws_event_queue_len);
|
let (tx, rx) = broadcast::channel(st.config.ws_event_queue_len);
|
||||||
ent.insert(tx);
|
ent.insert(tx);
|
||||||
rx
|
rx
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
use std::num::NonZeroUsize;
|
use std::num::NonZeroUsize;
|
||||||
use std::path::PathBuf;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, SystemTime};
|
use std::time::{Duration, SystemTime};
|
||||||
|
|
||||||
|
@ -15,8 +14,7 @@ use blah::types::{
|
||||||
ChatItem, ChatPayload, CreateRoomPayload, Id, MemberPermission, RoomAdminOp, RoomAdminPayload,
|
ChatItem, ChatPayload, CreateRoomPayload, Id, MemberPermission, RoomAdminOp, RoomAdminPayload,
|
||||||
RoomAttrs, ServerPermission, Signee, UserKey, WithItemId, WithSig,
|
RoomAttrs, ServerPermission, Signee, UserKey, WithItemId, WithSig,
|
||||||
};
|
};
|
||||||
use config::Config;
|
use config::ServerConfig;
|
||||||
use database::Database;
|
|
||||||
use ed25519_dalek::SIGNATURE_LENGTH;
|
use ed25519_dalek::SIGNATURE_LENGTH;
|
||||||
use id::IdExt;
|
use id::IdExt;
|
||||||
use middleware::{ApiError, Auth, MaybeAuth, ResultExt as _, SignedJson};
|
use middleware::{ApiError, Auth, MaybeAuth, ResultExt as _, SignedJson};
|
||||||
|
@ -28,80 +26,39 @@ use utils::ExpiringSet;
|
||||||
|
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
mod middleware;
|
mod middleware;
|
||||||
mod config;
|
pub mod config;
|
||||||
mod database;
|
mod database;
|
||||||
mod event;
|
mod event;
|
||||||
mod id;
|
mod id;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
/// Blah Chat Server
|
pub use database::Database;
|
||||||
#[derive(Debug, clap::Parser)]
|
|
||||||
#[clap(about, version = option_env!("CFG_RELEASE").unwrap_or(env!("CARGO_PKG_VERSION")))]
|
|
||||||
enum Cli {
|
|
||||||
/// Run the server with given configuration.
|
|
||||||
Serve {
|
|
||||||
/// The path to the configuration file.
|
|
||||||
#[arg(long, short)]
|
|
||||||
config: PathBuf,
|
|
||||||
},
|
|
||||||
|
|
||||||
/// Validate the configuration file and exit.
|
|
||||||
Validate {
|
|
||||||
/// The path to the configuration file.
|
|
||||||
#[arg(long, short)]
|
|
||||||
config: PathBuf,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
tracing_subscriber::fmt::init();
|
|
||||||
let cli = <Cli as clap::Parser>::parse();
|
|
||||||
|
|
||||||
fn parse_config(path: &std::path::Path) -> Result<Config> {
|
|
||||||
let src = std::fs::read_to_string(path)?;
|
|
||||||
let config = basic_toml::from_str::<Config>(&src)?;
|
|
||||||
config.validate()?;
|
|
||||||
Ok(config)
|
|
||||||
}
|
|
||||||
|
|
||||||
match cli {
|
|
||||||
Cli::Serve { config } => {
|
|
||||||
let config = parse_config(&config)?;
|
|
||||||
let st = AppState::init(config).context("failed to initialize state")?;
|
|
||||||
tokio::runtime::Builder::new_multi_thread()
|
|
||||||
.enable_all()
|
|
||||||
.build()
|
|
||||||
.context("failed to initialize tokio runtime")?
|
|
||||||
.block_on(main_async(st))
|
|
||||||
}
|
|
||||||
Cli::Validate { config } => {
|
|
||||||
parse_config(&config)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Locks must be grabbed in the field order.
|
// Locks must be grabbed in the field order.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct AppState {
|
pub struct AppState {
|
||||||
db: Database,
|
db: Database,
|
||||||
used_nonces: Mutex<ExpiringSet<u32>>,
|
used_nonces: Mutex<ExpiringSet<u32>>,
|
||||||
event: event::State,
|
event: event::State,
|
||||||
|
|
||||||
config: Config,
|
config: ServerConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppState {
|
impl AppState {
|
||||||
fn init(config: Config) -> Result<Self> {
|
pub fn new(db: Database, config: ServerConfig) -> Self {
|
||||||
Ok(Self {
|
Self {
|
||||||
db: Database::open(&config.database).context("failed to open database")?,
|
db,
|
||||||
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(
|
used_nonces: Mutex::new(ExpiringSet::new(Duration::from_secs(
|
||||||
config.server.timestamp_tolerance_secs,
|
config.timestamp_tolerance_secs,
|
||||||
))),
|
))),
|
||||||
event: event::State::default(),
|
event: event::State::default(),
|
||||||
|
|
||||||
config,
|
config,
|
||||||
})
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn serve(self) -> Result<()> {
|
||||||
|
serve(Arc::new(self)).await
|
||||||
}
|
}
|
||||||
|
|
||||||
fn verify_signed_data<T: Serialize>(&self, data: &WithSig<T>) -> Result<(), ApiError> {
|
fn verify_signed_data<T: Serialize>(&self, data: &WithSig<T>) -> Result<(), ApiError> {
|
||||||
|
@ -117,7 +74,7 @@ impl AppState {
|
||||||
.expect("after UNIX epoch")
|
.expect("after UNIX epoch")
|
||||||
.as_secs()
|
.as_secs()
|
||||||
.abs_diff(data.signee.timestamp);
|
.abs_diff(data.signee.timestamp);
|
||||||
if timestamp_diff > self.config.server.timestamp_tolerance_secs {
|
if timestamp_diff > self.config.timestamp_tolerance_secs {
|
||||||
return Err(error_response!(
|
return Err(error_response!(
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
"invalid_timestamp",
|
"invalid_timestamp",
|
||||||
|
@ -137,10 +94,22 @@ impl AppState {
|
||||||
|
|
||||||
type ArcState = State<Arc<AppState>>;
|
type ArcState = State<Arc<AppState>>;
|
||||||
|
|
||||||
async fn main_async(st: AppState) -> Result<()> {
|
async fn serve(st: Arc<AppState>) -> Result<()> {
|
||||||
let st = Arc::new(st);
|
let listener = tokio::net::TcpListener::bind(&st.config.listen)
|
||||||
|
.await
|
||||||
|
.context("failed to listen on socket")?;
|
||||||
|
tracing::info!("listening on {}", st.config.listen);
|
||||||
|
let router = router(st.clone());
|
||||||
|
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
|
||||||
|
|
||||||
let app = Router::new()
|
axum::serve(listener, router)
|
||||||
|
.await
|
||||||
|
.context("failed to serve")?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn router(st: Arc<AppState>) -> Router {
|
||||||
|
Router::new()
|
||||||
.route("/ws", get(handle_ws))
|
.route("/ws", get(handle_ws))
|
||||||
.route("/room", get(room_list))
|
.route("/room", get(room_list))
|
||||||
.route("/room/create", post(room_create))
|
.route("/room/create", post(room_create))
|
||||||
|
@ -150,27 +119,16 @@ async fn main_async(st: AppState) -> Result<()> {
|
||||||
.route("/room/:rid/item", get(room_item_list).post(room_item_post))
|
.route("/room/:rid/item", get(room_item_list).post(room_item_post))
|
||||||
.route("/room/:rid/item/:cid/seen", post(room_item_mark_seen))
|
.route("/room/:rid/item/:cid/seen", post(room_item_mark_seen))
|
||||||
.route("/room/:rid/admin", post(room_admin))
|
.route("/room/:rid/admin", post(room_admin))
|
||||||
.with_state(st.clone())
|
|
||||||
.layer(tower_http::limit::RequestBodyLimitLayer::new(
|
.layer(tower_http::limit::RequestBodyLimitLayer::new(
|
||||||
st.config.server.max_request_len,
|
st.config.max_request_len,
|
||||||
))
|
))
|
||||||
// NB. This comes at last (outmost layer), so inner errors will still be wrapped with
|
// NB. This comes at last (outmost layer), so inner errors will still be wrapped with
|
||||||
// correct CORS headers. Also `Authorization` must be explicitly included besides `*`.
|
// correct CORS headers. Also `Authorization` must be explicitly included besides `*`.
|
||||||
.layer(
|
.layer(
|
||||||
tower_http::cors::CorsLayer::permissive()
|
tower_http::cors::CorsLayer::permissive()
|
||||||
.allow_headers([header::HeaderName::from_static("*"), header::AUTHORIZATION]),
|
.allow_headers([header::HeaderName::from_static("*"), header::AUTHORIZATION]),
|
||||||
);
|
)
|
||||||
|
.with_state(st)
|
||||||
let listener = tokio::net::TcpListener::bind(&st.config.server.listen)
|
|
||||||
.await
|
|
||||||
.context("failed to listen on socket")?;
|
|
||||||
tracing::info!("listening on {}", st.config.server.listen);
|
|
||||||
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
|
|
||||||
|
|
||||||
axum::serve(listener, app)
|
|
||||||
.await
|
|
||||||
.context("failed to serve")?;
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type RE<T> = R<T, ApiError>;
|
type RE<T> = R<T, ApiError>;
|
||||||
|
@ -193,11 +151,11 @@ async fn handle_ws(State(st): ArcState, ws: WebSocketUpgrade) -> Response {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
struct RoomList {
|
pub struct RoomList {
|
||||||
rooms: Vec<RoomMetadata>,
|
pub rooms: Vec<RoomMetadata>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
skip_token: Option<Id>,
|
pub skip_token: Option<Id>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
@ -463,16 +421,16 @@ impl Pagination {
|
||||||
fn effective_page_len(&self, st: &AppState) -> usize {
|
fn effective_page_len(&self, st: &AppState) -> usize {
|
||||||
self.top
|
self.top
|
||||||
.unwrap_or(usize::MAX.try_into().expect("not zero"))
|
.unwrap_or(usize::MAX.try_into().expect("not zero"))
|
||||||
.min(st.config.server.max_page_len)
|
.min(st.config.max_page_len)
|
||||||
.get()
|
.get()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
struct RoomItems {
|
pub struct RoomItems {
|
||||||
items: Vec<WithItemId<ChatItem>>,
|
pub items: Vec<WithItemId<ChatItem>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
skip_token: Option<Id>,
|
pub skip_token: Option<Id>,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn room_item_list(
|
async fn room_item_list(
|
||||||
|
@ -548,7 +506,6 @@ async fn room_get_feed(
|
||||||
|
|
||||||
let feed_url = st
|
let feed_url = st
|
||||||
.config
|
.config
|
||||||
.server
|
|
||||||
.base_url
|
.base_url
|
||||||
.join(&format!("/room/{rid}/feed.json"))
|
.join(&format!("/room/{rid}/feed.json"))
|
||||||
.expect("base_url must be valid");
|
.expect("base_url must be valid");
|
||||||
|
@ -616,7 +573,7 @@ struct FeedItemExtra {
|
||||||
sig: [u8; SIGNATURE_LENGTH],
|
sig: [u8; SIGNATURE_LENGTH],
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
pub struct RoomMetadata {
|
pub struct RoomMetadata {
|
||||||
pub rid: Id,
|
pub rid: Id,
|
||||||
pub title: String,
|
pub title: String,
|
75
blahd/tests/basic.rs
Normal file
75
blahd/tests/basic.rs
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
// FIXME: False positive?
|
||||||
|
#![allow(clippy::unwrap_used)]
|
||||||
|
use std::fmt;
|
||||||
|
use std::future::IntoFuture;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use blahd::{AppState, Database, RoomList};
|
||||||
|
use reqwest::Client;
|
||||||
|
use rstest::{fixture, rstest};
|
||||||
|
use rusqlite::Connection;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
|
// Avoid name resolution.
|
||||||
|
const LOCALHOST: &str = "127.0.0.1";
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Server {
|
||||||
|
port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Server {
|
||||||
|
fn url(&self, rhs: impl fmt::Display) -> String {
|
||||||
|
format!("http://{}:{}{}", LOCALHOST, self.port, rhs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[fixture]
|
||||||
|
fn server() -> Server {
|
||||||
|
let mut conn = Connection::open_in_memory().unwrap();
|
||||||
|
Database::maybe_init(&mut conn).unwrap();
|
||||||
|
let db = Database::from_raw(conn).unwrap();
|
||||||
|
|
||||||
|
// Use std's to avoid async, since we need no name resolution.
|
||||||
|
let listener = std::net::TcpListener::bind(format!("{LOCALHOST}:0")).unwrap();
|
||||||
|
listener.set_nonblocking(true).unwrap();
|
||||||
|
let port = listener.local_addr().unwrap().port();
|
||||||
|
let listener = TcpListener::from_std(listener).unwrap();
|
||||||
|
|
||||||
|
// TODO: Testing config is hard to build because it does have a `Default` impl.
|
||||||
|
let config = basic_toml::from_str(&format!(
|
||||||
|
r#"
|
||||||
|
listen = "" # TODO: unused
|
||||||
|
base_url = "http://{LOCALHOST}:{port}"
|
||||||
|
"#
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
let st = AppState::new(db, config);
|
||||||
|
let router = blahd::router(Arc::new(st));
|
||||||
|
|
||||||
|
tokio::spawn(axum::serve(listener, router).into_future());
|
||||||
|
Server { port }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[fixture]
|
||||||
|
fn client() -> Client {
|
||||||
|
Client::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rstest]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn smoke(client: Client, server: Server) {
|
||||||
|
let got = client
|
||||||
|
.get(server.url("/room?filter=public"))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.json::<RoomList>()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let exp = RoomList {
|
||||||
|
rooms: Vec::new(),
|
||||||
|
skip_token: None,
|
||||||
|
};
|
||||||
|
assert_eq!(got, exp);
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue