Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions dowhy/utils/cit.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,16 @@ def conditional_MI(data=None, x=None, y=None, z=None):
:param x,y,z : column names from dataset
:returns : conditional mutual information between X and Y given Z
"""
X = data[list(x)].astype(int)
Y = data[list(y)].astype(int)
t = list(z)
Z = data[t].astype(int)
Z = Z.values.tolist()
x_cols = [x] if isinstance(x, str) else list(x)
y_cols = [y] if isinstance(y, str) else list(y)
t = list(z) if not isinstance(z, list) else z
# Use squeeze() to convert single-column DataFrames to Series so that
# iterating yields row values instead of column labels.
X = data[x_cols].astype(int).squeeze(axis=1)
Y = data[y_cols].astype(int).squeeze(axis=1)
Z = list(data[t].itertuples(index=False, name=None))
Hxz = entropy(map(lambda x: "%s/%s" % x, zip(X, Z))) # Finding Joint entropy of X and Z
Hyz = entropy(map(lambda x: "%s/%s" % x, zip(Y, Z))) # Finding Joint entropy of Y and Z
Hxz = entropy(map(lambda v: "%s/%s" % v, zip(X, Z))) # Finding Joint entropy of X and Z
Hyz = entropy(map(lambda v: "%s/%s" % v, zip(Y, Z))) # Finding Joint entropy of Y and Z
Hz = entropy(Z) # Finding Entropy of Z
Hxyz = entropy(map(lambda x: "%s/%s/%s" % x, zip(X, Y, Z))) # Finding Joint Entropy of X, Y and Z
Hxyz = entropy(map(lambda v: "%s/%s/%s" % v, zip(X, Y, Z))) # Finding Joint Entropy of X, Y and Z
return Hxz + Hyz - Hxyz - Hz
69 changes: 69 additions & 0 deletions tests/utils/test_cit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import numpy as np
import pandas as pd

from dowhy.utils.cit import conditional_MI


class TestConditionalMI:
"""Tests for the conditional_MI function."""

def test_multi_char_column_names(self):
"""Regression test for #949: column names with >1 character were iterated as chars."""
rng = np.random.default_rng(42)
n = 200
# Use multi-character column names that would be broken by list("Foo") -> ['F','o','o']
df = pd.DataFrame(
{
"Foo": rng.integers(0, 3, size=n),
"Bar": rng.integers(0, 3, size=n),
"Baz": rng.integers(0, 3, size=n),
}
)
# Should not raise KeyError
result = conditional_MI(data=df, x="Foo", y="Bar", z=["Baz"])
assert isinstance(result, float)

def test_single_char_column_names(self):
"""Single-character column names should still work."""
rng = np.random.default_rng(0)
n = 200
df = pd.DataFrame(
{
"X": rng.integers(0, 2, size=n),
"Y": rng.integers(0, 2, size=n),
"Z": rng.integers(0, 2, size=n),
}
)
result = conditional_MI(data=df, x="X", y="Y", z=["Z"])
assert isinstance(result, float)

def test_independent_variables_low_cmi(self):
"""Independent variables should have low conditional mutual information."""
rng = np.random.default_rng(7)
n = 5000
df = pd.DataFrame(
{
"Alpha": rng.integers(0, 2, size=n), # independent of Beta
"Beta": rng.integers(0, 2, size=n),
"Gamma": rng.integers(0, 2, size=n),
}
)
result = conditional_MI(data=df, x="Alpha", y="Beta", z=["Gamma"])
# Truly independent variables should yield low CMI
assert result < 0.05

def test_dependent_variables_high_cmi(self):
"""Fully dependent variables should have high conditional mutual information."""
rng = np.random.default_rng(42)
n = 1000
x_vals = rng.integers(0, 2, size=n)
df = pd.DataFrame(
{
"Foo": x_vals,
"Bar": x_vals, # identical to Foo -> fully dependent
"Baz": rng.integers(0, 2, size=n),
}
)
result = conditional_MI(data=df, x="Foo", y="Bar", z=["Baz"])
# Fully dependent variables should yield CMI close to 1 bit
assert result > 0.5