diff --git a/dowhy/causal_identifier/__init__.py b/dowhy/causal_identifier/__init__.py index 1d23200946..4529aa1b78 100644 --- a/dowhy/causal_identifier/__init__.py +++ b/dowhy/causal_identifier/__init__.py @@ -26,3 +26,4 @@ "construct_frontdoor_estimand", "construct_iv_estimand", ] +from dowhy.causal_identifier.zid_identifier import ZIDIdentifier diff --git a/dowhy/causal_identifier/auto_identifier.py b/dowhy/causal_identifier/auto_identifier.py index b779192d99..fea5de1ac3 100644 --- a/dowhy/causal_identifier/auto_identifier.py +++ b/dowhy/causal_identifier/auto_identifier.py @@ -162,6 +162,7 @@ def identify_effect_auto( optimize_backdoor: bool = False, costs: Optional[List] = None, generalized_adjustment: GeneralizedAdjustment = GeneralizedAdjustment.GENERALIZED_ADJUSTMENT_DEFAULT, + surrogate_nodes: Optional[List[str]] = None, ) -> IdentifiedEstimand: """Main method that returns an identified estimand (if one exists). @@ -205,6 +206,7 @@ def identify_effect_auto( costs, conditional_node_names, generalized_adjustment, + surrogate_nodes=surrogate_nodes, ) elif estimand_type == EstimandType.NONPARAMETRIC_NDE: return identify_nde_effect( @@ -240,6 +242,7 @@ def identify_ate_effect( costs: List, conditional_node_names: List[str] = None, generalized_adjustment: GeneralizedAdjustment = GeneralizedAdjustment.GENERALIZED_ADJUSTMENT_DEFAULT, + surrogate_nodes: Optional[List[str]] = None, ): estimands_dict = {} mediation_first_stage_confounders = None @@ -337,6 +340,23 @@ def identify_ate_effect( logger.warning( f"Generalized covariate adjustment identification is not supported for the detected Python version: {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}." ) + ### 5. Z-IDENTIFIABILITY (SURROGATE EXPERIMENTS) + if surrogate_nodes: + try: + from dowhy.causal_identifier.zid_identifier import ZIDIdentifier + _zid = ZIDIdentifier(graph, action_nodes, outcome_nodes, surrogate_nodes) + _zid_estimand = _zid.identify_effect() + estimands_dict["zid"] = _zid_estimand + logger.info("z-ID succeeded via surrogate nodes: " + str(surrogate_nodes)) + except ImportError: + logger.warning("ZIDIdentifier not available; skipping z-ID step.") + estimands_dict["zid"] = None + except Exception as _e: + logger.debug("z-ID did not succeed: " + str(_e)) + estimands_dict["zid"] = None + else: + estimands_dict["zid"] = None + # Finally returning the estimand object estimand = IdentifiedEstimand( None, diff --git a/dowhy/causal_identifier/zid_identifier.py b/dowhy/causal_identifier/zid_identifier.py new file mode 100644 index 0000000000..f713437575 --- /dev/null +++ b/dowhy/causal_identifier/zid_identifier.py @@ -0,0 +1,150 @@ +""" +Bridge identifier for z-Identifiability via surrogate experiments. + +Implements ZIDIdentifier, which integrates the complete z-ID decision procedure +(Bareinboim & Pearl, 2012) into DoWhy's identification pipeline by translating +DoWhy's graph representation into ananke-causal's ADMG format. +""" + +import itertools +import logging +from typing import List, Tuple + +import networkx as nx +from dowhy.causal_identifier.identified_estimand import IdentifiedEstimand + +logger = logging.getLogger(__name__) + + +class ZIDIdentifier: + """ + Bridge identifier for complete z-Identifiability using surrogate + experiments, powered by ananke-causal. + + DoWhy represents latent confounders as unobserved nodes (observed='no') + with directed edges to their children. This class converts that + representation into ananke's bidirected-edge ADMG format before running + the z-ID decision procedure. + """ + + def __init__( + self, + graph: nx.DiGraph, + action_nodes: List[str], + outcome_nodes: List[str], + surrogate_nodes: List[str], + ): + self.graph = graph + self.action_nodes = list(action_nodes) + self.outcome_nodes = list(outcome_nodes) + self.surrogate_nodes = list(surrogate_nodes) + + def identify_effect(self) -> IdentifiedEstimand: + """ + Run the z-ID decision procedure and return an IdentifiedEstimand. + Identifying functional on success: P(y|do(x)) = sum_z P(y|x,z) P(z). + Raises Exception if not z-identifiable. + """ + try: + from ananke.graphs import ADMG as _ADMG + from ananke.identification.idz import idz_id as _idz_id + except ImportError: + raise ImportError( + "pyananke is required for ZIDIdentifier. " + "Install it with: pip install dowhy[zid]" + ) + + ananke_graph = self._convert_to_ananke(self.graph) + + logger.debug( + "ZIDIdentifier: vertices=%s di_edges=%s bi_edges=%s", + ananke_graph.vertices, ananke_graph.di_edges, ananke_graph.bi_edges, + ) + + is_identifiable = _idz_id( + graph=ananke_graph, + treatments=self.action_nodes, + outcomes=self.outcome_nodes, + surrogates=self.surrogate_nodes, + ) + + if not is_identifiable: + raise Exception( + f"P({self.outcome_nodes} | do({self.action_nodes})) is NOT " + f"z-identifiable given surrogates Z={self.surrogate_nodes}." + ) + + logger.debug("ZIDIdentifier: effect IS z-identifiable. Functional: sum_z P(y|x,z)P(z)") + + return IdentifiedEstimand( + None, + treatment_variable=self.action_nodes, + outcome_variable=self.outcome_nodes, + estimand_type="nonparametric-ate", + estimands={"backdoor": None}, + backdoor_variables=self.surrogate_nodes, + instrumental_variables=None, + frontdoor_variables=None, + mediation_first_stage_confounders=None, + mediation_second_stage_confounders=None, + ) + + def _convert_to_ananke(self, nx_graph: nx.DiGraph) -> "ADMG": + """ + Convert DoWhy nx.DiGraph → ananke ADMG. + + Strategy 1 (primary): nodes with observed='no' are latent confounders. + Each such node U with observed children {C1,...,Cn} → C(n,2) bi_edges. + U and its di_edges are excluded from the ADMG. + + Strategy 2 (fallback): edges with style='bidirected', bidirected=True, + or arrowhead='both' are collected as bi_edges directly. + """ + from ananke.graphs import ADMG + + di_edges: List[Tuple[str, str]] = [] + bi_edges: List[Tuple[str, str]] = [] + latent_nodes: set = set() + + # Strategy 1 + for node, attrs in nx_graph.nodes(data=True): + if attrs.get("observed", "yes") == "no": + latent_nodes.add(node) + children = list(nx_graph.successors(node)) + obs_children = [ + c for c in children + if nx_graph.nodes[c].get("observed", "yes") == "yes" + ] + if len(obs_children) >= 2: + for a, b in itertools.combinations(obs_children, 2): + bi_edges.append((a, b)) + elif len(obs_children) == 1: + logger.warning("Latent node '%s' has only 1 observed child; no bi_edge emitted.", node) + else: + logger.warning("Latent node '%s' has no observed children; skipped.", node) + + # Strategy 2 + directed edges + for u, v, attrs in nx_graph.edges(data=True): + if (attrs.get("style") == "bidirected" + or attrs.get("bidirected") is True + or attrs.get("arrowhead") == "both"): + bi_edges.append((u, v)) + else: + if u not in latent_nodes and v not in latent_nodes: + di_edges.append((u, v)) + + vertices = [ + n for n in nx_graph.nodes() + if nx_graph.nodes[n].get("observed", "yes") == "yes" + ] + + # Deduplicate bidirected edges (unordered pairs) + seen: set = set() + unique_bi: List[Tuple[str, str]] = [] + for a, b in bi_edges: + key = frozenset((a, b)) + if key not in seen: + seen.add(key) + unique_bi.append((a, b)) + + return ADMG(vertices=vertices, di_edges=di_edges, bi_edges=unique_bi) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 16917ddd09..4ba9592a30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,7 @@ networkx = [ sympy = ">=1.10.1" scikit-learn = ">1.0" pydot = { version = "^1.4.2", optional = true } +pyananke = { version = ">=0.6.1", optional = true } joblib = ">=1.1.0" pygraphviz = { version = ">=1.9", optional = true } econml = ">=0.16" @@ -97,6 +98,7 @@ pygraphviz = ["pygraphviz"] pydot = ["pydot"] plotting = ["matplotlib"] econml = ["econml"] +zid = ["pyananke"] [tool.poetry.group.dev.dependencies] poethepoet = ">=0.24.4" diff --git a/tests/causal_identifiers/test_zid_identifier.py b/tests/causal_identifiers/test_zid_identifier.py new file mode 100644 index 0000000000..d790fb82da --- /dev/null +++ b/tests/causal_identifiers/test_zid_identifier.py @@ -0,0 +1,146 @@ +""" +Tests for ZIDIdentifier and z-ID integration in identify_effect_auto. + +Covers: + - ADMG conversion (unobserved nodes, explicit bidirected attrs) + - identify_effect() decision procedure (True/False cases) + - Integration via identify_effect_auto with surrogate_nodes parameter +""" + +import networkx as nx +import pytest + +ananke = pytest.importorskip("ananke", reason="ananke-causal not installed; skipping z-ID tests") + +from dowhy.causal_identifier.auto_identifier import EstimandType, identify_effect_auto +from dowhy.causal_identifier.zid_identifier import ZIDIdentifier + + +def build_graph(di_edges, confounders=None): + """Build a DoWhy-style nx.DiGraph with unobserved common-cause nodes.""" + G = nx.DiGraph() + observed = set() + for u, v in di_edges: + G.add_edge(u, v) + observed.update([u, v]) + for n in observed: + G.nodes[n]["observed"] = "yes" + for i, (a, b) in enumerate(confounders or []): + uid = f"__U_{i}__" + G.add_node(uid, observed="no") + G.add_edge(uid, a) + G.add_edge(uid, b) + return G + + +class TestZIDIdentifierConversion: + def test_no_confounders(self): + G = build_graph([("X", "Y")]) + zid = ZIDIdentifier(G, ["X"], ["Y"], []) + admg = zid._convert_to_ananke(G) + assert set(admg.vertices) == {"X", "Y"} + assert set(map(tuple, admg.di_edges)) == {("X", "Y")} + assert list(admg.bi_edges) == [] + + def test_single_latent_confounder(self): + G = build_graph([("X", "Y")], confounders=[("X", "Y")]) + zid = ZIDIdentifier(G, ["X"], ["Y"], []) + admg = zid._convert_to_ananke(G) + assert set(admg.vertices) == {"X", "Y"} + assert {frozenset(e) for e in admg.bi_edges} == {frozenset(("X", "Y"))} + + def test_multi_child_latent_emits_all_pairs(self): + G = nx.DiGraph() + for n in ["A", "B", "C"]: + G.add_node(n, observed="yes") + G.add_node("U", observed="no") + for c in ["A", "B", "C"]: + G.add_edge("U", c) + G.add_edge("A", "B") + G.add_edge("B", "C") + zid = ZIDIdentifier(G, ["A"], ["C"], ["B"]) + admg = zid._convert_to_ananke(G) + bi = {frozenset(e) for e in admg.bi_edges} + assert bi == {frozenset(("A", "B")), frozenset(("A", "C")), frozenset(("B", "C"))} + + def test_explicit_bidirected_attr(self): + G = nx.DiGraph() + G.add_node("X", observed="yes") + G.add_node("Y", observed="yes") + G.add_edge("X", "Y") + G.add_edge("Y", "X", style="bidirected") + zid = ZIDIdentifier(G, ["X"], ["Y"], []) + admg = zid._convert_to_ananke(G) + assert frozenset(("X", "Y")) in {frozenset(e) for e in admg.bi_edges} + + +class TestZIDIdentifierDecision: + def test_not_identifiable_no_surrogates(self): + G = build_graph([("X", "Y")], confounders=[("X", "Y")]) + zid = ZIDIdentifier(G, ["X"], ["Y"], []) + with pytest.raises(Exception, match="NOT z-identifiable"): + zid.identify_effect() + + def test_no_confounders_identifiable(self): + G = build_graph([("X", "Y")]) + zid = ZIDIdentifier(G, ["X"], ["Y"], []) + estimand = zid.identify_effect() + assert estimand.backdoor_variables == [] + + def test_rescue_case(self): + """Z->X->Y, W1<->X, W1<->Y, W2<->Z, X<->Z, Y<->Z: z-ID succeeds.""" + G = build_graph( + [("W_1", "Z"), ("X", "Y"), ("Z", "X")], + confounders=[("W_1", "X"), ("W_1", "Y"), ("W_2", "Z"), ("X", "Z"), ("Y", "Z")], + ) + zid = ZIDIdentifier(G, ["X"], ["Y"], ["Z"]) + estimand = zid.identify_effect() + assert estimand.backdoor_variables == ["Z"] + + +class TestZIDAutoIdentifier: + def test_surrogate_nodes_populates_zid_key(self): + G = build_graph( + [("W_1", "Z"), ("X", "Y"), ("Z", "X")], + confounders=[("W_1", "X"), ("W_1", "Y"), ("W_2", "Z"), ("X", "Z"), ("Y", "Z")], + ) + observed = [n for n in G.nodes if G.nodes[n].get("observed", "yes") == "yes"] + estimand = identify_effect_auto( + G, + action_nodes=["X"], + outcome_nodes=["Y"], + observed_nodes=observed, + estimand_type=EstimandType.NONPARAMETRIC_ATE, + surrogate_nodes=["Z"], + ) + assert "zid" in estimand.estimands + assert estimand.estimands["zid"] is not None + + def test_no_surrogates_zid_key_is_none(self): + G = build_graph([("X", "Y")]) + estimand = identify_effect_auto( + G, + action_nodes=["X"], + outcome_nodes=["Y"], + observed_nodes=["X", "Y"], + estimand_type=EstimandType.NONPARAMETRIC_ATE, + ) + assert estimand.estimands.get("zid") is None + + def test_backdoor_none_zid_rescues(self): + """Pure rescue: backdoor=None but zid succeeds.""" + G = build_graph( + [("W_1", "Z"), ("X", "Y"), ("Z", "X")], + confounders=[("W_1", "X"), ("W_1", "Y"), ("W_2", "Z"), ("X", "Z"), ("Y", "Z")], + ) + observed = [n for n in G.nodes if G.nodes[n].get("observed", "yes") == "yes"] + estimand = identify_effect_auto( + G, + action_nodes=["X"], + outcome_nodes=["Y"], + observed_nodes=observed, + estimand_type=EstimandType.NONPARAMETRIC_ATE, + surrogate_nodes=["Z"], + ) + assert estimand.estimands["backdoor"] is None + assert estimand.estimands["zid"] is not None