Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dowhy/causal_identifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
"construct_frontdoor_estimand",
"construct_iv_estimand",
]
from dowhy.causal_identifier.zid_identifier import ZIDIdentifier
20 changes: 20 additions & 0 deletions dowhy/causal_identifier/auto_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def identify_effect_auto(
optimize_backdoor: bool = False,
costs: Optional[List] = None,
generalized_adjustment: GeneralizedAdjustment = GeneralizedAdjustment.GENERALIZED_ADJUSTMENT_DEFAULT,
surrogate_nodes: Optional[List[str]] = None,
) -> IdentifiedEstimand:
"""Main method that returns an identified estimand (if one exists).

Expand Down Expand Up @@ -205,6 +206,7 @@ def identify_effect_auto(
costs,
conditional_node_names,
generalized_adjustment,
surrogate_nodes=surrogate_nodes,
)
elif estimand_type == EstimandType.NONPARAMETRIC_NDE:
return identify_nde_effect(
Expand Down Expand Up @@ -240,6 +242,7 @@ def identify_ate_effect(
costs: List,
conditional_node_names: List[str] = None,
generalized_adjustment: GeneralizedAdjustment = GeneralizedAdjustment.GENERALIZED_ADJUSTMENT_DEFAULT,
surrogate_nodes: Optional[List[str]] = None,
):
estimands_dict = {}
mediation_first_stage_confounders = None
Expand Down Expand Up @@ -337,6 +340,23 @@ def identify_ate_effect(
logger.warning(
f"Generalized covariate adjustment identification is not supported for the detected Python version: {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}."
)
### 5. Z-IDENTIFIABILITY (SURROGATE EXPERIMENTS)
if surrogate_nodes:
try:
from dowhy.causal_identifier.zid_identifier import ZIDIdentifier
_zid = ZIDIdentifier(graph, action_nodes, outcome_nodes, surrogate_nodes)
_zid_estimand = _zid.identify_effect()
estimands_dict["zid"] = _zid_estimand
logger.info("z-ID succeeded via surrogate nodes: " + str(surrogate_nodes))
except ImportError:
logger.warning("ZIDIdentifier not available; skipping z-ID step.")
estimands_dict["zid"] = None
except Exception as _e:
logger.debug("z-ID did not succeed: " + str(_e))
estimands_dict["zid"] = None
else:
estimands_dict["zid"] = None

# Finally returning the estimand object
estimand = IdentifiedEstimand(
None,
Expand Down
150 changes: 150 additions & 0 deletions dowhy/causal_identifier/zid_identifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
Bridge identifier for z-Identifiability via surrogate experiments.

Implements ZIDIdentifier, which integrates the complete z-ID decision procedure
(Bareinboim & Pearl, 2012) into DoWhy's identification pipeline by translating
DoWhy's graph representation into ananke-causal's ADMG format.
"""

import itertools
import logging
from typing import List, Tuple

import networkx as nx
from dowhy.causal_identifier.identified_estimand import IdentifiedEstimand

logger = logging.getLogger(__name__)


class ZIDIdentifier:
"""
Bridge identifier for complete z-Identifiability using surrogate
experiments, powered by ananke-causal.

DoWhy represents latent confounders as unobserved nodes (observed='no')
with directed edges to their children. This class converts that
representation into ananke's bidirected-edge ADMG format before running
the z-ID decision procedure.
"""

def __init__(
self,
graph: nx.DiGraph,
action_nodes: List[str],
outcome_nodes: List[str],
surrogate_nodes: List[str],
):
self.graph = graph
self.action_nodes = list(action_nodes)
self.outcome_nodes = list(outcome_nodes)
self.surrogate_nodes = list(surrogate_nodes)

def identify_effect(self) -> IdentifiedEstimand:
"""
Run the z-ID decision procedure and return an IdentifiedEstimand.
Identifying functional on success: P(y|do(x)) = sum_z P(y|x,z) P(z).
Raises Exception if not z-identifiable.
"""
try:
from ananke.graphs import ADMG as _ADMG
from ananke.identification.idz import idz_id as _idz_id
except ImportError:
raise ImportError(
"pyananke is required for ZIDIdentifier. "
"Install it with: pip install dowhy[zid]"
)

ananke_graph = self._convert_to_ananke(self.graph)

logger.debug(
"ZIDIdentifier: vertices=%s di_edges=%s bi_edges=%s",
ananke_graph.vertices, ananke_graph.di_edges, ananke_graph.bi_edges,
)

is_identifiable = _idz_id(
graph=ananke_graph,
treatments=self.action_nodes,
outcomes=self.outcome_nodes,
surrogates=self.surrogate_nodes,
)

if not is_identifiable:
raise Exception(
f"P({self.outcome_nodes} | do({self.action_nodes})) is NOT "
f"z-identifiable given surrogates Z={self.surrogate_nodes}."
)

