diff --git a/dowhy/causal_refuters/graph_refuter.py b/dowhy/causal_refuters/graph_refuter.py index 0bcac31c24..5e69618428 100644 --- a/dowhy/causal_refuters/graph_refuter.py +++ b/dowhy/causal_refuters/graph_refuter.py @@ -1,6 +1,8 @@ import logging +from math import log import numpy as np +from scipy.stats import chi2 from dowhy.causal_refuter import CausalRefutation, CausalRefuter from dowhy.utils.cit import conditional_MI, partial_corr @@ -56,14 +58,33 @@ def partial_correlation(self, x=None, y=None, z=None): self._results[key] = [p_value, True] def conditional_mutual_information(self, x=None, y=None, z=None): - cmi_val = conditional_MI(data=self._data, x=x, y=y, z=list(z)) + cmi_bits = conditional_MI(data=self._data, x=[x], y=[y], z=list(z)) key = (x, y) + (z,) - if cmi_val <= 0.05: + + n = len(self._data) + # Convert CMI (bits) to G-test statistic (asymptotically chi-squared under H0) + g_stat = 2 * n * cmi_bits * log(2) + + # Degrees of freedom: (|X| - 1)(|Y| - 1) * number of distinct Z combinations + # Compute from the same int-cast data used internally by conditional_MI + x_card = self._data[x].astype(int).nunique() + y_card = self._data[y].astype(int).nunique() + z_card = self._data[list(z)].astype(int).drop_duplicates().shape[0] if z else 1 + df = (x_card - 1) * (y_card - 1) * z_card + + if x_card <= 1 or y_card <= 1 or df <= 0: + # Degenerate contingency structure: the chi-squared approximation is not meaningful. + # Treat this as a non-rejection instead of forcing df=1. + p_value = 1.0 + else: + p_value = float(chi2.sf(g_stat, df=df)) + + if p_value >= 0.05: self._true_implications.append([x, y, z]) - self._results[key] = [cmi_val, True] + self._results[key] = [p_value, True] else: self._false_implications.append([x, y, z]) - self._results[key] = [cmi_val, False] + self._results[key] = [p_value, False] def refute_model(self, independence_constraints): """ diff --git a/dowhy/utils/cit.py b/dowhy/utils/cit.py index e7ec6408aa..03d8422350 100644 --- a/dowhy/utils/cit.py +++ b/dowhy/utils/cit.py @@ -136,17 +136,18 @@ def conditional_MI(data=None, x=None, y=None, z=None): = H(X,Z) - H(Z) - H(X,Y,Z) + H(Y,Z) = H(X,Z) + H(Y,Z) - H(X,Y,Z) - H(Z) :param data : dataset - :param x,y,z : column names from dataset + :param x,y,z : column names from dataset (each should be a list of column name strings) :returns : conditional mutual information between X and Y given Z """ X = data[list(x)].astype(int) Y = data[list(y)].astype(int) t = list(z) Z = data[t].astype(int) - Z = Z.values.tolist() - Z = list(data[t].itertuples(index=False, name=None)) - Hxz = entropy(map(lambda x: "%s/%s" % x, zip(X, Z))) # Finding Joint entropy of X and Z - Hyz = entropy(map(lambda x: "%s/%s" % x, zip(Y, Z))) # Finding Joint entropy of Y and Z - Hz = entropy(Z) # Finding Entropy of Z - Hxyz = entropy(map(lambda x: "%s/%s/%s" % x, zip(X, Y, Z))) # Finding Joint Entropy of X, Y and Z + X_rows = list(X.itertuples(index=False, name=None)) + Y_rows = list(Y.itertuples(index=False, name=None)) + Z_rows = list(Z.itertuples(index=False, name=None)) + Hxz = entropy(map(lambda row: "%s/%s" % row, zip(X_rows, Z_rows))) # Finding Joint entropy of X and Z + Hyz = entropy(map(lambda row: "%s/%s" % row, zip(Y_rows, Z_rows))) # Finding Joint entropy of Y and Z + Hz = entropy(Z_rows) # Finding Entropy of Z + Hxyz = entropy(map(lambda row: "%s/%s/%s" % row, zip(X_rows, Y_rows, Z_rows))) # Finding Joint Entropy of X, Y and Z return Hxz + Hyz - Hxyz - Hz diff --git a/tests/causal_refuters/test_graph_refuter.py b/tests/causal_refuters/test_graph_refuter.py new file mode 100644 index 0000000000..17f040c11c --- /dev/null +++ b/tests/causal_refuters/test_graph_refuter.py @@ -0,0 +1,76 @@ +import numpy as np +import pandas as pd +import pytest + +from dowhy.causal_refuters.graph_refuter import GraphRefuter + + +class TestGraphRefuterCMI: + """Focused unit tests for the conditional_mutual_information method of GraphRefuter.""" + + @pytest.fixture + def independent_data(self): + """Generate data where a and b are conditionally independent given c.""" + np.random.seed(42) + n = 5000 + c = np.random.randint(0, 2, n) + # a and b each depend only on c, so a ⊥ b | c + a = np.where(c == 0, np.random.randint(0, 2, n), np.random.randint(0, 2, n)) + b = np.where(c == 0, np.random.randint(0, 2, n), np.random.randint(0, 2, n)) + return pd.DataFrame({"a": a.astype(np.int64), "b": b.astype(np.int64), "c": c.astype(np.int64)}) + + @pytest.fixture + def dependent_data(self): + """Generate data where a and b are conditionally dependent given c.""" + np.random.seed(42) + n = 5000 + c = np.random.randint(0, 2, n) + a = np.random.randint(0, 2, n) + # b is perfectly determined by a, so a and b are NOT independent given c + b = a + return pd.DataFrame({"a": a.astype(np.int64), "b": b.astype(np.int64), "c": c.astype(np.int64)}) + + def test_independent_pair_yields_high_pvalue(self, independent_data): + """An approximately independent discrete pair should yield p_value >= 0.05.""" + refuter = GraphRefuter(data=independent_data) + refuter.conditional_mutual_information(x="a", y="b", z=frozenset(["c"])) + + assert len(refuter._true_implications) == 1, "Independent pair should be accepted as conditionally independent" + assert len(refuter._false_implications) == 0 + + key = ("a", "b") + (frozenset(["c"]),) + p_value, result = refuter._results[key] + assert p_value >= 0.05, f"Expected p_value >= 0.05 for independent pair, got {p_value}" + assert result is True + + def test_dependent_pair_yields_low_pvalue(self, dependent_data): + """A strongly dependent discrete pair should yield p_value < 0.05.""" + refuter = GraphRefuter(data=dependent_data) + refuter.conditional_mutual_information(x="a", y="b", z=frozenset(["c"])) + + assert len(refuter._false_implications) == 1, "Dependent pair should be rejected as conditionally independent" + assert len(refuter._true_implications) == 0 + + key = ("a", "b") + (frozenset(["c"]),) + p_value, result = refuter._results[key] + assert p_value < 0.05, f"Expected p_value < 0.05 for dependent pair, got {p_value}" + assert result is False + + def test_degenerate_constant_column(self): + """A constant column (cardinality 1) should be treated as non-rejection (p_value=1.0).""" + np.random.seed(0) + n = 100 + data = pd.DataFrame( + { + "a": np.zeros(n, dtype=np.int64), # constant: cardinality 1 + "b": np.random.randint(0, 2, n).astype(np.int64), + "c": np.random.randint(0, 2, n).astype(np.int64), + } + ) + refuter = GraphRefuter(data=data) + refuter.conditional_mutual_information(x="a", y="b", z=frozenset(["c"])) + + key = ("a", "b") + (frozenset(["c"]),) + p_value, result = refuter._results[key] + assert p_value == 1.0, "Degenerate (constant) variable should give p_value=1.0" + assert result is True