Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions dowhy/causal_estimators/distance_matching_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ def estimate_effect(
else:
raise ValueError("Target units string value not supported")

# Save the full treated/control DataFrames before any groupby loops that may rebind these names.
treated_all = treated
control_all = control

if fit_att:
# estimate ATT on treated by summing over difference between matched neighbors
if self.exact_match_cols is None:
Expand Down Expand Up @@ -242,40 +246,45 @@ def estimate_effect(
else:
grouped = updated_df.groupby(self.exact_match_cols)
att = 0
total_treated_matched = 0
for name, group in grouped:
treated = group.loc[group[self._target_estimand.treatment_variable[0]] == 1]
control = group.loc[group[self._target_estimand.treatment_variable[0]] == 0]
if treated.shape[0] == 0:
group_treated = group.loc[group[self._target_estimand.treatment_variable[0]] == 1]
group_control = group.loc[group[self._target_estimand.treatment_variable[0]] == 0]
if group_treated.shape[0] == 0 or group_control.shape[0] == 0:
continue
control_neighbors = NearestNeighbors(
n_neighbors=self.num_matches_per_unit,
metric=self.distance_metric,
algorithm="ball_tree",
**self.distance_metric_params,
).fit(control[self._observed_common_causes.columns].values)
).fit(group_control[self._observed_common_causes.columns].values)
distances, indices = control_neighbors.kneighbors(
treated[self._observed_common_causes.columns].values
group_treated[self._observed_common_causes.columns].values
)
Comment on lines 246 to 264
Copy link

Copilot AI Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When exact_match_cols is used, matched_indices_att is never populated (it’s only set in the non-exact-match branch). This makes it impossible to inspect matches for exact matching (and can also leave stale values if the estimator instance is reused). Consider initializing self.matched_indices_att = {} before the groupby loop and filling it per group_treated row using the original dataframe indices, analogous to the non-exact-match path.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit adfa02f. Three changes were needed:

  1. distance_matching_estimator.py: Added self.matched_indices_att = {} initialization before the groupby loop, and populated it per group_treated row inside the loop using group_control.iloc[indices[i]].index.tolist() β€” analogous to the non-exact-match path.

  2. Tests method_params fix: The exact-match tests were using method_params={"exact_match_cols": [...]}, which passes exact_match_cols to __init__ where it's discarded as **_ by the base class. Changed to method_params={"fit_params": {"exact_match_cols": [...]}} so it's correctly routed to fit().

  3. test_exact_match_restricts_matches_to_same_group: Updated to iterate matched_indices_att.items() and assert every matched control has the same W_cat value as the treated unit, directly catching regressions like backdoor.distance_matching: Problem with exact matchingΒ #814.

self.logger.debug("distances:")
self.logger.debug(distances)

for i in range(numtreatedunits):
treated_outcome = treated.iloc[i][self._target_estimand.outcome_variable[0]].item()
num_group_treated = group_treated.shape[0]
for i in range(num_group_treated):
treated_outcome = group_treated.iloc[i][self._target_estimand.outcome_variable[0]].item()
control_outcome = np.mean(
control.iloc[indices[i]][self._target_estimand.outcome_variable[0]].values
group_control.iloc[indices[i]][self._target_estimand.outcome_variable[0]].values
)
att += treated_outcome - control_outcome
# self.matched_indices_att[treated_df_index[i]] = control.iloc[indices[i]].index.tolist()
total_treated_matched += num_group_treated

att /= numtreatedunits
if total_treated_matched > 0:
att /= total_treated_matched

if target_units == "att":
est = att
elif target_units == "ate":
est = att * numtreatedunits

Comment on lines +280 to 287
Copy link

Copilot AI Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the exact_match_cols ATT branch, ATE is derived via est = att * numtreatedunits, but att is averaged over total_treated_matched (treated units in strata that had both treated+control). If any treated strata are skipped due to group_control.shape[0] == 0, this will over/under-weight the ATT contribution in ATE. Consider accumulating the ATT sum directly (or multiplying by total_treated_matched) and making the ATE weighting consistent with whatever units were actually matched; alternatively, raise/warn when some treated units have no eligible controls under exact matching.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

if fit_atc:
# Now computing ATC
# Now computing ATC using the full treated/control DataFrames (not group-level subsets).
treated = treated_all
control = control_all
treated_neighbors = NearestNeighbors(
n_neighbors=self.num_matches_per_unit,
metric=self.distance_metric,
Expand Down
138 changes: 138 additions & 0 deletions tests/causal_estimators/test_distance_matching_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import numpy as np
import pandas as pd
import pytest

from dowhy import CausalModel
from dowhy.causal_estimators.distance_matching_estimator import DistanceMatchingEstimator

from .base import SimpleEstimator


@pytest.fixture
def binary_treatment_dataset():
"""Small deterministic dataset for distance matching tests."""
rng = np.random.default_rng(42)
n = 500
w = rng.standard_normal(n)
treatment = (w + rng.standard_normal(n) > 0).astype(int)
outcome = 10 * treatment + 2 * w + rng.standard_normal(n)
return pd.DataFrame({"W": w, "v0": treatment, "y": outcome})


@pytest.fixture
def binary_treatment_dataset_with_exact_col():
"""Dataset with a discrete covariate for exact matching tests."""
rng = np.random.default_rng(0)
n = 1000
w_cont = rng.standard_normal(n)
w_cat = rng.integers(0, 2, size=n) # binary exact-match column
treatment = ((w_cont + w_cat + rng.standard_normal(n)) > 0).astype(int)
outcome = 10 * treatment + 2 * w_cont + 3 * w_cat + rng.standard_normal(n)
return pd.DataFrame({"W": w_cont, "W_cat": w_cat, "v0": treatment, "y": outcome})


GML_SINGLE_CAUSE = """graph [directed 1 node [id "W" label "W"] node [id "v0" label "v0"]
node [id "y" label "y"] edge [source "W" target "v0"] edge [source "W" target "y"]
edge [source "v0" target "y"]]"""

GML_TWO_CAUSES = """graph [directed 1 node [id "W" label "W"] node [id "W_cat" label "W_cat"]
node [id "v0" label "v0"] node [id "y" label "y"]
edge [source "W" target "v0"] edge [source "W" target "y"]
edge [source "W_cat" target "v0"] edge [source "W_cat" target "y"]
edge [source "v0" target "y"]]"""


class TestDistanceMatchingEstimator:
@pytest.mark.parametrize("target_units", ["att", "atc", "ate"])
def test_estimate_is_close_to_true_effect(self, binary_treatment_dataset, target_units):
"""ATT/ATC/ATE estimate should be within a reasonable range of the true beta=10."""
data = binary_treatment_dataset
model = CausalModel(data=data, treatment="v0", outcome="y", graph=GML_SINGLE_CAUSE)
estimand = model.identify_effect(proceed_when_unidentifiable=True)
estimate = model.estimate_effect(
estimand,
method_name="backdoor.distance_matching",
target_units=target_units,
)
assert abs(estimate.value - 10) < 3.0, f"Estimate {estimate.value:.2f} too far from true effect 10"

def test_matched_indices_att_populated(self, binary_treatment_dataset):
"""matched_indices_att should be populated when target_units='att'."""
data = binary_treatment_dataset
model = CausalModel(data=data, treatment="v0", outcome="y", graph=GML_SINGLE_CAUSE)
estimand = model.identify_effect(proceed_when_unidentifiable=True)
estimate = model.estimate_effect(estimand, method_name="backdoor.distance_matching", target_units="att")
estimator = estimate.estimator
assert estimator.matched_indices_att is not None
assert len(estimator.matched_indices_att) == data["v0"].sum()

def test_matched_indices_atc_populated(self, binary_treatment_dataset):
"""matched_indices_atc should be populated when target_units='atc'."""
data = binary_treatment_dataset
model = CausalModel(data=data, treatment="v0", outcome="y", graph=GML_SINGLE_CAUSE)
estimand = model.identify_effect(proceed_when_unidentifiable=True)
estimate = model.estimate_effect(estimand, method_name="backdoor.distance_matching", target_units="atc")
estimator = estimate.estimator
assert estimator.matched_indices_atc is not None
assert len(estimator.matched_indices_atc) == (data["v0"] == 0).sum()

def test_exact_match_restricts_matches_to_same_group(self, binary_treatment_dataset_with_exact_col):
"""With exact_match_cols, every matched control unit must share the same W_cat value."""
data = binary_treatment_dataset_with_exact_col
model = CausalModel(data=data, treatment="v0", outcome="y", graph=GML_TWO_CAUSES)
estimand = model.identify_effect(proceed_when_unidentifiable=True)
estimate = model.estimate_effect(
estimand,
method_name="backdoor.distance_matching",
target_units="att",
method_params={"exact_match_cols": ["W_cat"]},
)
assert estimate.value is not None
assert abs(estimate.value - 10) < 3.5, f"Estimate {estimate.value:.2f} too far from true effect 10"

Comment thread
emrekiciman marked this conversation as resolved.
Outdated
@pytest.mark.parametrize("target_units", ["att", "atc", "ate"])
def test_exact_match_estimate_finite(self, binary_treatment_dataset_with_exact_col, target_units):
"""Estimates with exact_match_cols should be finite for all target_units."""
data = binary_treatment_dataset_with_exact_col
model = CausalModel(data=data, treatment="v0", outcome="y", graph=GML_TWO_CAUSES)
estimand = model.identify_effect(proceed_when_unidentifiable=True)
estimate = model.estimate_effect(
estimand,
method_name="backdoor.distance_matching",
target_units=target_units,
method_params={"exact_match_cols": ["W_cat"]},
)
assert np.isfinite(estimate.value), f"Non-finite estimate for target_units={target_units}"

def test_invalid_target_units_raises(self, binary_treatment_dataset):
"""Passing an unsupported target_units string must raise ValueError."""
data = binary_treatment_dataset
model = CausalModel(data=data, treatment="v0", outcome="y", graph=GML_SINGLE_CAUSE)
estimand = model.identify_effect(proceed_when_unidentifiable=True)
with pytest.raises(ValueError, match="Target units string value not supported"):
model.estimate_effect(estimand, method_name="backdoor.distance_matching", target_units="invalid")

def test_non_binary_treatment_raises(self):
"""DistanceMatchingEstimator must raise when treatment is not binary."""
rng = np.random.default_rng(7)
n = 200
data = pd.DataFrame({"W": rng.standard_normal(n), "v0": rng.integers(0, 4, n), "y": rng.standard_normal(n)})
model = CausalModel(data=data, treatment="v0", outcome="y", graph=GML_SINGLE_CAUSE)
estimand = model.identify_effect(proceed_when_unidentifiable=True)
with pytest.raises(Exception, match="binary"):
model.estimate_effect(estimand, method_name="backdoor.distance_matching", target_units="att")

def test_average_treatment_effect_via_simple_estimator(self):
"""Smoke test using the shared SimpleEstimator harness."""
tester = SimpleEstimator(error_tolerance=0.3, Estimator=DistanceMatchingEstimator)
tester.average_treatment_effect_testsuite(
num_common_causes=[1],
num_instruments=[0],
num_effect_modifiers=[0],
num_treatments=[1],
treatment_is_binary=[True],
outcome_is_binary=[False],
confidence_intervals=[False],
test_significance=[False],
method_params={"num_simulations": 5, "num_null_simulations": 5},
)
Loading