Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions dowhy/causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,39 +347,33 @@ def get_common_causes(self, nodes1, nodes2):
causes_1 = set()
causes_2 = set()
for node in nodes1:
causes_1 = causes_1.union(self.get_ancestors(node))
causes_1.update(self.get_ancestors(node))
for node in nodes2:
# Cannot simply compute ancestors, since that will also include nodes1 and its parents (e.g. instruments)
parents_2 = self.get_parents(node)
for parent in parents_2:
if parent not in nodes1:
causes_2 = causes_2.union(
set(
[
parent,
]
)
)
causes_2 = causes_2.union(self.get_ancestors(parent))
causes_2.add(parent)
causes_2.update(self.get_ancestors(parent))
return list(causes_1.intersection(causes_2))

def get_effect_modifiers(self, nodes1, nodes2):
# Return effect modifiers according to the graph
modifiers = set()
for node in nodes2:
modifiers = modifiers.union(self.get_ancestors(node))
modifiers = modifiers.difference(nodes1)
modifiers.update(self.get_ancestors(node))
modifiers.difference_update(nodes1)
for node in nodes1:
modifiers = modifiers.difference(self.get_ancestors(node))
modifiers.difference_update(self.get_ancestors(node))
# removing all mediators
for node1 in nodes1:
for node2 in nodes2:
all_directed_paths = nx.all_simple_paths(self._graph, node1, node2)
for path in all_directed_paths:
modifiers = modifiers.difference(path)
modifiers.difference_update(path)
# Also add any effect modifiers that could not be auto-detected (e.g., they are also common causes)
marked_modifiers = [n for n, ndata in self._graph.nodes(data=True) if "effectmodifier" in ndata]
modifiers = modifiers.union(marked_modifiers)
modifiers.update(marked_modifiers)
return list(modifiers)

def get_parents(self, node_name):
Expand All @@ -395,7 +389,7 @@ def get_ancestors(self, node_name, new_graph=None):
def get_descendants(self, nodes):
descendants = set()
for node_name in nodes:
descendants = descendants.union(set(nx.descendants(self._graph, node_name)))
descendants.update(nx.descendants(self._graph, node_name))
return descendants

def all_observed(self, node_names):
Expand Down Expand Up @@ -429,14 +423,15 @@ def get_instruments(self, treatment_nodes, outcome_nodes):
g_no_parents_treatment = self.do_surgery(treatment_nodes, remove_incoming_edges=True)
ancestors_outcome = set()
for node in outcome_nodes:
ancestors_outcome = ancestors_outcome.union(nx.ancestors(g_no_parents_treatment, node))
ancestors_outcome.update(nx.ancestors(g_no_parents_treatment, node))
# [TODO: double check these work with multivariate implementation:]
# Exclusion
candidate_instruments = parents_treatment.difference(ancestors_outcome)
self.logger.debug("Candidate instruments after satisfying exclusion: %s", candidate_instruments)
# As-if-random setup
children_causes_outcome = [nx.descendants(g_no_parents_treatment, v) for v in ancestors_outcome]
children_causes_outcome = set([item for sublist in children_causes_outcome for item in sublist])
children_causes_outcome = set()
for v in ancestors_outcome:
children_causes_outcome.update(nx.descendants(g_no_parents_treatment, v))

# As-if-random
instruments = candidate_instruments.difference(children_causes_outcome)
Expand Down
7 changes: 4 additions & 3 deletions dowhy/gcm/falsify.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,8 @@ def _to_frozenset(x: Union[Set, List, str]):


def _get_non_descendants(causal_graph: DirectedGraph, node: Any, exclude_parents: bool = False) -> List[Any]:
nodes_to_exclude = nx.descendants(causal_graph, node).union({node})
nodes_to_exclude = nx.descendants(causal_graph, node)
nodes_to_exclude.add(node)
if exclude_parents:
nodes_to_exclude = nodes_to_exclude.union(causal_graph.predecessors(node))
return list(set(causal_graph.nodes).difference(nodes_to_exclude))
nodes_to_exclude.update(causal_graph.predecessors(node))
return list(set(causal_graph.nodes) - nodes_to_exclude)
2 changes: 1 addition & 1 deletion dowhy/gcm/independence_test/kernel_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def apply_rbf_kernel_with_adaptive_precision(X: np.ndarray) -> np.ndarray:

result = np.ones((X.shape[0], X.shape[0]))
for i in range(X.shape[1]):
distance_matrix = euclidean_distances(X, squared=True)
distance_matrix = euclidean_distances(X[:, [i]], squared=True)
result *= np.exp(-_median_based_precision(distance_matrix) * distance_matrix)

return result
Expand Down
6 changes: 4 additions & 2 deletions dowhy/gcm/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ def refute_causal_structure(
validation_summary = dict()
all_p_values = []

parents_per_node = {node: get_ordered_predecessors(causal_graph, node) for node in causal_graph.nodes}

for node in causal_graph.nodes:
parents = get_ordered_predecessors(causal_graph, node)
parents = parents_per_node[node]
non_descendants = _get_non_descendants(causal_graph, node, exclude_parents=True)

lmc_test_result = dict()
Expand Down Expand Up @@ -88,7 +90,7 @@ def refute_causal_structure(
is_dag_valid &= not successes[index]
index += 1

for parent in get_ordered_predecessors(causal_graph, node):
for parent in parents_per_node[node]:
validation_summary[node]["edge_dependence_test"][parent]["fdr_adjusted_p_value"] = adjusted_p_values[index]
validation_summary[node]["edge_dependence_test"][parent]["success"] = successes[index]
is_dag_valid &= successes[index]
Expand Down
Loading