From 03331d2a20b9e5b41085d7ec8bc18584cb8134d7 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 3 Apr 2026 11:22:36 +0000 Subject: [PATCH] perf: eliminate redundant set allocations and fix per-column kernel precision MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - causal_graph.py: replace set.union()/set.difference() with in-place update()/difference_update()/add() across get_common_causes, get_effect_modifiers, get_descendants, and get_instruments; eliminates O(N) unnecessary intermediate set objects per graph traversal call. Also replace the nested-list-comprehension flatten in get_instruments with an explicit loop + update(), avoiding an intermediate list allocation. - gcm/falsify.py: _get_non_descendants now mutates the set returned by nx.descendants() (already a fresh set) instead of calling .union({node}) and .union(predecessors) which each allocate a new set; uses set subtraction operator for the final difference. - gcm/validation.py: cache get_ordered_predecessors() results in a dict before the first loop so the second loop (FDR annotation) reuses the cached lists instead of re-traversing the graph N times. - gcm/independence_test/kernel_operation.py: fix correctness bug in apply_rbf_kernel_with_adaptive_precision — was computing euclidean_distances(X) (all columns) on every iteration instead of euclidean_distances(X[:, [i]]) for the i-th column; the intent of the product-kernel formulation requires per-feature distances. This also eliminates the redundant full-distance recomputation. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: github-actions[bot] --- dowhy/causal_graph.py | 31 ++++++++----------- dowhy/gcm/falsify.py | 7 +++-- .../gcm/independence_test/kernel_operation.py | 2 +- dowhy/gcm/validation.py | 6 ++-- 4 files changed, 22 insertions(+), 24 deletions(-) diff --git a/dowhy/causal_graph.py b/dowhy/causal_graph.py index 75e59b9aeb..e4ebb5ceb3 100755 --- a/dowhy/causal_graph.py +++ b/dowhy/causal_graph.py @@ -347,39 +347,33 @@ def get_common_causes(self, nodes1, nodes2): causes_1 = set() causes_2 = set() for node in nodes1: - causes_1 = causes_1.union(self.get_ancestors(node)) + causes_1.update(self.get_ancestors(node)) for node in nodes2: # Cannot simply compute ancestors, since that will also include nodes1 and its parents (e.g. instruments) parents_2 = self.get_parents(node) for parent in parents_2: if parent not in nodes1: - causes_2 = causes_2.union( - set( - [ - parent, - ] - ) - ) - causes_2 = causes_2.union(self.get_ancestors(parent)) + causes_2.add(parent) + causes_2.update(self.get_ancestors(parent)) return list(causes_1.intersection(causes_2)) def get_effect_modifiers(self, nodes1, nodes2): # Return effect modifiers according to the graph modifiers = set() for node in nodes2: - modifiers = modifiers.union(self.get_ancestors(node)) - modifiers = modifiers.difference(nodes1) + modifiers.update(self.get_ancestors(node)) + modifiers.difference_update(nodes1) for node in nodes1: - modifiers = modifiers.difference(self.get_ancestors(node)) + modifiers.difference_update(self.get_ancestors(node)) # removing all mediators for node1 in nodes1: for node2 in nodes2: all_directed_paths = nx.all_simple_paths(self._graph, node1, node2) for path in all_directed_paths: - modifiers = modifiers.difference(path) + modifiers.difference_update(path) # Also add any effect modifiers that could not be auto-detected (e.g., they are also common causes) marked_modifiers = [n for n, ndata in self._graph.nodes(data=True) if "effectmodifier" in ndata] - modifiers = modifiers.union(marked_modifiers) + modifiers.update(marked_modifiers) return list(modifiers) def get_parents(self, node_name): @@ -395,7 +389,7 @@ def get_ancestors(self, node_name, new_graph=None): def get_descendants(self, nodes): descendants = set() for node_name in nodes: - descendants = descendants.union(set(nx.descendants(self._graph, node_name))) + descendants.update(nx.descendants(self._graph, node_name)) return descendants def all_observed(self, node_names): @@ -429,14 +423,15 @@ def get_instruments(self, treatment_nodes, outcome_nodes): g_no_parents_treatment = self.do_surgery(treatment_nodes, remove_incoming_edges=True) ancestors_outcome = set() for node in outcome_nodes: - ancestors_outcome = ancestors_outcome.union(nx.ancestors(g_no_parents_treatment, node)) + ancestors_outcome.update(nx.ancestors(g_no_parents_treatment, node)) # [TODO: double check these work with multivariate implementation:] # Exclusion candidate_instruments = parents_treatment.difference(ancestors_outcome) self.logger.debug("Candidate instruments after satisfying exclusion: %s", candidate_instruments) # As-if-random setup - children_causes_outcome = [nx.descendants(g_no_parents_treatment, v) for v in ancestors_outcome] - children_causes_outcome = set([item for sublist in children_causes_outcome for item in sublist]) + children_causes_outcome = set() + for v in ancestors_outcome: + children_causes_outcome.update(nx.descendants(g_no_parents_treatment, v)) # As-if-random instruments = candidate_instruments.difference(children_causes_outcome) diff --git a/dowhy/gcm/falsify.py b/dowhy/gcm/falsify.py index 14d17f527b..386d95d4bd 100644 --- a/dowhy/gcm/falsify.py +++ b/dowhy/gcm/falsify.py @@ -994,7 +994,8 @@ def _to_frozenset(x: Union[Set, List, str]): def _get_non_descendants(causal_graph: DirectedGraph, node: Any, exclude_parents: bool = False) -> List[Any]: - nodes_to_exclude = nx.descendants(causal_graph, node).union({node}) + nodes_to_exclude = nx.descendants(causal_graph, node) + nodes_to_exclude.add(node) if exclude_parents: - nodes_to_exclude = nodes_to_exclude.union(causal_graph.predecessors(node)) - return list(set(causal_graph.nodes).difference(nodes_to_exclude)) + nodes_to_exclude.update(causal_graph.predecessors(node)) + return list(set(causal_graph.nodes) - nodes_to_exclude) diff --git a/dowhy/gcm/independence_test/kernel_operation.py b/dowhy/gcm/independence_test/kernel_operation.py index 60279700f1..7ecf3d0000 100644 --- a/dowhy/gcm/independence_test/kernel_operation.py +++ b/dowhy/gcm/independence_test/kernel_operation.py @@ -36,7 +36,7 @@ def apply_rbf_kernel_with_adaptive_precision(X: np.ndarray) -> np.ndarray: result = np.ones((X.shape[0], X.shape[0])) for i in range(X.shape[1]): - distance_matrix = euclidean_distances(X, squared=True) + distance_matrix = euclidean_distances(X[:, [i]], squared=True) result *= np.exp(-_median_based_precision(distance_matrix) * distance_matrix) return result diff --git a/dowhy/gcm/validation.py b/dowhy/gcm/validation.py index f0fe890d65..592fbf5193 100644 --- a/dowhy/gcm/validation.py +++ b/dowhy/gcm/validation.py @@ -48,8 +48,10 @@ def refute_causal_structure( validation_summary = dict() all_p_values = [] + parents_per_node = {node: get_ordered_predecessors(causal_graph, node) for node in causal_graph.nodes} + for node in causal_graph.nodes: - parents = get_ordered_predecessors(causal_graph, node) + parents = parents_per_node[node] non_descendants = _get_non_descendants(causal_graph, node, exclude_parents=True) lmc_test_result = dict() @@ -88,7 +90,7 @@ def refute_causal_structure( is_dag_valid &= not successes[index] index += 1 - for parent in get_ordered_predecessors(causal_graph, node): + for parent in parents_per_node[node]: validation_summary[node]["edge_dependence_test"][parent]["fdr_adjusted_p_value"] = adjusted_p_values[index] validation_summary[node]["edge_dependence_test"][parent]["success"] = successes[index] is_dag_valid &= successes[index]