diff --git a/dowhy/causal_estimators/two_stage_regression_estimator.py b/dowhy/causal_estimators/two_stage_regression_estimator.py index 1ca1c8b67..a058571c4 100644 --- a/dowhy/causal_estimators/two_stage_regression_estimator.py +++ b/dowhy/causal_estimators/two_stage_regression_estimator.py @@ -116,17 +116,19 @@ def __init__( modified_target_estimand.identifier_method = "backdoor" modified_target_estimand.backdoor_variables = self._target_estimand.mediation_second_stage_confounders if second_stage_model is not None: - self._second_stage_model = ( - second_stage_model - if isinstance(second_stage_model, CausalEstimator) - else second_stage_model( + if isinstance(second_stage_model, CausalEstimator): + self._second_stage_model = second_stage_model + # Update the estimand so the second-stage model uses the correct + # backdoor configuration rather than the original mediation estimand. + self._second_stage_model._target_estimand = modified_target_estimand + else: + self._second_stage_model = second_stage_model( modified_target_estimand, test_significance=self._significance_test, evaluate_effect_strength=self._effect_strength_eval, confidence_intervals=self._confidence_intervals, **kwargs, ) - ) else: self._second_stage_model = self.__class__.DEFAULT_SECOND_STAGE_MODEL( modified_target_estimand, diff --git a/tests/causal_estimators/test_two_stage_regression_estimator.py b/tests/causal_estimators/test_two_stage_regression_estimator.py index c27b21c49..3174e9712 100644 --- a/tests/causal_estimators/test_two_stage_regression_estimator.py +++ b/tests/causal_estimators/test_two_stage_regression_estimator.py @@ -316,3 +316,52 @@ def test_nde_estimand_uses_correct_backdoor_variables(self): nde_estimand = estimator._second_stage_model_nde._target_estimand assert nde_estimand.identifier_method == "backdoor" assert nde_estimand.backdoor_variables == estimand.mediation_second_stage_confounders + + +class TestTwoStageRegressionPreinstantiatedSecondStage: + """Regression tests for #1335: KeyError when second_stage_model is a pre-instantiated CausalEstimator. + + When a user passes an already-constructed estimator instance as second_stage_model, + the TwoStageRegressionEstimator must update its _target_estimand to use the + modified (backdoor) estimand rather than the original mediation estimand. + """ + + def test_nie_with_preinstantiated_second_stage_no_keyerror(self): + """Passing a pre-instantiated second_stage_model must not raise KeyError.""" + import statsmodels.api as sm + + from dowhy.causal_estimators.generalized_linear_model_estimator import GeneralizedLinearModelEstimator + + df = _make_mediation_data() + model = CausalModel(data=df, treatment="X", outcome="Y", graph=_MEDIATION_GML) + estimand = model.identify_effect( + estimand_type=EstimandType.NONPARAMETRIC_NIE, + proceed_when_unidentifiable=True, + ) + second_stage = GeneralizedLinearModelEstimator(identified_estimand=estimand, glm_family=sm.families.Gaussian()) + # This must not raise KeyError: None + estimate = model.estimate_effect( + identified_estimand=estimand, + method_name="mediation.two_stage_regression", + method_params={"second_stage_model": second_stage}, + ) + assert np.isfinite(estimate.value) + + def test_nie_preinstantiated_second_stage_estimand_updated(self): + """The pre-instantiated second_stage_model's _target_estimand is updated to backdoor.""" + import statsmodels.api as sm + + from dowhy.causal_estimators.generalized_linear_model_estimator import GeneralizedLinearModelEstimator + + df = _make_mediation_data() + model = CausalModel(data=df, treatment="X", outcome="Y", graph=_MEDIATION_GML) + estimand = model.identify_effect( + estimand_type=EstimandType.NONPARAMETRIC_NIE, + proceed_when_unidentifiable=True, + ) + second_stage = GeneralizedLinearModelEstimator(identified_estimand=estimand, glm_family=sm.families.Gaussian()) + estimator = TwoStageRegressionEstimator( + identified_estimand=estimand, + second_stage_model=second_stage, + ) + assert estimator._second_stage_model._target_estimand.identifier_method == "backdoor"