From 4ce6a9bd8eda7a5d99eb213d63d663001d3c2a0f Mon Sep 17 00:00:00 2001 From: JPZ4-5 Date: Thu, 27 Nov 2025 21:51:56 +0800 Subject: [PATCH 1/7] fix: optimize MMD calculation and fix tensor grouping bug Signed-off-by: JPZ4-5 --- .../algorithms/regularization.py | 363 ++++++++---------- 1 file changed, 162 insertions(+), 201 deletions(-) diff --git a/dowhy/causal_prediction/algorithms/regularization.py b/dowhy/causal_prediction/algorithms/regularization.py index 909d571719..448448ebdc 100644 --- a/dowhy/causal_prediction/algorithms/regularization.py +++ b/dowhy/causal_prediction/algorithms/regularization.py @@ -1,14 +1,13 @@ -import numpy as np import torch -from torch import nn -from torch.nn import functional as F +from torch import tensor -from dowhy.causal_prediction.algorithms.utils import mmd_compute +from dowhy.causal_prediction.algorithms.utils import gaussian_kernel, mmd_compute class Regularizer: """ Implements methods for applying unconditional and conditional regularization. + Optimized for GPU throughput by minimizing Python control flow. """ def __init__( @@ -33,11 +32,65 @@ def __init__( def mmd(self, x, y): """ Compute MMD penalty between x and y. - """ return mmd_compute(x, y, self.kernel_type, self.gamma) - def unconditional_reg(self, classifs, attribute_labels, num_envs, E_eq_A=False): + def _optimized_mmd_penalty(self, list_of_feature_tensors: list): + """ + Computes the sum of MMD penalties between all pairs in a list of tensors. + This optimized version handles both 'gaussian' and other kernel types efficiently, + avoiding redundant computations while maintaining backward compatibility. + + :param list_of_feature_tensors: A list where each element is a feature tensor + corresponding to a unique attribute value. + :return: The total MMD penalty as a scalar tensor. + """ + valid_tensors = [t.double() for t in list_of_feature_tensors if t.shape[0] > 0] + k = len(valid_tensors) + + if k <= 1: + return 0.0 + + original_dtype = list_of_feature_tensors[0].dtype + original_device = list_of_feature_tensors[0].device + + if self.kernel_type == "gaussian": + sum_K_ii = sum(gaussian_kernel(t, t, self.gamma).mean() for t in valid_tensors) + sum_K_ij = 0 + for i in range(k): + for j in range(i + 1, k): + sum_K_ij += gaussian_kernel(valid_tensors[i], valid_tensors[j], self.gamma).mean() + penalty = (k - 1) * sum_K_ii - 2 * sum_K_ij + + else: + means = [t.mean(0, keepdim=True) for t in valid_tensors] + cents = [t - m for t, m in zip(valid_tensors, means)] + + covas = [] + for t, cent in zip(valid_tensors, cents): + n = t.shape[0] + if n > 1: + covas.append((cent.t() @ cent) / (n - 1)) + else: + covas.append(torch.zeros_like(means[0].diag())) + + penalty = torch.tensor(0.0, device=original_device, dtype=torch.float64) + for i in range(k): + for j in range(i + 1, k): + mean_diff = (means[i] - means[j]).pow(2).mean() + cova_diff = (covas[i] - covas[j]).pow(2).mean() + penalty += mean_diff + cova_diff + + return penalty.to(dtype=original_dtype) + + def _split_by_attribute(self, features, labels): + unique_labels = torch.unique(labels) + if len(unique_labels) < 2: + return [features] + + return [features[labels == label] for label in unique_labels] + + def unconditional_reg(self, classifs, attribute_labels, num_envs, E_eq_A=False, use_optimization=False): """ Implement unconditional regularization φ(x) ⊥⊥ A_i @@ -45,58 +98,54 @@ def unconditional_reg(self, classifs, attribute_labels, num_envs, E_eq_A=False): :param attribute_labels: attribute labels loaded with the dataset for attribute A_i :param num_envs: number of environments/domains :param E_eq_A: Binary flag indicating whether attribute (A_i) coinicides with environment (E) definition + :param use_optimization: If True, uses an algebraically optimized method to compute the penalty. + If False, uses the original nested loop, which is more extensible for new MMD """ - penalty = 0 + penalty = tensor(0.0, dtype=classifs[0].dtype, device=classifs[0].device) if E_eq_A: # Environment (E) and attribute (A) coincide if self.E_conditioned is False: # there is no correlation between E and X_c - for i in range(num_envs): - for j in range(i + 1, num_envs): - penalty += self.mmd(classifs[i], classifs[j]) + if use_optimization: + penalty += self._optimized_mmd_penalty(classifs) + else: + for i in range(num_envs): + for j in range(i + 1, num_envs): + penalty += self.mmd(classifs[i], classifs[j]) else: if self.E_conditioned: for i in range(num_envs): - unique_attr_labels = torch.unique(attribute_labels[i]) - unique_attr_label_indices = [] - for label in unique_attr_labels: - label_ind = [ind for ind, j in enumerate(attribute_labels[i]) if j == label] - unique_attr_label_indices.append(label_ind) - - nulabels = unique_attr_labels.shape[0] - for aidx in range(nulabels): - for bidx in range(aidx + 1, nulabels): - penalty += self.mmd( - classifs[i][unique_attr_label_indices[aidx]], - classifs[i][unique_attr_label_indices[bidx]], - ) - - else: # this currently assumes we have a disjoint set of attributes (Aind) across environments i.e., environment is defined by multiple closely related values of the attribute - overall_nmb_indices, nmb_id = [], [] - for i in range(num_envs): - unique_attrs = torch.unique(attribute_labels[i]) - unique_attr_indices = [] - for attr in unique_attrs: - attr_ind = [ind for ind, j in enumerate(attribute_labels[i]) if j == attr] - unique_attr_indices.append(attr_ind) - overall_nmb_indices.append(attr_ind) - nmb_id.append(i) - - nuattr = len(overall_nmb_indices) - for aidx in range(nuattr): - for bidx in range(aidx + 1, nuattr): - a_nmb_id = nmb_id[aidx] - b_nmb_id = nmb_id[bidx] - penalty += self.mmd( - classifs[a_nmb_id][overall_nmb_indices[aidx]], - classifs[b_nmb_id][overall_nmb_indices[bidx]], - ) + tensors_list = self._split_by_attribute(classifs[i], attribute_labels[i]) + + if use_optimization: + penalty += self._optimized_mmd_penalty(tensors_list) + else: + k = len(tensors_list) + for aidx in range(k): + for bidx in range(aidx + 1, k): + penalty += self.mmd(tensors_list[aidx], tensors_list[bidx]) + + else: + all_features = torch.cat(classifs, dim=0) + all_labels = torch.cat(attribute_labels, dim=0) + + tensors_list = self._split_by_attribute(all_features, all_labels) + + if use_optimization: + penalty += self._optimized_mmd_penalty(tensors_list) + else: + k = len(tensors_list) + for aidx in range(k): + for bidx in range(aidx + 1, k): + penalty += self.mmd(tensors_list[aidx], tensors_list[bidx]) return penalty - def conditional_reg(self, classifs, attribute_labels, conditioning_subset, num_envs, E_eq_A=False): + def conditional_reg( + self, classifs, attribute_labels, conditioning_subset, num_envs, E_eq_A=False, use_optimization=False + ): """ Implement conditional regularization φ(x) ⊥⊥ A_i | A_s @@ -105,6 +154,8 @@ def conditional_reg(self, classifs, attribute_labels, conditioning_subset, num_e :param conditioning_subset: list of subset of observed variables A_s (attributes + targets) such that (X_c, A_i) are d-separated conditioned on this subset :param num_envs: number of environments/domains :param E_eq_A: Binary flag indicating whether attribute (A_i) coinicides with environment (E) definition + :param use_optimization: If True, uses an algebraically optimized method to compute the penalty. + If False, uses the original nested loop, which is more extensible for new MMD Find group indices for conditional regularization based on conditioning subset by taking all possible combinations e.g., conditioning_subset = [A1, Y], where A1 is in {0, 1} and Y is in {0, 1, 2}, @@ -125,173 +176,83 @@ def conditional_reg(self, classifs, attribute_labels, conditioning_subset, num_e }` """ - - penalty = 0 + penalty = tensor(0.0, dtype=classifs[0].dtype, device=classifs[0].device) if E_eq_A: # Environment (E) and attribute (A) coincide - if self.E_conditioned is False: # there is no correlation between E and X_c - overall_group_vindices = {} # storing group indices - overall_group_eindices = {} # storing corresponding environment indices + if self.E_conditioned is False: + all_feats = [] + all_groups = [] + all_attrs = [] for i in range(num_envs): - conditioning_subset_i = [subset_var[i] for subset_var in conditioning_subset] - conditioning_subset_i_uniform = [ - ele.unsqueeze(1) if ele.dim() == 1 else ele for ele in conditioning_subset_i - ] - grouping_data = torch.cat(conditioning_subset_i_uniform, 1) - assert grouping_data.min() >= 0, "Group numbers cannot be negative." - cardinality = 1 + torch.max(grouping_data, dim=0)[0] - cumprod = torch.cumprod(cardinality, dim=0) - n_groups = cumprod[-1].item() - factors = torch.cat((torch.tensor([1], dtype=cumprod.dtype, device=cumprod.device), cumprod[:-1])) - group_indices = (grouping_data.float() @ factors.float()).long() - - for group_idx in range(n_groups): - group_idx_indices = [ - gp_idx for gp_idx in range(len(group_indices)) if group_indices[gp_idx] == group_idx - ] - - if group_idx not in overall_group_vindices: - overall_group_vindices[group_idx] = {} - overall_group_eindices[group_idx] = {} - - unique_attrs = torch.unique( - attribute_labels[i][group_idx_indices] - ) # find distinct attributes in environment with same group_idx_indices - unique_attr_indices = [] - for attr in unique_attrs: # storing indices with same attribute value and group label - if attr not in overall_group_vindices[group_idx]: - overall_group_vindices[group_idx][attr] = [] - overall_group_eindices[group_idx][attr] = [] - single_attr = [] - for group_idx_indices_attr in group_idx_indices: - if attribute_labels[i][group_idx_indices_attr] == attr: - single_attr.append(group_idx_indices_attr) - overall_group_vindices[group_idx][attr].append(single_attr) - overall_group_eindices[group_idx][attr].append(i) - unique_attr_indices.append(single_attr) - - for ( - group_label - ) in ( - overall_group_vindices - ): # applying MMD penalty between distributions P(φ(x)|ai, g), P(φ(x)|aj, g) i.e samples with different attribute labelues but same group label - tensors_list = [] - for attr in overall_group_vindices[group_label]: - attrs_list = [] - if overall_group_vindices[group_label][attr] != []: - for il_ind, indices_list in enumerate(overall_group_vindices[group_label][attr]): - attrs_list.append( - classifs[overall_group_eindices[group_label][attr][il_ind]][indices_list] - ) - if len(attrs_list) > 0: - tensor_attrs = torch.cat(attrs_list, 0) - tensors_list.append(tensor_attrs) - - nuattr = len(tensors_list) - for aidx in range(nuattr): - for bidx in range(aidx + 1, nuattr): - penalty += self.mmd(tensors_list[aidx], tensors_list[bidx]) + cond_subset_i = [subset_var[i] for subset_var in conditioning_subset] + cond_subset_i_uniform = [ele.unsqueeze(1) if ele.dim() == 1 else ele for ele in cond_subset_i] + if cond_subset_i_uniform: + group_data_i = torch.cat(cond_subset_i_uniform, 1) + else: + group_data_i = torch.zeros((classifs[i].shape[0], 1), device=classifs[i].device) + + all_feats.append(classifs[i]) + all_groups.append(group_data_i) + all_attrs.append(torch.full((classifs[i].shape[0],), i, device=classifs[i].device)) + + total_feats = torch.cat(all_feats, 0) + total_groups = torch.cat(all_groups, 0) + total_attrs = torch.cat(all_attrs, 0) + + penalty += self._compute_conditional_penalty(total_feats, total_attrs, total_groups, use_optimization) else: if self.E_conditioned: for i in range(num_envs): - conditioning_subset_i = [subset_var[i] for subset_var in conditioning_subset] - conditioning_subset_i_uniform = [ - ele.unsqueeze(1) if ele.dim() == 1 else ele for ele in conditioning_subset_i - ] - grouping_data = torch.cat(conditioning_subset_i_uniform, 1) - assert grouping_data.min() >= 0, "Group numbers cannot be negative." - cardinality = 1 + torch.max(grouping_data, dim=0)[0] - cumprod = torch.cumprod(cardinality, dim=0) - n_groups = cumprod[-1].item() - factors = torch.cat((torch.tensor([1], dtype=cumprod.dtype, device=cumprod.device), cumprod[:-1])) - group_indices = (grouping_data.float() @ factors.float()).long() - - for group_idx in range(n_groups): - group_idx_indices = [ - gp_idx for gp_idx in range(len(group_indices)) if group_indices[gp_idx] == group_idx - ] - unique_attrs = torch.unique( - attribute_labels[i][group_idx_indices] - ) # find distinct attributes in environment with same group_idx_indices - unique_attr_indices = [] - for attr in unique_attrs: - single_attr = [] - for group_idx_indices_attr in group_idx_indices: - if attribute_labels[i][group_idx_indices_attr] == attr: - single_attr.append(group_idx_indices_attr) - unique_attr_indices.append(single_attr) - - nuattr = unique_attrs.shape[0] - for aidx in range(nuattr): - for bidx in range(aidx + 1, nuattr): - penalty += self.mmd( - classifs[i][unique_attr_indices[aidx]], classifs[i][unique_attr_indices[bidx]] - ) + cond_subset_i = [subset_var[i] for subset_var in conditioning_subset] + cond_subset_i_uniform = [ele.unsqueeze(1) if ele.dim() == 1 else ele for ele in cond_subset_i] + if cond_subset_i_uniform: + group_data = torch.cat(cond_subset_i_uniform, 1) + else: + group_data = torch.zeros((classifs[i].shape[0], 1), device=classifs[i].device) + + penalty += self._compute_conditional_penalty( + classifs[i], attribute_labels[i], group_data, use_optimization + ) else: - overall_group_vindices = {} # storing group indices - overall_group_eindices = {} # storing corresponding environment indices + all_feats = torch.cat(classifs, 0) + all_attrs = torch.cat(attribute_labels, 0) - for i in range(num_envs): - conditioning_subset_i = [subset_var[i] for subset_var in conditioning_subset] - conditioning_subset_i_uniform = [ - ele.unsqueeze(1) if ele.dim() == 1 else ele for ele in conditioning_subset_i - ] - grouping_data = torch.cat(conditioning_subset_i_uniform, 1) - assert grouping_data.min() >= 0, "Group numbers cannot be negative." - cardinality = 1 + torch.max(grouping_data, dim=0)[0] - cumprod = torch.cumprod(cardinality, dim=0) - n_groups = cumprod[-1].item() - factors = torch.cat((torch.tensor([1], dtype=cumprod.dtype, device=cumprod.device), cumprod[:-1])) - group_indices = (grouping_data.float() @ factors.float()).long() - - for group_idx in range(n_groups): - group_idx_indices = [ - gp_idx for gp_idx in range(len(group_indices)) if group_indices[gp_idx] == group_idx - ] - - if group_idx not in overall_group_vindices: - overall_group_vindices[group_idx] = {} - overall_group_eindices[group_idx] = {} - - unique_attrs = torch.unique( - attribute_labels[i][group_idx_indices] - ) # find distinct attributes in environment with same group_idx_indices - unique_attr_indices = [] - for attr in unique_attrs: # storing indices with same attribute value and group label - if attr not in overall_group_vindices[group_idx]: - overall_group_vindices[group_idx][attr] = [] - overall_group_eindices[group_idx][attr] = [] - single_attr = [] - for group_idx_indices_attr in group_idx_indices: - if attribute_labels[i][group_idx_indices_attr] == attr: - single_attr.append(group_idx_indices_attr) - overall_group_vindices[group_idx][attr].append(single_attr) - overall_group_eindices[group_idx][attr].append(i) - unique_attr_indices.append(single_attr) - - for ( - group_label - ) in ( - overall_group_vindices - ): # applying MMD penalty between distributions P(φ(x)|ai, g), P(φ(x)|aj, g) i.e samples with different attribute labelues but same group label - tensors_list = [] - for attr in overall_group_vindices[group_label]: - attrs_list = [] - if overall_group_vindices[group_label][attr] != []: - for il_ind, indices_list in enumerate(overall_group_vindices[group_label][attr]): - attrs_list.append( - classifs[overall_group_eindices[group_label][attr][il_ind]][indices_list] - ) - if len(attrs_list) > 0: - tensor_attrs = torch.cat(attrs_list, 0) - tensors_list.append(tensor_attrs) - - nuattr = len(tensors_list) - for aidx in range(nuattr): - for bidx in range(aidx + 1, nuattr): + all_cond_vars = [] + for var_list in conditioning_subset: + all_cond_vars.append(torch.cat(var_list, 0)) + + cond_uniform = [ele.unsqueeze(1) if ele.dim() == 1 else ele for ele in all_cond_vars] + if cond_uniform: + total_groups = torch.cat(cond_uniform, 1) + else: + total_groups = torch.zeros((all_feats.shape[0], 1), device=all_feats.device) + + penalty += self._compute_conditional_penalty(all_feats, all_attrs, total_groups, use_optimization) + + return penalty + + def _compute_conditional_penalty(self, features, attributes, group_data, use_optimization): + penalty = tensor(0.0, dtype=features[0].dtype, device=features[0].device) + + unique_groups, group_indices = torch.unique(group_data, dim=0, return_inverse=True) + present_group_ids = torch.unique(group_indices) + + for gid in present_group_ids: + mask = group_indices == gid + group_feats = features[mask] + group_attrs = attributes[mask] + tensors_list = self._split_by_attribute(group_feats, group_attrs) + + if len(tensors_list) > 1: + if use_optimization: + penalty += self._optimized_mmd_penalty(tensors_list) + else: + k = len(tensors_list) + for aidx in range(k): + for bidx in range(aidx + 1, k): penalty += self.mmd(tensors_list[aidx], tensors_list[bidx]) return penalty From a10d3c7c2128e88f204dba0af820ce9c8e1ff239 Mon Sep 17 00:00:00 2001 From: JPZ4-5 Date: Sat, 29 Nov 2025 23:05:47 +0800 Subject: [PATCH 2/7] fix: address code review comments (logic fixes & style) Signed-off-by: JPZ4-5 --- .../algorithms/regularization.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/dowhy/causal_prediction/algorithms/regularization.py b/dowhy/causal_prediction/algorithms/regularization.py index 448448ebdc..71042ba9cf 100644 --- a/dowhy/causal_prediction/algorithms/regularization.py +++ b/dowhy/causal_prediction/algorithms/regularization.py @@ -69,12 +69,15 @@ def _optimized_mmd_penalty(self, list_of_feature_tensors: list): covas = [] for t, cent in zip(valid_tensors, cents): n = t.shape[0] + d_dim = t.shape[1] if n > 1: covas.append((cent.t() @ cent) / (n - 1)) else: - covas.append(torch.zeros_like(means[0].diag())) + covas.append( + torch.zeros_likes((d_dim, d_dim), device=original_device, dtype=torch.float64) + ) - penalty = torch.tensor(0.0, device=original_device, dtype=torch.float64) + penalty = tensor(0.0, device=original_device, dtype=torch.float64) for i in range(k): for j in range(i + 1, k): mean_diff = (means[i] - means[j]).pow(2).mean() @@ -97,9 +100,12 @@ def unconditional_reg(self, classifs, attribute_labels, num_envs, E_eq_A=False, :param classifs: feature representations output from classifier layer (gφ(x)) :param attribute_labels: attribute labels loaded with the dataset for attribute A_i :param num_envs: number of environments/domains - :param E_eq_A: Binary flag indicating whether attribute (A_i) coinicides with environment (E) definition - :param use_optimization: If True, uses an algebraically optimized method to compute the penalty. - If False, uses the original nested loop, which is more extensible for new MMD + :param E_eq_A: Binary flag indicating whether attribute (A_i) coincides with environment (E) definition + :param use_optimization: If True, uses an algebraically optimized method to compute the penalty, which is + faster and suitable for standard MMD computations. If False, uses the original nested loop, which is more + extensible for implementing new or custom MMD variants, but may be slower. Choose True for performance + with standard MMD, and False when correctness or extensibility for new MMD types is required. + """ @@ -153,7 +159,7 @@ def conditional_reg( :param attribute_labels: attribute labels loaded with the dataset for attribute A_i :param conditioning_subset: list of subset of observed variables A_s (attributes + targets) such that (X_c, A_i) are d-separated conditioned on this subset :param num_envs: number of environments/domains - :param E_eq_A: Binary flag indicating whether attribute (A_i) coinicides with environment (E) definition + :param E_eq_A: Binary flag indicating whether attribute (A_i) coincides with environment (E) definition :param use_optimization: If True, uses an algebraically optimized method to compute the penalty. If False, uses the original nested loop, which is more extensible for new MMD @@ -235,7 +241,7 @@ def conditional_reg( return penalty def _compute_conditional_penalty(self, features, attributes, group_data, use_optimization): - penalty = tensor(0.0, dtype=features[0].dtype, device=features[0].device) + penalty = tensor(0.0, dtype=features.dtype, device=features.device) unique_groups, group_indices = torch.unique(group_data, dim=0, return_inverse=True) present_group_ids = torch.unique(group_indices) From b9c14c83221f3b1ad83aa2f3688889968ab47fc2 Mon Sep 17 00:00:00 2001 From: JPZ4-5 Date: Sat, 29 Nov 2025 23:20:09 +0800 Subject: [PATCH 3/7] reformat with black Signed-off-by: JPZ4-5 --- dowhy/causal_prediction/algorithms/regularization.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dowhy/causal_prediction/algorithms/regularization.py b/dowhy/causal_prediction/algorithms/regularization.py index 71042ba9cf..42db190592 100644 --- a/dowhy/causal_prediction/algorithms/regularization.py +++ b/dowhy/causal_prediction/algorithms/regularization.py @@ -73,9 +73,7 @@ def _optimized_mmd_penalty(self, list_of_feature_tensors: list): if n > 1: covas.append((cent.t() @ cent) / (n - 1)) else: - covas.append( - torch.zeros_likes((d_dim, d_dim), device=original_device, dtype=torch.float64) - ) + covas.append(torch.zeros_likes((d_dim, d_dim), device=original_device, dtype=torch.float64)) penalty = tensor(0.0, device=original_device, dtype=torch.float64) for i in range(k): From 2960c67fdc9615069cf392b6cea71d0a1b5a59f9 Mon Sep 17 00:00:00 2001 From: JPZ4-5 Date: Fri, 5 Dec 2025 17:37:16 +0800 Subject: [PATCH 4/7] bug fix Signed-off-by: JPZ4-5 --- dowhy/causal_prediction/algorithms/regularization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dowhy/causal_prediction/algorithms/regularization.py b/dowhy/causal_prediction/algorithms/regularization.py index 42db190592..e1dabdcd12 100644 --- a/dowhy/causal_prediction/algorithms/regularization.py +++ b/dowhy/causal_prediction/algorithms/regularization.py @@ -73,7 +73,7 @@ def _optimized_mmd_penalty(self, list_of_feature_tensors: list): if n > 1: covas.append((cent.t() @ cent) / (n - 1)) else: - covas.append(torch.zeros_likes((d_dim, d_dim), device=original_device, dtype=torch.float64)) + covas.append(torch.zeros((d_dim, d_dim), device=original_device, dtype=torch.float64)) penalty = tensor(0.0, device=original_device, dtype=torch.float64) for i in range(k): From ef4a1c88e5dff9f488a32df9d3119eca22eb5222 Mon Sep 17 00:00:00 2001 From: JPZ4-5 Date: Sat, 14 Feb 2026 00:23:11 +0800 Subject: [PATCH 5/7] Fix: Change mutable default args to None and add AdamW optimizer Signed-off-by: JPZ4-5 --- dowhy/causal_prediction/algorithms/base_algorithm.py | 8 ++++++-- dowhy/causal_prediction/algorithms/cacm.py | 8 ++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/dowhy/causal_prediction/algorithms/base_algorithm.py b/dowhy/causal_prediction/algorithms/base_algorithm.py index 9a61ccaa53..795cea8360 100644 --- a/dowhy/causal_prediction/algorithms/base_algorithm.py +++ b/dowhy/causal_prediction/algorithms/base_algorithm.py @@ -26,8 +26,8 @@ def __init__(self, model, optimizer, lr, weight_decay, betas, momentum): self.momentum = momentum # Check if the optimizer is currently supported - if self.optimizer not in ["Adam", "SGD"]: - error_msg = self.optimizer + " is not implemented currently. Try Adam or SGD." + if self.optimizer not in ["Adam", "AdamW", "SGD"]: + error_msg = self.optimizer + " is not implemented currently. Try Adam, AdamW or SGD." raise Exception(error_msg) def training_step(self, train_batch, batch_idx): @@ -89,6 +89,10 @@ def configure_optimizers(self): optimizer = torch.optim.Adam( self.parameters(), lr=self.lr, weight_decay=self.weight_decay, betas=self.betas ) + elif self.optimizer == "AdamW": + optimizer = torch.optim.AdamW( + self.parameters(), lr=self.lr, weight_decay=self.weight_decay, betas=self.betas + ) elif self.optimizer == "SGD": optimizer = torch.optim.SGD( self.parameters(), lr=self.lr, weight_decay=self.weight_decay, momentum=self.momentum diff --git a/dowhy/causal_prediction/algorithms/cacm.py b/dowhy/causal_prediction/algorithms/cacm.py index f08c26e1b7..6544e04d14 100644 --- a/dowhy/causal_prediction/algorithms/cacm.py +++ b/dowhy/causal_prediction/algorithms/cacm.py @@ -16,9 +16,9 @@ def __init__( momentum=0.9, kernel_type="gaussian", ci_test="mmd", - attr_types=[], + attr_types=None, E_conditioned=True, - E_eq_A=[], + E_eq_A=None, gamma=1e-6, lambda_causal=1.0, lambda_conf=1.0, @@ -97,13 +97,13 @@ def training_step(self, train_batch, batch_idx): objective /= nmb loss = objective - if self.attr_types != []: + if self.attr_types is not None: for attr_type_idx, attr_type in enumerate(self.attr_types): attribute_labels = [ ai for _, _, ai in minibatches ] # [(batch_size, num_attrs)_1, batch_size, num_attrs)_2, ..., (batch_size, num_attrs)_(num_environments)] - E_eq_A_attr = attr_type_idx in self.E_eq_A + E_eq_A_attr = False if self.E_eq_A is None else attr_type_idx in self.E_eq_A # Acause regularization if attr_type == "causal": From 74134c2a13ff9e9b408c5075fb843d2fefce4a40 Mon Sep 17 00:00:00 2001 From: JPZ4-5 Date: Sat, 14 Feb 2026 01:17:28 +0800 Subject: [PATCH 6/7] chore: refresh poetry.lock Signed-off-by: JPZ4-5 --- poetry.lock | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/poetry.lock b/poetry.lock index 64c302bb67..4b12fb9a45 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. [[package]] name = "absl-py" @@ -509,14 +509,14 @@ css = ["tinycss2 (>=1.1.0,<1.5)"] [[package]] name = "causal-learn" -version = "0.1.4.3" +version = "0.1.4.4" description = "causal-learn Python Package" optional = false python-versions = ">=3.7" groups = ["main"] files = [ - {file = "causal_learn-0.1.4.3-py3-none-any.whl", hash = "sha256:47ca196b786ea3b899e0e0e2c6d607c37c900df9457781f6ab0b97a2c6a747d0"}, - {file = "causal_learn-0.1.4.3.tar.gz", hash = "sha256:e832171668ded6b3ced6499be62951c1cca56b1d21c6cf67e920e50e3a671854"}, + {file = "causal_learn-0.1.4.4-py3-none-any.whl", hash = "sha256:e3d51dae578b58d6e4bba0544a18817dc63f574fa1e0120019febe7fee90baff"}, + {file = "causal_learn-0.1.4.4.tar.gz", hash = "sha256:82e65ba9593a31b0a33b5f313e5b824c0763de3a8da2a062458f6c12bd46017e"}, ] [package.dependencies] @@ -3707,7 +3707,7 @@ description = "CUBLAS native runtime libraries" optional = false python-versions = ">=3" groups = ["dev"] -markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" files = [ {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, @@ -3734,7 +3734,7 @@ description = "CUDA profiling tools runtime libs." optional = false python-versions = ">=3" groups = ["dev"] -markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" files = [ {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, @@ -3763,7 +3763,7 @@ description = "NVRTC native runtime libraries" optional = false python-versions = ">=3" groups = ["dev"] -markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" files = [ {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, @@ -3790,7 +3790,7 @@ description = "CUDA Runtime native Libraries" optional = false python-versions = ">=3" groups = ["dev"] -markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" files = [ {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, @@ -3819,7 +3819,7 @@ description = "cuDNN runtime libraries" optional = false python-versions = ">=3" groups = ["dev"] -markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" files = [ {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f"}, {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a"}, @@ -3852,7 +3852,7 @@ description = "CUFFT native runtime libraries" optional = false python-versions = ">=3" groups = ["dev"] -markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" files = [ {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, @@ -3897,7 +3897,7 @@ description = "CURAND native runtime libraries" optional = false python-versions = ">=3" groups = ["dev"] -markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" files = [ {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, @@ -3926,7 +3926,7 @@ description = "CUDA solver native runtime libraries" optional = false python-versions = ">=3" groups = ["dev"] -markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" files = [ {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, @@ -3965,7 +3965,7 @@ description = "CUSPARSE native runtime libraries" optional = false python-versions = ">=3" groups = ["dev"] -markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" files = [ {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, @@ -4014,7 +4014,7 @@ description = "NVIDIA Collective Communication Library (NCCL) Runtime" optional = false python-versions = ">=3" groups = ["dev"] -markers = "python_version < \"3.13\" and platform_machine != \"aarch64\" and platform_system == \"Linux\"" +markers = "platform_system == \"Linux\" and platform_machine != \"aarch64\" and python_version < \"3.13\"" files = [ {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"}, {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"}, @@ -4054,7 +4054,7 @@ description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" groups = ["dev"] -markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" files = [ {file = "nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:e3f1171dbdc83c5932a45f0f4c99180a70de9bd2718c1ab77d14104f6d7147f9"}, {file = "nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:994a05ef08ef4b0b299829cde613a424382aff7efb08a7172c1fa616cc3af2ca"}, @@ -4068,7 +4068,7 @@ description = "NVIDIA Tools Extension" optional = false python-versions = ">=3" groups = ["dev"] -markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" files = [ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, @@ -4779,7 +4779,7 @@ files = [ {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, ] -markers = {dev = "(os_name != \"nt\" or sys_platform != \"win32\") and (os_name != \"nt\" or sys_platform != \"win32\" and sys_platform != \"emscripten\" or python_version == \"3.9\")", docs = "sys_platform != \"win32\" and sys_platform != \"emscripten\" or python_version == \"3.9\" and sys_platform != \"win32\""} +markers = {dev = "(sys_platform != \"win32\" and sys_platform != \"emscripten\" or os_name != \"nt\" or python_version == \"3.9\") and (sys_platform != \"win32\" or os_name != \"nt\")", docs = "sys_platform != \"win32\" and sys_platform != \"emscripten\" or python_version == \"3.9\" and sys_platform != \"win32\""} [[package]] name = "pure-eval" @@ -5637,7 +5637,7 @@ description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = ">=3.10" groups = ["main", "dev", "docs"] -markers = "python_version >= \"3.10\" and python_version < \"3.13\"" +markers = "python_version < \"3.13\" and python_version >= \"3.10\"" files = [ {file = "scipy-1.15.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:a345928c86d535060c9c2b25e71e87c39ab2f22fc96e9636bd74d1dbf9de448c"}, {file = "scipy-1.15.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:ad3432cb0f9ed87477a8d97f03b763fd1d57709f1bbde3c9369b1dff5503b253"}, @@ -6035,7 +6035,7 @@ description = "Sparse n-dimensional arrays for the PyData ecosystem" optional = false python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"econml\" and python_version >= \"3.10\"" +markers = "python_version >= \"3.10\" and extra == \"econml\"" files = [ {file = "sparse-0.17.0-py2.py3-none-any.whl", hash = "sha256:1922d1d97f692b1061c4f03a1dd6ee21850aedc88e171aa845715f5069952f18"}, {file = "sparse-0.17.0.tar.gz", hash = "sha256:6b1ad51a810c5be40b6f95e28513ec810fe1c785923bd83b2e4839a751df4bf7"}, @@ -6833,7 +6833,7 @@ description = "A language and compiler for custom Deep Learning operations" optional = false python-versions = "*" groups = ["dev"] -markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" files = [ {file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"}, {file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"}, @@ -7222,4 +7222,4 @@ pygraphviz = ["pygraphviz"] [metadata] lock-version = "2.1" python-versions = ">=3.9,<3.14" -content-hash = "cb45922c7229cb9dba5a59588b9a446aa64ccfe8baf263362e26ac7b131fe581" +content-hash = "24504cd4b76eb1985e762e4db4fca43aab5e5095456c9a0c47c2a523849507fe" From ff9b7e383e0a9a625da05196913c19d787cd4bf8 Mon Sep 17 00:00:00 2001 From: JPZ4-5 Date: Sat, 14 Feb 2026 19:36:52 +0800 Subject: [PATCH 7/7] chore: revert dependency files to upstream versions Signed-off-by: JPZ4-5 --- poetry.lock | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/poetry.lock b/poetry.lock index 4b12fb9a45..64c302bb67 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -509,14 +509,14 @@ css = ["tinycss2 (>=1.1.0,<1.5)"] [[package]] name = "causal-learn" -version = "0.1.4.4" +version = "0.1.4.3" description = "causal-learn Python Package" optional = false python-versions = ">=3.7" groups = ["main"] files = [ - {file = "causal_learn-0.1.4.4-py3-none-any.whl", hash = "sha256:e3d51dae578b58d6e4bba0544a18817dc63f574fa1e0120019febe7fee90baff"}, - {file = "causal_learn-0.1.4.4.tar.gz", hash = "sha256:82e65ba9593a31b0a33b5f313e5b824c0763de3a8da2a062458f6c12bd46017e"}, + {file = "causal_learn-0.1.4.3-py3-none-any.whl", hash = "sha256:47ca196b786ea3b899e0e0e2c6d607c37c900df9457781f6ab0b97a2c6a747d0"}, + {file = "causal_learn-0.1.4.3.tar.gz", hash = "sha256:e832171668ded6b3ced6499be62951c1cca56b1d21c6cf67e920e50e3a671854"}, ] [package.dependencies] @@ -3707,7 +3707,7 @@ description = "CUBLAS native runtime libraries" optional = false python-versions = ">=3" groups = ["dev"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" +markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" files = [ {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, @@ -3734,7 +3734,7 @@ description = "CUDA profiling tools runtime libs." optional = false python-versions = ">=3" groups = ["dev"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" +markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" files = [ {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, @@ -3763,7 +3763,7 @@ description = "NVRTC native runtime libraries" optional = false python-versions = ">=3" groups = ["dev"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" +markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" files = [ {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, @@ -3790,7 +3790,7 @@ description = "CUDA Runtime native Libraries" optional = false python-versions = ">=3" groups = ["dev"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" +markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" files = [ {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, @@ -3819,7 +3819,7 @@ description = "cuDNN runtime libraries" optional = false python-versions = ">=3" groups = ["dev"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" +markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" files = [ {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f"}, {file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a"}, @@ -3852,7 +3852,7 @@ description = "CUFFT native runtime libraries" optional = false python-versions = ">=3" groups = ["dev"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" +markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" files = [ {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, @@ -3897,7 +3897,7 @@ description = "CURAND native runtime libraries" optional = false python-versions = ">=3" groups = ["dev"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" +markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" files = [ {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, @@ -3926,7 +3926,7 @@ description = "CUDA solver native runtime libraries" optional = false python-versions = ">=3" groups = ["dev"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" +markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" files = [ {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, @@ -3965,7 +3965,7 @@ description = "CUSPARSE native runtime libraries" optional = false python-versions = ">=3" groups = ["dev"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" +markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" files = [ {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, @@ -4014,7 +4014,7 @@ description = "NVIDIA Collective Communication Library (NCCL) Runtime" optional = false python-versions = ">=3" groups = ["dev"] -markers = "platform_system == \"Linux\" and platform_machine != \"aarch64\" and python_version < \"3.13\"" +markers = "python_version < \"3.13\" and platform_machine != \"aarch64\" and platform_system == \"Linux\"" files = [ {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"}, {file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"}, @@ -4054,7 +4054,7 @@ description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" groups = ["dev"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" +markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" files = [ {file = "nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:e3f1171dbdc83c5932a45f0f4c99180a70de9bd2718c1ab77d14104f6d7147f9"}, {file = "nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:994a05ef08ef4b0b299829cde613a424382aff7efb08a7172c1fa616cc3af2ca"}, @@ -4068,7 +4068,7 @@ description = "NVIDIA Tools Extension" optional = false python-versions = ">=3" groups = ["dev"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" +markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" files = [ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, @@ -4779,7 +4779,7 @@ files = [ {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, ] -markers = {dev = "(sys_platform != \"win32\" and sys_platform != \"emscripten\" or os_name != \"nt\" or python_version == \"3.9\") and (sys_platform != \"win32\" or os_name != \"nt\")", docs = "sys_platform != \"win32\" and sys_platform != \"emscripten\" or python_version == \"3.9\" and sys_platform != \"win32\""} +markers = {dev = "(os_name != \"nt\" or sys_platform != \"win32\") and (os_name != \"nt\" or sys_platform != \"win32\" and sys_platform != \"emscripten\" or python_version == \"3.9\")", docs = "sys_platform != \"win32\" and sys_platform != \"emscripten\" or python_version == \"3.9\" and sys_platform != \"win32\""} [[package]] name = "pure-eval" @@ -5637,7 +5637,7 @@ description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = ">=3.10" groups = ["main", "dev", "docs"] -markers = "python_version < \"3.13\" and python_version >= \"3.10\"" +markers = "python_version >= \"3.10\" and python_version < \"3.13\"" files = [ {file = "scipy-1.15.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:a345928c86d535060c9c2b25e71e87c39ab2f22fc96e9636bd74d1dbf9de448c"}, {file = "scipy-1.15.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:ad3432cb0f9ed87477a8d97f03b763fd1d57709f1bbde3c9369b1dff5503b253"}, @@ -6035,7 +6035,7 @@ description = "Sparse n-dimensional arrays for the PyData ecosystem" optional = false python-versions = ">=3.10" groups = ["main"] -markers = "python_version >= \"3.10\" and extra == \"econml\"" +markers = "extra == \"econml\" and python_version >= \"3.10\"" files = [ {file = "sparse-0.17.0-py2.py3-none-any.whl", hash = "sha256:1922d1d97f692b1061c4f03a1dd6ee21850aedc88e171aa845715f5069952f18"}, {file = "sparse-0.17.0.tar.gz", hash = "sha256:6b1ad51a810c5be40b6f95e28513ec810fe1c785923bd83b2e4839a751df4bf7"}, @@ -6833,7 +6833,7 @@ description = "A language and compiler for custom Deep Learning operations" optional = false python-versions = "*" groups = ["dev"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\"" +markers = "python_version < \"3.13\" and platform_machine == \"x86_64\" and platform_system == \"Linux\"" files = [ {file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"}, {file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"}, @@ -7222,4 +7222,4 @@ pygraphviz = ["pygraphviz"] [metadata] lock-version = "2.1" python-versions = ">=3.9,<3.14" -content-hash = "24504cd4b76eb1985e762e4db4fca43aab5e5095456c9a0c47c2a523849507fe" +content-hash = "cb45922c7229cb9dba5a59588b9a446aa64ccfe8baf263362e26ac7b131fe581"