diff --git a/dowhy/gcm/fitting_sampling.py b/dowhy/gcm/fitting_sampling.py index e3dbe4bd7..8303df059 100644 --- a/dowhy/gcm/fitting_sampling.py +++ b/dowhy/gcm/fitting_sampling.py @@ -1,6 +1,6 @@ """This module provides functionality for fitting probabilistic causal models and drawing samples from them.""" -from typing import Any +from typing import Any, Dict import networkx as nx import numpy as np @@ -91,9 +91,11 @@ def fit_causal_model_of_target( if is_root_node(causal_model.graph, target_node): causal_model.causal_mechanism(target_node).fit(X=training_data[target_node].to_numpy()[~y_nan_mask]) + ordered_predecessors: list = [] else: + ordered_predecessors = get_ordered_predecessors(causal_model.graph, target_node) causal_model.causal_mechanism(target_node).fit( - X=training_data[get_ordered_predecessors(causal_model.graph, target_node)].to_numpy()[~y_nan_mask], + X=training_data[ordered_predecessors].to_numpy()[~y_nan_mask], Y=training_data[target_node].to_numpy()[~y_nan_mask], ) @@ -102,9 +104,7 @@ def fit_causal_model_of_target( # this would automatically fail when the number of parents is different, there are other more subtle cases, # where the number is still the same, but it's different parents, and therefore different data. That would yield # wrong results, but would not fail. - causal_model.graph.nodes[target_node][PARENTS_DURING_FIT] = get_ordered_predecessors( - causal_model.graph, target_node - ) + causal_model.graph.nodes[target_node][PARENTS_DURING_FIT] = ordered_predecessors def draw_samples(causal_model: ProbabilisticCausalModel, num_samples: int) -> pd.DataFrame: @@ -118,7 +118,7 @@ def draw_samples(causal_model: ProbabilisticCausalModel, num_samples: int) -> pd validate_causal_dag(causal_model.graph) sorted_nodes = list(nx.topological_sort(causal_model.graph)) - drawn_samples = pd.DataFrame(np.empty((num_samples, len(sorted_nodes))), columns=sorted_nodes) + drawn_samples: Dict[Any, np.ndarray] = {} for node in sorted_nodes: causal_mechanism = causal_model.causal_mechanism(node) @@ -126,12 +126,8 @@ def draw_samples(causal_model: ProbabilisticCausalModel, num_samples: int) -> pd if is_root_node(causal_model.graph, node): drawn_samples[node] = causal_mechanism.draw_samples(num_samples).squeeze() else: - drawn_samples[node] = causal_mechanism.draw_samples( - _parent_samples_of(node, causal_model, drawn_samples) - ).squeeze() - - return drawn_samples - + predecessors = get_ordered_predecessors(causal_model.graph, node) + parent_data = np.column_stack([drawn_samples[p] for p in predecessors]) + drawn_samples[node] = causal_mechanism.draw_samples(parent_data).squeeze() -def _parent_samples_of(node: Any, scm: ProbabilisticCausalModel, samples: pd.DataFrame) -> np.ndarray: - return samples[get_ordered_predecessors(scm.graph, node)].to_numpy() + return pd.DataFrame(drawn_samples, columns=sorted_nodes)