Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,15 @@
"contributions": [
"code"
]
},
{
"login": "kvr06-ai",
"name": "Kaushik Rajan",
"avatar_url": "https://avatars.githubusercontent.com/u/182360080?v=4",
"profile": "http://kaushikrajan.me",
"contributions": [
"code"
]
}
],
"contributorsPerLine": 7,
Expand Down
5 changes: 3 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![All Contributors](https://img.shields.io/badge/all_contributors-51-orange.svg?style=flat-square)](#contributors-)
[![All Contributors](https://img.shields.io/badge/all_contributors-52-orange.svg?style=flat-square)](#contributors-)
<!-- ALL-CONTRIBUTORS-BADGE:END -->

## Contributors ✨
Expand Down Expand Up @@ -73,11 +73,12 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
<td align="center" valign="top" width="14.28%"><a href="https://github.com/emmanuel-ferdman"><img src="https://avatars.githubusercontent.com/u/35470921?v=4?s=100" width="100px;" alt="Emmanuel Ferdman"/><br /><sub><b>Emmanuel Ferdman</b></sub></a><br /><a href="https://github.com/py-why/dowhy/commits?author=emmanuel-ferdman" title="Code">💻</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/JPZ4-5"><img src="https://avatars.githubusercontent.com/u/103734362?v=4?s=100" width="100px;" alt="JPZ4-5"/><br /><sub><b>JPZ4-5</b></sub></a><br /><a href="https://github.com/py-why/dowhy/commits?author=JPZ4-5" title="Code">💻</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://www.linkedin.com/in/ghiles-meddour/"><img src="https://avatars.githubusercontent.com/u/88532760?v=4?s=100" width="100px;" alt="Ghiles Meddour"/><br /><sub><b>Ghiles Meddour</b></sub></a><br /><a href="https://github.com/py-why/dowhy/commits?author=ghilesmeddour" title="Code">💻</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/toroleapinc"><img src="https://avatars.githubusercontent.com/u/88481784?v=4?s=100" width="100px;" alt="edvatar"/><br /><sub><b>edvatar</b></sub></a><br /><a href="https://github.com/py-why/dowhy/commits?author=toroleapinc" title="Documentation">📖</a><a href="https://github.com/py-why/dowhy/commits?author=toroleapinc" title="Code">💻</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/toroleapinc"><img src="https://avatars.githubusercontent.com/u/88481784?v=4?s=100" width="100px;" alt="edvatar"/><br /><sub><b>edvatar</b></sub></a><br /><a href="https://github.com/py-why/dowhy/commits?author=toroleapinc" title="Documentation">📖</a> <a href="https://github.com/py-why/dowhy/commits?author=toroleapinc" title="Code">💻</a></td>
</tr>
<tr>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/kevinchiv"><img src="https://avatars.githubusercontent.com/u/20054278?v=4?s=100" width="100px;" alt="Kevin Chiv"/><br /><sub><b>Kevin Chiv</b></sub></a><br /><a href="https://github.com/py-why/dowhy/commits?author=kevinchiv" title="Code">💻</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/haoyu-haoyu"><img src="https://avatars.githubusercontent.com/u/85037553?v=4?s=100" width="100px;" alt="haoyu-haoyu"/><br /><sub><b>haoyu-haoyu</b></sub></a><br /><a href="https://github.com/py-why/dowhy/commits?author=haoyu-haoyu" title="Code">💻</a></td>
<td align="center" valign="top" width="14.28%"><a href="http://kaushikrajan.me"><img src="https://avatars.githubusercontent.com/u/182360080?v=4?s=100" width="100px;" alt="Kaushik Rajan"/><br /><sub><b>Kaushik Rajan</b></sub></a><br /><a href="https://github.com/py-why/dowhy/commits?author=kvr06-ai" title="Code">💻</a></td>
</tr>
</tbody>
</table>
Expand Down
15 changes: 12 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,19 @@
author = "PyWhy community"
version = os.environ.get("CURRENT_VERSION")


# Version Information (for version-switcher)
not_empty = lambda x: len(x) > 0
to_tag_obj = lambda t: {"name": t, "url": f"/dowhy/{t}/index.html"}
has_doc = lambda t: os.path.exists(f"../../dowhy-docs/{t}/index.html")
def not_empty(value):
return len(value) > 0


def to_tag_obj(tag):
return {"name": tag, "url": f"/dowhy/{tag}/index.html"}


def has_doc(tag):
return os.path.exists(f"../../dowhy-docs/{tag}/index.html")


git_tags = reversed(list(filter(not_empty, os.environ.get("TAGS").split(","))))
doc_tags = list(filter(has_doc, git_tags))
Expand Down
35 changes: 24 additions & 11 deletions dowhy/causal_estimators/distance_matching_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ def estimate_effect(
else:
raise ValueError("Target units string value not supported")

# Save the full treated/control DataFrames before any groupby loops that may rebind these names.
treated_all = treated
control_all = control

if fit_att:
# estimate ATT on treated by summing over difference between matched neighbors
if self.exact_match_cols is None:
Expand Down Expand Up @@ -242,40 +246,49 @@ def estimate_effect(
else:
grouped = updated_df.groupby(self.exact_match_cols)
att = 0
total_treated_matched = 0
self.matched_indices_att = {}
for name, group in grouped:
treated = group.loc[group[self._target_estimand.treatment_variable[0]] == 1]
control = group.loc[group[self._target_estimand.treatment_variable[0]] == 0]
if treated.shape[0] == 0:
group_treated = group.loc[group[self._target_estimand.treatment_variable[0]] == 1]
group_control = group.loc[group[self._target_estimand.treatment_variable[0]] == 0]
if group_treated.shape[0] == 0 or group_control.shape[0] == 0:
continue
control_neighbors = NearestNeighbors(
n_neighbors=self.num_matches_per_unit,
metric=self.distance_metric,
algorithm="ball_tree",
**self.distance_metric_params,
).fit(control[self._observed_common_causes.columns].values)
).fit(group_control[self._observed_common_causes.columns].values)
distances, indices = control_neighbors.kneighbors(
treated[self._observed_common_causes.columns].values
group_treated[self._observed_common_causes.columns].values
)
self.logger.debug("distances:")
self.logger.debug(distances)

for i in range(numtreatedunits):
treated_outcome = treated.iloc[i][self._target_estimand.outcome_variable[0]].item()
num_group_treated = group_treated.shape[0]
group_treated_index = group_treated.index.tolist()
for i in range(num_group_treated):
treated_outcome = group_treated.iloc[i][self._target_estimand.outcome_variable[0]].item()
control_outcome = np.mean(
control.iloc[indices[i]][self._target_estimand.outcome_variable[0]].values
group_control.iloc[indices[i]][self._target_estimand.outcome_variable[0]].values
)
att += treated_outcome - control_outcome
# self.matched_indices_att[treated_df_index[i]] = control.iloc[indices[i]].index.tolist()
matched_ctrl_idx = group_control.iloc[indices[i]].index.tolist()
self.matched_indices_att[group_treated_index[i]] = matched_ctrl_idx
total_treated_matched += num_group_treated

att /= numtreatedunits
if total_treated_matched > 0:
att /= total_treated_matched

if target_units == "att":
est = att
elif target_units == "ate":
est = att * numtreatedunits

if fit_atc:
# Now computing ATC
# Now computing ATC using the full treated/control DataFrames (not group-level subsets).
treated = treated_all
control = control_all
treated_neighbors = NearestNeighbors(
n_neighbors=self.num_matches_per_unit,
metric=self.distance_metric,
Expand Down
77 changes: 68 additions & 9 deletions dowhy/causal_estimators/linear_regression_estimator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import itertools
from typing import List, Optional, Union

import numpy as np
import pandas as pd
import scipy.stats
import statsmodels.api as sm

from dowhy.causal_estimator import CausalEstimator
Expand Down Expand Up @@ -105,16 +107,65 @@ def _build_model(self, data: pd.DataFrame):
model = sm.OLS(data[self._target_estimand.outcome_variable[0]], features).fit()
return (features, model)

def _ate_and_se_for_treatment(self, treatment_index: int):
"""Compute the unscaled ATE and its standard error for one treatment variable.

Uses the Delta method: for ATE = c'β, SE = sqrt(c' Σ c), where Σ is the
OLS parameter covariance matrix and c is the contrast vector.

The feature column order produced by ``_build_features`` (after ``sm.add_constant``) is:
[const, T_0, …, T_k, W_0, …, W_m, T_0·X_0, …, T_0·X_n, T_1·X_0, …]

:param treatment_index: 0-based index into the treatment variable list.
:returns: Tuple of (ate_unscaled, se_unscaled).
"""
n_treatments = len(self._target_estimand.treatment_variable)
# Use the actual number of encoded columns, not the number of variable names.
# Categorical variables are one-hot encoded (drop_first=True), so a variable
# with k levels produces k-1 columns — len(names) would be wrong.
n_common_causes = self._observed_common_causes.shape[1] if self._observed_common_causes is not None else 0
em_means = self._effect_modifiers.mean(axis=0).to_numpy()
n_effect_modifiers = len(em_means)

params = self.model.params.to_numpy()
cov = self.model.cov_params().to_numpy()

n_params = len(params)
expected_params = 1 + n_treatments + n_common_causes + n_treatments * n_effect_modifiers
assert n_params == expected_params, (
f"Model has {n_params} params but expected {expected_params}. "
"Column ordering assumption in _ate_and_se_for_treatment may be broken "
"(check that encoded column counts are used, not variable name counts)."
)
c = np.zeros(n_params)
# Direct treatment coefficient (offset by 1 for the intercept)
c[1 + treatment_index] = 1.0
# Interaction coefficients T_i · X_j start at:
# 1 (const) + n_treatments + n_common_causes + treatment_index * n_effect_modifiers
interaction_start = 1 + n_treatments + n_common_causes + treatment_index * n_effect_modifiers
c[interaction_start : interaction_start + n_effect_modifiers] = em_means

ate = float(c @ params)
var_ate = float(c @ cov @ c)
se = float(np.sqrt(max(var_ate, 0.0)))
return ate, se

def _estimate_confidence_intervals(self, confidence_level, method=None):
if self._effect_modifier_names:
# The average treatment effect is a combination of different
# regression coefficients. Complicated to compute the confidence
# interval analytically. For example, if y=a + b1.t + b2.tx, then
# the average treatment effect is b1+b2.mean(x).
# Refer Gelman, Hill. ARM Book. Chapter 9
# http://www.stat.columbia.edu/~gelman/arm/chap9.pdf
# TODO: Looking for contributions
raise NotImplementedError
# Use the Delta method to compute asymptotic confidence intervals for the
# ATE when effect modifiers are present. The ATE is a linear combination
# of the OLS coefficients: ATE = b_T + b_{TX_1}*E[X_1] + …
# Reference: Gelman & Hill, ARM Book, Chapter 9
n_treatments = len(self._target_estimand.treatment_variable)
scale = self._treatment_value - self._control_value
t_score = scipy.stats.t.ppf((1.0 + confidence_level) / 2.0, df=self.model.df_resid)
rows = []
for i in range(n_treatments):
ate_unscaled, se_unscaled = self._ate_and_se_for_treatment(i)
ate_scaled = scale * ate_unscaled
margin = abs(scale) * t_score * se_unscaled
rows.append([ate_scaled - margin, ate_scaled + margin])
return np.array(rows)
else:
conf_ints = self.model.conf_int(alpha=1 - confidence_level)
# For a linear regression model, the causal effect of a variable is equal to the coefficient corresponding to the
Expand All @@ -126,7 +177,15 @@ def _estimate_confidence_intervals(self, confidence_level, method=None):

def _estimate_std_error(self, method=None):
if self._effect_modifier_names:
raise NotImplementedError
# Delta method: SE(scale * ATE) = |scale| * sqrt(c' Σ c)
scale = self._treatment_value - self._control_value
ses = np.array(
[
abs(scale) * self._ate_and_se_for_treatment(i)[1]
for i in range(len(self._target_estimand.treatment_variable))
]
)
return ses
else:
std_error = self.model.bse[1 : (len(self._target_estimand.treatment_variable) + 1)]

Expand Down
15 changes: 8 additions & 7 deletions dowhy/causal_identifier/auto_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,14 +942,15 @@ def identify_generalized_adjustment_set(


def identify_mediation(graph: nx.DiGraph, action_nodes: List[str], outcome_nodes: List[str]):
"""Find a valid mediator if it exists.
"""Find all valid mediators between action and outcome nodes.

Currently only supports a single variable mediator set.
Returns a list of all variables that lie on a directed path from
action_nodes to outcome_nodes (each individually blocks at least one
such path when conditioned on).
"""
mediation_var = None
mediation_vars = []
mediation_paths = get_all_directed_paths(graph, action_nodes, outcome_nodes)
eligible_variables = get_descendants(graph, action_nodes) - set(outcome_nodes)
# For simplicity, assuming a one-variable mediation set
for candidate_var in eligible_variables:
is_valid_mediation = check_valid_mediation_set(
graph,
Expand All @@ -960,9 +961,9 @@ def identify_mediation(graph: nx.DiGraph, action_nodes: List[str], outcome_nodes
)
logger.debug("Candidate mediation set: {0}, on_mediating_path: {1}".format(candidate_var, is_valid_mediation))
if is_valid_mediation:
mediation_var = candidate_var
break
return parse_state(mediation_var)
mediation_vars.append(candidate_var)
# Sort for deterministic output — eligible_variables is a set.
return sorted(mediation_vars)


def identify_mediation_first_stage_confounders(
Expand Down
Loading