Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions dowhy/gcm/fitting_sampling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""This module provides functionality for fitting probabilistic causal models and drawing samples from them."""

from typing import Any
from typing import Any, Dict

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -91,9 +91,11 @@ def fit_causal_model_of_target(

if is_root_node(causal_model.graph, target_node):
causal_model.causal_mechanism(target_node).fit(X=training_data[target_node].to_numpy()[~y_nan_mask])
ordered_predecessors: list = []
else:
ordered_predecessors = get_ordered_predecessors(causal_model.graph, target_node)
causal_model.causal_mechanism(target_node).fit(
X=training_data[get_ordered_predecessors(causal_model.graph, target_node)].to_numpy()[~y_nan_mask],
X=training_data[ordered_predecessors].to_numpy()[~y_nan_mask],
Y=training_data[target_node].to_numpy()[~y_nan_mask],
)

Expand All @@ -102,9 +104,7 @@ def fit_causal_model_of_target(
# this would automatically fail when the number of parents is different, there are other more subtle cases,
# where the number is still the same, but it's different parents, and therefore different data. That would yield
# wrong results, but would not fail.
causal_model.graph.nodes[target_node][PARENTS_DURING_FIT] = get_ordered_predecessors(
causal_model.graph, target_node
)
causal_model.graph.nodes[target_node][PARENTS_DURING_FIT] = ordered_predecessors


def draw_samples(causal_model: ProbabilisticCausalModel, num_samples: int) -> pd.DataFrame:
Expand All @@ -118,20 +118,16 @@ def draw_samples(causal_model: ProbabilisticCausalModel, num_samples: int) -> pd
validate_causal_dag(causal_model.graph)

sorted_nodes = list(nx.topological_sort(causal_model.graph))
drawn_samples = pd.DataFrame(np.empty((num_samples, len(sorted_nodes))), columns=sorted_nodes)
drawn_samples: Dict[Any, np.ndarray] = {}

for node in sorted_nodes:
causal_mechanism = causal_model.causal_mechanism(node)

if is_root_node(causal_model.graph, node):
drawn_samples[node] = causal_mechanism.draw_samples(num_samples).squeeze()
else:
drawn_samples[node] = causal_mechanism.draw_samples(
_parent_samples_of(node, causal_model, drawn_samples)
).squeeze()

return drawn_samples

predecessors = get_ordered_predecessors(causal_model.graph, node)
parent_data = np.column_stack([drawn_samples[p] for p in predecessors])
drawn_samples[node] = causal_mechanism.draw_samples(parent_data).squeeze()

def _parent_samples_of(node: Any, scm: ProbabilisticCausalModel, samples: pd.DataFrame) -> np.ndarray:
return samples[get_ordered_predecessors(scm.graph, node)].to_numpy()
return pd.DataFrame(drawn_samples, columns=sorted_nodes)