logger.debug("ZIDIdentifier: effect IS z-identifiable. Functional: sum_z P(y|x,z)P(z)")

return IdentifiedEstimand(
None,
treatment_variable=self.action_nodes,
outcome_variable=self.outcome_nodes,
estimand_type="nonparametric-ate",
estimands={"backdoor": None},
backdoor_variables=self.surrogate_nodes,
instrumental_variables=None,
frontdoor_variables=None,
mediation_first_stage_confounders=None,
mediation_second_stage_confounders=None,
)

def _convert_to_ananke(self, nx_graph: nx.DiGraph) -> "ADMG":
"""
Convert DoWhy nx.DiGraph → ananke ADMG.

Strategy 1 (primary): nodes with observed='no' are latent confounders.
Each such node U with observed children {C1,...,Cn} → C(n,2) bi_edges.
U and its di_edges are excluded from the ADMG.

Strategy 2 (fallback): edges with style='bidirected', bidirected=True,
or arrowhead='both' are collected as bi_edges directly.
"""
from ananke.graphs import ADMG

di_edges: List[Tuple[str, str]] = []
bi_edges: List[Tuple[str, str]] = []
latent_nodes: set = set()

# Strategy 1
for node, attrs in nx_graph.nodes(data=True):
if attrs.get("observed", "yes") == "no":
latent_nodes.add(node)
children = list(nx_graph.successors(node))
obs_children = [
c for c in children
if nx_graph.nodes[c].get("observed", "yes") == "yes"
]
if len(obs_children) >= 2:
for a, b in itertools.combinations(obs_children, 2):
bi_edges.append((a, b))
elif len(obs_children) == 1:
logger.warning("Latent node '%s' has only 1 observed child; no bi_edge emitted.", node)
else:
logger.warning("Latent node '%s' has no observed children; skipped.", node)

# Strategy 2 + directed edges
for u, v, attrs in nx_graph.edges(data=True):
if (attrs.get("style") == "bidirected"
or attrs.get("bidirected") is True
or attrs.get("arrowhead") == "both"):
bi_edges.append((u, v))
else:
if u not in latent_nodes and v not in latent_nodes:
di_edges.append((u, v))

vertices = [
n for n in nx_graph.nodes()
if nx_graph.nodes[n].get("observed", "yes") == "yes"
]

# Deduplicate bidirected edges (unordered pairs)
seen: set = set()
unique_bi: List[Tuple[str, str]] = []
for a, b in bi_edges:
key = frozenset((a, b))
if key not in seen:
seen.add(key)
unique_bi.append((a, b))

