diff --git a/dowhy/causal_estimators/distance_matching_estimator.py b/dowhy/causal_estimators/distance_matching_estimator.py index 3d0bd70f2..c35242b78 100644 --- a/dowhy/causal_estimators/distance_matching_estimator.py +++ b/dowhy/causal_estimators/distance_matching_estimator.py @@ -217,7 +217,7 @@ def estimate_effect( n_neighbors=self.num_matches_per_unit, metric=self.distance_metric, algorithm="ball_tree", - **self.distance_metric_params, + metric_params=self.distance_metric_params, ).fit(control[self._observed_common_causes.columns].values) distances, indices = control_neighbors.kneighbors(treated[self._observed_common_causes.columns].values) self.logger.debug("distances:") @@ -257,7 +257,7 @@ def estimate_effect( n_neighbors=self.num_matches_per_unit, metric=self.distance_metric, algorithm="ball_tree", - **self.distance_metric_params, + metric_params=self.distance_metric_params, ).fit(group_control[self._observed_common_causes.columns].values) distances, indices = control_neighbors.kneighbors( group_treated[self._observed_common_causes.columns].values @@ -293,7 +293,7 @@ def estimate_effect( n_neighbors=self.num_matches_per_unit, metric=self.distance_metric, algorithm="ball_tree", - **self.distance_metric_params, + metric_params=self.distance_metric_params, ).fit(treated[self._observed_common_causes.columns].values) distances, indices = treated_neighbors.kneighbors(control[self._observed_common_causes.columns].values) diff --git a/tests/causal_estimators/test_distance_matching_estimator.py b/tests/causal_estimators/test_distance_matching_estimator.py index 7549a6a39..3a4bd97bc 100644 --- a/tests/causal_estimators/test_distance_matching_estimator.py +++ b/tests/causal_estimators/test_distance_matching_estimator.py @@ -149,3 +149,26 @@ def test_average_treatment_effect_via_simple_estimator(self): test_significance=[False], method_params={"num_simulations": 5, "num_null_simulations": 5}, ) + + def test_distance_matching_with_mahalanobis_and_v_param(self, binary_treatment_dataset): + """Regression test for issue #1390: ensure V param is correctly passed as metric_params to NearestNeighbors.""" + data = binary_treatment_dataset + model = CausalModel(data=data, treatment="v0", outcome="y", graph=GML_SINGLE_CAUSE) + estimand = model.identify_effect(proceed_when_unidentifiable=True) + + # Calculate covariance matrix for W to use as V parameter + X = data[["W"]].values + V_matrix = np.cov(X.T) + if V_matrix.ndim == 0: + V_matrix = np.array([[V_matrix]]) + + estimate = model.estimate_effect( + estimand, + method_name="backdoor.distance_matching", + target_units="att", + method_params={ + "distance_metric": "mahalanobis", + "V": V_matrix, + }, + ) + assert np.isfinite(estimate.value), "Estimate with Mahalanobis and V param should be finite."