Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions tests/test_causal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,57 @@ def test_causal_estimator_cache(self):
assert (estimates[1].estimator) == model.get_estimator(methods[1])
assert (estimates[0].estimator) != model.get_estimator(methods[1]) # check not same object

def test_fit_estimator_false_reuses_cached_estimator(self):
"""Test that fit_estimator=False reuses the cached estimator without refitting.

After a first call with fit_estimator=True (the default), a subsequent call
with fit_estimator=False must:
- return the same estimator object from the cache (no re-instantiation)
- produce an identical estimate value for deterministic estimators
"""
data = dowhy.datasets.linear_dataset(
beta=10,
num_common_causes=3,
num_samples=500,
num_treatments=1,
treatment_is_binary=True,
)
model = CausalModel(
data=data["df"],
treatment=data["treatment_name"],
outcome=data["outcome_name"],
graph=data["gml_graph"],
proceed_when_unidentifiable=True,
test_significance=None,
)
identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)
method = "backdoor.linear_regression"

# First call: fit the estimator and cache it.
estimate1 = model.estimate_effect(
identified_estimand,
method_name=method,
control_value=0,
treatment_value=1,
)
estimator_after_first_call = model.get_estimator(method)

# Second call with fit_estimator=False must reuse the cached estimator.
estimate2 = model.estimate_effect(
identified_estimand,
method_name=method,
control_value=0,
treatment_value=1,
fit_estimator=False,
)
Comment on lines +736 to +743
estimator_after_second_call = model.get_estimator(method)

# Same object identity — no new estimator was created.
assert estimator_after_first_call is estimator_after_second_call

# Linear regression is deterministic: both calls must yield the same estimate.
assert estimate1.value == pytest.approx(estimate2.value)


if __name__ == "__main__":
pytest.main([__file__])