[Repo Assist] fix: correct exact_match_cols logic in DistanceMatchingEstimator (closes #814)#1465
Conversation
#814) Three bugs in estimate_effect() with exact_match_cols: 1. Inner loop used global numtreatedunits instead of the per-group treated count, causing IndexError / wrong summation when group sizes differ from the overall dataset size. 2. No guard when a group has zero control units, which caused NearestNeighbors.fit() to receive an empty array. 3. After the groupby loop the names 'treated' and 'control' were left pointing at the last group's subsets, so the ATC branch (and ATE = ATT + ATC) operated on stale, partial data. Fix: - Rename loop variables to group_treated / group_control so outer names are never clobbered. - Iterate over group_treated.shape[0] (not numtreatedunits). - Skip groups with no control units (same as the existing skip for no treated units). - Accumulate total_treated_matched for a correct per-sample ATT average. - Restore treated/control from saved copies before the ATC block. Also adds tests/causal_estimators/test_distance_matching_estimator.py with 12 tests covering ATT/ATC/ATE estimation, matched_indices population, exact matching correctness, and edge-case error handling. Note: C901 complexity on estimate_effect was pre-existing (21); this change adds one branch (complexity 22). A dedicated refactor PR would be appropriate to bring it below 10. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Fixes incorrect exact_match_cols handling in DistanceMatchingEstimator.estimate_effect() so matching is performed within exact-match strata and ATC computation doesn’t accidentally run on the last stratum’s subset (issue #814).
Changes:
- Correct per-stratum ATT loop iteration and guard against strata with no treated or no controls.
- Preserve full treated/control DataFrames so ATC (and thus ATE) uses the full sample rather than a stratum subset.
- Add a new test module covering ATT/ATC/ATE basics, invalid inputs, and exact-match scenarios.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
dowhy/causal_estimators/distance_matching_estimator.py |
Fixes exact-match ATT loop logic and restores full treated/control for ATC calculation. |
tests/causal_estimators/test_distance_matching_estimator.py |
Adds regression/unit tests for distance matching, including exact-match usage. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| else: | ||
| grouped = updated_df.groupby(self.exact_match_cols) | ||
| att = 0 | ||
| total_treated_matched = 0 | ||
| for name, group in grouped: | ||
| treated = group.loc[group[self._target_estimand.treatment_variable[0]] == 1] | ||
| control = group.loc[group[self._target_estimand.treatment_variable[0]] == 0] | ||
| if treated.shape[0] == 0: | ||
| group_treated = group.loc[group[self._target_estimand.treatment_variable[0]] == 1] | ||
| group_control = group.loc[group[self._target_estimand.treatment_variable[0]] == 0] | ||
| if group_treated.shape[0] == 0 or group_control.shape[0] == 0: | ||
| continue | ||
| control_neighbors = NearestNeighbors( | ||
| n_neighbors=self.num_matches_per_unit, | ||
| metric=self.distance_metric, | ||
| algorithm="ball_tree", | ||
| **self.distance_metric_params, | ||
| ).fit(control[self._observed_common_causes.columns].values) | ||
| ).fit(group_control[self._observed_common_causes.columns].values) | ||
| distances, indices = control_neighbors.kneighbors( | ||
| treated[self._observed_common_causes.columns].values | ||
| group_treated[self._observed_common_causes.columns].values | ||
| ) |
There was a problem hiding this comment.
When exact_match_cols is used, matched_indices_att is never populated (it’s only set in the non-exact-match branch). This makes it impossible to inspect matches for exact matching (and can also leave stale values if the estimator instance is reused). Consider initializing self.matched_indices_att = {} before the groupby loop and filling it per group_treated row using the original dataframe indices, analogous to the non-exact-match path.
There was a problem hiding this comment.
Done in commit adfa02f. Three changes were needed:
-
distance_matching_estimator.py: Addedself.matched_indices_att = {}initialization before the groupby loop, and populated it pergroup_treatedrow inside the loop usinggroup_control.iloc[indices[i]].index.tolist()— analogous to the non-exact-match path. -
Tests
method_paramsfix: The exact-match tests were usingmethod_params={"exact_match_cols": [...]}, which passesexact_match_colsto__init__where it's discarded as**_by the base class. Changed tomethod_params={"fit_params": {"exact_match_cols": [...]}}so it's correctly routed tofit(). -
test_exact_match_restricts_matches_to_same_group: Updated to iteratematched_indices_att.items()and assert every matched control has the sameW_catvalue as the treated unit, directly catching regressions like backdoor.distance_matching: Problem with exact matching #814.
| if total_treated_matched > 0: | ||
| att /= total_treated_matched | ||
|
|
||
| if target_units == "att": | ||
| est = att | ||
| elif target_units == "ate": | ||
| est = att * numtreatedunits | ||
|
|
There was a problem hiding this comment.
In the exact_match_cols ATT branch, ATE is derived via est = att * numtreatedunits, but att is averaged over total_treated_matched (treated units in strata that had both treated+control). If any treated strata are skipped due to group_control.shape[0] == 0, this will over/under-weight the ATT contribution in ATE. Consider accumulating the ATT sum directly (or multiplying by total_treated_matched) and making the ATE weighting consistent with whatever units were actually matched; alternatively, raise/warn when some treated units have no eligible controls under exact matching.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Emre Kıcıman <emrek@microsoft.com>
…test method_params
- Initialize self.matched_indices_att = {} before the groupby loop in the exact-match branch
- Populate matched_indices_att per group_treated row using original dataframe indices
- Fix tests to use method_params={"fit_params": {"exact_match_cols": [...]}} so exact_match_cols
is correctly routed to fit() (not just __init__)
- Update test_exact_match_restricts_matches_to_same_group to iterate matched_indices_att.items()
and verify each matched control has the same W_cat value as the treated unit
Signed-off-by: GitHub Copilot <copilot@github.com>
Agent-Logs-Url: https://github.com/py-why/dowhy/sessions/58b2fe09-f297-4b0e-a6f9-a8f4c987b350
Co-authored-by: emrekiciman <5982160+emrekiciman@users.noreply.github.com>
Signed-off-by: GitHub Copilot <copilot@github.com> Agent-Logs-Url: https://github.com/py-why/dowhy/sessions/58b2fe09-f297-4b0e-a6f9-a8f4c987b350 Co-authored-by: emrekiciman <5982160+emrekiciman@users.noreply.github.com>
🤖 This is an automated pull request from Repo Assist.
Summary
Fixes three bugs in
DistanceMatchingEstimator.estimate_effect()whenexact_match_colsis provided, as reported in #814. The root symptom is that matched control units are not restricted to the same exact-match group, producing incorrect ATT/ATE estimates.Root Causes
Bug 1 — Wrong iteration count in per-group loop (closes #814)
Note on flake8 C901:
estimate_effecthad a pre-existing complexity of 21 (exceeds the limit of 10). This PR adds one branch, raising it to 22. Fixing the pre-existing complexity is out of scope here; a separate refactor PR would be appropriate.Format check:
blackandisortreport no changes needed on the modified files.