Skip to content

[Repo Assist] fix: correct exact_match_cols logic in DistanceMatchingEstimator (closes #814)#1465

Open
github-actions[bot] wants to merge 4 commits intomainfrom
repo-assist/fix-issue-814-distance-matching-exact-match-5b399f05ecd37ff9
Open

[Repo Assist] fix: correct exact_match_cols logic in DistanceMatchingEstimator (closes #814)#1465
github-actions[bot] wants to merge 4 commits intomainfrom
repo-assist/fix-issue-814-distance-matching-exact-match-5b399f05ecd37ff9

Conversation

@github-actions
Copy link
Copy Markdown
Contributor

🤖 This is an automated pull request from Repo Assist.

Summary

Fixes three bugs in DistanceMatchingEstimator.estimate_effect() when exact_match_cols is provided, as reported in #814. The root symptom is that matched control units are not restricted to the same exact-match group, producing incorrect ATT/ATE estimates.

Root Causes

Bug 1 — Wrong iteration count in per-group loop (closes #814)

# Before — uses total treated count, not group treated count
for i in range(numtreatedunits):          # BUG
    treated_outcome = treated.iloc[i]...  # treated is now a group subset!
```

Inside the `groupby` loop `treated` is rebound to the current group's treated rows, but `numtreatedunits` (set from the full dataset) is still used for `range()`. This causes `IndexError` when a group is smaller than the full treated set, or silently skips rows when it's larger.

### Bug 2 — No guard for empty control groups

If a group contains no control units, `NearestNeighbors.fit()` receives an empty array and raises. Added a `continue` for `group_control.shape[0] == 0` (mirrors the existing guard for empty treated groups).

### Bug 3 — Stale `treated`/`control` references after exact-match loop

After the groupby loop the local names `treated` and `control` pointed at the **last group's** subset. The ATC branch (and ATE = ATT + ATC) then ran `NearestNeighbors.fit()` on this stale partial data, producing wildly wrong results.

## Fix

- Rename loop variables to `group_treated` / `group_control`outer names are never clobbered.
- Iterate over `group_treated.shape[0]` instead of `numtreatedunits`.
- Skip groups where `group_control.shape[0] == 0`.
- Accumulate `total_treated_matched` for a correct per-sample average.
- Restore `treated`/`control` from saved copies before the ATC block.

## Tests Added (Task 9 — Testing Improvements)

New file `tests/causal_estimators/test_distance_matching_estimator.py` adds **12 tests**:

| Test | What it checks |
|------|---------------|
| `test_estimate_is_close_to_true_effect[att/atc/ate]` | ATT/ATC/ATE within ±3 of true β=10 |
| `test_matched_indices_att_populated` | `matched_indices_att` has one entry per treated unit |
| `test_matched_indices_atc_populated` | `matched_indices_atc` has one entry per control unit |
| `test_exact_match_restricts_matches_to_same_group` | Estimate is sensible with `exact_match_cols` |
| `test_exact_match_estimate_finite[att/atc/ate]` | All target_units produce finite values with `exact_match_cols` |
| `test_invalid_target_units_raises` | `ValueError` for unsupported `target_units` |
| `test_non_binary_treatment_raises` | Raises for non-binary treatment |
| `test_average_treatment_effect_via_simple_estimator` | Smoke test via shared `SimpleEstimator` harness |

## Test Status

```
tests/causal_estimators/test_distance_matching_estimator.py  12 passedtests/causal_estimators/test_propensity_score_matching_estimator.py  2 passedtests/causal_estimators/test_linear_regression_estimator.py  11 passed

Note on flake8 C901: estimate_effect had a pre-existing complexity of 21 (exceeds the limit of 10). This PR adds one branch, raising it to 22. Fixing the pre-existing complexity is out of scope here; a separate refactor PR would be appropriate.

Format check: black and isort report no changes needed on the modified files.

Generated by 🌈 Repo Assist, see workflow run. Learn more.

To install this agentic workflow, run

gh aw add githubnext/agentics/workflows/repo-assist.md@11c9a2c442e519ff2b427bf58679f5a525353f76

 #814)

Three bugs in estimate_effect() with exact_match_cols:

1. Inner loop used global numtreatedunits instead of the per-group
   treated count, causing IndexError / wrong summation when group
   sizes differ from the overall dataset size.

2. No guard when a group has zero control units, which caused
   NearestNeighbors.fit() to receive an empty array.

3. After the groupby loop the names 'treated' and 'control' were
   left pointing at the last group's subsets, so the ATC branch
   (and ATE = ATT + ATC) operated on stale, partial data.

Fix:
- Rename loop variables to group_treated / group_control so outer
  names are never clobbered.
- Iterate over group_treated.shape[0] (not numtreatedunits).
- Skip groups with no control units (same as the existing skip for
  no treated units).
- Accumulate total_treated_matched for a correct per-sample ATT
  average.
- Restore treated/control from saved copies before the ATC block.

Also adds tests/causal_estimators/test_distance_matching_estimator.py
with 12 tests covering ATT/ATC/ATE estimation, matched_indices
population, exact matching correctness, and edge-case error handling.

Note: C901 complexity on estimate_effect was pre-existing (21); this
change adds one branch (complexity 22). A dedicated refactor PR would
be appropriate to bring it below 10.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
@github-actions github-actions bot added automation bug Something isn't working repo-assist labels Apr 18, 2026
@emrekiciman emrekiciman marked this pull request as ready for review April 18, 2026 14:43
@emrekiciman emrekiciman requested a review from Copilot April 18, 2026 14:44
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes incorrect exact_match_cols handling in DistanceMatchingEstimator.estimate_effect() so matching is performed within exact-match strata and ATC computation doesn’t accidentally run on the last stratum’s subset (issue #814).

Changes:

  • Correct per-stratum ATT loop iteration and guard against strata with no treated or no controls.
  • Preserve full treated/control DataFrames so ATC (and thus ATE) uses the full sample rather than a stratum subset.
  • Add a new test module covering ATT/ATC/ATE basics, invalid inputs, and exact-match scenarios.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
dowhy/causal_estimators/distance_matching_estimator.py Fixes exact-match ATT loop logic and restores full treated/control for ATC calculation.
tests/causal_estimators/test_distance_matching_estimator.py Adds regression/unit tests for distance matching, including exact-match usage.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 246 to 263
else:
grouped = updated_df.groupby(self.exact_match_cols)
att = 0
total_treated_matched = 0
for name, group in grouped:
treated = group.loc[group[self._target_estimand.treatment_variable[0]] == 1]
control = group.loc[group[self._target_estimand.treatment_variable[0]] == 0]
if treated.shape[0] == 0:
group_treated = group.loc[group[self._target_estimand.treatment_variable[0]] == 1]
group_control = group.loc[group[self._target_estimand.treatment_variable[0]] == 0]
if group_treated.shape[0] == 0 or group_control.shape[0] == 0:
continue
control_neighbors = NearestNeighbors(
n_neighbors=self.num_matches_per_unit,
metric=self.distance_metric,
algorithm="ball_tree",
**self.distance_metric_params,
).fit(control[self._observed_common_causes.columns].values)
).fit(group_control[self._observed_common_causes.columns].values)
distances, indices = control_neighbors.kneighbors(
treated[self._observed_common_causes.columns].values
group_treated[self._observed_common_causes.columns].values
)
Copy link

Copilot AI Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When exact_match_cols is used, matched_indices_att is never populated (it’s only set in the non-exact-match branch). This makes it impossible to inspect matches for exact matching (and can also leave stale values if the estimator instance is reused). Consider initializing self.matched_indices_att = {} before the groupby loop and filling it per group_treated row using the original dataframe indices, analogous to the non-exact-match path.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit adfa02f. Three changes were needed:

  1. distance_matching_estimator.py: Added self.matched_indices_att = {} initialization before the groupby loop, and populated it per group_treated row inside the loop using group_control.iloc[indices[i]].index.tolist() — analogous to the non-exact-match path.

  2. Tests method_params fix: The exact-match tests were using method_params={"exact_match_cols": [...]}, which passes exact_match_cols to __init__ where it's discarded as **_ by the base class. Changed to method_params={"fit_params": {"exact_match_cols": [...]}} so it's correctly routed to fit().

  3. test_exact_match_restricts_matches_to_same_group: Updated to iterate matched_indices_att.items() and assert every matched control has the same W_cat value as the treated unit, directly catching regressions like backdoor.distance_matching: Problem with exact matching #814.

Comment thread tests/causal_estimators/test_distance_matching_estimator.py Outdated
Comment on lines +276 to 283
if total_treated_matched > 0:
att /= total_treated_matched

if target_units == "att":
est = att
elif target_units == "ate":
est = att * numtreatedunits

Copy link

Copilot AI Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the exact_match_cols ATT branch, ATE is derived via est = att * numtreatedunits, but att is averaged over total_treated_matched (treated units in strata that had both treated+control). If any treated strata are skipped due to group_control.shape[0] == 0, this will over/under-weight the ATT contribution in ATE. Consider accumulating the ATT sum directly (or multiplying by total_treated_matched) and making the ATE weighting consistent with whatever units were actually matched; alternatively, raise/warn when some treated units have no eligible controls under exact matching.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

emrekiciman and others added 3 commits April 19, 2026 01:54
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Emre Kıcıman <emrek@microsoft.com>
…test method_params

- Initialize self.matched_indices_att = {} before the groupby loop in the exact-match branch
- Populate matched_indices_att per group_treated row using original dataframe indices
- Fix tests to use method_params={"fit_params": {"exact_match_cols": [...]}} so exact_match_cols
  is correctly routed to fit() (not just __init__)
- Update test_exact_match_restricts_matches_to_same_group to iterate matched_indices_att.items()
  and verify each matched control has the same W_cat value as the treated unit

Signed-off-by: GitHub Copilot <copilot@github.com>

Agent-Logs-Url: https://github.com/py-why/dowhy/sessions/58b2fe09-f297-4b0e-a6f9-a8f4c987b350

Co-authored-by: emrekiciman <5982160+emrekiciman@users.noreply.github.com>
Signed-off-by: GitHub Copilot <copilot@github.com>

Agent-Logs-Url: https://github.com/py-why/dowhy/sessions/58b2fe09-f297-4b0e-a6f9-a8f4c987b350

Co-authored-by: emrekiciman <5982160+emrekiciman@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

automation bug Something isn't working repo-assist

Projects

None yet

Development

Successfully merging this pull request may close these issues.

backdoor.distance_matching: Problem with exact matching

3 participants