diff --git a/dowhy/causal_estimator.py b/dowhy/causal_estimator.py index 23de0587a..29ac34134 100755 --- a/dowhy/causal_estimator.py +++ b/dowhy/causal_estimator.py @@ -206,13 +206,41 @@ def estimate_effect_naive(self, data: pd.DataFrame): """ :param data: Pandas dataframe to estimate effect """ - # TODO Only works for binary treatment - df_withtreatment = data.loc[data[self._target_estimand.treatment_variable] == 1] - df_notreatment = data.loc[data[self._target_estimand.treatment_variable] == 0] - est = np.mean(df_withtreatment[self._target_estimand.outcome_variable]) - np.mean( - df_notreatment[self._target_estimand.outcome_variable] + # TODO Only works for a single treatment variable + treatment_variable = self._target_estimand.treatment_variable + if isinstance(treatment_variable, list): + if len(treatment_variable) != 1: + raise ValueError( + "estimate_effect_naive only supports exactly one treatment variable, " + f"got {len(treatment_variable)}: {treatment_variable}" + ) + treatment_name = treatment_variable[0] + else: + treatment_name = treatment_variable + outcome_variable = self._target_estimand.outcome_variable + if isinstance(outcome_variable, list): + if len(outcome_variable) != 1: + raise ValueError( + "estimate_effect_naive only supports exactly one outcome variable, " + f"got {len(outcome_variable)}: {outcome_variable}" + ) + outcome_name = outcome_variable[0] + else: + outcome_name = outcome_variable + treatment_value = getattr(self, "_treatment_value", 1) + control_value = getattr(self, "_control_value", 0) + df_withtreatment = data.loc[data[treatment_name] == treatment_value] + df_notreatment = data.loc[data[treatment_name] == control_value] + est = np.mean(df_withtreatment[outcome_name]) - np.mean(df_notreatment[outcome_name]) + return CausalEstimate( + data, + treatment_name, + outcome_name, + est, + None, + control_value=control_value, + treatment_value=treatment_value, ) - return CausalEstimate(data, None, None, est, None, control_value=0, treatment_value=1) def _estimate_effect_fn(self, data_df): """Function used in conditional effect estimation. This function is to be overridden by each child estimator.