return ADMG(vertices=vertices, di_edges=di_edges, bi_edges=unique_bi)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ networkx = [
sympy = ">=1.10.1"
scikit-learn = ">1.0"
pydot = { version = "^1.4.2", optional = true }
pyananke = { version = ">=0.6.1", optional = true }
joblib = ">=1.1.0"
pygraphviz = { version = ">=1.9", optional = true }
econml = ">=0.16"
Expand All @@ -97,6 +98,7 @@ pygraphviz = ["pygraphviz"]
pydot = ["pydot"]
plotting = ["matplotlib"]
econml = ["econml"]
zid = ["pyananke"]

[tool.poetry.group.dev.dependencies]
poethepoet = ">=0.24.4"
Expand Down
146 changes: 146 additions & 0 deletions tests/causal_identifiers/test_zid_identifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
Tests for ZIDIdentifier and z-ID integration in identify_effect_auto.

Covers:
- ADMG conversion (unobserved nodes, explicit bidirected attrs)
- identify_effect() decision procedure (True/False cases)
- Integration via identify_effect_auto with surrogate_nodes parameter
"""

import networkx as nx
import pytest

ananke = pytest.importorskip("ananke", reason="ananke-causal not installed; skipping z-ID tests")

from dowhy.causal_identifier.auto_identifier import EstimandType, identify_effect_auto
from dowhy.causal_identifier.zid_identifier import ZIDIdentifier


def build_graph(di_edges, confounders=None):
"""Build a DoWhy-style nx.DiGraph with unobserved common-cause nodes."""
G = nx.DiGraph()
observed = set()
for u, v in di_edges:
G.add_edge(u, v)
observed.update([u, v])
for n in observed:
G.nodes[n]["observed"] = "yes"
for i, (a, b) in enumerate(confounders or []):
uid = f"__U_{i}__"
G.add_node(uid, observed="no")
G.add_edge(uid, a)
G.add_edge(uid, b)
return G


class TestZIDIdentifierConversion:
def test_no_confounders(self):
G = build_graph([("X", "Y")])
zid = ZIDIdentifier(G, ["X"], ["Y"], [])
admg = zid._convert_to_ananke(G)
assert set(admg.vertices) == {"X", "Y"}
assert set(map(tuple, admg.di_edges)) == {("X", "Y")}
assert list(admg.bi_edges) == []

def test_single_latent_confounder(self):
G = build_graph([("X", "Y")], confounders=[("X", "Y")])
zid = ZIDIdentifier(G, ["X"], ["Y"], [])
admg = zid._convert_to_ananke(G)
assert set(admg.vertices) == {"X", "Y"}
assert {frozenset(e) for e in admg.bi_edges} == {frozenset(("X", "Y"))}

def test_multi_child_latent_emits_all_pairs(self):
G = nx.DiGraph()
for n in ["A", "B", "C"]:
G.add_node(n, observed="yes")
G.add_node("U", observed="no")
for c in ["A", "B", "C"]:
G.add_edge("U", c)
G.add_edge("A", "B")
G.add_edge("B", "C")
zid = ZIDIdentifier(G, ["A"], ["C"], ["B"])
admg = zid._convert_to_ananke(G)
bi = {frozenset(e) for e in admg.bi_edges}
assert bi == {frozenset(("A", "B")), frozenset(("A", "C")), frozenset(("B", "C"))}

def test_explicit_bidirected_attr(self):
G = nx.DiGraph()
G.add_node("X", observed="yes")
G.add_node("Y", observed="yes")
G.add_edge("X", "Y")
G.add_edge("Y", "X", style="bidirected")
zid = ZIDIdentifier(G, ["X"], ["Y"], [])
admg = zid._convert_to_ananke(G)
assert frozenset(("X", "Y")) in {frozenset(e) for e in admg.bi_edges}


class TestZIDIdentifierDecision:
def test_not_identifiable_no_surrogates(self):
G = build_graph([("X", "Y")], confounders=[("X", "Y")])
zid = ZIDIdentifier(G, ["X"], ["Y"], [])
with pytest.raises(Exception, match="NOT z-identifiable"):
zid.identify_effect()

def test_no_confounders_identifiable(self):
G = build_graph([("X", "Y")])
zid = ZIDIdentifier(G, ["X"], ["Y"], [])
estimand = zid.identify_effect()
assert estimand.backdoor_variables == []

def test_rescue_case(self):
"""Z->X->Y, W1<->X, W1<->Y, W2<->Z, X<->Z, Y<->Z: z-ID succeeds."""
G = build_graph(
[("W_1", "Z"), ("X", "Y"), ("Z", "X")],
confounders=[("W_1", "X"), ("W_1", "Y"), ("W_2", "Z"), ("X", "Z"), ("Y", "Z")],
)
zid = ZIDIdentifier(G, ["X"], ["Y"], ["Z"])
estimand = zid.identify_effect()
assert estimand.backdoor_variables == ["Z"]


class TestZIDAutoIdentifier:
def test_surrogate_nodes_populates_zid_key(self):
G = build_graph(
[("W_1", "Z"), ("X", "Y"), ("Z", "X")],
confounders=[("W_1", "X"), ("W_1", "Y"), ("W_2", "Z"), ("X", "Z"), ("Y", "Z")],
)
observed = [n for n in G.nodes if G.nodes[n].get("observed", "yes") == "yes"]
estimand = identify_effect_auto(
G,
action_nodes=["X"],
outcome_nodes=["Y"],
observed_nodes=observed,
estimand_type=EstimandType.NONPARAMETRIC_ATE,
surrogate_nodes=["Z"],
)
assert "zid" in estimand.estimands
assert estimand.estimands["zid"] is not None

def test_no_surrogates_zid_key_is_none(self):
G = build_graph([("X", "Y")])
estimand = identify_effect_auto(
G,
action_nodes=["X"],
outcome_nodes=["Y"],
observed_nodes=["X", "Y"],
estimand_type=EstimandType.NONPARAMETRIC_ATE,
)
assert estimand.estimands.get("zid") is None

def test_backdoor_none_zid_rescues(self):
"""Pure rescue: backdoor=None but zid succeeds."""
G = build_graph(
[("W_1", "Z"), ("X", "Y"), ("Z", "X")],
confounders=[("W_1", "X"), ("W_1", "Y"), ("W_2", "Z"), ("X", "Z"), ("Y", "Z")],
)
observed = [n for n in G.nodes if G.nodes[n].get("observed", "yes") == "yes"]
estimand = identify_effect_auto(
G,
action_nodes=["X"],
outcome_nodes=["Y"],
observed_nodes=observed,
estimand_type=EstimandType.NONPARAMETRIC_ATE,
surrogate_nodes=["Z"],
)
assert estimand.estimands["backdoor"] is None
assert estimand.estimands["zid"] is not None
Loading