From 055c7a6fc40f6088dbafa20426e52d9be8f82eb3 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 31 Mar 2026 17:26:39 +0000 Subject: [PATCH 1/2] fix: compute proper chi-squared p-value for CMI independence test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The conditional_mutual_information() method in GraphRefuter was comparing raw conditional mutual information (in bits) against 0.05 as if it were a p-value. This is statistically incorrect: CMI is an information-theoretic quantity, not a probability. Fix: apply the G-test (likelihood ratio test). Under the null hypothesis of conditional independence, the test statistic G = 2 * N * CMI_nats is asymptotically chi-squared with df = (|X|-1)(|Y|-1)*|Z_combos| degrees of freedom. Convert CMI from bits to nats (multiply by ln 2), compute the chi-squared survival function to get a proper p-value, then compare against α = 0.05 — consistent with partial_correlation's approach. Closes #413 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dowhy/causal_refuters/graph_refuter.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/dowhy/causal_refuters/graph_refuter.py b/dowhy/causal_refuters/graph_refuter.py index 0bcac31c24..eae90d6bf9 100644 --- a/dowhy/causal_refuters/graph_refuter.py +++ b/dowhy/causal_refuters/graph_refuter.py @@ -1,6 +1,8 @@ import logging +from math import log import numpy as np +from scipy.stats import chi2 from dowhy.causal_refuter import CausalRefutation, CausalRefuter from dowhy.utils.cit import conditional_MI, partial_corr @@ -56,14 +58,27 @@ def partial_correlation(self, x=None, y=None, z=None): self._results[key] = [p_value, True] def conditional_mutual_information(self, x=None, y=None, z=None): - cmi_val = conditional_MI(data=self._data, x=x, y=y, z=list(z)) + cmi_bits = conditional_MI(data=self._data, x=x, y=y, z=list(z)) key = (x, y) + (z,) - if cmi_val <= 0.05: + + n = len(self._data) + # Convert CMI (bits) to G-test statistic (asymptotically chi-squared under H0) + g_stat = 2 * n * cmi_bits * log(2) + + # Degrees of freedom: (|X| - 1)(|Y| - 1) * number of distinct Z combinations + x_card = self._data[x].nunique() + y_card = self._data[y].nunique() + z_card = self._data[list(z)].drop_duplicates().shape[0] if z else 1 + df = max(1, (x_card - 1) * (y_card - 1) * z_card) + + p_value = float(chi2.sf(g_stat, df=df)) + + if p_value >= 0.05: self._true_implications.append([x, y, z]) - self._results[key] = [cmi_val, True] + self._results[key] = [p_value, True] else: self._false_implications.append([x, y, z]) - self._results[key] = [cmi_val, False] + self._results[key] = [p_value, False] def refute_model(self, independence_constraints): """ From 2e79e9f989a4568381531b702c7444f40bc4b2dc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 17 Apr 2026 07:02:52 +0000 Subject: [PATCH 2/2] fix: address review comments - fix conditional_MI iteration, cardinalities, degenerate cases, add tests Agent-Logs-Url: https://github.com/py-why/dowhy/sessions/ad7f75b5-016f-452f-81fd-143df9e22fce Co-authored-by: emrekiciman <5982160+emrekiciman@users.noreply.github.com> --- dowhy/causal_refuters/graph_refuter.py | 20 ++++-- dowhy/utils/cit.py | 15 ++-- tests/causal_refuters/test_graph_refuter.py | 76 +++++++++++++++++++++ 3 files changed, 97 insertions(+), 14 deletions(-) create mode 100644 tests/causal_refuters/test_graph_refuter.py diff --git a/dowhy/causal_refuters/graph_refuter.py b/dowhy/causal_refuters/graph_refuter.py index eae90d6bf9..5e69618428 100644 --- a/dowhy/causal_refuters/graph_refuter.py +++ b/dowhy/causal_refuters/graph_refuter.py @@ -58,7 +58,7 @@ def partial_correlation(self, x=None, y=None, z=None): self._results[key] = [p_value, True] def conditional_mutual_information(self, x=None, y=None, z=None): - cmi_bits = conditional_MI(data=self._data, x=x, y=y, z=list(z)) + cmi_bits = conditional_MI(data=self._data, x=[x], y=[y], z=list(z)) key = (x, y) + (z,) n = len(self._data) @@ -66,12 +66,18 @@ def conditional_mutual_information(self, x=None, y=None, z=None): g_stat = 2 * n * cmi_bits * log(2) # Degrees of freedom: (|X| - 1)(|Y| - 1) * number of distinct Z combinations - x_card = self._data[x].nunique() - y_card = self._data[y].nunique() - z_card = self._data[list(z)].drop_duplicates().shape[0] if z else 1 - df = max(1, (x_card - 1) * (y_card - 1) * z_card) - - p_value = float(chi2.sf(g_stat, df=df)) + # Compute from the same int-cast data used internally by conditional_MI + x_card = self._data[x].astype(int).nunique() + y_card = self._data[y].astype(int).nunique() + z_card = self._data[list(z)].astype(int).drop_duplicates().shape[0] if z else 1 + df = (x_card - 1) * (y_card - 1) * z_card + + if x_card <= 1 or y_card <= 1 or df <= 0: + # Degenerate contingency structure: the chi-squared approximation is not meaningful. + # Treat this as a non-rejection instead of forcing df=1. + p_value = 1.0 + else: + p_value = float(chi2.sf(g_stat, df=df)) if p_value >= 0.05: self._true_implications.append([x, y, z]) diff --git a/dowhy/utils/cit.py b/dowhy/utils/cit.py index e7ec6408aa..03d8422350 100644 --- a/dowhy/utils/cit.py +++ b/dowhy/utils/cit.py @@ -136,17 +136,18 @@ def conditional_MI(data=None, x=None, y=None, z=None): = H(X,Z) - H(Z) - H(X,Y,Z) + H(Y,Z) = H(X,Z) + H(Y,Z) - H(X,Y,Z) - H(Z) :param data : dataset - :param x,y,z : column names from dataset + :param x,y,z : column names from dataset (each should be a list of column name strings) :returns : conditional mutual information between X and Y given Z """ X = data[list(x)].astype(int) Y = data[list(y)].astype(int) t = list(z) Z = data[t].astype(int) - Z = Z.values.tolist() - Z = list(data[t].itertuples(index=False, name=None)) - Hxz = entropy(map(lambda x: "%s/%s" % x, zip(X, Z))) # Finding Joint entropy of X and Z - Hyz = entropy(map(lambda x: "%s/%s" % x, zip(Y, Z))) # Finding Joint entropy of Y and Z - Hz = entropy(Z) # Finding Entropy of Z - Hxyz = entropy(map(lambda x: "%s/%s/%s" % x, zip(X, Y, Z))) # Finding Joint Entropy of X, Y and Z + X_rows = list(X.itertuples(index=False, name=None)) + Y_rows = list(Y.itertuples(index=False, name=None)) + Z_rows = list(Z.itertuples(index=False, name=None)) + Hxz = entropy(map(lambda row: "%s/%s" % row, zip(X_rows, Z_rows))) # Finding Joint entropy of X and Z + Hyz = entropy(map(lambda row: "%s/%s" % row, zip(Y_rows, Z_rows))) # Finding Joint entropy of Y and Z + Hz = entropy(Z_rows) # Finding Entropy of Z + Hxyz = entropy(map(lambda row: "%s/%s/%s" % row, zip(X_rows, Y_rows, Z_rows))) # Finding Joint Entropy of X, Y and Z return Hxz + Hyz - Hxyz - Hz diff --git a/tests/causal_refuters/test_graph_refuter.py b/tests/causal_refuters/test_graph_refuter.py new file mode 100644 index 0000000000..17f040c11c --- /dev/null +++ b/tests/causal_refuters/test_graph_refuter.py @@ -0,0 +1,76 @@ +import numpy as np +import pandas as pd +import pytest + +from dowhy.causal_refuters.graph_refuter import GraphRefuter + + +class TestGraphRefuterCMI: + """Focused unit tests for the conditional_mutual_information method of GraphRefuter.""" + + @pytest.fixture + def independent_data(self): + """Generate data where a and b are conditionally independent given c.""" + np.random.seed(42) + n = 5000 + c = np.random.randint(0, 2, n) + # a and b each depend only on c, so a ⊥ b | c + a = np.where(c == 0, np.random.randint(0, 2, n), np.random.randint(0, 2, n)) + b = np.where(c == 0, np.random.randint(0, 2, n), np.random.randint(0, 2, n)) + return pd.DataFrame({"a": a.astype(np.int64), "b": b.astype(np.int64), "c": c.astype(np.int64)}) + + @pytest.fixture + def dependent_data(self): + """Generate data where a and b are conditionally dependent given c.""" + np.random.seed(42) + n = 5000 + c = np.random.randint(0, 2, n) + a = np.random.randint(0, 2, n) + # b is perfectly determined by a, so a and b are NOT independent given c + b = a + return pd.DataFrame({"a": a.astype(np.int64), "b": b.astype(np.int64), "c": c.astype(np.int64)}) + + def test_independent_pair_yields_high_pvalue(self, independent_data): + """An approximately independent discrete pair should yield p_value >= 0.05.""" + refuter = GraphRefuter(data=independent_data) + refuter.conditional_mutual_information(x="a", y="b", z=frozenset(["c"])) + + assert len(refuter._true_implications) == 1, "Independent pair should be accepted as conditionally independent" + assert len(refuter._false_implications) == 0 + + key = ("a", "b") + (frozenset(["c"]),) + p_value, result = refuter._results[key] + assert p_value >= 0.05, f"Expected p_value >= 0.05 for independent pair, got {p_value}" + assert result is True + + def test_dependent_pair_yields_low_pvalue(self, dependent_data): + """A strongly dependent discrete pair should yield p_value < 0.05.""" + refuter = GraphRefuter(data=dependent_data) + refuter.conditional_mutual_information(x="a", y="b", z=frozenset(["c"])) + + assert len(refuter._false_implications) == 1, "Dependent pair should be rejected as conditionally independent" + assert len(refuter._true_implications) == 0 + + key = ("a", "b") + (frozenset(["c"]),) + p_value, result = refuter._results[key] + assert p_value < 0.05, f"Expected p_value < 0.05 for dependent pair, got {p_value}" + assert result is False + + def test_degenerate_constant_column(self): + """A constant column (cardinality 1) should be treated as non-rejection (p_value=1.0).""" + np.random.seed(0) + n = 100 + data = pd.DataFrame( + { + "a": np.zeros(n, dtype=np.int64), # constant: cardinality 1 + "b": np.random.randint(0, 2, n).astype(np.int64), + "c": np.random.randint(0, 2, n).astype(np.int64), + } + ) + refuter = GraphRefuter(data=data) + refuter.conditional_mutual_information(x="a", y="b", z=frozenset(["c"])) + + key = ("a", "b") + (frozenset(["c"]),) + p_value, result = refuter._results[key] + assert p_value == 1.0, "Degenerate (constant) variable should give p_value=1.0" + assert result is True