diff --git a/tests/test_causal_model.py b/tests/test_causal_model.py index 2f97075b1..d529b6b13 100644 --- a/tests/test_causal_model.py +++ b/tests/test_causal_model.py @@ -698,6 +698,57 @@ def test_causal_estimator_cache(self): assert (estimates[1].estimator) == model.get_estimator(methods[1]) assert (estimates[0].estimator) != model.get_estimator(methods[1]) # check not same object + def test_fit_estimator_false_reuses_cached_estimator(self): + """Test that fit_estimator=False reuses the cached estimator without refitting. + + After a first call with fit_estimator=True (the default), a subsequent call + with fit_estimator=False must: + - return the same estimator object from the cache (no re-instantiation) + - produce an identical estimate value for deterministic estimators + """ + data = dowhy.datasets.linear_dataset( + beta=10, + num_common_causes=3, + num_samples=500, + num_treatments=1, + treatment_is_binary=True, + ) + model = CausalModel( + data=data["df"], + treatment=data["treatment_name"], + outcome=data["outcome_name"], + graph=data["gml_graph"], + proceed_when_unidentifiable=True, + test_significance=None, + ) + identified_estimand = model.identify_effect(proceed_when_unidentifiable=True) + method = "backdoor.linear_regression" + + # First call: fit the estimator and cache it. + estimate1 = model.estimate_effect( + identified_estimand, + method_name=method, + control_value=0, + treatment_value=1, + ) + estimator_after_first_call = model.get_estimator(method) + + # Second call with fit_estimator=False must reuse the cached estimator. + estimate2 = model.estimate_effect( + identified_estimand, + method_name=method, + control_value=0, + treatment_value=1, + fit_estimator=False, + ) + estimator_after_second_call = model.get_estimator(method) + + # Same object identity — no new estimator was created. + assert estimator_after_first_call is estimator_after_second_call + + # Linear regression is deterministic: both calls must yield the same estimate. + assert estimate1.value == pytest.approx(estimate2.value) + if __name__ == "__main__": pytest.main([__file__])