Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions dowhy/causal_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
confidence_level: float = DEFAULT_CONFIDENCE_LEVEL,
need_conditional_estimates: Union[bool, str] = "auto",
num_quantiles_to_discretize_cont_cols: int = NUM_QUANTILES_TO_DISCRETIZE_CONT_COLS,
random_state: Optional[Union[int, np.random.RandomState]] = None,
**_,
):
"""Initializes an estimator with data and names of relevant variables.
Expand All @@ -87,6 +88,9 @@ def __init__(
:param num_quantiles_to_discretize_cont_cols: The number of quantiles
into which a numeric effect modifier is split, to enable
estimation of conditional treatment effect over it.
:param random_state: Seed or numpy RandomState used to make the
bootstrap confidence intervals and significance tests reproducible.
If None (default), results vary between runs.
:param kwargs: (optional) Additional estimator-specific parameters
:returns: an instance of the estimator class.
"""
Expand All @@ -113,8 +117,21 @@ def __init__(
self._bootstrap_estimates = None
self._bootstrap_null_estimates = None

self._random_state = random_state

self._encoders = Encoders()

def _get_random_state(self):
"""Return a numpy RandomState built from the estimator's random_state.

A RandomState instance is returned as-is; an int (or None) is used to
seed a new RandomState. None preserves the previous non-deterministic
behavior.
"""
if isinstance(self._random_state, np.random.RandomState):
return self._random_state
return np.random.RandomState(self._random_state)

def __getstate__(self):
"""Return picklable state, excluding the non-picklable logger (Python < 3.12)."""
state = self.__dict__.copy()
Expand Down Expand Up @@ -336,9 +353,11 @@ def _generate_bootstrap_estimates(self, data: pd.DataFrame, num_bootstrap_simula
self.logger.info("INFO: The sample size: {}".format(sample_size))
self.logger.info("INFO: The number of simulations: {}".format(num_bootstrap_simulations))

random_state = self._get_random_state()

# Perform the set number of simulations
for index in range(num_bootstrap_simulations):
new_data = resample(data, n_samples=sample_size)
new_data = resample(data, n_samples=sample_size, random_state=random_state)
new_estimator = self.get_new_estimator_object(
self._target_estimand,
# names of treatment and outcome
Expand Down Expand Up @@ -547,8 +566,9 @@ def _test_significance_with_bootstrap(self, data: pd.DataFrame, estimate_value,
null_estimates = np.zeros(num_null_simulations)
new_estimand = copy.deepcopy(self._target_estimand)
new_estimand.outcome_variable = ["dummy_outcome"]
random_state = self._get_random_state()
for i in range(num_null_simulations):
new_outcome = np.random.permutation(data[self._target_estimand.outcome_variable])
new_outcome = random_state.permutation(data[self._target_estimand.outcome_variable])
new_data = data.assign(dummy_outcome=new_outcome)
new_estimator = self.get_new_estimator_object(
new_estimand,
Expand Down
47 changes: 47 additions & 0 deletions tests/causal_estimators/test_bootstrap_reproducibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
import pandas as pd

from dowhy import CausalModel


def _make_data(n=400):
rng = np.random.RandomState(0)
w = rng.normal(size=n)
v = (rng.uniform(size=n) < 1 / (1 + np.exp(-w))).astype(int)
y = 2 * v + w + rng.normal(size=n)
return pd.DataFrame({"v0": v, "W0": w, "y": y})


def _estimate(random_state):
data = _make_data()
model = CausalModel(data=data, treatment="v0", outcome="y", common_causes=["W0"])
identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
return model.estimate_effect(
identified_estimand,
method_name="backdoor.propensity_score_weighting",
test_significance=True,
confidence_intervals=True,
method_params={
"init_params": {"random_state": random_state},
"num_null_simulations": 20,
"num_simulations": 50,
},
)


def test_bootstrap_confidence_intervals_are_reproducible_with_random_state():
first = _estimate(random_state=42).get_confidence_intervals()
second = _estimate(random_state=42).get_confidence_intervals()
assert np.allclose(first, second)


def test_bootstrap_significance_is_reproducible_with_random_state():
first = _estimate(random_state=42).test_stat_significance()["p_value"]
second = _estimate(random_state=42).test_stat_significance()["p_value"]
assert np.allclose(first, second)


def test_bootstrap_confidence_intervals_differ_across_random_states():
first = _estimate(random_state=42).get_confidence_intervals()
second = _estimate(random_state=7).get_confidence_intervals()
assert not np.allclose(first, second)
Loading