diff --git a/dowhy/utils/cit.py b/dowhy/utils/cit.py index e7ec6408aa..74dd1a5b25 100644 --- a/dowhy/utils/cit.py +++ b/dowhy/utils/cit.py @@ -139,14 +139,16 @@ def conditional_MI(data=None, x=None, y=None, z=None): :param x,y,z : column names from dataset :returns : conditional mutual information between X and Y given Z """ - X = data[list(x)].astype(int) - Y = data[list(y)].astype(int) - t = list(z) - Z = data[t].astype(int) - Z = Z.values.tolist() + x_cols = [x] if isinstance(x, str) else list(x) + y_cols = [y] if isinstance(y, str) else list(y) + t = list(z) if not isinstance(z, list) else z + # Use squeeze() to convert single-column DataFrames to Series so that + # iterating yields row values instead of column labels. + X = data[x_cols].astype(int).squeeze(axis=1) + Y = data[y_cols].astype(int).squeeze(axis=1) Z = list(data[t].itertuples(index=False, name=None)) - Hxz = entropy(map(lambda x: "%s/%s" % x, zip(X, Z))) # Finding Joint entropy of X and Z - Hyz = entropy(map(lambda x: "%s/%s" % x, zip(Y, Z))) # Finding Joint entropy of Y and Z + Hxz = entropy(map(lambda v: "%s/%s" % v, zip(X, Z))) # Finding Joint entropy of X and Z + Hyz = entropy(map(lambda v: "%s/%s" % v, zip(Y, Z))) # Finding Joint entropy of Y and Z Hz = entropy(Z) # Finding Entropy of Z - Hxyz = entropy(map(lambda x: "%s/%s/%s" % x, zip(X, Y, Z))) # Finding Joint Entropy of X, Y and Z + Hxyz = entropy(map(lambda v: "%s/%s/%s" % v, zip(X, Y, Z))) # Finding Joint Entropy of X, Y and Z return Hxz + Hyz - Hxyz - Hz diff --git a/tests/utils/test_cit.py b/tests/utils/test_cit.py new file mode 100644 index 0000000000..d95c9c9d65 --- /dev/null +++ b/tests/utils/test_cit.py @@ -0,0 +1,69 @@ +import numpy as np +import pandas as pd + +from dowhy.utils.cit import conditional_MI + + +class TestConditionalMI: + """Tests for the conditional_MI function.""" + + def test_multi_char_column_names(self): + """Regression test for #949: column names with >1 character were iterated as chars.""" + rng = np.random.default_rng(42) + n = 200 + # Use multi-character column names that would be broken by list("Foo") -> ['F','o','o'] + df = pd.DataFrame( + { + "Foo": rng.integers(0, 3, size=n), + "Bar": rng.integers(0, 3, size=n), + "Baz": rng.integers(0, 3, size=n), + } + ) + # Should not raise KeyError + result = conditional_MI(data=df, x="Foo", y="Bar", z=["Baz"]) + assert isinstance(result, float) + + def test_single_char_column_names(self): + """Single-character column names should still work.""" + rng = np.random.default_rng(0) + n = 200 + df = pd.DataFrame( + { + "X": rng.integers(0, 2, size=n), + "Y": rng.integers(0, 2, size=n), + "Z": rng.integers(0, 2, size=n), + } + ) + result = conditional_MI(data=df, x="X", y="Y", z=["Z"]) + assert isinstance(result, float) + + def test_independent_variables_low_cmi(self): + """Independent variables should have low conditional mutual information.""" + rng = np.random.default_rng(7) + n = 5000 + df = pd.DataFrame( + { + "Alpha": rng.integers(0, 2, size=n), # independent of Beta + "Beta": rng.integers(0, 2, size=n), + "Gamma": rng.integers(0, 2, size=n), + } + ) + result = conditional_MI(data=df, x="Alpha", y="Beta", z=["Gamma"]) + # Truly independent variables should yield low CMI + assert result < 0.05 + + def test_dependent_variables_high_cmi(self): + """Fully dependent variables should have high conditional mutual information.""" + rng = np.random.default_rng(42) + n = 1000 + x_vals = rng.integers(0, 2, size=n) + df = pd.DataFrame( + { + "Foo": x_vals, + "Bar": x_vals, # identical to Foo -> fully dependent + "Baz": rng.integers(0, 2, size=n), + } + ) + result = conditional_MI(data=df, x="Foo", y="Bar", z=["Baz"]) + # Fully dependent variables should yield CMI close to 1 bit + assert result > 0.5