-
Notifications
You must be signed in to change notification settings - Fork 1k
[Repo Assist] fix: correct exact_match_cols logic in DistanceMatchingEstimator (closes #814) #1465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
855bc76
833d8d2
28fa528
adfa02f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -206,6 +206,10 @@ def estimate_effect( | |
| else: | ||
| raise ValueError("Target units string value not supported") | ||
|
|
||
| # Save the full treated/control DataFrames before any groupby loops that may rebind these names. | ||
| treated_all = treated | ||
| control_all = control | ||
|
|
||
| if fit_att: | ||
| # estimate ATT on treated by summing over difference between matched neighbors | ||
| if self.exact_match_cols is None: | ||
|
|
@@ -242,40 +246,45 @@ def estimate_effect( | |
| else: | ||
| grouped = updated_df.groupby(self.exact_match_cols) | ||
| att = 0 | ||
| total_treated_matched = 0 | ||
| for name, group in grouped: | ||
| treated = group.loc[group[self._target_estimand.treatment_variable[0]] == 1] | ||
| control = group.loc[group[self._target_estimand.treatment_variable[0]] == 0] | ||
| if treated.shape[0] == 0: | ||
| group_treated = group.loc[group[self._target_estimand.treatment_variable[0]] == 1] | ||
| group_control = group.loc[group[self._target_estimand.treatment_variable[0]] == 0] | ||
| if group_treated.shape[0] == 0 or group_control.shape[0] == 0: | ||
| continue | ||
| control_neighbors = NearestNeighbors( | ||
| n_neighbors=self.num_matches_per_unit, | ||
| metric=self.distance_metric, | ||
| algorithm="ball_tree", | ||
| **self.distance_metric_params, | ||
| ).fit(control[self._observed_common_causes.columns].values) | ||
| ).fit(group_control[self._observed_common_causes.columns].values) | ||
| distances, indices = control_neighbors.kneighbors( | ||
| treated[self._observed_common_causes.columns].values | ||
| group_treated[self._observed_common_causes.columns].values | ||
| ) | ||
| self.logger.debug("distances:") | ||
| self.logger.debug(distances) | ||
|
|
||
| for i in range(numtreatedunits): | ||
| treated_outcome = treated.iloc[i][self._target_estimand.outcome_variable[0]].item() | ||
| num_group_treated = group_treated.shape[0] | ||
| for i in range(num_group_treated): | ||
| treated_outcome = group_treated.iloc[i][self._target_estimand.outcome_variable[0]].item() | ||
| control_outcome = np.mean( | ||
| control.iloc[indices[i]][self._target_estimand.outcome_variable[0]].values | ||
| group_control.iloc[indices[i]][self._target_estimand.outcome_variable[0]].values | ||
| ) | ||
| att += treated_outcome - control_outcome | ||
| # self.matched_indices_att[treated_df_index[i]] = control.iloc[indices[i]].index.tolist() | ||
| total_treated_matched += num_group_treated | ||
|
|
||
| att /= numtreatedunits | ||
| if total_treated_matched > 0: | ||
| att /= total_treated_matched | ||
|
|
||
| if target_units == "att": | ||
| est = att | ||
| elif target_units == "ate": | ||
| est = att * numtreatedunits | ||
|
|
||
|
Comment on lines
+280
to
287
|
||
| if fit_atc: | ||
| # Now computing ATC | ||
| # Now computing ATC using the full treated/control DataFrames (not group-level subsets). | ||
| treated = treated_all | ||
| control = control_all | ||
| treated_neighbors = NearestNeighbors( | ||
| n_neighbors=self.num_matches_per_unit, | ||
| metric=self.distance_metric, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| import numpy as np | ||
| import pandas as pd | ||
| import pytest | ||
|
|
||
| from dowhy import CausalModel | ||
| from dowhy.causal_estimators.distance_matching_estimator import DistanceMatchingEstimator | ||
|
|
||
| from .base import SimpleEstimator | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def binary_treatment_dataset(): | ||
| """Small deterministic dataset for distance matching tests.""" | ||
| rng = np.random.default_rng(42) | ||
| n = 500 | ||
| w = rng.standard_normal(n) | ||
| treatment = (w + rng.standard_normal(n) > 0).astype(int) | ||
| outcome = 10 * treatment + 2 * w + rng.standard_normal(n) | ||
| return pd.DataFrame({"W": w, "v0": treatment, "y": outcome}) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def binary_treatment_dataset_with_exact_col(): | ||
| """Dataset with a discrete covariate for exact matching tests.""" | ||
| rng = np.random.default_rng(0) | ||
| n = 1000 | ||
| w_cont = rng.standard_normal(n) | ||
| w_cat = rng.integers(0, 2, size=n) # binary exact-match column | ||
| treatment = ((w_cont + w_cat + rng.standard_normal(n)) > 0).astype(int) | ||
| outcome = 10 * treatment + 2 * w_cont + 3 * w_cat + rng.standard_normal(n) | ||
| return pd.DataFrame({"W": w_cont, "W_cat": w_cat, "v0": treatment, "y": outcome}) | ||
|
|
||
|
|
||
| GML_SINGLE_CAUSE = """graph [directed 1 node [id "W" label "W"] node [id "v0" label "v0"] | ||
| node [id "y" label "y"] edge [source "W" target "v0"] edge [source "W" target "y"] | ||
| edge [source "v0" target "y"]]""" | ||
|
|
||
| GML_TWO_CAUSES = """graph [directed 1 node [id "W" label "W"] node [id "W_cat" label "W_cat"] | ||
| node [id "v0" label "v0"] node [id "y" label "y"] | ||
| edge [source "W" target "v0"] edge [source "W" target "y"] | ||
| edge [source "W_cat" target "v0"] edge [source "W_cat" target "y"] | ||
| edge [source "v0" target "y"]]""" | ||
|
|
||
|
|
||
| class TestDistanceMatchingEstimator: | ||
| @pytest.mark.parametrize("target_units", ["att", "atc", "ate"]) | ||
| def test_estimate_is_close_to_true_effect(self, binary_treatment_dataset, target_units): | ||
| """ATT/ATC/ATE estimate should be within a reasonable range of the true beta=10.""" | ||
| data = binary_treatment_dataset | ||
| model = CausalModel(data=data, treatment="v0", outcome="y", graph=GML_SINGLE_CAUSE) | ||
| estimand = model.identify_effect(proceed_when_unidentifiable=True) | ||
| estimate = model.estimate_effect( | ||
| estimand, | ||
| method_name="backdoor.distance_matching", | ||
| target_units=target_units, | ||
| ) | ||
| assert abs(estimate.value - 10) < 3.0, f"Estimate {estimate.value:.2f} too far from true effect 10" | ||
|
|
||
| def test_matched_indices_att_populated(self, binary_treatment_dataset): | ||
| """matched_indices_att should be populated when target_units='att'.""" | ||
| data = binary_treatment_dataset | ||
| model = CausalModel(data=data, treatment="v0", outcome="y", graph=GML_SINGLE_CAUSE) | ||
| estimand = model.identify_effect(proceed_when_unidentifiable=True) | ||
| estimate = model.estimate_effect(estimand, method_name="backdoor.distance_matching", target_units="att") | ||
| estimator = estimate.estimator | ||
| assert estimator.matched_indices_att is not None | ||
| assert len(estimator.matched_indices_att) == data["v0"].sum() | ||
|
|
||
| def test_matched_indices_atc_populated(self, binary_treatment_dataset): | ||
| """matched_indices_atc should be populated when target_units='atc'.""" | ||
| data = binary_treatment_dataset | ||
| model = CausalModel(data=data, treatment="v0", outcome="y", graph=GML_SINGLE_CAUSE) | ||
| estimand = model.identify_effect(proceed_when_unidentifiable=True) | ||
| estimate = model.estimate_effect(estimand, method_name="backdoor.distance_matching", target_units="atc") | ||
| estimator = estimate.estimator | ||
| assert estimator.matched_indices_atc is not None | ||
| assert len(estimator.matched_indices_atc) == (data["v0"] == 0).sum() | ||
|
|
||
| def test_exact_match_restricts_matches_to_same_group(self, binary_treatment_dataset_with_exact_col): | ||
| """With exact_match_cols, every matched control unit must share the same W_cat value.""" | ||
| data = binary_treatment_dataset_with_exact_col | ||
| model = CausalModel(data=data, treatment="v0", outcome="y", graph=GML_TWO_CAUSES) | ||
| estimand = model.identify_effect(proceed_when_unidentifiable=True) | ||
| estimate = model.estimate_effect( | ||
| estimand, | ||
| method_name="backdoor.distance_matching", | ||
| target_units="att", | ||
| method_params={"exact_match_cols": ["W_cat"]}, | ||
| ) | ||
| assert estimate.value is not None | ||
| assert abs(estimate.value - 10) < 3.5, f"Estimate {estimate.value:.2f} too far from true effect 10" | ||
|
|
||
|
emrekiciman marked this conversation as resolved.
Outdated
|
||
| @pytest.mark.parametrize("target_units", ["att", "atc", "ate"]) | ||
| def test_exact_match_estimate_finite(self, binary_treatment_dataset_with_exact_col, target_units): | ||
| """Estimates with exact_match_cols should be finite for all target_units.""" | ||
| data = binary_treatment_dataset_with_exact_col | ||
| model = CausalModel(data=data, treatment="v0", outcome="y", graph=GML_TWO_CAUSES) | ||
| estimand = model.identify_effect(proceed_when_unidentifiable=True) | ||
| estimate = model.estimate_effect( | ||
| estimand, | ||
| method_name="backdoor.distance_matching", | ||
| target_units=target_units, | ||
| method_params={"exact_match_cols": ["W_cat"]}, | ||
| ) | ||
| assert np.isfinite(estimate.value), f"Non-finite estimate for target_units={target_units}" | ||
|
|
||
| def test_invalid_target_units_raises(self, binary_treatment_dataset): | ||
| """Passing an unsupported target_units string must raise ValueError.""" | ||
| data = binary_treatment_dataset | ||
| model = CausalModel(data=data, treatment="v0", outcome="y", graph=GML_SINGLE_CAUSE) | ||
| estimand = model.identify_effect(proceed_when_unidentifiable=True) | ||
| with pytest.raises(ValueError, match="Target units string value not supported"): | ||
| model.estimate_effect(estimand, method_name="backdoor.distance_matching", target_units="invalid") | ||
|
|
||
| def test_non_binary_treatment_raises(self): | ||
| """DistanceMatchingEstimator must raise when treatment is not binary.""" | ||
| rng = np.random.default_rng(7) | ||
| n = 200 | ||
| data = pd.DataFrame({"W": rng.standard_normal(n), "v0": rng.integers(0, 4, n), "y": rng.standard_normal(n)}) | ||
| model = CausalModel(data=data, treatment="v0", outcome="y", graph=GML_SINGLE_CAUSE) | ||
| estimand = model.identify_effect(proceed_when_unidentifiable=True) | ||
| with pytest.raises(Exception, match="binary"): | ||
| model.estimate_effect(estimand, method_name="backdoor.distance_matching", target_units="att") | ||
|
|
||
| def test_average_treatment_effect_via_simple_estimator(self): | ||
| """Smoke test using the shared SimpleEstimator harness.""" | ||
| tester = SimpleEstimator(error_tolerance=0.3, Estimator=DistanceMatchingEstimator) | ||
| tester.average_treatment_effect_testsuite( | ||
| num_common_causes=[1], | ||
| num_instruments=[0], | ||
| num_effect_modifiers=[0], | ||
| num_treatments=[1], | ||
| treatment_is_binary=[True], | ||
| outcome_is_binary=[False], | ||
| confidence_intervals=[False], | ||
| test_significance=[False], | ||
| method_params={"num_simulations": 5, "num_null_simulations": 5}, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When
exact_match_colsis used,matched_indices_attis never populated (itβs only set in the non-exact-match branch). This makes it impossible to inspect matches for exact matching (and can also leave stale values if the estimator instance is reused). Consider initializingself.matched_indices_att = {}before the groupby loop and filling it pergroup_treatedrow using the original dataframe indices, analogous to the non-exact-match path.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot apply changes based on this feedback
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in commit adfa02f. Three changes were needed:
distance_matching_estimator.py: Addedself.matched_indices_att = {}initialization before the groupby loop, and populated it pergroup_treatedrow inside the loop usinggroup_control.iloc[indices[i]].index.tolist()β analogous to the non-exact-match path.Tests
method_paramsfix: The exact-match tests were usingmethod_params={"exact_match_cols": [...]}, which passesexact_match_colsto__init__where it's discarded as**_by the base class. Changed tomethod_params={"fit_params": {"exact_match_cols": [...]}}so it's correctly routed tofit().test_exact_match_restricts_matches_to_same_group: Updated to iteratematched_indices_att.items()and assert every matched control has the sameW_catvalue as the treated unit, directly catching regressions like backdoor.distance_matching: Problem with exact matchingΒ #814.