diff --git a/dowhy/gcm/whatif.py b/dowhy/gcm/whatif.py index de8f3fcab3..c252957a1f 100644 --- a/dowhy/gcm/whatif.py +++ b/dowhy/gcm/whatif.py @@ -15,6 +15,7 @@ validate_causal_dag, ) from dowhy.gcm.fitting_sampling import draw_samples +from dowhy.gcm.util.general import set_random_seed from dowhy.graph import ( DirectedGraph, get_ordered_predecessors, @@ -29,6 +30,7 @@ def interventional_samples( interventions: Dict[Any, Callable[[np.ndarray], Union[float, np.ndarray]]], observed_data: Optional[pd.DataFrame] = None, num_samples_to_draw: Optional[int] = None, + random_seed: Optional[int] = None, ) -> pd.DataFrame: """Performs intervention on nodes in the causal graph. @@ -40,8 +42,12 @@ def interventional_samples( :param observed_data: Optionally, data on which to perform interventions. If None are given, data is generated based on the generative models. :param num_samples_to_draw: Sample size to draw from the interventional distribution. + :param random_seed: Optional seed for the random number generator to make results reproducible. :return: Samples from the interventional distribution. """ + if random_seed is not None: + set_random_seed(random_seed) + validate_causal_dag(causal_model.graph) for node in interventions: validate_node_in_graph(causal_model.graph, node) @@ -107,6 +113,7 @@ def counterfactual_samples( interventions: Dict[Any, Callable[[np.ndarray], Union[float, np.ndarray]]], observed_data: Optional[pd.DataFrame] = None, noise_data: Optional[pd.DataFrame] = None, + random_seed: Optional[int] = None, ) -> pd.DataFrame: """Estimates counterfactual data for observed data if we were to perform specified interventions. This function implements the 3-step process for computing counterfactuals by Pearl (see https://ftp.cs.ucla.edu/pub/stat_ser/r485.pdf). @@ -121,8 +128,12 @@ def counterfactual_samples( :param noise_data: Data of noise terms corresponding to nodes in the causal graph. If not provided, these have to be estimated from observed data. Then we require causal models of nodes to be invertible. + :param random_seed: Optional seed for the random number generator to make results reproducible. :return: Estimated counterfactual data. """ + if random_seed is not None: + set_random_seed(random_seed) + for node in interventions: validate_node_in_graph(causal_model.graph, node) @@ -196,6 +207,7 @@ def average_causal_effect( interventions_reference: Dict[Any, Callable[[np.ndarray], Union[float, np.ndarray]]], observed_data: Optional[pd.DataFrame] = None, num_samples_to_draw: Optional[int] = None, + random_seed: Optional[int] = None, ) -> float: """Estimates the average causal effect (ACE) on the target of two different sets of interventions. The interventions can be specified through the parameters `interventions_alternative` and `interventions_reference`. @@ -223,8 +235,11 @@ def average_causal_effect( models. :param num_samples_to_draw: Number of samples drawn from the causal model for estimating ACE if no observed data is given. + :param random_seed: Optional seed for the random number generator to make results reproducible. :return: The estimated average causal effect (ACE). """ + if random_seed is not None: + set_random_seed(random_seed) # For estimating the effect, we only need to consider the nodes that have a directed path to the target node, i.e. # all ancestors of the target. causal_model = ProbabilisticCausalModel(node_connected_subgraph_view(causal_model.graph, target_node)) diff --git a/tests/gcm/test_whatif.py b/tests/gcm/test_whatif.py index a59cb35555..3d80f9e147 100644 --- a/tests/gcm/test_whatif.py +++ b/tests/gcm/test_whatif.py @@ -278,3 +278,29 @@ def test_given_discrete_data_when_performing_interventions_then_returns_correct_ assert np.all(samples["X"].to_numpy() == -2) assert np.median(samples["Y"].to_numpy()) == -1 assert np.mean(samples["Z"].to_numpy()) == approx(-2, abs=0.05) + + +def test_interventional_samples_with_random_seed_is_reproducible(): + causal_model = ProbabilisticCausalModel(nx.DiGraph([("X", "Y")])) + causal_model.set_causal_mechanism("X", EmpiricalDistribution()) + causal_model.set_causal_mechanism("Y", AdditiveNoiseModel(prediction_model=create_linear_regressor())) + data = pd.DataFrame({"X": np.random.normal(0, 1, 500)}) + data["Y"] = 2 * data["X"] + np.random.normal(0, 0.1, 500) + fit(causal_model, data) + + result1 = interventional_samples(causal_model, {"X": lambda x: 1.0}, num_samples_to_draw=100, random_seed=42) + result2 = interventional_samples(causal_model, {"X": lambda x: 1.0}, num_samples_to_draw=100, random_seed=42) + np.testing.assert_array_equal(result1.to_numpy(), result2.to_numpy()) + + +def test_counterfactual_samples_with_random_seed_is_reproducible(): + causal_model = InvertibleStructuralCausalModel(nx.DiGraph([("X", "Y")])) + causal_model.set_causal_mechanism("X", EmpiricalDistribution()) + causal_model.set_causal_mechanism("Y", AdditiveNoiseModel(prediction_model=create_linear_regressor())) + data = pd.DataFrame({"X": np.random.normal(0, 1, 200)}) + data["Y"] = 2 * data["X"] + np.random.normal(0, 0.1, 200) + fit(causal_model, data) + + result1 = counterfactual_samples(causal_model, {"X": lambda x: 0.0}, observed_data=data, random_seed=42) + result2 = counterfactual_samples(causal_model, {"X": lambda x: 0.0}, observed_data=data, random_seed=42) + np.testing.assert_array_equal(result1.to_numpy(), result2.to_numpy())