Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions dowhy/gcm/whatif.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
validate_causal_dag,
)
from dowhy.gcm.fitting_sampling import draw_samples
from dowhy.gcm.util.general import set_random_seed
from dowhy.graph import (
DirectedGraph,
get_ordered_predecessors,
Expand All @@ -29,6 +30,7 @@ def interventional_samples(
interventions: Dict[Any, Callable[[np.ndarray], Union[float, np.ndarray]]],
observed_data: Optional[pd.DataFrame] = None,
num_samples_to_draw: Optional[int] = None,
random_seed: Optional[int] = None,
) -> pd.DataFrame:
"""Performs intervention on nodes in the causal graph.

Expand All @@ -40,8 +42,12 @@ def interventional_samples(
:param observed_data: Optionally, data on which to perform interventions. If None are given, data is generated based
on the generative models.
:param num_samples_to_draw: Sample size to draw from the interventional distribution.
:param random_seed: Optional seed for the random number generator to make results reproducible.
:return: Samples from the interventional distribution.
"""
if random_seed is not None:
set_random_seed(random_seed)

validate_causal_dag(causal_model.graph)
for node in interventions:
validate_node_in_graph(causal_model.graph, node)
Expand Down Expand Up @@ -107,6 +113,7 @@ def counterfactual_samples(
interventions: Dict[Any, Callable[[np.ndarray], Union[float, np.ndarray]]],
observed_data: Optional[pd.DataFrame] = None,
noise_data: Optional[pd.DataFrame] = None,
random_seed: Optional[int] = None,
) -> pd.DataFrame:
"""Estimates counterfactual data for observed data if we were to perform specified interventions. This function
implements the 3-step process for computing counterfactuals by Pearl (see https://ftp.cs.ucla.edu/pub/stat_ser/r485.pdf).
Expand All @@ -121,8 +128,12 @@ def counterfactual_samples(
:param noise_data: Data of noise terms corresponding to nodes in the causal graph. If not provided,
these have to be estimated from observed data. Then we require causal models of nodes to be
invertible.
:param random_seed: Optional seed for the random number generator to make results reproducible.
:return: Estimated counterfactual data.
"""
if random_seed is not None:
set_random_seed(random_seed)

for node in interventions:
validate_node_in_graph(causal_model.graph, node)

Expand Down Expand Up @@ -196,6 +207,7 @@ def average_causal_effect(
interventions_reference: Dict[Any, Callable[[np.ndarray], Union[float, np.ndarray]]],
observed_data: Optional[pd.DataFrame] = None,
num_samples_to_draw: Optional[int] = None,
random_seed: Optional[int] = None,
) -> float:
"""Estimates the average causal effect (ACE) on the target of two different sets of interventions.
The interventions can be specified through the parameters `interventions_alternative` and `interventions_reference`.
Expand Down Expand Up @@ -223,8 +235,11 @@ def average_causal_effect(
models.
:param num_samples_to_draw: Number of samples drawn from the causal model for estimating ACE if no observed data is
given.
:param random_seed: Optional seed for the random number generator to make results reproducible.
:return: The estimated average causal effect (ACE).
"""
if random_seed is not None:
set_random_seed(random_seed)
# For estimating the effect, we only need to consider the nodes that have a directed path to the target node, i.e.
# all ancestors of the target.
causal_model = ProbabilisticCausalModel(node_connected_subgraph_view(causal_model.graph, target_node))
Expand Down
26 changes: 26 additions & 0 deletions tests/gcm/test_whatif.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,29 @@ def test_given_discrete_data_when_performing_interventions_then_returns_correct_
assert np.all(samples["X"].to_numpy() == -2)
assert np.median(samples["Y"].to_numpy()) == -1
assert np.mean(samples["Z"].to_numpy()) == approx(-2, abs=0.05)


def test_interventional_samples_with_random_seed_is_reproducible():
causal_model = ProbabilisticCausalModel(nx.DiGraph([("X", "Y")]))
causal_model.set_causal_mechanism("X", EmpiricalDistribution())
causal_model.set_causal_mechanism("Y", AdditiveNoiseModel(prediction_model=create_linear_regressor()))
data = pd.DataFrame({"X": np.random.normal(0, 1, 500)})
data["Y"] = 2 * data["X"] + np.random.normal(0, 0.1, 500)
fit(causal_model, data)

result1 = interventional_samples(causal_model, {"X": lambda x: 1.0}, num_samples_to_draw=100, random_seed=42)
result2 = interventional_samples(causal_model, {"X": lambda x: 1.0}, num_samples_to_draw=100, random_seed=42)
np.testing.assert_array_equal(result1.to_numpy(), result2.to_numpy())


def test_counterfactual_samples_with_random_seed_is_reproducible():
causal_model = InvertibleStructuralCausalModel(nx.DiGraph([("X", "Y")]))
causal_model.set_causal_mechanism("X", EmpiricalDistribution())
causal_model.set_causal_mechanism("Y", AdditiveNoiseModel(prediction_model=create_linear_regressor()))
data = pd.DataFrame({"X": np.random.normal(0, 1, 200)})
data["Y"] = 2 * data["X"] + np.random.normal(0, 0.1, 200)
fit(causal_model, data)

result1 = counterfactual_samples(causal_model, {"X": lambda x: 0.0}, observed_data=data, random_seed=42)
result2 = counterfactual_samples(causal_model, {"X": lambda x: 0.0}, observed_data=data, random_seed=42)
np.testing.assert_array_equal(result1.to_numpy(), result2.to_numpy())