diff --git a/inst/validation/method_comparison.R b/inst/validation/method_comparison.R new file mode 100644 index 0000000..dc64c97 --- /dev/null +++ b/inst/validation/method_comparison.R @@ -0,0 +1,310 @@ +## Method comparison: marginal vs joint g-computation across scenarios. +## Phase 1.5 of the joint g-comp refactor (issue #65). Produces +## side-by-side target-stat comparison tables under matched parameter +## sets, identifies stats whose values shift materially under the +## method correction, and writes the result as a markdown report. +## +## Usage: +## source(system.file("validation/method_comparison.R", package = "ARTnet")) +## res <- compare_methods() +## summarize_comparison(res) +## render_comparison_report(res, file = "inst/validation/method_comparison.md") +## +## Requires ARTnetData. The full run takes ~30s on a recent laptop +## (4 scenarios x 2 methods x build_epistats + build_netparams + build_netstats). + +# ---- Scenario definitions ---------------------------------------------------- + +# Each scenario: a list with $name and the args to build_{epistats,netparams,netstats}. +# Conceptually: +# atlanta_default = baseline EpiModelHIV-Template setup +# national_no_geog = no geographic stratification (sanity: no geogYN baked in) +# atlanta_nhbs_shifted = Atlanta with NHBS-MSM-like race composition +# (35% Black, 25% Hispanic, 40% W/Other; vs Atlanta's +# default ~52/5/44 mix). Tests population-shift bias. +# atlanta_no_race = race = FALSE path (sanity) +COMPARISON_SCENARIOS <- list( + list( + name = "atlanta_default", + epistats = list(geog.lvl = "city", geog.cat = "Atlanta", + init.hiv.prev = c(0.33, 0.137, 0.084), + race = TRUE, time.unit = 7), + netparams = list(smooth.main.dur = TRUE), + netstats = list(expect.mort = 0.000478213, network.size = 5000) + ), + list( + name = "national_no_geog", + epistats = list(race = TRUE, time.unit = 7), + netparams = list(smooth.main.dur = TRUE), + netstats = list(expect.mort = 0.000478213, network.size = 5000) + ), + list( + name = "atlanta_nhbs_shifted", + epistats = list(geog.lvl = "city", geog.cat = "Atlanta", + init.hiv.prev = c(0.33, 0.137, 0.084), + race = TRUE, time.unit = 7), + netparams = list(smooth.main.dur = TRUE), + netstats = list(expect.mort = 0.000478213, network.size = 5000, + race.prop = c(0.35, 0.25, 0.40)) + ), + list( + name = "atlanta_no_race", + epistats = list(geog.lvl = "city", geog.cat = "Atlanta", + init.hiv.prev = c(0.33, 0.137, 0.084), + race = FALSE, time.unit = 7), + netparams = list(smooth.main.dur = TRUE), + netstats = list(expect.mort = 0.000478213, network.size = 5000) + ) +) + +.COMPARISON_SEED <- 20260420L + + +# ---- Utilities -------------------------------------------------------------- + +.require_artnetdata <- function() { + if (system.file(package = "ARTnetData") == "") { + stop("ARTnetData not installed; method comparison cannot run.") + } +} + +# Walk a netstats object and produce a long-format data.frame of every +# target statistic with columns scenario, method, layer, stat, level, value. +# `level` is NA for scalars; integer index for vector-valued stats. +.extract_target_stats <- function(netstats, scenario, method) { + rows <- list() + add_scalar <- function(layer, stat, x) { + if (is.null(x) || length(x) != 1) return(NULL) + data.frame(scenario = scenario, method = method, layer = layer, + stat = stat, level = NA_integer_, value = as.numeric(x), + stringsAsFactors = FALSE) + } + add_vec <- function(layer, stat, v) { + if (is.null(v) || length(v) < 1) return(NULL) + data.frame(scenario = scenario, method = method, layer = layer, + stat = stat, level = seq_along(v), value = unname(as.numeric(v)), + stringsAsFactors = FALSE) + } + for (layer in c("main", "casl", "inst")) { + L <- netstats[[layer]] + if (is.null(L)) next + rows <- c(rows, list( + add_scalar(layer, "edges", L$edges), + add_scalar(layer, "concurrent", L$concurrent), + add_scalar(layer, "nodematch_race_diffF", L$nodematch_race_diffF), + add_scalar(layer, "absdiff_age", L$absdiff_age), + add_scalar(layer, "absdiff_sqrt.age", L$absdiff_sqrt.age), + add_vec(layer, "nodefactor_race", L$nodefactor_race), + add_vec(layer, "nodefactor_age.grp", L$nodefactor_age.grp), + add_vec(layer, "nodefactor_diag.status", L$nodefactor_diag.status), + add_vec(layer, "nodefactor_deg.casl", L$nodefactor_deg.casl), + add_vec(layer, "nodefactor_deg.main", L$nodefactor_deg.main), + add_vec(layer, "nodefactor_deg.tot", L$nodefactor_deg.tot), + add_vec(layer, "nodefactor_risk.grp", L$nodefactor_risk.grp), + add_vec(layer, "nodematch_race", L$nodematch_race), + add_vec(layer, "nodematch_age.grp", L$nodematch_age.grp) + )) + # Stratum-level dissolution durations (main / casl only — inst is + # ~offset(edges) with a fixed 1). + if (layer != "inst" && !is.null(L$diss.byage$duration) && + length(L$diss.byage$duration) > 1) { + rows <- c(rows, list( + add_vec(layer, "dissolution_duration", L$diss.byage$duration) + )) + } + } + do.call(rbind, rows[!vapply(rows, is.null, logical(1))]) +} + +# Run one scenario through both methods, return both extracted stat tables. +.run_one_scenario <- function(scenario) { + do_run <- function(np_method, dur_method) { + set.seed(.COMPARISON_SEED) + epistats <- do.call(ARTnet::build_epistats, scenario$epistats) + set.seed(.COMPARISON_SEED) + np_args <- c(list(epistats = epistats), scenario$netparams, + list(method = np_method, duration.method = dur_method)) + netparams <- do.call(ARTnet::build_netparams, np_args) + set.seed(.COMPARISON_SEED) + ns_args <- c(list(epistats = epistats, netparams = netparams), + scenario$netstats, list(method = np_method)) + do.call(ARTnet::build_netstats, ns_args) + } + list( + existing = .extract_target_stats(do_run("existing", "empirical"), + scenario$name, "existing"), + joint = .extract_target_stats(do_run("joint", "joint_lm"), + scenario$name, "joint") + ) +} + + +# ---- Public entry points ----------------------------------------------------- + +#' Run the method comparison across scenarios. +#' +#' @param scenarios List of scenario specs (defaults to COMPARISON_SCENARIOS). +#' @return Long-format data.frame with columns scenario, layer, stat, level, +#' existing, joint, abs_diff, pct_diff (set to NA when existing == 0). +#' Each row is a single (scenario, layer, stat, level) cell with both +#' methods' values side-by-side. +compare_methods <- function(scenarios = COMPARISON_SCENARIOS) { + .require_artnetdata() + message("Running ", length(scenarios), " scenarios x 2 methods (", + "this is build_*x6, expect ~30s)...") + pieces <- list() + for (s in scenarios) { + message(" ", s$name) + res <- .run_one_scenario(s) + pieces[[s$name]] <- res + } + # Wide on method + long_existing <- do.call(rbind, lapply(pieces, `[[`, "existing")) + long_joint <- do.call(rbind, lapply(pieces, `[[`, "joint")) + key <- c("scenario", "layer", "stat", "level") + wide <- merge(long_existing[, c(key, "value")], + long_joint[, c(key, "value")], + by = key, all = TRUE, + suffixes = c("_existing", "_joint")) + names(wide)[names(wide) == "value_existing"] <- "existing" + names(wide)[names(wide) == "value_joint"] <- "joint" + wide$abs_diff <- wide$joint - wide$existing + wide$pct_diff <- ifelse(abs(wide$existing) > 1e-12, + 100 * wide$abs_diff / wide$existing, + NA_real_) + wide[order(wide$scenario, wide$layer, wide$stat, + ifelse(is.na(wide$level), -1L, wide$level)), , drop = FALSE] +} + + +#' Print a high-level summary of the comparison. +summarize_comparison <- function(comparison, threshold_pct = 5) { + cat(sprintf("\nTotal cells: %d (across %d scenarios)\n", + nrow(comparison), length(unique(comparison$scenario)))) + ok <- !is.na(comparison$pct_diff) + cat(sprintf("Cells with |%%diff| > %g%%: %d\n", + threshold_pct, + sum(abs(comparison$pct_diff) > threshold_pct, na.rm = TRUE))) + for (s in unique(comparison$scenario)) { + sub <- comparison[comparison$scenario == s, , drop = FALSE] + nbig <- sum(abs(sub$pct_diff) > threshold_pct, na.rm = TRUE) + cat(sprintf("\n=== %s: %d cells, %d materially shifted (>%g%%) ===\n", + s, nrow(sub), nbig, threshold_pct)) + if (nbig == 0) { + cat(" (no material shifts)\n") + next + } + big <- sub[!is.na(sub$pct_diff) & abs(sub$pct_diff) > threshold_pct, ] + big <- big[order(-abs(big$pct_diff)), , drop = FALSE] + n_show <- min(10, nrow(big)) + cat(sprintf(" Top %d by |%%diff|:\n", n_show)) + for (i in seq_len(n_show)) { + r <- big[i, ] + level_str <- if (is.na(r$level)) "" else sprintf("[%d]", r$level) + cat(sprintf(" %-7s %-22s%-5s existing=%9.2f joint=%9.2f (%+0.1f%%)\n", + r$layer, r$stat, level_str, r$existing, r$joint, r$pct_diff)) + } + } + invisible(comparison) +} + + +#' Render a markdown report of the comparison results. +render_comparison_report <- function(comparison, + file = "inst/validation/method_comparison.md", + threshold_pct = 5) { + con <- file(file, "w") + on.exit(close(con)) + out <- function(...) cat(..., "\n", sep = "", file = con) + + out("# Method comparison: marginal vs joint g-computation") + out() + out("Generated by `inst/validation/method_comparison.R` on ", as.character(Sys.Date()), + ". Phase 1.5 of the joint g-comp refactor; closes part of issue #65.") + out() + out("ARTnet version: ", as.character(packageVersion("ARTnet")), ". Seed: ", + .COMPARISON_SEED, ". Network size: 5000.") + out() + out("## Scenarios") + out() + out("| Scenario | Description |") + out("|---|---|") + out("| `atlanta_default` | Baseline EpiModelHIV-Template config (Atlanta, race = TRUE) |") + out("| `national_no_geog` | No geographic stratification (sanity check) |") + out("| `atlanta_nhbs_shifted` | Atlanta config with `race.prop = c(0.35, 0.25, 0.40)` (NHBS-MSM-like) |") + out("| `atlanta_no_race` | `race = FALSE` path (sanity check) |") + out() + out("## High-level summary") + out() + ok <- !is.na(comparison$pct_diff) + total_big <- sum(abs(comparison$pct_diff) > threshold_pct, na.rm = TRUE) + out("- Total target-stat cells across scenarios: ", nrow(comparison)) + out("- Cells where |joint vs existing %diff| > ", threshold_pct, "%: **", + total_big, "**") + out() + for (s in unique(comparison$scenario)) { + sub <- comparison[comparison$scenario == s, , drop = FALSE] + nbig <- sum(abs(sub$pct_diff) > threshold_pct, na.rm = TRUE) + out("- `", s, "`: ", nrow(sub), " cells, ", nbig, " materially shifted (>", + threshold_pct, "%)") + } + out() + out("## Per-scenario top shifts (by |% diff|)") + out() + for (s in unique(comparison$scenario)) { + sub <- comparison[comparison$scenario == s, , drop = FALSE] + big <- sub[!is.na(sub$pct_diff) & abs(sub$pct_diff) > threshold_pct, ] + big <- big[order(-abs(big$pct_diff)), , drop = FALSE] + out("### ", s) + out() + if (nrow(big) == 0) { + out("_No cells exceed |", threshold_pct, "%| threshold._") + out() + next + } + out("| Layer | Stat | Level | Existing | Joint | %Δ |") + out("|---|---|---:|---:|---:|---:|") + for (i in seq_len(min(15, nrow(big)))) { + r <- big[i, ] + level_str <- if (is.na(r$level)) "—" else as.character(r$level) + out(sprintf("| %s | %s | %s | %.2f | %.2f | %+.1f%% |", + r$layer, r$stat, level_str, + r$existing, r$joint, r$pct_diff)) + } + if (nrow(big) > 15) { + out() + out("_...and ", nrow(big) - 15, " more cells over threshold._") + } + out() + } + out("## Edges / concurrent comparison (all scenarios)") + out() + out("| Scenario | Layer | Existing edges | Joint edges | %Δ | Existing concurrent | Joint concurrent | %Δ |") + out("|---|---|---:|---:|---:|---:|---:|---:|") + edges <- comparison[comparison$stat == "edges", ] + conc <- comparison[comparison$stat == "concurrent", ] + for (s in unique(comparison$scenario)) { + for (l in c("main", "casl", "inst")) { + e <- edges[edges$scenario == s & edges$layer == l, ] + c1 <- conc[conc$scenario == s & conc$layer == l, ] + if (nrow(e) == 0) next + conc_existing <- if (nrow(c1) > 0) sprintf("%.2f", c1$existing) else "—" + conc_joint <- if (nrow(c1) > 0) sprintf("%.2f", c1$joint) else "—" + conc_pct <- if (nrow(c1) > 0 && !is.na(c1$pct_diff)) + sprintf("%+.1f%%", c1$pct_diff) else "—" + out(sprintf("| %s | %s | %.2f | %.2f | %+.1f%% | %s | %s | %s |", + s, l, e$existing, e$joint, e$pct_diff, + conc_existing, conc_joint, conc_pct)) + } + } + out() + out("## Reproducibility") + out() + out("```r") + out("source(system.file(\"validation/method_comparison.R\", package = \"ARTnet\"))") + out("res <- compare_methods()") + out("summarize_comparison(res)") + out("render_comparison_report(res)") + out("```") + invisible(file) +} diff --git a/inst/validation/method_comparison.md b/inst/validation/method_comparison.md new file mode 100644 index 0000000..6b69d04 --- /dev/null +++ b/inst/validation/method_comparison.md @@ -0,0 +1,140 @@ +# Method comparison: marginal vs joint g-computation + +Generated by `inst/validation/method_comparison.R` on 2026-04-25. Phase 1.5 of the joint g-comp refactor; closes part of issue #65. + +ARTnet version: 2.9.0. Seed: 20260420. Network size: 5000. + +## Scenarios + +| Scenario | Description | +|---|---| +| `atlanta_default` | Baseline EpiModelHIV-Template config (Atlanta, race = TRUE) | +| `national_no_geog` | No geographic stratification (sanity check) | +| `atlanta_nhbs_shifted` | Atlanta config with `race.prop = c(0.35, 0.25, 0.40)` (NHBS-MSM-like) | +| `atlanta_no_race` | `race = FALSE` path (sanity check) | + +## High-level summary + +- Total target-stat cells across scenarios: 363 +- Cells where |joint vs existing %diff| > 5%: **229** + +- `atlanta_default`: 96 cells, 63 materially shifted (>5%) +- `atlanta_nhbs_shifted`: 96 cells, 66 materially shifted (>5%) +- `atlanta_no_race`: 75 cells, 49 materially shifted (>5%) +- `national_no_geog`: 96 cells, 51 materially shifted (>5%) + +## Per-scenario top shifts (by |% diff|) + +### atlanta_default + +| Layer | Stat | Level | Existing | Joint | %Δ | +|---|---|---:|---:|---:|---:| +| inst | nodematch_age.grp | 5 | 7.90 | 3.86 | -51.1% | +| main | dissolution_duration | 6 | 934.45 | 491.48 | -47.4% | +| inst | nodefactor_diag.status | 2 | 150.36 | 89.69 | -40.3% | +| main | dissolution_duration | 4 | 539.20 | 323.74 | -40.0% | +| casl | nodefactor_deg.main | 3 | 11.74 | 16.38 | +39.5% | +| main | dissolution_duration | 5 | 682.59 | 428.11 | -37.3% | +| casl | dissolution_duration | 2 | 50.44 | 68.93 | +36.7% | +| inst | nodefactor_age.grp | 1 | 46.94 | 63.36 | +35.0% | +| inst | nodefactor_age.grp | 5 | 85.81 | 57.25 | -33.3% | +| inst | nodematch_age.grp | 4 | 11.22 | 7.99 | -28.8% | +| casl | dissolution_duration | 4 | 113.16 | 81.76 | -27.7% | +| casl | dissolution_duration | 3 | 73.17 | 92.29 | +26.1% | +| casl | nodefactor_race | 2 | 119.99 | 149.17 | +24.3% | +| casl | dissolution_duration | 6 | 150.10 | 186.04 | +23.9% | +| inst | nodefactor_age.grp | 4 | 85.36 | 65.21 | -23.6% | + +_...and 48 more cells over threshold._ + +### atlanta_nhbs_shifted + +| Layer | Stat | Level | Existing | Joint | %Δ | +|---|---|---:|---:|---:|---:| +| main | dissolution_duration | 6 | 934.45 | 491.48 | -47.4% | +| inst | nodematch_age.grp | 5 | 7.90 | 4.18 | -47.1% | +| main | dissolution_duration | 4 | 539.20 | 323.74 | -40.0% | +| inst | nodefactor_diag.status | 2 | 129.36 | 80.06 | -38.1% | +| main | dissolution_duration | 5 | 682.59 | 428.11 | -37.3% | +| casl | nodefactor_deg.main | 3 | 11.74 | 15.96 | +36.0% | +| casl | dissolution_duration | 5 | 159.54 | 102.54 | -35.7% | +| inst | nodefactor_age.grp | 1 | 46.94 | 62.72 | +33.6% | +| inst | nodefactor_age.grp | 5 | 85.81 | 61.75 | -28.0% | +| casl | dissolution_duration | 4 | 113.16 | 81.76 | -27.7% | +| casl | absdiff_age | — | 12912.10 | 16137.21 | +25.0% | +| inst | nodematch_age.grp | 4 | 11.22 | 8.54 | -23.9% | +| main | dissolution_duration | 3 | 253.67 | 196.29 | -22.6% | +| casl | absdiff_sqrt.age | — | 1043.61 | 1278.92 | +22.5% | +| main | nodematch_race_diffF | — | 759.45 | 588.51 | -22.5% | + +_...and 51 more cells over threshold._ + +### atlanta_no_race + +| Layer | Stat | Level | Existing | Joint | %Δ | +|---|---|---:|---:|---:|---:| +| main | dissolution_duration | 6 | 934.45 | 489.84 | -47.6% | +| inst | nodefactor_age.grp | 1 | 46.94 | 67.75 | +44.4% | +| casl | nodefactor_deg.main | 3 | 11.74 | 16.80 | +43.1% | +| main | dissolution_duration | 4 | 539.20 | 322.52 | -40.2% | +| main | dissolution_duration | 5 | 682.59 | 426.73 | -37.5% | +| inst | nodematch_age.grp | 5 | 7.90 | 5.18 | -34.5% | +| casl | dissolution_duration | 5 | 159.54 | 107.80 | -32.4% | +| inst | nodematch_age.grp | 1 | 13.72 | 17.90 | +30.5% | +| casl | absdiff_age | — | 12912.10 | 16581.25 | +28.4% | +| main | absdiff_age | — | 5028.10 | 6346.42 | +26.2% | +| casl | absdiff_sqrt.age | — | 1043.61 | 1316.79 | +26.2% | +| inst | nodefactor_diag.status | 2 | 232.52 | 174.43 | -25.0% | +| casl | dissolution_duration | 1 | 105.85 | 79.69 | -24.7% | +| casl | dissolution_duration | 4 | 113.16 | 86.61 | -23.5% | +| main | dissolution_duration | 3 | 253.67 | 195.22 | -23.0% | + +_...and 34 more cells over threshold._ + +### national_no_geog + +| Layer | Stat | Level | Existing | Joint | %Δ | +|---|---|---:|---:|---:|---:| +| main | dissolution_duration | 6 | 916.66 | 459.09 | -49.9% | +| main | dissolution_duration | 4 | 528.94 | 302.73 | -42.8% | +| main | dissolution_duration | 5 | 669.60 | 400.16 | -40.2% | +| inst | nodematch_age.grp | 5 | 8.51 | 5.15 | -39.6% | +| casl | dissolution_duration | 5 | 161.98 | 104.87 | -35.3% | +| inst | nodefactor_diag.status | 2 | 93.33 | 61.53 | -34.1% | +| casl | nodefactor_deg.main | 3 | 19.56 | 25.18 | +28.7% | +| casl | dissolution_duration | 1 | 107.47 | 76.90 | -28.4% | +| casl | dissolution_duration | 4 | 114.89 | 83.60 | -27.2% | +| main | dissolution_duration | 3 | 248.85 | 183.55 | -26.2% | +| inst | nodefactor_age.grp | 1 | 55.02 | 68.06 | +23.7% | +| casl | nodefactor_race | 2 | 510.68 | 626.72 | +22.7% | +| casl | absdiff_age | — | 14388.91 | 17503.66 | +21.6% | +| inst | nodefactor_age.grp | 5 | 100.47 | 80.16 | -20.2% | +| casl | absdiff_sqrt.age | — | 1174.49 | 1401.01 | +19.3% | + +_...and 36 more cells over threshold._ + +## Edges / concurrent comparison (all scenarios) + +| Scenario | Layer | Existing edges | Joint edges | %Δ | Existing concurrent | Joint concurrent | %Δ | +|---|---|---:|---:|---:|---:|---:|---:| +| atlanta_default | main | 995.15 | 844.80 | -15.1% | 48.54 | 47.53 | -2.1% | +| atlanta_default | casl | 1334.95 | 1437.45 | +7.7% | 728.16 | 806.54 | +10.8% | +| atlanta_default | inst | 191.37 | 171.30 | -10.5% | — | — | — | +| atlanta_nhbs_shifted | main | 995.15 | 900.86 | -9.5% | 48.54 | 58.54 | +20.6% | +| atlanta_nhbs_shifted | casl | 1334.95 | 1462.26 | +9.5% | 728.16 | 820.76 | +12.7% | +| atlanta_nhbs_shifted | inst | 191.37 | 177.73 | -7.1% | — | — | — | +| atlanta_no_race | main | 995.15 | 1021.72 | +2.7% | 48.54 | 59.13 | +21.8% | +| atlanta_no_race | casl | 1334.95 | 1502.02 | +12.5% | 728.16 | 836.50 | +14.9% | +| atlanta_no_race | inst | 191.37 | 205.30 | +7.3% | — | — | — | +| national_no_geog | main | 1008.64 | 964.45 | -4.4% | 79.17 | 88.73 | +12.1% | +| national_no_geog | casl | 1362.33 | 1487.06 | +9.2% | 716.64 | 802.16 | +11.9% | +| national_no_geog | inst | 219.35 | 213.25 | -2.8% | — | — | — | + +## Reproducibility + +```r +source(system.file("validation/method_comparison.R", package = "ARTnet")) +res <- compare_methods() +summarize_comparison(res) +render_comparison_report(res) +``` diff --git a/tests/testthat/test-method-comparison.R b/tests/testthat/test-method-comparison.R new file mode 100644 index 0000000..ba0f781 --- /dev/null +++ b/tests/testthat/test-method-comparison.R @@ -0,0 +1,79 @@ +# Tests for the method-comparison validation harness in +# inst/validation/method_comparison.R (Phase 1.5 / issue #65). +# +# These tests exercise the helper structure on a single small scenario +# rather than the full suite (full suite is intentionally slow and +# produces inst/validation/method_comparison.md as its real output). + +skip_without_artnetdata <- function() { + testthat::skip_if(system.file(package = "ARTnetData") == "", + "ARTnetData not installed") +} + +source_helper <- function() { + path <- system.file("validation/method_comparison.R", package = "ARTnet") + if (!nzchar(path)) { + path <- "inst/validation/method_comparison.R" + } + testthat::skip_if(!file.exists(path), "method_comparison.R not found") + source(path, local = parent.frame()) +} + +mini_scenario <- list( + list( + name = "test_atlanta", + epistats = list(geog.lvl = "city", geog.cat = "Atlanta", + init.hiv.prev = c(0.33, 0.137, 0.084), + race = TRUE, time.unit = 7), + netparams = list(smooth.main.dur = TRUE), + netstats = list(expect.mort = 0.000478213, network.size = 2000) + ) +) + +test_that("compare_methods returns expected long-format structure", { + skip_without_artnetdata() + source_helper() + res <- suppressMessages(compare_methods(mini_scenario)) + expect_s3_class(res, "data.frame") + expect_named(res, c("scenario", "layer", "stat", "level", + "existing", "joint", "abs_diff", "pct_diff"), + ignore.order = TRUE) + expect_true(all(res$layer %in% c("main", "casl", "inst"))) + expect_true(all(res$scenario == "test_atlanta")) + expect_true(all(c("edges", "nodefactor_race", "nodefactor_age.grp", + "dissolution_duration") %in% res$stat)) + # All stats produced both methods + expect_false(any(is.na(res$existing))) + expect_false(any(is.na(res$joint))) +}) + +test_that("abs_diff and pct_diff are computed consistently", { + skip_without_artnetdata() + source_helper() + res <- suppressMessages(compare_methods(mini_scenario)) + expect_equal(res$abs_diff, res$joint - res$existing, tolerance = 1e-9) + # When existing is non-zero, pct_diff matches the formula + ok <- abs(res$existing) > 1e-12 & !is.na(res$pct_diff) + expect_equal(res$pct_diff[ok], + 100 * res$abs_diff[ok] / res$existing[ok], + tolerance = 1e-9) +}) + +test_that("comparison includes dissolution_duration for main and casl, not inst", { + skip_without_artnetdata() + source_helper() + res <- suppressMessages(compare_methods(mini_scenario)) + dur <- res[res$stat == "dissolution_duration", , drop = FALSE] + expect_true(all(dur$layer %in% c("main", "casl"))) + expect_true(any(dur$layer == "main")) + expect_true(any(dur$layer == "casl")) +}) + +test_that("at least one cell shifts > 5% between methods on Atlanta default", { + skip_without_artnetdata() + source_helper() + res <- suppressMessages(compare_methods(mini_scenario)) + # The whole point of the refactor is that joint differs from existing + # in at least some places. If this isn't true, something is broken. + expect_true(sum(abs(res$pct_diff) > 5, na.rm = TRUE) > 5) +})