diff --git a/Cargo.toml b/Cargo.toml index a9e3903..95e5fa6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,7 @@ tower = "0.4" fastrand = "1.5" brotli = { version = "3", default-features = false, features = ["std"]} rcgen = { version = "0.9", default-features = false } + +[dev-dependencies] +tokio-test = "0.4" +axum-test-helper = "0.1" diff --git a/src/server.rs b/src/server.rs index 0d6052c..48134c4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::net::SocketAddr; use axum::headers::HeaderName; -use axum::http::{HeaderValue, StatusCode, Uri}; +use axum::http::{HeaderMap, HeaderValue, StatusCode, Uri}; use axum::response::{Html, IntoResponse, Response}; use axum::routing::{get, get_service}; use axum::Router; @@ -29,6 +29,36 @@ pub struct Options { } pub async fn run_server(options: Options, output: WasmBindgenOutput) -> Result<()> { + let app = get_router(&options, output); + let mut address_string = options.address; + if !address_string.contains(":") { + address_string += + &(":".to_owned() + &pick_port::pick_free_port(1334, 10).unwrap_or(1334).to_string()); + } + let addr: SocketAddr = address_string.parse().expect("Couldn't parse address"); + + if options.https { + let certificate = rcgen::generate_simple_self_signed([String::from("localhost")])?; + let config = RustlsConfig::from_der( + vec![certificate.serialize_der()?], + certificate.serialize_private_key_der(), + ) + .await?; + + tracing::info!("starting webserver at https://{}", addr); + axum_server_dual_protocol::bind_dual_protocol(addr, config) + .set_upgrade(true) + .serve(app.into_make_service()) + .await?; + } else { + tracing::info!("starting webserver at http://{}", addr); + axum_server::bind(addr).serve(app.into_make_service()).await?; + } + + Ok(()) +} + +fn get_router(options: &Options, output: WasmBindgenOutput) -> Router { let WasmBindgenOutput { js, compressed_wasm, snippets, local_modules } = output; let middleware_stack = ServiceBuilder::new() @@ -53,13 +83,13 @@ pub async fn run_server(options: Options, output: WasmBindgenOutput) -> Result<( let html = html.replace("{{ TITLE }}", &options.title); let serve_dir = - get_service(ServeDir::new(options.directory)).handle_error(internal_server_error); + get_service(ServeDir::new(options.directory.clone())).handle_error(internal_server_error); let serve_wasm = || async move { ([("content-encoding", "br")], WithContentType("application/wasm", compressed_wasm)) }; - let app = Router::new() + Router::new() .route("/", get(move || async { Html(html) })) .route("/api/wasm.js", get(|| async { WithContentType("application/javascript", js) })) .route("/api/wasm.wasm", get(serve_wasm)) @@ -77,34 +107,7 @@ pub async fn run_server(options: Options, output: WasmBindgenOutput) -> Result<( }), ) .fallback(serve_dir) - .layer(middleware_stack); - - let mut address_string = options.address; - if !address_string.contains(":") { - address_string += - &(":".to_owned() + &pick_port::pick_free_port(1334, 10).unwrap_or(1334).to_string()); - } - let addr: SocketAddr = address_string.parse().expect("Couldn't parse address"); - - if options.https { - let certificate = rcgen::generate_simple_self_signed([String::from("localhost")])?; - let config = RustlsConfig::from_der( - vec![certificate.serialize_der()?], - certificate.serialize_private_key_der(), - ) - .await?; - - tracing::info!("starting webserver at https://{}", addr); - axum_server_dual_protocol::bind_dual_protocol(addr, config) - .set_upgrade(true) - .serve(app.into_make_service()) - .await?; - } else { - tracing::info!("starting webserver at http://{}", addr); - axum_server::bind(addr).serve(app.into_make_service()).await?; - } - - Ok(()) + .layer(middleware_stack) } fn get_snippet_source( @@ -165,3 +168,57 @@ mod pick_port { .or_else(ask_free_tcp_port) } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::server::get_router; + use crate::wasm_bindgen::WasmBindgenOutput; + use crate::Options; + use axum::http::StatusCode; + use axum_test_helper::TestClient; + + const FAKE_BR_COMPRESSED_WASM: [u8; 4] = [1, 2, 3, 4]; + + fn fake_options() -> Options { + Options { + title: "title".to_string(), + address: "127.0.0.1:0".to_string(), + directory: ".".to_string(), + https: false, + no_module: false, + } + } + + fn fake_wasm_bindgen_output() -> WasmBindgenOutput { + WasmBindgenOutput { + js: "fake js".to_string(), + compressed_wasm: FAKE_BR_COMPRESSED_WASM.to_vec(), + snippets: HashMap::>::new(), + local_modules: HashMap::::new(), + } + } + + fn make_test_client() -> TestClient { + let options = fake_options(); + let output = fake_wasm_bindgen_output(); + let router = get_router(&options, output); + TestClient::new(router) + } + + #[tokio::test] + async fn test_router() { + let client = make_test_client(); + + // Test with br compression requested + let mut res = client + .get("/api/wasm.wasm") + .header("accept-encoding", "gzip, deflate, br") + .send() + .await; + assert_eq!(res.status(), StatusCode::OK); + let result = res.chunk().await.unwrap(); + assert_eq!(result.to_vec(), FAKE_BR_COMPRESSED_WASM); + } +}