diff --git a/.gitignore b/.gitignore index c66905f5..8e560bbe 100644 --- a/.gitignore +++ b/.gitignore @@ -229,3 +229,5 @@ paper_analyses alphaquant/resources/reference_databases alphaquant/resources/phosphopred_databases +.claude/settings.local.json +.claude/settings.local.json diff --git a/alphaquant/classify/ml_info_table.py b/alphaquant/classify/ml_info_table.py index a92b7915..a975a3a9 100644 --- a/alphaquant/classify/ml_info_table.py +++ b/alphaquant/classify/ml_info_table.py @@ -64,10 +64,12 @@ def _adapt_precursor_name_to_modification_type(self): def _define_ml_info_filename(self): - self.ml_info_filename = aq_utils.get_progress_folder_filename(self._input_file, ".ml_info_table.tsv") + self.ml_info_filename = aq_utils.get_progress_folder_filename(self._input_file, ".ml_info_table.tsv.zip") def _write_ml_info_table(self): - self._ml_info_df.to_csv(self.ml_info_filename, sep="\t", index=False) + archive_name = self.ml_info_filename.split("/")[-1].removesuffix(".zip") + compression = {"method": "zip", "archive_name": archive_name} + self._ml_info_df.to_csv(self.ml_info_filename, sep="\t", index=False, compression=compression) class MLInfoTableLoader(): @@ -81,4 +83,3 @@ def __init__(self, ml_info_file, samples_used): def _subset_df_to_relevant_samples(self): self.ml_info_df = self.ml_info_df[self.ml_info_df["sample_ID"].isin(self._samples_used)] self.ml_info_df = self.ml_info_df.drop(columns=["sample_ID"]) - diff --git a/alphaquant/cluster/cluster_ions.py b/alphaquant/cluster/cluster_ions.py index 8767054b..9e24ee00 100644 --- a/alphaquant/cluster/cluster_ions.py +++ b/alphaquant/cluster/cluster_ions.py @@ -1,3 +1,28 @@ +"""Hierarchical proteoform clustering and differential-expression aggregation. + +This module builds a hierarchical tree for each protein (fragments → peptides +→ modified peptides → unmodified peptides → protein) and performs two +*independent* statistical procedures at each level: + +1. **Proteoform clustering** — pairwise double-differential tests assess + whether sibling ions share the same fold change (null hypothesis: "ions + have equal regulation"). The resulting similarity p-values are corrected + with Benjamini-Yekutieli (appropriate for the intrinsic dependencies among + pairwise comparisons) and then used for hierarchical clustering to separate + proteoforms. + +2. **Differential-expression aggregation** — the per-ion z-values from + individual differential-expression tests (null hypothesis: "no change + between conditions") are combined into parent-level z-values via Stouffer's + method. This aggregation propagates evidence of regulation up the tree. + +The two sets of p-values address fundamentally different questions and are +corrected separately. The Benjamini-Yekutieli correction in step 1 applies +*within* a single protein to the similarity matrix; the final protein-level +p-values produced by step 2 are later corrected across all proteins with +Benjamini-Hochberg (see ``diffquant_table.py``). +""" + import scipy.spatial.distance import scipy.cluster.hierarchy import alphaquant.cluster.cluster_utils as aqcluster_utils @@ -28,7 +53,7 @@ -def get_scored_clusterselected_ions(gene_name, diffions, normed_c1, normed_c2, ion2diffDist, p2z, deedpair2doublediffdist, pval_threshold_basis, fcfc_threshold, take_median_ion, fcdiff_cutoff_clustermerge, fragment_outlier_filtering=True): +def get_scored_clusterselected_ions(gene_name, diffions, normed_c1, normed_c2, ion2diffDist, p2z, deedpair2doublediffdist, pval_threshold_basis, fcfc_threshold, take_median_ion, fcdiff_cutoff_clustermerge, aggregation_mode="stouffer_decorrelation", cluster_threshold_ion_type=0.01): """Main entry point for hierarchical clustering and tree-based quantification of a protein. This function creates a hierarchical tree structure from fragment ions up to the protein level @@ -47,7 +72,7 @@ def get_scored_clusterselected_ions(gene_name, diffions, normed_c1, normed_c2, i fcfc_threshold: Fold-change difference threshold for clustering take_median_ion: If True, use median-centered ions for clustering fcdiff_cutoff_clustermerge: Fold-change threshold for merging similar clusters - fragment_outlier_filtering: Whether to filter outlier fragments when aggregating to peptides + aggregation_mode: Strategy for combining child z-values (see cluster_utils.AGGREGATION_MODES) Returns: anytree.Node: Root node of the hierarchical tree containing all statistics and clustering results @@ -63,7 +88,7 @@ def get_scored_clusterselected_ions(gene_name, diffions, normed_c1, normed_c2, i root_node = create_hierarchical_ion_grouping(gene_name, diffions) add_reduced_names_to_root(root_node) #LOGGER.info(anytree.RenderTree(root_node)) - root_node_clust = cluster_along_specified_levels(root_node, name2diffion, normed_c1, normed_c2, ion2diffDist, p2z, deedpair2doublediffdist, pval_threshold_basis, fcfc_threshold, take_median_ion, fragment_outlier_filtering) + root_node_clust = cluster_along_specified_levels(root_node, name2diffion, normed_c1, normed_c2, ion2diffDist, p2z, deedpair2doublediffdist, pval_threshold_basis, fcfc_threshold, take_median_ion, aggregation_mode=aggregation_mode, cluster_threshold_ion_type=cluster_threshold_ion_type) level_sorted_nodes = [[node for node in children] for children in anytree.ZigZagGroupIter(root_node_clust)] level_sorted_nodes.reverse() #the base nodes are first @@ -114,7 +139,7 @@ def add_reduced_names_to_root(node): import pandas as pd -def cluster_along_specified_levels(root_node, ionname2diffion, normed_c1, normed_c2, ion2diffDist, p2z, deedpair2doublediffdist, pval_threshold_basis, fcfc_threshold, take_median_ion, fragment_outlier_filtering=True):#~60% of overall runtime +def cluster_along_specified_levels(root_node, ionname2diffion, normed_c1, normed_c2, ion2diffDist, p2z, deedpair2doublediffdist, pval_threshold_basis, fcfc_threshold, take_median_ion, aggregation_mode="stouffer_decorrelation", cluster_threshold_ion_type=0.01):#~60% of overall runtime """Performs hierarchical clustering at each level of the tree from bottom to top. Starting from base ions (fragments/MS1), this function iterates through each level @@ -123,6 +148,31 @@ def cluster_along_specified_levels(root_node, ionname2diffion, normed_c1, normed pairwise for consistent fold-change differences, clustered hierarchically, and statistics are aggregated to parent nodes. + Important: this function interleaves two *independent* statistical procedures + that address different questions: + + 1. **Proteoform clustering** (``find_fold_change_clusters``): tests whether + pairs of sibling ions (e.g. two peptides of the same protein) have + *different* fold changes, i.e. whether they belong to different + proteoforms. The resulting pairwise p-values form a dependent similarity + matrix that is corrected for multiple testing with Benjamini-Yekutieli + (appropriate for dependent tests) *before* hierarchical clustering. + + 2. **Differential-expression aggregation** (``aggregate_node_properties``): + combines the per-ion z-values (derived from individual differential + expression tests) into a single parent-level z-value via Stouffer's + method. These z-values quantify *how much* each ion changes between + conditions and are unrelated to the proteoform-similarity p-values + from step 1. + + The Benjamini-Yekutieli correction in step 1 therefore precedes the + Stouffer aggregation in step 2 by design, because they operate on + different null hypotheses: step 1 asks "do these two peptides differ + from each other?" while step 2 asks "does this protein change between + conditions?". The final protein-level p-values produced in step 2 + are later corrected with Benjamini-Hochberg across all proteins (see + ``diffquant_table.py``). + Args: root_node: Root of the hierarchical tree (protein level) ionname2diffion: Dictionary mapping ion names to DifferentialIon objects @@ -134,7 +184,7 @@ def cluster_along_specified_levels(root_node, ionname2diffion, normed_c1, normed pval_threshold_basis: P-value threshold for clustering decisions fcfc_threshold: Fold-change threshold for clustering take_median_ion: Whether to use median-centered ions - fragment_outlier_filtering: Whether to filter fragment outliers + aggregation_mode: Strategy for combining child z-values (see cluster_utils.AGGREGATION_MODES) Returns: anytree.Node: The root node with all clustering annotations and aggregated statistics @@ -164,7 +214,7 @@ def cluster_along_specified_levels(root_node, ionname2diffion, normed_c1, normed if take_median_ion: grouped_mainclust_leafs = aqcluster_utils.select_median_fc_leafs(grouped_mainclust_leafs) diffions = aqcluster_utils.map_grouped_leafs_to_diffions(grouped_mainclust_leafs, ionname2diffion) #the diffions are the ions that are actually compared - childnode2clust = find_fold_change_clusters(type_node, diffions, normed_c1, normed_c2, ion2diffDist, p2z, deedpair2doublediffdist, pval_threshold_basis, fcfc_threshold) #the clustering is performed on the child nodes + childnode2clust = find_fold_change_clusters(type_node, diffions, normed_c1, normed_c2, ion2diffDist, p2z, deedpair2doublediffdist, pval_threshold_basis, fcfc_threshold, cluster_threshold_ion_type=cluster_threshold_ion_type) #the clustering is performed on the child nodes childnode2clust = merge_similar_clusters_if_applicable(childnode2clust, type_node, fcdiff_cutoff_clustermerge = FCDIFF_CUTOFF_CLUSTERMERGE) childnode2clust = aq_cluster_sorting.decide_cluster_order(childnode2clust) @@ -172,7 +222,7 @@ def cluster_along_specified_levels(root_node, ionname2diffion, normed_c1, normed aqcluster_utils.assign_clusterstats_to_type_node(type_node, childnode2clust) aqcluster_utils.annotate_mainclust_leaves(childnode2clust) aqcluster_utils.assign_cluster_number(type_node, childnode2clust) - aqcluster_utils.aggregate_node_properties(type_node,only_use_mainclust=True, peptide_outlier_filtering=False, fragment_outlier_filtering=fragment_outlier_filtering) + aqcluster_utils.aggregate_node_properties(type_node,only_use_mainclust=True, peptide_outlier_filtering=False, aggregation_mode=aggregation_mode) return root_node @@ -183,21 +233,46 @@ def get_childnode2clust_for_single_ion(type_node): return {type_node.children[0]: 0} -def find_fold_change_clusters(type_node, diffions, normed_c1, normed_c2, ion2diffDist, p2z, deedpair2doublediffdist, pval_threshold_basis, fcfc_threshold): - """Compares the fold changes of the ions corresponding to the nodes that are compared and returns the set of ions with consistent fold changes. +def find_fold_change_clusters(type_node, diffions, normed_c1, normed_c2, ion2diffDist, p2z, deedpair2doublediffdist, pval_threshold_basis, fcfc_threshold, cluster_threshold_ion_type=0.01): + """Cluster sibling ions by the similarity of their fold changes (proteoform inference). + + For each pair of sibling ion groups, a *double-differential* test + (``evaluate_similarity``) computes a p-value for the null hypothesis that + their fold changes are equal (i.e. they belong to the same proteoform). + These pairwise p-values form a condensed similarity matrix that is + corrected for multiple testing with the Benjamini-Yekutieli procedure + (``get_multiple_testing_corrected_condensed_similarity_matrix``), which is + appropriate here because the pairwise comparisons are intrinsically + dependent. The corrected matrix is then converted to a distance matrix and + subjected to Ward hierarchical clustering. + + Note: the p-values computed and corrected here test *inter-ion similarity* + for proteoform grouping. They are entirely separate from the per-ion + differential-expression p-values that are later aggregated via Stouffer's + method in ``aggregate_node_properties``. Args: - diffions (list[list[ionnames]]): contains the sets of ions to be tested, for example [[fragion1_precursor1, fragion2_precursor1, fragion3_precursor1],[fragion1_precursor2],[fragion1_precursor3, fragion2_precursor3]]. The ions are assumed to be similar in type (e.g. fragment, precursor)! - normed_c1 (ConditionBackground): [description] - normed_c2 (ConditionBackground): [description] - ion2diffDist (dict(ion : SubtractedBackground)): [description] - p2z ([type]): [description] - deedpair2doublediffdist ([type]): [description] - fc_threshold (float, optional): [description]. Defaults to 0. - pval_threshold_basis (float, optional): the threshold at which to merge peptides at the gene level. Defaults to 0.01 + type_node: Parent node whose children are being clustered + diffions: List of lists of DifferentialIon objects, one sublist per + child node, e.g. ``[[fragion1_prec1, fragion2_prec1], [fragion1_prec2]]`` + normed_c1: ConditionBackgrounds for condition 1 + normed_c2: ConditionBackgrounds for condition 2 + ion2diffDist: Mapping of ion pairs to differential background distributions + p2z: Cache for p-value to z-value conversions + deedpair2doublediffdist: Cache for double-differential distributions + pval_threshold_basis: P-value threshold at the gene level for cutting the + clustering dendrogram; thresholds for lower levels are looked up in + ``LEVEL2PVALTHRESH`` + fcfc_threshold: Minimum fold-change difference below which two ion groups + are assumed similar without a formal test + cluster_threshold_ion_type: P-value threshold at the ion_type level + + Returns: + list[tuple[Node, int]]: Pairs of (child node, cluster index) sorted by + node name for reproducibility """ - pval_threshold_basis = get_pval_threshold_basis(type_node, pval_threshold_basis) + pval_threshold_basis = get_pval_threshold_basis(type_node, pval_threshold_basis, cluster_threshold_ion_type=cluster_threshold_ion_type) diffions_idxs = [[x] for x in range(len(diffions))] diffions_fcs = aqcluster_utils.get_fcs_ions(diffions) #mt_corrected_pval_thresh = pval_threshold_basis/len(diffions) @@ -216,26 +291,36 @@ def find_fold_change_clusters(type_node, diffions, normed_c1, normed_c2, ion2dif return childnode2clust -def get_pval_threshold_basis(type_node, pval_threshold_basis): #the pval threshold is only set at the gene level, the rest of the levels are set as specified in the LEVEL2PVALTHRESH dictionary +def get_pval_threshold_basis(type_node, pval_threshold_basis, cluster_threshold_ion_type=0.01): #the pval threshold is only set at the gene level, the rest of the levels are set as specified in the LEVEL2PVALTHRESH dictionary if type_node.level == "gene": return pval_threshold_basis - else: - return LEVEL2PVALTHRESH.get(type_node.level, 0.2) + if type_node.level == "ion_type": + return cluster_threshold_ion_type + return LEVEL2PVALTHRESH.get(type_node.level, 0.2) def get_multiple_testing_corrected_condensed_similarity_matrix(condensed_distance_matrix: np.array): - """ - condensed_distance_matrix contains all p-values of the pairwise comparisons of the ions. They are by definition dependent. + """Apply Benjamini-Yekutieli FDR correction to pairwise ion-similarity p-values. + + The condensed matrix contains p-values from *double-differential* tests + between all pairs of sibling ions (see ``evaluate_similarity``). Each + p-value tests the null hypothesis that two ions share the same fold + change. Because every ion appears in multiple pairwise comparisons, the + tests are intrinsically dependent, which is why the Benjamini-Yekutieli + (``fdr_by``) procedure is used instead of Benjamini-Hochberg. + + These corrected p-values are used exclusively for proteoform clustering + (deciding which ions belong together) and are unrelated to the per-ion + differential-expression p-values that are later aggregated via Stouffer's + method. Args: - condensed_distance_matrix (np.array): Condensed distance matrix containing p-values of pairwise comparisons. + condensed_distance_matrix: 1-D array of pairwise similarity p-values + in scipy condensed-matrix format. Returns: - np.array: Corrected condensed distance matrix. + np.array: Corrected p-values in the same condensed-matrix layout. """ - # Apply Benjamini-Yekutieli correction _, corrected_pvalues, _, _ = multitest.multipletests(condensed_distance_matrix, method='fdr_by') - - # Return the corrected condensed matrix return corrected_pvalues @@ -341,4 +426,3 @@ def exclude_node(node): node.is_included = False for descendant in node.descendants: descendant.is_included = False - diff --git a/alphaquant/cluster/cluster_utils.py b/alphaquant/cluster/cluster_utils.py index a5f1bf5c..90370ecf 100644 --- a/alphaquant/cluster/cluster_utils.py +++ b/alphaquant/cluster/cluster_utils.py @@ -5,7 +5,6 @@ import collections import alphaquant.config.variables as aqvariables from anytree import Node, LevelOrderGroupIter -import alphaquant.utils.diffquant_utils as aq_utils_diffquant import re import alphaquant.config.config as aqconfig @@ -18,14 +17,49 @@ LEVELS_UNIQUE = ["base","ion_type", "mod_seq_charge", "mod_seq", "seq", "gene"] TYPE2LEVEL = dict(zip(TYPES, LEVELS)) - -def aggregate_node_properties(node, only_use_mainclust, peptide_outlier_filtering=False, fragment_outlier_filtering=True): - """Aggregates statistical properties from child nodes to a parent node in the tree. - - This is the core function for propagating statistics up the hierarchical tree structure. - It combines z-values, fold changes, and quality metrics from child nodes (e.g., peptides) - into parent node (e.g., protein) statistics. The aggregation can optionally exclude - proteoforms (non-main clusters) and filter outlier children. +DEFAULT_AGGREGATION_MODE = "stouffer_decorrelation" +AGGREGATION_MODES = ( + "stouffer_decorrelation", + "mean_z", + "median_z", + "min_median_max_z", + "min_max_z", + "summed_z", +) +_LEGACY_AGGREGATION_MODES = ("stouffer_icc",) + +# Node types where alternative aggregation modes (mean_z, median_z, …) may +# be selected via the aggregation_mode parameter. For all other node types, +# Stouffer is always used. +_DEPENDENT_NODE_TYPES = {"frgion", "ms1_isotopes"} + +def node_is_excluded_from_aggregation(node): + """Return True for children removed by post-clustering correction steps.""" + return ( + getattr(node, "exclude_residual_decorrelation", False) + or getattr(node, "exclude_ptm_fragment_selection", False) + ) + + +def aggregate_node_properties(node, only_use_mainclust, peptide_outlier_filtering=False, aggregation_mode=DEFAULT_AGGREGATION_MODE): + """Aggregates differential-expression statistics from child nodes to a parent node. + + This is the core function for propagating statistics up the hierarchical tree + structure. It combines z-values, fold changes, and quality metrics from child + nodes (e.g. peptides) into parent node (e.g. protein) statistics using + Stouffer's Z-score method (``sum_and_re_scale_zvalues``). The aggregation can + optionally exclude proteoform variants (non-main clusters) and filter outlier + children. + + The z-values aggregated here originate from *per-ion differential-expression* + tests (null hypothesis: "no change between conditions for this ion"). They + are *not* the proteoform-similarity p-values computed in + ``find_fold_change_clusters``, which test a different null hypothesis + ("do two ions share the same fold change?") and are corrected separately + with Benjamini-Yekutieli *before* this function is called. In other words, + the Benjamini-Yekutieli correction applied during proteoform clustering and + the Stouffer aggregation performed here address independent statistical + questions and operate on different sets of p-values. Args: node: The parent node whose properties will be computed from its children @@ -33,8 +67,17 @@ def aggregate_node_properties(node, only_use_mainclust, peptide_outlier_filterin excluding proteoform variants peptide_outlier_filtering: If True and node is a protein, exclude peptides identified as statistical outliers (default: False) - fragment_outlier_filtering: If True and node is a peptide, exclude extreme - fragment ions before aggregation (default: True) + aggregation_mode: Strategy for combining child z-values at dependent levels + (frgion, ms1_isotopes). Higher levels always use Stouffer. + Can be a single string applied to all dependent levels, or a dict + mapping node types to modes (e.g. ``{"frgion": "stouffer_decorrelation", + "ms1_isotopes": "median_z"}``). Allowed mode strings: + "stouffer_decorrelation" - Stouffer's method used after residual decorrelation (default) + "mean_z" - arithmetic mean of z-values + "median_z" - median z-value + "min_median_max_z" - combine min, median, max z-values assuming independence + "min_max_z" - combine min, max z-values assuming independence (2-point summary) + "summed_z" - classic Stouffer assuming independence (rho=0, ignores ICC) Side effects: Sets node.z_val, node.p_val, node.fc, node.cv, node.min_intensity, @@ -42,11 +85,31 @@ def aggregate_node_properties(node, only_use_mainclust, peptide_outlier_filterin optionally node.ml_score based on aggregated child values. """ if only_use_mainclust: - childs = [x for x in node.children if x.is_included & (x.cluster ==0)] + childs = [ + x for x in node.children + if x.is_included & (x.cluster == 0) and not node_is_excluded_from_aggregation(x) + ] else: - childs = [x for x in node.children if x.is_included] + childs = [ + x for x in node.children + if x.is_included and not node_is_excluded_from_aggregation(x) + ] + + if len(childs) == 0: + if any(node_is_excluded_from_aggregation(x) for x in node.children): + # Residual decorrelation and PTM fragment selection together exhausted + # all eligible children. Keep the existing z-value so this node still + # contributes to higher levels — cascading the exclusion upward would + # double-penalise PTM data where both filters operate independently. + return + raise ValueError(f"Node {node.name!r} ({node.type}) has no eligible children to aggregate.") - childs_zfiltered = get_selected_nodes_for_zvalcalc(childs, peptide_outlier_filtering, node, fragment_outlier_filtering) + childs_zfiltered = get_selected_nodes_for_zvalcalc(childs, peptide_outlier_filtering, node) + + if len(childs_zfiltered) == 0: + if any(node_is_excluded_from_aggregation(x) for x in node.children): + return + raise ValueError(f"Node {node.name!r} ({node.type}) has no eligible children after filtering.") zvals = get_feature_numpy_array_from_nodes(nodes=childs_zfiltered, feature_name="z_val") @@ -64,9 +127,14 @@ def aggregate_node_properties(node, only_use_mainclust, peptide_outlier_filterin fraction_consistent = sum([x.fraction_consistent/len(node.children) for x in childs if x.cluster ==0]) - - - z_normed = sum_and_re_scale_zvalues(zvals) + rho = getattr(node, 'icc_correction', 0.0) + if node.type not in _DEPENDENT_NODE_TYPES: + effective_mode = DEFAULT_AGGREGATION_MODE + elif isinstance(aggregation_mode, dict): + effective_mode = aggregation_mode.get(node.type, DEFAULT_AGGREGATION_MODE) + else: + effective_mode = aggregation_mode + z_normed = combine_zvalues(zvals, rho=rho, mode=effective_mode) p_val = transform_znormed_to_pval(z_normed) p_val = set_bounds_for_p_if_too_extreme(p_val) @@ -125,7 +193,7 @@ def _select_peptides_around_median_z(peptide_nodes, max_peptides=31): return selected_peptides -def get_selected_nodes_for_zvalcalc(childs, peptide_outlier_filtering, node, fragment_outlier_filtering=True): +def get_selected_nodes_for_zvalcalc(childs, peptide_outlier_filtering, node): if peptide_outlier_filtering and node.type == "gene": filtered_childs = [x for x in childs if not x.is_outlier_peptide] # Additional restriction: if more than 31 peptides, keep only 31 closest to median z-value @@ -133,91 +201,67 @@ def get_selected_nodes_for_zvalcalc(childs, peptide_outlier_filtering, node, fra filtered_childs = _select_peptides_around_median_z(filtered_childs, max_peptides=31) return filtered_childs - elif fragment_outlier_filtering and node.type == "frgion": - return remove_outlier_fragion_childs(childs) - else: - return childs - + if node.type == "frgion": + filtered = childs + if aqvariables.ION_OUTLIER_MAD_THRESHOLD is not None: + filtered = _filter_ions_by_mad(filtered, aqvariables.ION_OUTLIER_MAD_THRESHOLD) + if aqvariables.MAX_N_FRAGMENTS is not None and len(filtered) > aqvariables.MAX_N_FRAGMENTS: + filtered = _select_peptides_around_median_z(filtered, max_peptides=aqvariables.MAX_N_FRAGMENTS) + if aqvariables.CLASSIC_FRAGMENT_OUTLIER_FILTERING: + filtered = remove_outlier_fragion_childs(filtered) + return filtered + return childs -def filter_fewpeps_per_protein(peptide_nodes): - peps_filtered = [] - pepnode2zval2numleaves = [] - for pepnode in peptide_nodes: - pepleaves = [x for x in pepnode.leaves if "seq" in getattr(x,"inclusion_levels", [])] - pepnode2zval2numleaves.append((pepnode, pepnode.z_val,len(pepleaves))) - pepnode2zval2numleaves = sorted(pepnode2zval2numleaves, key=lambda x : abs(x[1])) #sort with lowest absolute z-val (least significant) first +def _filter_ions_by_mad(nodes, threshold): + """Remove ion nodes whose z-value is a MAD-outlier among siblings. - return get_median_peptides(pepnode2zval2numleaves) - -def filter_outlier_peptides_old(peptide_nodes, fraction_highly_significant): - """ - Filters outlier peptides based on p-value significance. - - Checks if there's a minority of peptides (<40%) that has substantially more - significant p-values (at least a factor of 5) compared to the median. - Only starts checking if the median p-value is 0.05 or higher. - If this minority case exists, returns only the less significant half of peptides. + Requires at least 4 nodes to attempt filtering, and always retains + at least 2 nodes (the two closest to the median). Args: - peptide_nodes: List of peptide nodes with p_val attributes + nodes: List of child nodes with z_val attributes. + threshold: Number of scaled-MAD units beyond which a node is + considered an outlier (e.g. 3.0). Returns: - Filtered list of peptide nodes + Filtered list of nodes with outliers removed. """ - if len(peptide_nodes) < 4: - return peptide_nodes + if len(nodes) < 4: + return nodes - # Get p-values from peptide nodes - p_values = [node.p_val for node in peptide_nodes] - median_p_val = np.median(p_values) + z_vals = np.array([n.z_val for n in nodes]) + median_z = float(np.median(z_vals)) - # Only check for outliers if median p-value is 0.05 or higher - if median_p_val < 0.05: - return peptide_nodes + robust_std = _robust_std_estimate(z_vals, median_z) + if robust_std == 0: + return nodes - # Check for minority with substantially more significant p-values - threshold_p_val = median_p_val / 5.0 # at least 5x more significant (lower p-value) - highly_significant_nodes = [node for node in peptide_nodes if node.p_val <= threshold_p_val] - remaining_nodes = [node for node in peptide_nodes if node.p_val > threshold_p_val] + cutoff = threshold * robust_std + kept = [n for n in nodes if abs(n.z_val - median_z) <= cutoff] - # Check if this is a minority (<40%) - if len(highly_significant_nodes) / len(peptide_nodes) < 0.3: - return _filter_minority_highly_significant(highly_significant_nodes, remaining_nodes, fraction_highly_significant) + if len(kept) < 2: + kept = sorted(nodes, key=lambda n: abs(n.z_val - median_z))[:2] - return peptide_nodes + return kept -def _filter_minority_highly_significant_old(highly_significant_nodes, remaining_nodes, fraction_highly_significant): - """ - Handle filtering when highly significant nodes are a minority (<40%). +def _robust_std_estimate(values, median): + """Estimate std via scaled MAD, falling back to IQR if MAD is zero.""" + MAD_SCALE = 1.4826 # makes MAD consistent with std for normal data + IQR_SCALE = 1.349 # makes IQR consistent with std for normal data - Args: - highly_significant_nodes: Nodes with p-value <= threshold_p_val - remaining_nodes: All peptide nodes - threshold_p_val: The p-value threshold used to identify highly significant nodes - fraction_highly_significant: Global fraction of highly significant ions + mad = float(np.median(np.abs(values - median))) + if mad > 0: + return MAD_SCALE * mad - Returns: - Filtered list of peptide nodes to exclude for analysis - """ - # if len(highly_significant_nodes) == 1: - # return highly_significant_nodes+remaining_nodes - # Calculate how many highly significant nodes to exclude - num_to_exclude = int(len(highly_significant_nodes) * (fraction_highly_significant / 0.08)) - num_to_exclude_bounded = max(1, min(len(highly_significant_nodes)-1, num_to_exclude)) + q75, q25 = np.percentile(values, [75, 25]) + iqr = float(q75 - q25) + if iqr > 0: + return iqr / IQR_SCALE - # Sort by p-value (most significant first) and exclude the best ones - highly_significant_nodes_sorted = sorted(highly_significant_nodes, key=lambda x: x.p_val) - nodes_to_keep = highly_significant_nodes_sorted[num_to_exclude_bounded:] #keep the least significant ones - return nodes_to_keep + remaining_nodes + return 0.0 import math -def get_median_peptides(pepnode2zval2numleaves): #least significant peptides are sorted first - median_idx = math.floor(len(pepnode2zval2numleaves)/2) - if len(pepnode2zval2numleaves)<3: - return [x[0] for x in pepnode2zval2numleaves] - else: - return [x[0] for x in pepnode2zval2numleaves[:median_idx+1]] def remove_outlier_fragion_childs(childs): """Filters extreme fragment ions before aggregating to peptide level. @@ -237,15 +281,7 @@ def remove_outlier_fragion_childs(childs): list: Filtered subset of fragment ion nodes to use for aggregation """ zvals = get_feature_numpy_array_from_nodes(nodes=childs, feature_name="z_val") - if aqvariables.PTM_FRAGMENT_SELECTION: - sorted_idxs_zvals = np.argsort(np.abs(zvals)) - median_idx = math.floor(len(zvals)/2) - median_idx = 7 if median_idx > 7 else median_idx - if median_idx < len(sorted_idxs_zvals): - idxs_to_use = sorted_idxs_zvals[:median_idx+1] - else: - idxs_to_use = sorted_idxs_zvals - elif len(zvals) > 4: + if len(zvals) > 4: sorted_idxs_zvals = np.argsort(zvals) median_idx = math.floor(len(zvals)/2) idx_start = median_idx - 2 @@ -263,16 +299,153 @@ def remove_outlier_fragion_childs(childs): return [childs[idx] for idx in idxs_to_use] -def sum_and_re_scale_zvalues(zvals): +def apply_ptm_fragment_selection(protnodes, max_keep=8): + """Apply PTM low-|Z| fragment filtering after residual decorrelation. + + For every ``frgion`` parent, currently eligible base-ion children are sorted + by ``abs(z_val)``. The least extreme children through the median rank are + retained, capped by ``max_keep``. The default ``max_keep=8`` matches the + legacy PTM fragment-selection rule in ``remove_outlier_fragion_childs``. + + Returns: + tuple[int, int]: ``(children_dropped, parents_touched)``. + """ + try: + max_keep = int(max_keep) + except (TypeError, ValueError): + max_keep = 8 + max_keep = max(1, max_keep) + + n_dropped = 0 + n_touched = 0 + for protnode in protnodes: + for parent in anytree.PreOrderIter(protnode): + if getattr(parent, "type", None) != "frgion": + continue + eligible = [ + child for child in parent.children + if child.is_included and not node_is_excluded_from_aggregation(child) + ] + n_live = len(eligible) + if n_live <= 1: + continue + + keep_n = min((n_live // 2) + 1, max_keep, n_live) + if keep_n >= n_live: + continue + + zvals = get_feature_numpy_array_from_nodes( + nodes=eligible, feature_name="z_val") + order = np.argsort(np.abs(zvals)) + kept = {eligible[idx] for idx in order[:keep_n]} + dropped_here = 0 + for child in eligible: + keep = child in kept + child.exclude_ptm_fragment_selection = not keep + if not keep: + child.is_outlier_fragment = True + dropped_here += 1 + elif not getattr(child, "is_outlier_fragment", False): + child.is_outlier_fragment = False + if dropped_here: + n_dropped += dropped_here + n_touched += 1 + + return n_dropped, n_touched + + +def combine_zvalues(zvals, rho=0.0, mode=DEFAULT_AGGREGATION_MODE): + """Dispatch function that selects the z-value combination strategy. + + Args: + zvals: Array or list of z-values to combine + rho: Correlation design-effect parameter for Stouffer aggregation. + mode: One of AGGREGATION_MODES + + Returns: + float: Combined z-value on a standard normal scale + """ + if mode not in AGGREGATION_MODES and mode not in _LEGACY_AGGREGATION_MODES: + raise ValueError(f"Unknown aggregation mode: {mode!r}. Choose from {AGGREGATION_MODES}") + + if len(zvals) == 1: + return zvals[0] + + if mode in ("stouffer_decorrelation", "stouffer_icc"): + return sum_and_re_scale_zvalues(zvals, rho=rho) + elif mode == "mean_z": + return _combine_mean_z(zvals) + elif mode == "median_z": + return _combine_median_z(zvals) + elif mode == "min_median_max_z": + return _combine_min_median_max_z(zvals) + elif mode == "min_max_z": + return _combine_min_max_z(zvals) + elif mode == "summed_z": + return _combine_summed_z(zvals) + + +def _combine_mean_z(zvals): + """Arithmetic mean of z-values — treats children as a single effective measurement.""" + return float(np.mean(zvals)) + + +def _combine_median_z(zvals): + """Median z-value — robust to outlier children.""" + return float(np.median(zvals)) + + +def _combine_min_median_max_z(zvals): + """Pick min, median, max z-values and combine via Stouffer assuming independence. + + Provides a 3-point summary that captures the full spread of evidence. + For n <= 3, falls back to Stouffer on all values (the summary would be + the full set anyway). + """ + if len(zvals) <= 3: + return sum_and_re_scale_zvalues(zvals, rho=0.0) + z_min = float(np.min(zvals)) + z_med = float(np.median(zvals)) + z_max = float(np.max(zvals)) + return sum_and_re_scale_zvalues(np.array([z_min, z_med, z_max]), rho=0.0) + + +def _combine_min_max_z(zvals): + """Pick min and max z-values and combine via Stouffer assuming independence. + + A 2-point summary that captures the extreme spread of evidence. + For n <= 2, falls back to Stouffer on all values. + """ + if len(zvals) <= 2: + return sum_and_re_scale_zvalues(zvals, rho=0.0) + z_min = float(np.min(zvals)) + z_max = float(np.max(zvals)) + return sum_and_re_scale_zvalues(np.array([z_min, z_max]), rho=0.0) + + +def _combine_summed_z(zvals): + """Classic Stouffer combination assuming full independence (rho=0). + + Unlike the Stouffer modes, this ignores any estimated ICC correction + and always treats child z-values as independent. + """ + return sum_and_re_scale_zvalues(zvals, rho=0.0) + + +def sum_and_re_scale_zvalues(zvals, rho=0.0): """Combines multiple z-values into a single aggregated z-value using Stouffer's method. This implements Stouffer's Z-score method for meta-analysis: z-values are summed - and divided by sqrt(n) to account for the number of tests. The result is then - rescaled back to a standard normal distribution. This allows combining evidence - from multiple ions/peptides while maintaining proper statistical interpretation. + and divided by sqrt(n * DEFF) to account for both the number of tests and their + correlation. The design effect DEFF = 1 + (n-1) * rho corrects for the fact that + correlated z-values carry less independent information than n truly independent ones. + The result is then rescaled back to a standard normal distribution. Args: zvals: Array or list of z-values to combine + rho: Intraclass correlation (ICC) among the z-values. 0.0 assumes independence + (classic Stouffer), higher values produce more conservative (less significant) + combined z-values. Returns: float: Combined z-value following a standard normal distribution under the null @@ -280,8 +453,10 @@ def sum_and_re_scale_zvalues(zvals): if len(zvals) == 1: return zvals[0] # No aggregation needed for single values - avoids floating-point precision errors + n = len(zvals) + deff = 1.0 + (n - 1) * rho # design effect: inflated variance due to intra-group correlation z_sum = sum(zvals) - p_z = NormalDist(mu = 0, sigma = np.sqrt(len(zvals))).cdf(z_sum) + p_z = NormalDist(mu = 0, sigma = np.sqrt(n * deff)).cdf(z_sum) p_z = set_bounds_for_p_if_too_extreme(p_z) z_normed = NormalDist(mu = 0, sigma=1).inv_cdf(p_z) #this is just a re-scaling of the z-value to a standard normal distribution return z_normed @@ -306,64 +481,6 @@ def set_bounds_for_p_if_too_extreme(p_val): else: return p_val -def calc_fold_change_from_included_leaves_fcs(node): - included_leaves = obtain_all_included_leaves(node) - list_of_fcs = [x.fcs for x in included_leaves] - merged_fcs = np.concatenate(list_of_fcs) - return np.median(merged_fcs) - -def calc_weighted_fold_change_from_included_leaves_fcs(node): - included_leaves = obtain_all_included_leaves(node) - list_of_fcs = [x.fcs for x in included_leaves] - weights = [get_weight_of_leaf(x) for x in included_leaves] - weighted_median = calculate_weighted_median(weights, list_of_fcs) - return weighted_median - -def get_weight_of_leaf(leaf): - if hasattr(leaf, "ml_score_fragion"): - return 2**-leaf.ml_score_fragion - else: - return 1 - -def calculate_weighted_median(weights, fcs): - weighted_fcs = [(fc, weight) for weight, fc_list in zip(weights, fcs) for fc in fc_list] - sorted_weighted_fcs = sorted(weighted_fcs, key=lambda x: x[0]) - sorted_fcs, sorted_weights = zip(*sorted_weighted_fcs) - cumulative_weights = np.cumsum(sorted_weights) - total_weight = cumulative_weights[-1] - median_cutoff = total_weight / 2 - median_idx = np.where(cumulative_weights >= median_cutoff)[0][0] - weighted_median = sorted_fcs[median_idx] - return weighted_median - -def obtain_all_included_leaves(node): - list_of_included_leaves = [] - traverse_and_add_included_leaves(node, list_of_included_leaves) - return list_of_included_leaves - -def traverse_and_add_included_leaves(node, list_of_included_leaves, is_root=True): - """ - Recursively searches for leaves from the given node, where each node in the - path to the leaf has the 'is_included' attribute set to True, except for the initial node. - Fills up the list_of_included_leaves with the included leaves. - - Parameters: - node (anytree.Node): The node to start the search from. - list_of_included_leaves (list): The list to store the included leaves in. - is_root (bool): Indicates if the current node is the root node of the traversal. - """ - - if len(node.children) == 0: # if the node is a leaf - if is_root or (node.is_included and node.cluster == 0): - list_of_included_leaves.append(node) - return - - # If it's the root node or if the current node is included, then proceed to its children - if is_root or (node.is_included and node.cluster == 0): - for child in node.children: - # Recursive call with is_root set to False, as we are now dealing with child nodes - traverse_and_add_included_leaves(child, list_of_included_leaves, is_root=False) - def sum_ml_scores(ml_scores): abs_ml_scores = [abs(x) for x in ml_scores] return sum(abs_ml_scores) @@ -382,15 +499,6 @@ def get_grouped_mainclust_leafs(child_nodes): grouped_leafs.append(child_leaves_mainclust) return grouped_leafs -def select_highid_lowcv_leafs(grouped_leafs): - grouped_leafs_lowcv = [] - for leafs in grouped_leafs: - top_quantile_idx = math.ceil(len(leafs) * 0.2) - leafs_repsorted = sorted(leafs, key = lambda x : x.min_reps)[:top_quantile_idx] - leafs_repsorted_cvsorted = sorted(leafs_repsorted, key = lambda x : x.cv) - grouped_leafs_lowcv.append([leafs_repsorted_cvsorted[0]]) - return grouped_leafs_lowcv - def select_median_fc_leafs(grouped_leafs): grouped_leafs_medianfc = [] for leafs in grouped_leafs: @@ -566,22 +674,6 @@ def remove_unnecessary_attributes(node, attributes_to_remove): import os -def get_nodes_of_type(cond1, cond2, results_folder, node_type = 'mod_seq_charge'): - - tree_sn = aqutils.read_condpair_tree(cond1, cond2, results_folder=results_folder) - tree_sn.type = "asd" - return anytree.findall(tree_sn, filter_= lambda x : (x.type == node_type)) - - - -def get_levelnodes_from_nodeslist(nodeslist, level): - levelnodes = [] - for node in nodeslist: - precursors = anytree.findall(node, filter_= lambda x : (x.type == level)) - levelnodes.extend(precursors) - return levelnodes - - def find_node_parent_at_level(node, level): if node.type == level: return node @@ -609,13 +701,6 @@ def shorten_root_to_level(root, parent_level): -def get_parent2children_dict(tree, parent_level): - parent2children = {} - parent_nodes = anytree.search.findall(tree, filter_=lambda node: node.level == parent_level) - for parent_node in parent_nodes: - parent2children[parent_node.name] = [child.name for child in parent_node.children] - return parent2children - def get_parent2leaves_dict(protein): """Returns a dict that maps the parent node name to the names of the leaves of the parent node """ diff --git a/alphaquant/cluster/icc_correction.py b/alphaquant/cluster/icc_correction.py new file mode 100644 index 00000000..be186e2d --- /dev/null +++ b/alphaquant/cluster/icc_correction.py @@ -0,0 +1,521 @@ +"""Data-driven ICC (intraclass correlation) correction for Stouffer z-value aggregation. + +Estimates a global ICC at every aggregation level of the protein tree and +annotates each node with an ``icc_correction`` attribute so that +``aggregate_node_properties`` can apply a design-effect correction. + +The tree levels processed (bottom-to-top): + - ``frgion``: correlation among fragment ions within the same precursor + - ``ms1_isotopes``: correlation among MS1 isotopes within the same precursor + - ``mod_seq_charge``: correlation among ion families within the same precursor + - ``mod_seq``: correlation among charge states within the same modified peptide + - ``seq``: correlation among modification variants within the same sequence + - ``gene``: correlation among peptides within the same protein + +A protein contributes to the null distribution only when *both* its +protein-level p-value and the p-values of its individual group nodes +exceed the significance threshold. This two-level filter ensures the ICC +is estimated purely from technical noise, free of biological signal at +any level. + +The resulting null-median ICC is applied uniformly to all proteins. +Per-protein estimation is deliberately avoided: the ICC captures a +property of the *measurement* (shared chromatographic / detector effects), +not of the individual protein, and per-protein estimates are too noisy +to be reliable given typical group sizes. + +Processing proceeds bottom-to-top: after each level's ICC is estimated +and applied, trees are re-aggregated so that higher levels see the +corrected z-values from below. +""" + +import anytree +import numpy as np +import matplotlib.pyplot as plt + +import alphaquant.cluster.cluster_utils as aqcluster_utils + +import alphaquant.config.config as aqconfig +import alphaquant.config.variables as aqvariables +import logging +aqconfig.setup_logging() +LOGGER = logging.getLogger(__name__) + +# Node types that receive ICC correction, ordered bottom-to-top. +_ICC_NODE_TYPES = ("frgion", "ms1_isotopes", "mod_seq_charge", "mod_seq", "seq", "gene") + +_MIN_GROUPS = 3 # minimum number of group nodes for a reliable ICC estimate +_MIN_IONS = 6 # minimum total number of child items across those groups + +# Null protein selection threshold is read at runtime from +# aqvariables.ICC_NULL_PVAL_THRESHOLD (default 0.1, configurable via +# run_pipeline icc_null_pval_threshold argument). + +_N_PERMUTATIONS = 3 + +# Gene-level subsampled estimation: each draw picks a random subset of +# null proteins and computes ICC treating each protein as a group. +_GENE_SUBSAMPLE_SIZE = 5 +_GENE_N_DRAWS = 500 + +_NODE_TYPE_LABELS = { + "frgion": "Fragment ion", + "ms1_isotopes": "MS1 isotope", + "mod_seq_charge": "Ion family", + "mod_seq": "Precursor", + "seq": "Mod. peptide", + "gene": "Peptide", +} + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + +def estimate_and_apply_icc_correction(protnodes, runtime_plots=False, aggregation_mode="stouffer_decorrelation"): + """Estimate a global ICC at every tree level, apply uniformly, and re-aggregate. + + Processes levels bottom-to-top. After each level's ICC is estimated and + assigned, all trees are re-aggregated so that the next (higher) level + sees corrected z-values from below. + + Args: + protnodes: list of protein root nodes (anytree.Node) + runtime_plots: if True, show ICC distribution plots for all levels + aggregation_mode: z-value combination strategy forwarded to re-aggregation + """ + if not protnodes: + return + + all_level_results = {} + + for node_type in _ICC_NODE_TYPES: + if not _has_node_type(protnodes, node_type): + continue + + LOGGER.info(f"ICC correction: estimating for node type '{node_type}'") + + if node_type == "gene": + null_iccs, perm_iccs, icc_median = _estimate_gene_level_icc(protnodes) + else: + null_iccs, perm_iccs, icc_median = _estimate_null_icc_distribution( + protnodes, node_type + ) + + n_annotated = _assign_icc_to_all_proteins(protnodes, node_type, icc_median) + + obs_med = float(np.median(null_iccs)) if null_iccs else 0.0 + perm_med = float(np.median(perm_iccs)) if perm_iccs else 0.0 + LOGGER.info( + f"ICC correction ({node_type}): applied={icc_median:.4f} " + f"(obs={obs_med:.4f} - perm={perm_med:.4f}), " + f"n_null={len(null_iccs)}, n_perm={len(perm_iccs)}, " + f"n_annotated={n_annotated}" + ) + + all_level_results[node_type] = (null_iccs, perm_iccs, icc_median) + + _re_aggregate_trees(protnodes, aggregation_mode=aggregation_mode) + + if runtime_plots and all_level_results: + _plot_icc_distributions_all_levels(all_level_results) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _has_node_type(protnodes, node_type): + """Return True if at least one protein tree contains a node of *node_type*.""" + for prot in protnodes: + matches = anytree.search.findall( + prot, filter_=lambda n: n.type == node_type + ) + if matches: + return True + return False + + +def _estimate_null_icc_distribution(protnodes, node_type): + """Compute observed and permutation-based ICC for each null protein. + + A protein qualifies as null only when *both*: + - its protein-level p_val > aqvariables.ICC_NULL_PVAL_THRESHOLD, *and* + - only group nodes whose own p_val > threshold are used. + + Returns: + (null_iccs, permutation_iccs, icc_median). + If no null proteins qualify, returns ([], [], 0.0). + """ + null_iccs = [] + null_group_zvals_list = [] + + pval_threshold = aqvariables.get_icc_null_pval_threshold(node_type) + for prot in protnodes: + if not hasattr(prot, "p_val"): + continue + if prot.p_val <= pval_threshold: + continue + + group_zvals = _collect_group_zvals(prot, node_type, node_p_val_threshold=pval_threshold) + if not group_zvals: + continue + + icc = _compute_icc_from_groups(group_zvals) + if icc is not None: + null_iccs.append(icc) + null_group_zvals_list.append(group_zvals) + + if len(null_iccs) == 0: + LOGGER.warning( + f"ICC correction ({node_type}): no null proteins qualified; " + "falling back to icc_median=0.0" + ) + return [], [], 0.0 + + permutation_iccs = _compute_permutation_null(null_group_zvals_list) + + icc_median = _null_normalized_icc(null_iccs, permutation_iccs) + return null_iccs, permutation_iccs, icc_median + + +def _estimate_gene_level_icc(protnodes, seed=42): + """Estimate ICC at the protein (gene) level using subsampled estimation. + + Since each protein has exactly one gene node, per-protein ICC is + undefined. Instead, we subsample groups of null proteins and compute + ICC treating each protein as a group with its peptide z-values as items. + + The permutation null shuffles all z-values across proteins (destroying + the protein-level grouping) and repeats the same subsampled estimation. + + Returns: + (null_iccs, permutation_iccs, icc_median). + """ + protein_groups = [] + gene_threshold = aqvariables.get_icc_null_pval_threshold("gene") + for prot in protnodes: + if not hasattr(prot, "p_val") or prot.p_val <= gene_threshold: + continue + children_zvals = [c.z_val for c in prot.children if hasattr(c, "z_val")] + if len(children_zvals) >= 2: + protein_groups.append(np.array(children_zvals)) + + if len(protein_groups) < _GENE_SUBSAMPLE_SIZE: + LOGGER.warning( + f"ICC correction (gene): only {len(protein_groups)} null proteins " + f"with ≥2 peptides (need {_GENE_SUBSAMPLE_SIZE}); falling back to 0.0" + ) + return [], [], 0.0 + + rng = np.random.RandomState(seed) + null_iccs = _subsampled_icc(protein_groups, rng) + + sizes = [len(g) for g in protein_groups] + pooled = np.concatenate(protein_groups) + rng_perm = np.random.RandomState(seed + 1) + rng_perm.shuffle(pooled) + shuffled_groups = [] + offset = 0 + for s in sizes: + shuffled_groups.append(pooled[offset:offset + s].copy()) + offset += s + perm_iccs = _subsampled_icc(shuffled_groups, np.random.RandomState(seed)) + + if not null_iccs: + LOGGER.warning("ICC correction (gene): no valid subsamples; falling back to 0.0") + return [], [], 0.0 + + icc_median = _null_normalized_icc(null_iccs, perm_iccs) + return null_iccs, perm_iccs, icc_median + + +def _subsampled_icc(protein_groups, rng, + subset_size=_GENE_SUBSAMPLE_SIZE, + n_draws=_GENE_N_DRAWS): + """Compute ICC on random subsets of protein groups. + + Each draw selects *subset_size* proteins, treats each as a group + (with its peptide z-values as items), and computes one ICC value. + """ + iccs = [] + n_proteins = len(protein_groups) + for _ in range(n_draws): + indices = rng.choice(n_proteins, size=subset_size, replace=False) + group_zvals = [protein_groups[i] for i in indices] + icc = _compute_icc_from_groups(group_zvals) + if icc is not None: + iccs.append(icc) + return iccs + + +def _compute_permutation_null(group_zvals_list, n_permutations=_N_PERMUTATIONS, seed=42): + """Compute ICC on shuffled data to establish a permutation baseline. + + For each protein's grouped z-values, all values are pooled and randomly + reassigned to groups of the original sizes. This destroys any real + intra-group correlation, so the resulting ICC distribution reflects pure + sampling noise. + """ + rng = np.random.RandomState(seed) + perm_iccs = [] + + for group_zvals in group_zvals_list: + sizes = [len(g) for g in group_zvals] + pooled = np.concatenate(group_zvals) + + for _ in range(n_permutations): + rng.shuffle(pooled) + shuffled_groups = [] + offset = 0 + for s in sizes: + shuffled_groups.append(pooled[offset:offset + s].copy()) + offset += s + + icc = _compute_icc_from_groups(shuffled_groups) + if icc is not None: + perm_iccs.append(icc) + + return perm_iccs + + +def _null_normalized_icc(null_iccs, perm_iccs): + """Return the null-normalized ICC: median(observed) - median(shuffled), clamped to ≥0.""" + obs_med = float(np.median(null_iccs)) + perm_med = float(np.median(perm_iccs)) if len(perm_iccs) > 0 else 0.0 + return max(0.0, obs_med - perm_med) + + +def _assign_icc_to_all_proteins(protnodes, node_type, icc_median): + """Assign the global null-median ICC uniformly to every node of *node_type*. + + Returns: + Number of nodes annotated. + """ + n_annotated = 0 + for prot in protnodes: + target_nodes = anytree.search.findall( + prot, filter_=lambda n: n.type == node_type + ) + for node in target_nodes: + node.icc_correction = icc_median + n_annotated += 1 + return n_annotated + + +def _collect_group_zvals(protein_node, node_type, node_p_val_threshold=None): + """Extract grouped z-values from a protein tree. + + For each node of *node_type*, collects the z-values of its immediate + children. At the ion-type level (frgion, ms1_isotopes) the children + are base ions; at higher levels they are aggregated child nodes. + + Returns: + list[np.ndarray]: One array of child z-values per qualifying group + node, or an empty list if insufficient data. + """ + group_nodes = anytree.search.findall( + protein_node, filter_=lambda n: n.type == node_type + ) + if not group_nodes: + return [] + + if node_p_val_threshold is not None: + group_nodes = [ + n for n in group_nodes + if hasattr(n, "p_val") and n.p_val > node_p_val_threshold + ] + + group_zvals = [] + for gnode in group_nodes: + children = [ + c for c in gnode.children + if hasattr(c, "z_val") + ] + if children: + group_zvals.append( + np.array([c.z_val for c in children]) + ) + + n_groups = len(group_zvals) + if n_groups < _MIN_GROUPS: + return [] + + n_items = sum(len(g) for g in group_zvals) + if n_items < _MIN_IONS: + return [] + + return group_zvals + + +def _compute_icc_from_groups(group_zvals): + """Compute one-way random-effects ICC from pre-collected grouped z-values. + + Returns: + ICC (float) or None if degenerate. + """ + n_groups = len(group_zvals) + if n_groups < _MIN_GROUPS: + return None + + all_vals = np.concatenate(group_zvals) + n_items = len(all_vals) + if n_items < _MIN_IONS: + return None + + grand_mean = all_vals.mean() + + ss_within = 0.0 + group_sizes = np.empty(n_groups, dtype=int) + group_means = np.empty(n_groups) + for i, gv in enumerate(group_zvals): + gm = gv.mean() + group_means[i] = gm + group_sizes[i] = len(gv) + ss_within += np.sum((gv - gm) ** 2) + + df_within = n_items - n_groups + if df_within < 1: + return None + sigma2_residual = ss_within / df_within + + expanded_group_means = np.repeat(group_means, group_sizes) + ss_between = np.sum((expanded_group_means - grand_mean) ** 2) + df_between = n_groups - 1 + ms_between = ss_between / df_between + + n0 = (n_items - np.sum(group_sizes ** 2) / n_items) / df_between + if n0 < 1e-15: + return None + + sigma2_group = (ms_between - sigma2_residual) / n0 + sigma2_group = max(sigma2_group, 0.0) + + total = sigma2_group + sigma2_residual + if total < 1e-15: + return None + + return sigma2_group / total + + +def _compute_icc_from_tree(protein_node, node_type, node_p_val_threshold=None): + """Convenience wrapper: extract groups from tree, then compute ICC.""" + group_zvals = _collect_group_zvals(protein_node, node_type, node_p_val_threshold) + if not group_zvals: + return None + return _compute_icc_from_groups(group_zvals) + + +def _re_aggregate_trees(protnodes, aggregation_mode="stouffer_decorrelation"): + """Re-aggregate all protein trees bottom-to-top after ICC annotation. + + Walks every tree from leaves upward, re-computing z-values and p-values + at each level so that the newly annotated ``icc_correction`` attributes + take effect. + """ + for prot in protnodes: + for level_nodes in aqcluster_utils.iterate_through_tree_levels_bottom_to_top(prot): + node_types = list(set(node.type for node in level_nodes)) + if node_types == ["base"]: + continue + for node in level_nodes: + if node.children: + aqcluster_utils.aggregate_node_properties( + node, only_use_mainclust=True, peptide_outlier_filtering=False, + aggregation_mode=aggregation_mode + ) + + +# --------------------------------------------------------------------------- +# Plotting +# --------------------------------------------------------------------------- + +def _plot_icc_distributions_all_levels(all_level_results): + """Multi-row figure with histogram + CDF per level. + + Each row corresponds to one tree level. Left panel shows histograms of + observed-null and permutation-null ICC distributions; right panel shows + cumulative distribution functions. A vertical line marks the applied + median ICC. + + Args: + all_level_results: dict mapping node_type → + (null_iccs, perm_iccs, icc_median). + """ + ordered_types = [nt for nt in _ICC_NODE_TYPES if nt in all_level_results] + display_order = list(reversed(ordered_types)) + n_levels = len(display_order) + if n_levels == 0: + return + + fig, axes = plt.subplots(n_levels, 2, figsize=(7.5, 1.8 * n_levels), + squeeze=False) + fig.subplots_adjust(hspace=0.65, wspace=0.35, top=0.96, bottom=0.06, + left=0.12, right=0.98) + + color_obs = "#4C72B0" + color_perm = "#DD8452" + bins = np.linspace(0, 1, 40) + + for row, node_type in enumerate(display_order): + null_iccs, perm_iccs, icc_median = all_level_results[node_type] + ax_hist, ax_cdf = axes[row] + label = _NODE_TYPE_LABELS.get(node_type, node_type) + + legend_handles = [] + + if len(perm_iccs) > 0: + perm_med = float(np.median(perm_iccs)) + ax_hist.hist(perm_iccs, bins=bins, alpha=0.45, color=color_perm, + edgecolor="none") + s = np.sort(perm_iccs) + ax_cdf.plot(s, np.arange(1, len(s) + 1) / len(s), color=color_perm) + legend_handles.append( + plt.matplotlib.patches.Patch( + facecolor=color_perm, alpha=0.6, + label=f"shuffled (med={perm_med:.2f})") + ) + + if len(null_iccs) > 0: + obs_med = float(np.median(null_iccs)) + ax_hist.hist(null_iccs, bins=bins, alpha=0.45, color=color_obs, + edgecolor="none") + s = np.sort(null_iccs) + ax_cdf.plot(s, np.arange(1, len(s) + 1) / len(s), color=color_obs) + legend_handles.append( + plt.matplotlib.patches.Patch( + facecolor=color_obs, alpha=0.6, + label=f"observed (med={obs_med:.2f})") + ) + + ax_cdf.axvline(0.5, color="gray", linestyle=":", linewidth=0.6) + + for ax in (ax_hist, ax_cdf): + ax.set_xlim(-0.05, 1.05) + ax.tick_params(axis='both', which='both', length=3, pad=2) + + ax_hist.set_ylabel("count") + ax_cdf.set_ylabel("cum. fraction") + ax_cdf.set_ylim(-0.02, 1.02) + ax_cdf.set_yticks([0.0, 0.5, 1.0]) + + if row == n_levels - 1: + ax_hist.set_xlabel("ICC") + ax_cdf.set_xlabel("ICC") + else: + ax_hist.set_xticklabels([]) + ax_cdf.set_xticklabels([]) + + ax_hist.legend(handles=legend_handles, loc='upper right', + fontsize=6, frameon=False, handlelength=1, + handletextpad=0.4, borderpad=0.2) + + pos_l = ax_hist.get_position() + pos_r = ax_cdf.get_position() + x_mid = (pos_l.x0 + pos_r.x1) / 2 + fig.text(x_mid, pos_l.y1 + 0.008, label, + ha='center', va='bottom', fontsize=10, fontweight='bold') + + fig.suptitle("ICC distributions across tree levels", fontsize=12, + fontweight='bold', y=0.995) + plt.show() diff --git a/alphaquant/cluster/ml_reorder.py b/alphaquant/cluster/ml_reorder.py index 0a0d75de..f4546a96 100644 --- a/alphaquant/cluster/ml_reorder.py +++ b/alphaquant/cluster/ml_reorder.py @@ -3,7 +3,7 @@ import anytree -def update_nodes_w_ml_score(protnodes : list[anytree.Node]): +def update_nodes_w_ml_score(protnodes : list[anytree.Node], aggregation_mode="stouffer_decorrelation"): """ Update and re-order clusters within protein nodes based on ML scores. @@ -13,16 +13,17 @@ def update_nodes_w_ml_score(protnodes : list[anytree.Node]): Args: protnodes (list[anytree.Node]): A list of protein nodes to be processed. + aggregation_mode: Strategy for combining child z-values during re-aggregation. Returns: None """ for prot in protnodes: - _re_order_depending_on_ml_score(prot) + _re_order_depending_on_ml_score(prot, aggregation_mode=aggregation_mode) -def _re_order_depending_on_ml_score(protnode : anytree.Node): +def _re_order_depending_on_ml_score(protnode : anytree.Node, aggregation_mode="stouffer_decorrelation"): """ Reorder clusters in a protein node tree based on machine learning scores. @@ -38,6 +39,7 @@ def _re_order_depending_on_ml_score(protnode : anytree.Node): Args: protnode (anytree.Node): The protein node to be processed. + aggregation_mode: Strategy for combining child z-values during re-aggregation. Returns: None @@ -51,13 +53,18 @@ def _re_order_depending_on_ml_score(protnode : anytree.Node): if len(type_nodes)==0: continue for type_node in type_nodes: #go through the nodes, re-order the children. Propagate the values from the newly ordered children to the type node - child_nodes = type_node.children - had_ml_score = hasattr(child_nodes[0], 'ml_score') + child_nodes = [ + x for x in type_node.children + if x.is_included and not aqcluster_utils.node_is_excluded_from_aggregation(x) + ] + if len(child_nodes) == 0: + continue + had_ml_score = all(hasattr(child, 'ml_score') for child in child_nodes) if had_ml_score: clust2newclust = _get_clust2newclust(child_nodes) _re_assign_proteoform_stats(child_nodes, clust2newclust) _re_order_clusters_by_ml_score(child_nodes, clust2newclust) - aqcluster_utils.aggregate_node_properties(type_node,only_use_mainclust=True, peptide_outlier_filtering=False) + aqcluster_utils.aggregate_node_properties(type_node,only_use_mainclust=True, peptide_outlier_filtering=False, aggregation_mode=aggregation_mode) def _get_clust2newclust(nodes: list[anytree.Node]) -> dict[int, int]: @@ -162,6 +169,3 @@ def _re_order_clusters_by_ml_score(nodes : list[anytree.Node], clust2newclust : """ for node in nodes: node.cluster =clust2newclust.get(node.cluster) - - - diff --git a/alphaquant/cluster/outlier_filtering.py b/alphaquant/cluster/outlier_filtering.py index 7ac38025..d6e8f1d8 100644 --- a/alphaquant/cluster/outlier_filtering.py +++ b/alphaquant/cluster/outlier_filtering.py @@ -1,13 +1,14 @@ import alphaquant.cluster.cluster_utils as aqcluster_utils +import alphaquant.config.variables as aqvariables import anytree import numpy as np -def apply_peptide_outlier_filtering(protnodes: list[anytree.Node]): +def apply_peptide_outlier_filtering(protnodes: list[anytree.Node], aggregation_mode="stouffer_decorrelation"): regulation_score = calculate_regulation_score(protnodes) for protnode in protnodes: _determine_and_annotate_outlier_status_of_peptides(protnode, regulation_score) - aqcluster_utils.aggregate_node_properties(protnode, only_use_mainclust=True, peptide_outlier_filtering=True) + aqcluster_utils.aggregate_node_properties(protnode, only_use_mainclust=True, peptide_outlier_filtering=True, aggregation_mode=aggregation_mode) @@ -23,7 +24,12 @@ def calculate_regulation_score(protnodes: list[anytree.Node]): fraction_sig = num_sig / (num_sig + num_insig) log2fc_ratio_sig_vs_insig = np.median(abs_log2fc[sig_mask_005]) / (np.median(abs_log2fc[nonsig_mask]) + 1e-6) - regulation_score = min(1, log2fc_ratio_sig_vs_insig * fraction_sig/10) #merges the regulation strength and the fraction of significant proteins into one score divided by to normalize it, the normalization factor corresponds to a very stongly regulated dataset + regulation_score = min( + 1, + log2fc_ratio_sig_vs_insig + * fraction_sig + / aqvariables.PEPTIDE_OUTLIER_REGULATION_NORMALIZATION_FACTOR, + ) # merges the regulation strength and the fraction of significant proteins into one score; the normalization factor corresponds to a strongly regulated dataset return regulation_score diff --git a/alphaquant/cluster/residual_decorrelation.py b/alphaquant/cluster/residual_decorrelation.py new file mode 100644 index 00000000..4519019f --- /dev/null +++ b/alphaquant/cluster/residual_decorrelation.py @@ -0,0 +1,722 @@ +"""Residual decorrelation: remove correlated siblings before z-score aggregation. + +Overview +-------- +When AlphaQuant aggregates child nodes (e.g. peptides → protein, fragments → +peptide) using Stouffer's method, inflated sibling correlations bias the +combined z-score upward. This module identifies and prunes the most +correlated children at every level of the ion tree so that the surviving +set's pairwise correlation distribution matches a condition-shuffled null. + +Algorithm +--------- +1. **Residual computation** (``attach_lm_residuals``): + For every base ion the within-condition mean intensity is subtracted, + yielding condition-mean residuals. Residuals for higher-level nodes are + the row-wise mean of their children's residuals, propagated bottom-up. + These residuals capture shared technical variation independently of the + fold-change signal. + +2. **Per-parent precomputation** (``_build_parent``, ``ParentPrecompute``): + For each parent node at a given level the children's residual vectors are + stacked into a matrix and a Pearson correlation matrix ``C`` is computed. + A greedy removal order is then computed once: at each step the child with + the highest mean pairwise correlation to the surviving set is removed. + The maximum pairwise correlation after each removal is stored as + ``max_r_trajectory``, making it cheap to replay any cutoff later via + ``survivors_at``. + +3. **Null distribution** (``_cross_parent_shuffle_null``): + Rows are permuted across parents (cross-parent shuffle) to produce a + baseline that represents what sibling correlations look like when children + are exchanged between unrelated proteins. + +4. **Level sweep** (``run_level_sweep``): + A grid of correlation cutoffs (default 1.0 → 0.1) is scanned. For each + cutoff the surviving correlation values across all parents are collected + and compared to the null via a one-sided excess-CDF distance ``D`` + (``_excess_cdf_distance``): the maximum over all ``r`` of + ``F_null(r) − F_corrected(r)``, i.e. how much the corrected distribution + still exceeds the null. The lowest cutoff with ``D ≤ tolerance`` is + chosen. If none qualifies the tightest cutoff is used regardless. + +5. **Application** (``apply_residual_decorrelation``): + Main entry point. Runs steps 1–4 for every ``LEVEL_PAIRS`` pair, marks + pruned children with ``node.exclude_residual_decorrelation = True``, then + re-aggregates node statistics bottom-up with the decorrelation-aware + aggregation mode. + +Structure +--------- +Data classes + ParentPrecompute – precomputed correlation matrix + removal trajectory + LevelSweepResult – outcome of one level sweep (cutoffs, distances, traces) + +Internal helpers + _node_matches_level – type-string → node matching + _build_parent – build ParentPrecompute for one parent node + _pair_rs_from_C – extract upper-triangle r values given survivors + _cross_parent_shuffle_null – permutation-based null distribution + _aggregate_pair_rs – collect r values across parents at a cutoff + _excess_cdf_distance – one-sided CDF distance metric + +Public API + attach_lm_residuals – attach within-condition residuals to tree nodes + run_level_sweep – sweep cutoffs for one (parent, child) level pair + apply_residual_decorrelation – orchestrate the full pipeline (main entry point) + plot_level_sweep_cdfs – CDF comparison plot for one level result + plot_level_sweep_diagnostics – full diagnostic figure (CDF + sweep trace) +""" +from __future__ import annotations + +from dataclasses import dataclass, field +import warnings + +import anytree +from anytree import PreOrderIter +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +import alphaquant.cluster.cluster_utils as aqcluster_utils +import alphaquant.config.variables as aqvariables +import alphaquant.config.config as aqconfig +import logging + +aqconfig.setup_logging() +LOGGER = logging.getLogger(__name__) + + +LEVEL_PAIRS = ( + ("gene", "seq"), + ("seq", "mod_seq"), + ("mod_seq", "mod_seq_charge"), + ("mod_seq_charge", "ion_type"), + ("frgion", "base"), + ("ms1_isotopes", "base"), +) + +DEFAULT_CUTOFF_GRID = tuple(round(1.0 - 0.1 * k, 2) for k in range(10)) +DEFAULT_TOLERANCE = 0.10 +DEFAULT_MIN_KEEP = 1 + + +@dataclass +class ParentPrecompute: + """Precomputed correlation structure for one parent node. + + Built once per parent by ``_build_parent``; the removal order and + max-r trajectory are stored so that ``survivors_at`` can replay any + cutoff in O(n) without recomputing correlations. + + Attributes + ---------- + parent_node: the tree node this precompute belongs to + child_nodes: ordered tuple of children whose residuals were used + C: (n × n) Pearson correlation matrix; diagonal is NaN + remove_order: greedy removal sequence (indices into child_nodes) + max_r_trajectory: max pairwise r after removing k children (length n) + """ + + parent_node: anytree.Node + child_nodes: tuple[anytree.Node, ...] + C: np.ndarray + remove_order: np.ndarray + max_r_trajectory: np.ndarray + + def survivors_at(self, cutoff: float, min_keep: int) -> np.ndarray: + """Return a boolean mask of children that survive at ``cutoff``. + + Replays the greedy removal order until ``max_r_trajectory`` drops + below ``cutoff``, always retaining at least ``min_keep`` children. + """ + n = self.C.shape[0] + if n == 0: + return np.zeros(0, dtype=bool) + # at most n-min_keep children may be removed + k_max = max(0, n - min_keep) + # scan the trajectory: stop at the first step where max_r is already below cutoff + k = 0 + while k <= k_max and self.max_r_trajectory[k] > cutoff: + k += 1 + # clamp in case the loop overshot (shouldn't happen with monotone trajectory) + if k > k_max: + k = k_max + # replay: mark the first k entries of the greedy removal order as dead + alive = np.ones(n, dtype=bool) + if k > 0: + alive[self.remove_order[:k]] = False + return alive + + +@dataclass +class LevelSweepResult: + """Outcome of a full cutoff sweep for one (parent_level, child_level) pair. + + Attributes + ---------- + level: (parent_level, child_level) string pair + cutoff: chosen correlation cutoff + d_before/d_after: excess CDF distance before and after pruning + n_parents: total parents examined at this level + parents_touched: parents where at least one child was dropped + children_dropped: total children marked for exclusion + grid_trace: list of (cutoff, distance, dropped, touched) per grid step + unmodified_sorted: sorted pairwise r values before pruning (for plotting) + corrected_sorted: sorted pairwise r values after pruning (for plotting) + null_sorted: sorted null distribution r values (for plotting) + """ + + level: tuple[str, str] + cutoff: float + d_before: float + d_after: float + n_parents: int + parents_touched: int + children_dropped: int + grid_trace: list[tuple[float, float, int, int]] = field(default_factory=list) + unmodified_sorted: np.ndarray = field(default_factory=lambda: np.empty(0, dtype=np.float64)) + corrected_sorted: np.ndarray = field(default_factory=lambda: np.empty(0, dtype=np.float64)) + null_sorted: np.ndarray = field(default_factory=lambda: np.empty(0, dtype=np.float64)) + + +def _node_matches_level(node, level: str) -> bool: + if level == "ion_type": + return node.type in {"frgion", "ms1_isotopes", "precursor"} + return node.type == level + + +def _build_parent(parent_node, child_nodes, mat): + """Build a ParentPrecompute by greedily computing the removal order. + + At each step the child with the highest mean pairwise correlation to the + remaining set is removed. The maximum pairwise r after each removal is + recorded in ``max_r_trajectory`` and monotonically enforced (non-increasing) + so that ``survivors_at`` can use a simple threshold scan. + """ + if mat.shape[0] < 2: + return None + + # compute full Pearson correlation matrix across children's residual vectors + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"invalid value encountered") + C = np.corrcoef(mat) + # constant rows produce NaN correlations; replace with 0 so they don't distort means + if not np.all(np.isfinite(C)): + C = np.nan_to_num(C, nan=0.0, posinf=1.0, neginf=-1.0) + C = C.copy() + # NaN on diagonal so nanmean excludes self-correlation when averaging rows + np.fill_diagonal(C, np.nan) + + n = C.shape[0] + alive = np.ones(n, dtype=bool) + remove_order = [] + max_r = [] + + # record the max pairwise r before any removal (step 0 of the trajectory) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"All-NaN slice encountered") + init_max = float(np.nanmax(C)) if n >= 2 else -np.inf + if not np.isfinite(init_max): + init_max = -np.inf + max_r.append(init_max) + + while alive.sum() > 1: + # extract the submatrix of currently surviving children + sub_idx = np.where(alive)[0] + cc = C[np.ix_(sub_idx, sub_idx)] + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"Mean of empty slice") + warnings.filterwarnings("ignore", r"All-NaN slice encountered") + mean_r = np.nanmean(cc, axis=1) + # the "worst" child is the one most correlated on average with its siblings; + # replace NaN with -inf so it is never chosen as worst (avoids argmax crash) + worst_local = int(np.nanargmax(np.where(np.isnan(mean_r), -np.inf, mean_r))) + alive[sub_idx[worst_local]] = False + remove_order.append(sub_idx[worst_local]) + # record the new maximum pairwise r among survivors after this removal + if alive.sum() >= 2: + rem = np.where(alive)[0] + cc2 = C[np.ix_(rem, rem)] + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"All-NaN slice encountered") + m = float(np.nanmax(cc2)) + if not np.isfinite(m): + m = -np.inf + else: + m = -np.inf + max_r.append(m) + + traj = np.asarray(max_r, dtype=np.float64) + # enforce monotone non-increasing: numerical noise can cause slight upward bumps + # which would break the threshold scan in survivors_at + for i in range(1, traj.size): + if traj[i] > traj[i - 1]: + traj[i] = traj[i - 1] + + return ParentPrecompute( + parent_node=parent_node, + child_nodes=tuple(child_nodes), + C=C.astype(np.float64), + remove_order=np.asarray(remove_order, dtype=np.int64), + max_r_trajectory=traj, + ) + + +def _pair_rs_from_C(C: np.ndarray, survivors: np.ndarray) -> np.ndarray: + if survivors.sum() < 2: + return np.empty(0, dtype=np.float64) + idx = np.where(survivors)[0] + sub = C[np.ix_(idx, idx)] + # k=1 skips the diagonal, giving only unique off-diagonal pairs + iu = np.triu_indices(sub.shape[0], k=1) + vals = sub[iu] + # drop NaN entries that arise from originally constant residual vectors + vals = vals[~np.isnan(vals)] + return vals.astype(np.float64, copy=False) + + +def _cross_parent_shuffle_null(mats: list[np.ndarray], rng: np.random.Generator) -> np.ndarray: + """Build a null correlation distribution by shuffling rows across parents. + + All residual rows from every parent are pooled, randomly permuted, then + re-assigned to groups of the original sizes. Pairwise correlations within + each group are computed and concatenated. This destroys within-parent + structure while preserving group sizes, giving the expected correlation + distribution under the null hypothesis that children are unrelated. + """ + if not mats: + return np.empty(0, dtype=np.float64) + # remember original group sizes so shuffled rows can be re-partitioned identically + sizes = [m.shape[0] for m in mats] + # pool all residual rows from all parents into one matrix and shuffle + pool = np.vstack(mats) + pool = pool[rng.permutation(pool.shape[0])] + out = [] + idx = 0 + for size in sizes: + # re-assign the next 'size' shuffled rows to this group + chunk = pool[idx:idx + size] + idx += size + if size < 2: + continue + # skip constant rows: they produce NaN correlations and add no information + keep = chunk.std(axis=1) > 0 + if keep.sum() < 2: + continue + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"invalid value encountered") + c = np.corrcoef(chunk[keep]) + iu = np.triu_indices(c.shape[0], k=1) + vals = c[iu] + vals = vals[~np.isnan(vals)] + if vals.size: + out.append(vals) + if not out: + return np.empty(0, dtype=np.float64) + return np.concatenate(out).astype(np.float64, copy=False) + + +def _aggregate_pair_rs(parents: list[ParentPrecompute], cutoff: float, min_keep: int) -> np.ndarray: + chunks = [] + for pp in parents: + alive = pp.survivors_at(cutoff, min_keep) + if alive.sum() < 2: + continue + chunks.append(_pair_rs_from_C(pp.C, alive)) + if not chunks: + return np.empty(0, dtype=np.float64) + return np.concatenate(chunks) + + +def _excess_cdf_distance(corrected: np.ndarray, null_sorted: np.ndarray) -> float: + """One-sided excess CDF distance between corrected and null distributions. + + Returns max over r of ``F_null(r) − F_corrected(r)``, clipped to 0 from + below. A value of 0 means the corrected distribution is nowhere above the + null; higher values indicate residual excess correlation. + """ + if corrected.size == 0 or null_sorted.size == 0: + return 0.0 + corr_sorted = np.sort(corrected) + # evaluate both CDFs on the union of all observed r values + grid = np.unique(np.concatenate([corr_sorted, null_sorted])) + # searchsorted with side="right" gives F(r) = P(X ≤ r) at each grid point + f_corr = np.searchsorted(corr_sorted, grid, side="right") / corr_sorted.size + f_null = np.searchsorted(null_sorted, grid, side="right") / null_sorted.size + # one-sided: only penalise when the corrected distribution exceeds the null + # (F_null > F_corr means more mass above r in corrected than null → excess correlation) + return float(np.max(np.maximum(f_null - f_corr, 0.0))) + + +def run_level_sweep( + parents: list[ParentPrecompute], + null_sorted: np.ndarray, + *, + cutoff_grid: tuple[float, ...] = DEFAULT_CUTOFF_GRID, + tolerance: float = DEFAULT_TOLERANCE, + min_keep: int = DEFAULT_MIN_KEEP, + level: tuple[str, str] = ("", ""), +): + """Sweep correlation cutoffs and return the mildest one within tolerance. + + For each cutoff in ``cutoff_grid`` (scanned from loose to tight), the + surviving pairwise r values are collected and the excess CDF distance to + the null is computed. The first cutoff whose distance falls at or below + ``tolerance`` is chosen. If none qualifies the tightest cutoff is used + unconditionally. + + Parameters + ---------- + parents: precomputed parent structures for this level pair + null_sorted: sorted null pairwise r values (from cross-parent shuffle) + cutoff_grid: correlation thresholds to scan, ordered loose → tight + tolerance: maximum allowed excess CDF distance D after pruning + min_keep: minimum children to retain per parent regardless of cutoff + level: (parent_level, child_level) label stored in the result + + Returns + ------- + LevelSweepResult with the chosen cutoff and full diagnostic traces. + """ + # measure baseline excess distance with no pruning (cutoff = 1.0 keeps everyone) + baseline = _aggregate_pair_rs(parents, 1.0, min_keep) + d_before = _excess_cdf_distance(baseline, null_sorted) + + chosen = None + trace = [] + for cutoff in cutoff_grid: + touched = 0 + dropped = 0 + chunks = [] + for pp in parents: + alive = pp.survivors_at(cutoff, min_keep) + n_drop = int(pp.C.shape[0] - alive.sum()) + if n_drop > 0: + touched += 1 + dropped += n_drop + if alive.sum() >= 2: + chunks.append(_pair_rs_from_C(pp.C, alive)) + corrected = np.concatenate(chunks) if chunks else np.empty(0, dtype=np.float64) + d = _excess_cdf_distance(corrected, null_sorted) + trace.append((cutoff, d, dropped, touched)) + # take the first (loosest) cutoff that already satisfies the tolerance — + # prefer dropping as few children as possible + if chosen is None and d <= tolerance: + chosen = (cutoff, d, corrected, dropped, touched) + + # if no cutoff reached tolerance, fall back to the tightest one in the grid + if chosen is None: + cutoff, d, dropped, touched = trace[-1] + corrected = _aggregate_pair_rs(parents, cutoff, min_keep) + chosen = (cutoff, d, corrected, dropped, touched) + + cutoff, d_after, corrected, dropped, touched = chosen + return LevelSweepResult( + level=level, + cutoff=float(cutoff), + d_before=float(d_before), + d_after=float(d_after), + n_parents=len(parents), + parents_touched=int(touched), + children_dropped=int(dropped), + grid_trace=trace, + unmodified_sorted=np.sort(baseline), + corrected_sorted=np.sort(corrected), + null_sorted=np.asarray(null_sorted, dtype=np.float64), + ) + + +def attach_lm_residuals(protnodes, df_c1_normed, df_c2_normed, min_n_per_cond=2): + """Attach per-ion residuals from ``log2(intensity) ~ condition``. + + The AlphaQuant pipeline passes log2-normalized intensities here. The + saved ``*.normed.tsv`` tables are exponentiated for output, so callers that + start from those files must convert them back to log2 before using this + helper. + """ + # build a single intensity matrix with all samples from both conditions + X = pd.concat([df_c1_normed, df_c2_normed], axis=1) + c1_cols = list(df_c1_normed.columns) + c2_cols = list(df_c2_normed.columns) + X = X.astype(float) + + # compute within-condition means per ion (row-wise) + m1 = X[c1_cols].mean(axis=1, skipna=True) + m2 = X[c2_cols].mean(axis=1, skipna=True) + + # subtract the within-condition mean: residual = intensity - condition mean + # this removes the fold-change signal, leaving only condition-independent noise + res = X.copy() + res[c1_cols] = X[c1_cols].sub(m1, axis=0) + res[c2_cols] = X[c2_cols].sub(m2, axis=0) + # mask ions with too few valid values in either condition — their residuals are unreliable + n1_ok = X[c1_cols].notna().sum(axis=1) >= int(min_n_per_cond) + n2_ok = X[c2_cols].notna().sum(axis=1) >= int(min_n_per_cond) + res.loc[~(n1_ok & n2_ok), :] = np.nan + + for protnode in protnodes: + # initialise residuals to None on every node before filling + for node in PreOrderIter(protnode): + node.residuals = None + + # assign residual vectors to base (leaf) ions by matching ion name to the matrix index + for node in PreOrderIter(protnode): + if node.type != "base": + continue + if node.name in res.index: + node.residuals = res.loc[node.name].to_numpy(dtype=float) + else: + node.residuals = None + + # propagate residuals bottom-up: each non-base node gets the column-wise mean + # of its children's residual vectors, so higher-level nodes carry an averaged + # representation of shared technical variation across their subtree + for level_nodes in aqcluster_utils.iterate_through_tree_levels_bottom_to_top(protnode): + for node in level_nodes: + if node.type == "base": + continue + vecs = [ + child.residuals + for child in node.children + if isinstance(getattr(child, "residuals", None), np.ndarray) + ] + if not vecs: + node.residuals = None + continue + stacked = np.vstack(vecs) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="Mean of empty slice") + with np.errstate(all="ignore"): + mean_vec = np.nanmean(stacked, axis=0) + # if all samples are NaN after averaging, treat as missing + node.residuals = None if np.all(np.isnan(mean_vec)) else mean_vec + + +def _collect_level_parents(protnodes, parent_level, child_level): + parents = [] + for protnode in protnodes: + for node in PreOrderIter(protnode): + if not _node_matches_level(node, parent_level): + continue + child_nodes = [] + vecs = [] + for child in node.children: + if not _node_matches_level(child, child_level): + continue + v = getattr(child, "residuals", None) + # skip children without residuals or with any NaN sample — + # NaN entries would propagate into the correlation matrix + if v is None or not isinstance(v, np.ndarray): + continue + if np.any(np.isnan(v)): + continue + child_nodes.append(child) + vecs.append(v) + # need at least 2 children to form a correlation matrix + if len(vecs) < 2: + continue + mat = np.vstack(vecs) + pp = _build_parent(node, child_nodes, mat) + if pp is not None: + parents.append(pp) + return parents + + +def apply_residual_decorrelation( + protnodes, + df_c1_normed, + df_c2_normed, + *, + tolerance=DEFAULT_TOLERANCE, + min_keep=DEFAULT_MIN_KEEP, + cutoff_grid=DEFAULT_CUTOFF_GRID, + aggregation_mode="stouffer_decorrelation", + null_seed=42, +): + """Main entry point: run full residual decorrelation on a list of protein nodes. + + Steps + ----- + 1. Attach within-condition mean residuals to every node (``attach_lm_residuals``). + 2. For each level pair in ``LEVEL_PAIRS``, build a cross-parent shuffle null, + run the cutoff sweep, and mark pruned children with + ``node.exclude_residual_decorrelation = True``. + 3. Optionally apply PTM fragment selection if ``PTM_FRAGMENT_SELECTION`` is set. + 4. Re-aggregate all node statistics bottom-up using ``aggregation_mode``. + 5. Strip residual arrays from nodes to keep the tree serializable. + + Parameters + ---------- + protnodes: list of root protein nodes (anytree) + df_c1_normed: log2-normalized intensities for condition 1 (ions × samples) + df_c2_normed: log2-normalized intensities for condition 2 (ions × samples) + tolerance: maximum excess CDF distance D allowed after pruning + min_keep: minimum children retained per parent at each level + cutoff_grid: correlation cutoffs to sweep per level pair + aggregation_mode: z-aggregation mode used when re-aggregating after pruning + null_seed: random seed for the cross-parent shuffle null + + Returns + ------- + pandas.DataFrame summarising cutoffs, distances, and drop counts per level. + """ + # reset exclusion flags in case this function is called more than once on the same nodes + for protnode in protnodes: + for node in PreOrderIter(protnode): + node.exclude_residual_decorrelation = False + node.exclude_ptm_fragment_selection = False + + # step 1: compute within-condition residuals and attach them to every node + attach_lm_residuals(protnodes, df_c1_normed, df_c2_normed) + + rng = np.random.default_rng(null_seed) + level_results = [] + + # step 2: run the sweep for each level pair independently + for parent_level, child_level in LEVEL_PAIRS: + parents = _collect_level_parents(protnodes, parent_level, child_level) + # build the null from the same residual matrices used for the real sweep + mats = [ + np.vstack([child.residuals for child in pp.child_nodes]) + for pp in parents + ] + null_sorted = np.sort(_cross_parent_shuffle_null(mats, rng)) + sweep = run_level_sweep( + parents, + null_sorted, + cutoff_grid=cutoff_grid, + tolerance=tolerance, + min_keep=min_keep, + level=(parent_level, child_level), + ) + level_results.append(sweep) + msg = ( + f"residual decorrelation {parent_level}->{child_level}: " + f"cutoff={sweep.cutoff:g} " + f"D={sweep.d_before:.4g}->{sweep.d_after:.4g} " + f"dropped={sweep.children_dropped:,} " + f"parents_touched={sweep.parents_touched:,}/{sweep.n_parents:,}" + ) + LOGGER.info(msg) + print(msg, flush=True) + + # mark children that did not survive the chosen cutoff + for pp in parents: + survivors = pp.survivors_at(sweep.cutoff, min_keep) + for keep, child in zip(survivors, pp.child_nodes): + if not keep: + child.exclude_residual_decorrelation = True + + # step 3 (optional): apply PTM fragment selection on top of decorrelation exclusions + if aqvariables.PTM_FRAGMENT_SELECTION: + n_ptm_dropped, n_ptm_parents = aqcluster_utils.apply_ptm_fragment_selection( + protnodes, + ) + LOGGER.info( + "PTM fragment low-|Z| selection after residual decorrelation: " + "dropped %s children across %s frgion parents", + n_ptm_dropped, + n_ptm_parents, + ) + print( + "PTM fragment low-|Z| selection after residual decorrelation: " + f"dropped {n_ptm_dropped:,} children across " + f"{n_ptm_parents:,} frgion parents", + flush=True, + ) + + # step 4: re-aggregate node statistics bottom-up now that exclusion flags are set + for protnode in protnodes: + for level_nodes in aqcluster_utils.iterate_through_tree_levels_bottom_to_top(protnode): + for node in level_nodes: + if node.type == "base": + continue + aqcluster_utils.aggregate_node_properties( + node, + only_use_mainclust=True, + peptide_outlier_filtering=False, + aggregation_mode=aggregation_mode, + ) + + # step 5: residual vectors are only needed during the sweep; remove them before + # downstream ML reordering / JSON export to keep the tree serializable + for protnode in protnodes: + for node in PreOrderIter(protnode): + if hasattr(node, "residuals"): + delattr(node, "residuals") + + summary = pd.DataFrame( + [ + { + "parent_level": result.level[0], + "child_level": result.level[1], + "cutoff": result.cutoff, + "d_before": result.d_before, + "d_after": result.d_after, + "n_parents": result.n_parents, + "parents_touched": result.parents_touched, + "children_dropped": result.children_dropped, + } + for result in level_results + ] + ) + if not summary.empty: + LOGGER.info("Residual decorrelation summary:\n%s", summary.to_string(index=False)) + return summary + + +def plot_level_sweep_cdfs(level_result: LevelSweepResult, *, ax=None, title: str | None = None): + if ax is None: + _, ax = plt.subplots(figsize=(6, 4)) + + for values, label, color in ( + (level_result.null_sorted, "null", "#7f8c8d"), + (level_result.unmodified_sorted, "before", "#d35400"), + (level_result.corrected_sorted, "after", "#1f77b4"), + ): + if values.size == 0: + continue + y = np.arange(1, values.size + 1) / values.size + ax.step(values, y, where="post", label=label, color=color) + + ax.set_xlabel("pairwise residual correlation r") + ax.set_ylabel("cumulative fraction") + if title is None: + title = ( + f"{level_result.level[0]} -> {level_result.level[1]} | " + f"cutoff={level_result.cutoff:.2f}, D={level_result.d_after:.3f}" + ) + ax.set_title(title) + ax.legend() + ax.grid(alpha=0.25) + return ax + + +def plot_level_sweep_diagnostics(level_result: LevelSweepResult): + fig, axes = plt.subplots(1, 2, figsize=(11, 4.2)) + plot_level_sweep_cdfs(level_result, ax=axes[0]) + + if level_result.grid_trace: + cutoffs = [x[0] for x in level_result.grid_trace] + distances = [x[1] for x in level_result.grid_trace] + dropped = [x[2] for x in level_result.grid_trace] + axes[1].plot(cutoffs, distances, marker="o", color="#1f77b4", label="excess CDF distance") + axes[1].axvline(level_result.cutoff, color="#d35400", linestyle="--", label="chosen cutoff") + axes[1].scatter([level_result.cutoff], [level_result.d_after], color="#d35400", zorder=3) + ax2 = axes[1].twinx() + ax2.bar(cutoffs, dropped, width=0.06, alpha=0.18, color="#2c3e50", label="children dropped") + axes[1].set_xlabel("cutoff") + axes[1].set_ylabel("excess CDF distance") + ax2.set_ylabel("children dropped") + axes[1].invert_xaxis() + axes[1].grid(alpha=0.25) + lines1, labels1 = axes[1].get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + axes[1].legend(lines1 + lines2, labels1 + labels2, loc="best") + axes[1].set_title( + f"{level_result.level[0]} -> {level_result.level[1]} sweep" + ) + + fig.tight_layout() + return fig diff --git a/alphaquant/config/quant_reader_config.yaml b/alphaquant/config/quant_reader_config.yaml index 050149b6..ffb79612 100644 --- a/alphaquant/config/quant_reader_config.yaml +++ b/alphaquant/config/quant_reader_config.yaml @@ -863,6 +863,9 @@ spectronaut_fragion_ms1_gene: value: "True" use_iontree: True ml_level: CHARGE + variance_predictor_cols: + - EG.Cscore + - FG.ShapeQualityScore spectronaut_fragion_ms1_gene: @@ -912,6 +915,9 @@ spectronaut_fragion_ms1_gene: value: 5.0 use_iontree: True ml_level: CHARGE + variance_predictor_cols: + - EG.Cscore + - FG.ShapeQualityScore spectronaut_precursor_fragion_ms1: format: longtable @@ -971,6 +977,9 @@ spectronaut_precursor_fragion_ms1: value: "True" use_iontree: True ml_level: CHARGE + variance_predictor_cols: + - EG.Cscore + - FG.ShapeQualityScore spectronaut_precursor_fragion_ms1: @@ -1027,6 +1036,9 @@ spectronaut_precursor_fragion_ms1: value: 5.0 use_iontree: True ml_level: CHARGE + variance_predictor_cols: + - EG.Cscore + - FG.ShapeQualityScore spectronaut_precursor_fragion_ms1_protein: @@ -1083,6 +1095,9 @@ spectronaut_precursor_fragion_ms1_protein: value: 5.0 use_iontree: True ml_level: CHARGE + variance_predictor_cols: + - EG.Cscore + - FG.ShapeQualityScore spectronaut_fragion_ms1_protein: @@ -1132,6 +1147,9 @@ spectronaut_fragion_ms1_protein: value: 5.0 use_iontree: True ml_level: CHARGE + variance_predictor_cols: + - EG.Cscore + - FG.ShapeQualityScore annotation_columns: - PG.Genes @@ -1170,6 +1188,9 @@ spectronaut_fragion_protein: value: 5.0 use_iontree: True ml_level: CHARGE + variance_predictor_cols: + - EG.Cscore + - FG.ShapeQualityScore annotation_columns: - PG.Genes diff --git a/alphaquant/config/variables.py b/alphaquant/config/variables.py index 50a87308..4d8cf96c 100644 --- a/alphaquant/config/variables.py +++ b/alphaquant/config/variables.py @@ -5,7 +5,14 @@ PROGRESS_FOLDER = "progress" PREFER_PRECURSORS_FOR_CLUSTERING = True PEPTIDE_OUTLIER_FILTERING = True +PEPTIDE_OUTLIER_REGULATION_NORMALIZATION_FACTOR = 15.0 +OUTLIER_CORRECTION_FACTOR = 1.0 PTM_FRAGMENT_SELECTION = False +MAX_N_FRAGMENTS = None +ION_OUTLIER_MAD_THRESHOLD = None +CLASSIC_FRAGMENT_OUTLIER_FILTERING = False +ICC_NULL_PVAL_THRESHOLD = 0.1 +NUM_BG_CONTEXTS = 10 CONDITION_PAIR_SEPARATOR = "_VS_" #prefixes for the different ion types @@ -15,6 +22,10 @@ FRG = "FRG" ION = "ION" +INPUT_TYPE = None # e.g. "diann_precursor_fragion", set via set_input_config() +CONFIG_DICT = None # the full config dict for the detected input type + + def determine_variables(input_file, input_type): _determine_quant_id(input_file) _determine_prefer_precursors_for_clustering(input_type) @@ -36,18 +47,63 @@ def _determine_prefer_precursors_for_clustering(input_type): else: PREFER_PRECURSORS_FOR_CLUSTERING = False -def set_quant_id(quant_id): - global QUANT_ID - QUANT_ID = quant_id - def set_peptide_outlier_filtering(peptide_outlier_filtering): global PEPTIDE_OUTLIER_FILTERING PEPTIDE_OUTLIER_FILTERING = peptide_outlier_filtering +def set_outlier_correction_factor(outlier_correction_factor): + global OUTLIER_CORRECTION_FACTOR + OUTLIER_CORRECTION_FACTOR = float(outlier_correction_factor) + +def set_max_n_fragments(max_n_fragments): + global MAX_N_FRAGMENTS + MAX_N_FRAGMENTS = int(max_n_fragments) if max_n_fragments is not None else None + +def set_ion_outlier_mad_threshold(threshold): + global ION_OUTLIER_MAD_THRESHOLD + ION_OUTLIER_MAD_THRESHOLD = float(threshold) if threshold is not None else None + +def set_classic_fragment_outlier_filtering(enabled): + global CLASSIC_FRAGMENT_OUTLIER_FILTERING + CLASSIC_FRAGMENT_OUTLIER_FILTERING = bool(enabled) + +def set_icc_null_pval_threshold(threshold): + """Set the ICC null p-value threshold. + + Args: + threshold: Float applied to all ICC-estimation levels. + """ + global ICC_NULL_PVAL_THRESHOLD + ICC_NULL_PVAL_THRESHOLD = float(threshold) + + +def get_icc_null_pval_threshold(node_type=None): + """Return the ICC null p-value threshold for *node_type*. + + ``node_type`` is accepted for compatibility with callers that ask for the + level-specific value, but one global threshold is now used for all levels. + """ + return ICC_NULL_PVAL_THRESHOLD + def set_ptm_fragment_selection(is_ptm: bool): global PTM_FRAGMENT_SELECTION PTM_FRAGMENT_SELECTION = bool(is_ptm) -# Backwards-compat alias -def set_phospho_fragment_selection(is_phospho: bool): - set_ptm_fragment_selection(is_phospho) +def set_input_config(input_type, config_dict): + """Store the detected input type and its full config dict as module globals. + + Called once during pipeline setup so that other modules (e.g. + ``background_distributions``) can inspect the active configuration + without passing it through every function signature. + + Args: + input_type (str): Identifier of the detected input format + (e.g. ``"diann_precursor_fragion"``). + config_dict (dict): The complete YAML config dict for *input_type*. + + Side effects: + Sets ``INPUT_TYPE`` and ``CONFIG_DICT`` at module level. + """ + global INPUT_TYPE, CONFIG_DICT + INPUT_TYPE = input_type + CONFIG_DICT = config_dict diff --git a/alphaquant/diffquant/background_distributions.py b/alphaquant/diffquant/background_distributions.py index b02a7ec1..b7146e9a 100644 --- a/alphaquant/diffquant/background_distributions.py +++ b/alphaquant/diffquant/background_distributions.py @@ -11,50 +11,160 @@ from numba import njit from statistics import NormalDist import alphaquant.diffquant.diffutils as aqdiffutils - +import alphaquant.config.variables as aqvariables class ConditionBackgrounds(): + """Orchestrates background distribution calculation. + + For single-pool mode, delegates to one ``_BackgroundCalculation``. + When ``split_by_ion_type=True`` and the config defines both fragment ions + and MS1 isotopes, runs a separate calculation per ion type and combines + the resulting per-ion dicts. + """ + + def __init__(self, normed_condition_df, p2z, ion2varscore=None, split_by_ion_type=False): + """Initialise condition backgrounds. - def __init__(self, normed_condition_df, p2z): - self.backgrounds = [] + Args: + normed_condition_df (pd.DataFrame): Normalised intensity matrix + (ions x samples) for one condition. + p2z (dict): Pre-computed p-value to z-value lookup. + ion2varscore (dict | None): Optional per-ion variance-predictor + scores used for sorting ions before background partitioning. + When None, ions are sorted by median intensity. + split_by_ion_type (bool): If True **and** the current config + defines both fragment ions and MS1 isotopes, build separate + background pools for each ion type. Defaults to False. + """ self.ion2background = {} self.ion2nonNanvals = {} self.ion2allvals = {} - self.idx2ion = {} - self.init_ion2nonNanvals(normed_condition_df) - self.context_ranges = [] - self.select_intensity_ranges(p2z) + + if split_by_ion_type and self._has_multiple_ion_types(): + self._build_split(normed_condition_df, p2z, ion2varscore) + else: + self._build_single(normed_condition_df, p2z, ion2varscore) self.all_intensities = np.concatenate(list(self.ion2nonNanvals.values())) self.num_replicates = len(next(iter(self.ion2allvals.values()))) + def _build_single(self, normed_condition_df, p2z, ion2varscore): + """Build backgrounds from a single pool of all ions.""" + calc = _BackgroundCalculation(normed_condition_df, p2z, ion2varscore) + self._update_backgrounds(calc) + + def _build_split(self, normed_condition_df, p2z, ion2varscore): + """Run a separate calculation for fragment ions and MS1 isotopes. + + Falls back to single-pool when either subset has fewer than 10 ions. + """ + ion_type_group = self._split_by_ion_type(normed_condition_df.index) + if ion_type_group["FRGION"].sum() < 10 or ion_type_group["MS1ISOTOPES"].sum() < 10: + self._build_single(normed_condition_df, p2z, ion2varscore) + return + + for marker, mask in ion_type_group.items(): + sub_df = normed_condition_df.loc[mask].copy() + LOGGER.info(f"Building background for ion type '{marker}' ({len(sub_df)} ions)") + calc = _BackgroundCalculation(sub_df, p2z, ion2varscore=ion2varscore) + self._update_backgrounds(calc) + + def _update_backgrounds(self, calc): + """Merge results from a ``_BackgroundCalculation`` into this instance.""" + self.ion2background.update(calc.ion2background) + self.ion2nonNanvals.update(calc.ion2nonNanvals) + self.ion2allvals.update(calc.ion2allvals) + + @staticmethod + def _has_multiple_ion_types(): + """Return True if the config defines both fragment ions and MS1 isotopes.""" + if aqvariables.CONFIG_DICT is None: + return False + ion_hierarchy = aqvariables.CONFIG_DICT.get("ion_hierarchy", {}) + return "fragion" in ion_hierarchy and "ms1iso" in ion_hierarchy + + @staticmethod + def _split_by_ion_type(index): + """Split an index into fragment-ion and MS1-isotope subsets.""" + index_str = np.array([str(x) for x in index]) + frgion_mask = np.array(["FRGION" in s for s in index_str]) + ms1_mask = np.array(["MS1ISOTOPES" in s for s in index_str]) + return {"FRGION": frgion_mask, "MS1ISOTOPES": ms1_mask} + + +class _BackgroundCalculation(): + """Computes background distributions for a single pool of ions. + + Sorts the ions (by variance-predictor score or median intensity), + partitions them into overlapping intensity ranges, and creates a + BackGroundDistribution for each range. + + After construction the following dicts are populated: + ``ion2background``, ``ion2nonNanvals``, ``ion2allvals``. + """ + + def __init__(self, normed_condition_df, p2z, ion2varscore=None): + self.ion2background = {} + self.ion2nonNanvals = {} + self.ion2allvals = {} + + self._sort_and_index(normed_condition_df, ion2varscore) + self._create_background_distributions(p2z) + def _sort_and_index(self, normed_condition_df, ion2varscore): + """Sort ions and build index-to-ion mappings. - def init_ion2nonNanvals(self, normed_condition_df): - normed_condition_df['median'] = normed_condition_df.median(numeric_only=True, axis=1) - normed_condition_df = normed_condition_df.sort_values(by='median').drop('median', axis=1) - self.normed_condition_df = normed_condition_df - #nonan_array = get_nonna_array(normed_condition_df.to_numpy()) - #self.ion2nonNanvals = dict(zip(normed_condition_df.index, nonan_array)) + When *ion2varscore* is provided, ions are sorted by their + variance-predictor score (rank-based). Otherwise, falls back to + sorting by row-wise median intensity. + + After sorting, populates ``ion2nonNanvals``, ``ion2allvals``, and + the private ``_idx2ion`` mapping. + + Args: + normed_condition_df (pd.DataFrame): Normalised intensity matrix + (ions x samples). + ion2varscore (dict | None): Mapping from ion id to a combined + variance-predictor score, or None for median-intensity sorting. + """ + if ion2varscore is not None: + sort_scores = normed_condition_df.index.map( + lambda x: ion2varscore.get(x, 0.5) + ) + normed_condition_df = normed_condition_df.assign( + _sort_score=sort_scores + ).sort_values(by='_sort_score').drop('_sort_score', axis=1) + else: + normed_condition_df = normed_condition_df.assign( + median=normed_condition_df.median(numeric_only=True, axis=1) + ).sort_values(by='median').drop('median', axis=1) + self._normed_condition_df = normed_condition_df self.ion2nonNanvals = aqutils.get_non_nas_from_pd_df(normed_condition_df) self.ion2allvals = aqutils.get_ionints_from_pd_df(normed_condition_df) - self.idx2ion = dict(zip(range(len(normed_condition_df.index)), normed_condition_df.index)) + self._idx2ion = dict(zip(range(len(normed_condition_df.index)), normed_condition_df.index)) + def _create_background_distributions(self, p2z): + """Partition sorted ions into overlapping intensity ranges and build backgrounds. - def select_intensity_ranges(self, p2z): + Creates ``BackGroundDistribution`` objects for overlapping windows + of ions and assigns every ion to one of these distributions via + ``self.ion2background``. + + Args: + p2z (dict): Pre-computed p-value to z-value lookup (passed through + to ``BackGroundDistribution``). + """ total_available_comparisons =0 - num_contexts = 10 - cumulative_counts = np.zeros(self.normed_condition_df.shape[0]) + num_contexts = aqvariables.NUM_BG_CONTEXTS + cumulative_counts = np.zeros(self._normed_condition_df.shape[0]) - for idx ,count in enumerate(self.normed_condition_df.count(axis=1)): + for idx ,count in enumerate(self._normed_condition_df.count(axis=1)): total_available_comparisons+=count-1 cumulative_counts[idx] = int(total_available_comparisons/2) - - #assign the context sizes context_size = np.max([1000, int(total_available_comparisons/(1+num_contexts/2))]) if context_size> total_available_comparisons: context_size = int(total_available_comparisons/2) @@ -64,15 +174,13 @@ def select_intensity_ranges(self, p2z): middle_idx = int(np.searchsorted(cumulative_counts, halfcontext_size)) end_idx = int(np.searchsorted(cumulative_counts, context_size)) - context_boundaries[0] = 0 context_boundaries[1] = middle_idx context_boundaries[2] = end_idx while context_boundaries[1] < len(cumulative_counts): - bgdist = BackGroundDistribution(context_boundaries[0], context_boundaries[2], self.ion2nonNanvals, self.idx2ion, p2z) - self.context_ranges.append([context_boundaries[0], context_boundaries[2]]) - self.assign_ions2bgdists(context_boundaries[0], context_boundaries[2], bgdist) - self.backgrounds.append(bgdist) + bgdist = BackGroundDistribution(context_boundaries[0], context_boundaries[2], self.ion2nonNanvals, self._idx2ion, p2z) + for idx in range(context_boundaries[0], context_boundaries[2]): + self.ion2background[self._idx2ion[idx]] = bgdist context_boundaries[0] = context_boundaries[1] context_boundaries[1] = context_boundaries[2] end_idx = np.searchsorted(cumulative_counts, context_size + cumulative_counts[context_boundaries[0]]) @@ -80,12 +188,6 @@ def select_intensity_ranges(self, p2z): end_idx = len(cumulative_counts) context_boundaries[2] = end_idx - def assign_ions2bgdists(self, boundaries1, boundaries2, bgdist): - ion2bg_local = {} #dict(map(lambda _idx : (self.normed_condition_df.index.values[_idx], bgdist), range(boundaries1, boundaries2))) - for idx in range(boundaries1, boundaries2): - ion2bg_local.update({self.idx2ion.get(idx) : bgdist}) - self.ion2background.update(ion2bg_local) - # Cell import numpy as np import random @@ -451,10 +553,6 @@ def get_doublediff_bg(deed_ion1, deed_ion2, deedpair2doublediffdist, p2z): return subtr_bg -def invert_deedkey(deedkey): - return (deedkey[1], deedkey[0]) - - # Cell from numba import njit @@ -509,12 +607,3 @@ def transform_cumulative_into_fc2count(cumulative, min_fc): fcs, counts = _transform_cumulative_vectorized(cumulative, min_fc) return dict(zip(fcs, counts)) -# Cell -@njit -def get_cumul_from_freq(freq): - res = np.zeros(len(freq), dtype=np.int64) - res[0] = freq[0] - for i in range(1,len(freq)): - res[i] = res[i-1] + freq[i] - - return res diff --git a/alphaquant/diffquant/condpair_analysis.py b/alphaquant/diffquant/condpair_analysis.py index 00ad54c3..f1a785a5 100644 --- a/alphaquant/diffquant/condpair_analysis.py +++ b/alphaquant/diffquant/condpair_analysis.py @@ -1,10 +1,12 @@ import alphaquant.diffquant.background_distributions as aqbg import alphaquant.diffquant.diff_analysis as aqdiff +import alphaquant.diffquant.intensity_summarization as aq_summarization import alphaquant.norm.normalization as aqnorm import alphaquant.plotting.pairwise as aq_plot_pairwise import alphaquant.diffquant.diffutils as aqutils import alphaquant.cluster.cluster_ions as aqclust import alphaquant.classify.classify_precursors as aq_class_precursors +import alphaquant.diffquant.variance_predictor as aq_variance_predictor import alphaquant.cluster.ml_reorder as aq_clust_mlreorder import alphaquant.tables.diffquant_table as aq_tablewriter_protein import alphaquant.tables.proteoformtable as aq_tablewriter_proteoform @@ -12,6 +14,7 @@ import alphaquant.cluster.cluster_utils as aqclust_utils import alphaquant.cluster.cluster_missingval as aq_clust_missingval import alphaquant.cluster.outlier_filtering as aq_clust_outlier +import alphaquant.cluster.residual_decorrelation as aq_clust_resid import pandas as pd import numpy as np @@ -66,10 +69,28 @@ def analyze_condpair(*,runconfig, condpair): df_c1_normed, df_c2_normed = aqnorm.normalize_if_specified(df_c1 = df_c1, df_c2 = df_c2, c1_samples = c1_samples, c2_samples = c2_samples, normalize_within_conds = runconfig.normalize, normalize_between_conds = runconfig.normalize, runtime_plots = runconfig.runtime_plots, protein_subset_for_normalization_file=runconfig.protein_subset_for_normalization_file, pep2prot = pep2prot)#, "./test_data/normed_intensities.tsv") + summarization_nodes = getattr(runconfig, 'summarization_nodes', []) + if summarization_nodes: + df_c1_normed, df_c2_normed, pep2prot = aq_summarization.apply_summarization( + df_c1_normed, df_c2_normed, pep2prot, summarization_nodes + ) + if runconfig.results_dir != None: write_out_normed_df(df_c1_normed, df_c2_normed, pep2prot, runconfig.results_dir, condpair) - normed_c1 = aqbg.ConditionBackgrounds(df_c1_normed, p2z) - normed_c2 = aqbg.ConditionBackgrounds(df_c2_normed, p2z) + + ion_index = df_c1_normed.index.union(df_c2_normed.index) + ion_variance = _compute_pooled_ion_variance(df_c1_normed, df_c2_normed, ion_index) + ion_median_intensity = _compute_pooled_median_intensity(df_c1_normed, df_c2_normed, ion_index) + if getattr(runconfig, 'use_variance_predictor', False): + ion2varscore = _load_variance_predictor_scores(runconfig, c1_samples, c2_samples, + ion_index, ion_variance, + ion_median_intensity) + else: + ion2varscore = None + + split_backgrounds_if_possible = getattr(runconfig, 'split_ion_backgrounds', False) + normed_c1 = aqbg.ConditionBackgrounds(df_c1_normed, p2z, ion2varscore=ion2varscore, split_by_ion_type=split_backgrounds_if_possible) + normed_c2 = aqbg.ConditionBackgrounds(df_c2_normed, p2z, ion2varscore=ion2varscore, split_by_ion_type=split_backgrounds_if_possible) ions_to_check = normed_c1.ion2nonNanvals.keys() & normed_c2.ion2nonNanvals.keys() ions_to_check = sorted(ions_to_check) @@ -113,7 +134,8 @@ def analyze_condpair(*,runconfig, condpair): clustered_prot_node = aqclust.get_scored_clusterselected_ions(prot, ions, normed_c1, normed_c2, bgpair2diffDist, p2z, deedpair2doublediffdist, pval_threshold_basis = runconfig.cluster_threshold_pval, fcfc_threshold = runconfig.cluster_threshold_fcfc, take_median_ion=runconfig.take_median_ion, fcdiff_cutoff_clustermerge= runconfig.fcdiff_cutoff_clustermerge, - fragment_outlier_filtering=runconfig.fragment_outlier_filtering) + aggregation_mode=runconfig.aggregation_mode, + cluster_threshold_ion_type=getattr(runconfig, "cluster_threshold_ion_type", 0.01)) protnodes.append(clustered_prot_node) if count_prots%100==0: @@ -121,6 +143,23 @@ def analyze_condpair(*,runconfig, condpair): count_prots+=1 + aggregation_mode = getattr(runconfig, "aggregation_mode", "stouffer_decorrelation") + uses_stouffer_decorrelation = ( + aggregation_mode == "stouffer_decorrelation" + or ( + isinstance(aggregation_mode, dict) + and "stouffer_decorrelation" in aggregation_mode.values() + ) + ) + if uses_stouffer_decorrelation: + aq_clust_resid.apply_residual_decorrelation( + protnodes, + df_c1_normed, + df_c2_normed, + tolerance=getattr(runconfig, "residual_decorrelation_tolerance", 0.10), + min_keep=getattr(runconfig, "residual_decorrelation_min_keep", 1), + aggregation_mode=runconfig.aggregation_mode, + ) if len(prot2missingval_diffions.keys())>0: LOGGER.info(f"start analysis of proteins w. completely missing values") @@ -146,7 +185,7 @@ def analyze_condpair(*,runconfig, condpair): if ml_successfull and (ml_performance_dict["r2_score"] >0.05): #only use the ml score if it is meaningful - aq_clust_mlreorder.update_nodes_w_ml_score(protnodes) + aq_clust_mlreorder.update_nodes_w_ml_score(protnodes, aggregation_mode=runconfig.aggregation_mode) LOGGER.info(f"ML based quality score above quality threshold and added to the nodes.") runconfig.ml_based_quality_score = True else: @@ -154,7 +193,7 @@ def analyze_condpair(*,runconfig, condpair): runconfig.ml_based_quality_score = False if runconfig.peptide_outlier_filtering: - aq_clust_outlier.apply_peptide_outlier_filtering(protnodes) + aq_clust_outlier.apply_peptide_outlier_filtering(protnodes, aggregation_mode=runconfig.aggregation_mode) protnodes_combined = protnodes + protnodes_missingval condpair_node = aqclust_utils.get_condpair_node(protnodes_combined, condpair) @@ -165,6 +204,93 @@ def analyze_condpair(*,runconfig, condpair): return res_df, pep_df +def _compute_pooled_ion_variance(df_c1_normed, df_c2_normed, ion_index): + """Compute pooled within-condition variance for each ion. + + For each ion, the variance is computed within each condition + (row-wise across replicate columns) and then averaged. Ions that + appear in only one condition get that condition's variance. + + Args: + df_c1_normed (pd.DataFrame): Normalised intensities for condition 1 + (ions x samples). + df_c2_normed (pd.DataFrame): Normalised intensities for condition 2. + ion_index (pd.Index): Union of ions from both conditions. + + Returns: + pd.Series: Per-ion pooled variance, indexed by ion id. + """ + var_c1 = df_c1_normed.var(axis=1) + var_c2 = df_c2_normed.var(axis=1) + pooled = pd.concat([var_c1, var_c2], axis=1).mean(axis=1) + return pooled.reindex(ion_index) + + +def _compute_pooled_median_intensity(df_c1_normed, df_c2_normed, ion_index): + """Compute per-ion median intensity pooled across conditions. + + For each ion, the row-wise median is computed within each condition + and then averaged. + + Args: + df_c1_normed (pd.DataFrame): Normalised intensities for condition 1. + df_c2_normed (pd.DataFrame): Normalised intensities for condition 2. + ion_index (pd.Index): Union of ions from both conditions. + + Returns: + pd.Series: Per-ion pooled median intensity, indexed by ion id. + """ + med_c1 = df_c1_normed.median(axis=1) + med_c2 = df_c2_normed.median(axis=1) + pooled = pd.concat([med_c1, med_c2], axis=1).mean(axis=1) + return pooled.reindex(ion_index) + + +def _load_variance_predictor_scores(runconfig, c1_samples, c2_samples, + ion_index, ion_variance, + ion_median_intensity): + """Load quality metrics and fit a linear model predicting ion variance. + + Uses the quality-metric columns from the ml_info_table together with + the per-ion median intensity as predictors, and the observed per-ion + variance as the regression target. The predicted values serve as + scores for background-distribution ordering. + + Args: + runconfig: Pipeline run configuration object. Relevant attributes are + ``variance_predictor_cols`` (list[str] | None) and + ``ml_input_file`` (str | None). + c1_samples (list[str]): Sample names for condition 1. + c2_samples (list[str]): Sample names for condition 2. + ion_index (pd.Index): Ion identifiers to score. + ion_variance (pd.Series): Observed per-ion pooled variance. + ion_median_intensity (pd.Series): Per-ion pooled median intensity. + + Returns: + dict[str, float] | None: Mapping from ion id to predicted variance + score, or None when the configuration is missing or fitting fails + (triggering fallback to intensity-based sorting). + """ + variance_predictor_cols = getattr(runconfig, 'variance_predictor_cols', None) + ml_input_file = getattr(runconfig, 'ml_input_file', None) + + if not variance_predictor_cols or not ml_input_file: + return None + + try: + return aq_variance_predictor.load_variance_predictor_scores( + ml_info_file=ml_input_file, + samples_used=c1_samples + c2_samples, + variance_predictor_cols=variance_predictor_cols, + ion_index=ion_index, + ion_variance=ion_variance, + ion_median_intensity=ion_median_intensity, + ) + except Exception as e: + LOGGER.warning("Failed to load variance predictor scores: %s. Falling back to intensity.", e) + return None + + import alphaquant.diffquant.diffutils as aqutils def get_unnormed_df_condpair(input_file:str, samplemap_df:pd.DataFrame, condpair:str, file_has_alphaquant_format: bool) -> pd.DataFrame: diff --git a/alphaquant/diffquant/diff_analysis.py b/alphaquant/diffquant/diff_analysis.py index 30b9453a..52552316 100644 --- a/alphaquant/diffquant/diff_analysis.py +++ b/alphaquant/diffquant/diff_analysis.py @@ -6,6 +6,7 @@ from scipy.stats import ttest_ind from scipy.stats import t as student_t import alphaquant.diffquant.diffutils as aqdiffutils +import alphaquant.config.variables as aqvariables class DifferentialIon(): """Computes differential statistics for an ion using empirical background distributions. @@ -188,7 +189,7 @@ def _calc_ttest_peptide(self, noNanvals_from, noNanvals_to, p2z, outlier_correct if outlier_correction and se_standard > 0 and n1 > 1 and n2 > 1: se_robust = _calc_robust_se_ttest(noNanvals_from, noNanvals_to) if se_robust > 0: - scaling_factor = max(1.0, min(5.0, se_robust / se_standard)) + scaling_factor = max(1.0, min(5.0, se_robust / se_standard * aqvariables.OUTLIER_CORRECTION_FACTOR)) t_adj = t_stat / scaling_factor p_val = 2.0 * float(student_t.sf(abs(t_adj), df)) @@ -207,6 +208,8 @@ def calc_outlier_scaling_factor(noNanvals_from, noNanvals_to, diffDist): (e.g., due to biological variability or technical outliers), the variance estimate is inflated accordingly. This makes the test more conservative when data quality is poor. + The result is further scaled by ``aqvariables.OUTLIER_CORRECTION_FACTOR`` (default 1.0). + Args: noNanvals_from: Log2 intensities from condition 1 noNanvals_to: Log2 intensities from condition 2 @@ -228,7 +231,7 @@ def calc_outlier_scaling_factor(noNanvals_from, noNanvals_to, diffDist): highest_SD_to = max(between_rep_SD_to, sd_to) highest_SD_combined = math.sqrt(highest_SD_from**2 + highest_SD_to**2) - scaling_factor = max(1.0, highest_SD_combined/diffDist.SD) + scaling_factor = max(1.0, highest_SD_combined/diffDist.SD * aqvariables.OUTLIER_CORRECTION_FACTOR) return scaling_factor def _robust_sd(x): diff --git a/alphaquant/diffquant/intensity_summarization.py b/alphaquant/diffquant/intensity_summarization.py new file mode 100644 index 00000000..2903a38f --- /dev/null +++ b/alphaquant/diffquant/intensity_summarization.py @@ -0,0 +1,326 @@ +"""Intensity summarization for hierarchical ion trees. + +Sums base-ion intensities at specified tree node types before differential +analysis. For example, specifying ``["frgion"]`` will sum all individual +fragment ions under each fragment-ion group into a single intensity per +replicate, while leaving MS1-isotope and precursor base ions untouched. + +When the requested summarization level is above the ion-type level +(e.g. ``"mod_seq_charge"`` or ``"seq"``), leaves are split by their +ion-type ancestor so that fragment and MS1 intensities are never mixed. +""" + +import re +from collections import defaultdict + +import anytree +import numpy as np +import pandas as pd + +from alphaquant.cluster.cluster_ions import REGEX_FRGIONS_ISOTOPES, LEVEL_NAMES + +import alphaquant.config.config as aqconfig +import logging + +aqconfig.setup_logging() +LOGGER = logging.getLogger(__name__) + +ION_TYPE_NODES = {"frgion", "ms1_isotopes", "precursor"} + +# Appended to an ion_type node name so the result is a valid base-ion name +# that the downstream tree builder can parse back into the hierarchy. +_NODE_TYPE_TO_COMPLETION_SUFFIX = { + "frgion": "ION_SUM", + "ms1_isotopes": "ISOTOPES_SUM", + "precursor": "URSOR_SUM", +} + +# When summarising above ion_type level we insert a synthetic path fragment +# between the higher-level node name and the ion-type suffix. +_LEVEL_TO_SYNTHETIC_INFIX = { + "mod_seq_charge": "_", + "mod_seq": "CHARGE_0_", + "seq": "SUM_CHARGE_0_", +} + +_ION_TYPE_TO_FULL_SUFFIX = { + "frgion": "FRGION_SUM", + "ms1_isotopes": "MS1ISOTOPES_SUM", + "precursor": "PRECURSOR_SUM", +} + + +# --------------------------------------------------------------------------- +# Tree construction (lightweight, from ion-name strings only) +# --------------------------------------------------------------------------- + +def build_tree_from_ion_names(protein_name, ion_names): + """Build a hierarchical tree from base-ion name strings. + + Uses the same regex logic as + :func:`~alphaquant.cluster.cluster_ions.create_hierarchical_ion_grouping` + but operates on plain strings rather than ``DifferentialIon`` objects. + """ + nodes = [ + anytree.Node(name, type="base", level="base") + for name in ion_names + ] + + for level_idx, level_patterns in enumerate(REGEX_FRGIONS_ISOTOPES): + name2node = {} + for pattern, node_type in level_patterns: + for node in nodes: + m = re.match(pattern, node.name) + if m: + matching_name = m.group(1) + if matching_name not in name2node: + name2node[matching_name] = anytree.Node( + matching_name, + type=node_type, + level=LEVEL_NAMES[level_idx], + ) + node.parent = name2node[matching_name] + if name2node: + nodes = list(name2node.values()) + + root = anytree.Node(protein_name, type="gene", level="gene") + for node in nodes: + node.parent = root + return root + + +# --------------------------------------------------------------------------- +# Naming helpers +# --------------------------------------------------------------------------- + +def _make_summarized_name_for_ion_type_node(node): + """Parseable summarized name for a node at the ion_type level.""" + return node.name + _NODE_TYPE_TO_COMPLETION_SUFFIX[node.type] + + +def _make_summarized_name_for_higher_node(parent_node, ion_type): + """Parseable summarized name when summarising above ion_type level. + + Inserts a synthetic path fragment so that the downstream tree builder + can still parse the resulting name into the full hierarchy. + """ + infix = _LEVEL_TO_SYNTHETIC_INFIX[parent_node.level] + suffix = _ION_TYPE_TO_FULL_SUFFIX[ion_type] + return parent_node.name + infix + suffix + + +# --------------------------------------------------------------------------- +# Grouping logic +# --------------------------------------------------------------------------- + +def compute_summarization_groups(pep2prot, ion_names, summarization_nodes): + """Determine which base ions to group and what to name the summaries. + + Args: + pep2prot: dict mapping ion name -> protein name. + ion_names: iterable of all base-ion names present in either condition. + summarization_nodes: list of node types to summarize + (e.g. ``["frgion"]``). + + Returns: + groups: list of ``(new_name, [leaf_ion_names], protein)`` tuples. + remaining: set of ion names that stay as individual rows. + """ + if not summarization_nodes: + return [], set(ion_names) + + prot2ions = defaultdict(list) + for ion in ion_names: + prot = pep2prot.get(ion) + if prot is not None: + prot2ions[prot].append(ion) + + groups = [] + summarized_ions = set() + + for prot, ions in prot2ions.items(): + tree = build_tree_from_ion_names(prot, ions) + + for node_type in summarization_nodes: + target_nodes = anytree.findall( + tree, filter_=lambda n, nt=node_type: n.type == nt + ) + + for target_node in target_nodes: + if node_type in ION_TYPE_NODES: + leaf_names = [ + l.name for l in target_node.leaves if l.type == "base" + ] + if leaf_names: + new_name = _make_summarized_name_for_ion_type_node( + target_node + ) + groups.append((new_name, leaf_names, prot)) + summarized_ions.update(leaf_names) + else: + # Above ion_type: split by ion type to avoid mixing + type_to_leaves = defaultdict(list) + for desc in anytree.PreOrderIter(target_node): + if desc.type in ION_TYPE_NODES: + for leaf in desc.leaves: + if leaf.type == "base": + type_to_leaves[desc.type].append(leaf.name) + for ion_type, leaf_names in type_to_leaves.items(): + new_name = _make_summarized_name_for_higher_node( + target_node, ion_type + ) + groups.append((new_name, leaf_names, prot)) + summarized_ions.update(leaf_names) + + remaining = set(ion_names) - summarized_ions + return groups, remaining + + +# --------------------------------------------------------------------------- +# DataFrame summarization +# --------------------------------------------------------------------------- + +def summarize_condition_df(df, groups, remaining_ions): + """Apply summarization to a per-condition dataframe. + + Sums intensities in **linear** space for grouped ions, keeps remaining + ions as-is. + + Args: + df: DataFrame with log2 intensities, index = quant_id, + columns = sample names. + groups: list of ``(new_name, [leaf_ion_names], protein)`` tuples. + remaining_ions: set of ion names to keep unchanged. + + Returns: + Summarized DataFrame (same column layout, modified index). + """ + parts = [] + + remaining_in_df = df.index.intersection(remaining_ions) + if len(remaining_in_df) > 0: + parts.append(df.loc[remaining_in_df]) + + for new_name, leaf_names, _prot in groups: + present = [ion for ion in leaf_names if ion in df.index] + if not present: + continue + subset = df.loc[present] + linear = 2.0 ** subset + summed = linear.sum(axis=0) + all_nan = subset.isna().all(axis=0) + with np.errstate(divide='ignore'): + log2_summed = np.log2(summed) + log2_summed[all_nan] = np.nan + log2_summed.name = new_name + parts.append(log2_summed.to_frame().T) + + if not parts: + return pd.DataFrame(columns=df.columns) + + return pd.concat(parts) + + +# --------------------------------------------------------------------------- +# Ion quality filtering per group +# --------------------------------------------------------------------------- + +def _filter_group_ions(leaf_names, df_c1, df_c2): + """Select which leaf ions to include in a summarization group. + + Strategy: + 1. Keep only ions that have values in ALL replicates of BOTH conditions. + 2. If no ion qualifies, pick the single ion with the most non-NaN values + across both conditions. + + Returns: + Filtered list of ion names. + """ + present_c1 = [ion for ion in leaf_names if ion in df_c1.index] + present_c2 = [ion for ion in leaf_names if ion in df_c2.index] + present_both = set(present_c1) & set(present_c2) + + if not present_both: + all_present = set(present_c1) | set(present_c2) + if not all_present: + return [] + best_ion = max(all_present, key=lambda ion: ( + df_c1.loc[ion].notna().sum() if ion in df_c1.index else 0 + ) + ( + df_c2.loc[ion].notna().sum() if ion in df_c2.index else 0 + )) + return [best_ion] + + n_cols_c1 = df_c1.shape[1] + n_cols_c2 = df_c2.shape[1] + + complete = [ + ion for ion in present_both + if df_c1.loc[ion].notna().sum() == n_cols_c1 + and df_c2.loc[ion].notna().sum() == n_cols_c2 + ] + + if complete: + return complete + + best_ion = max(present_both, key=lambda ion: + df_c1.loc[ion].notna().sum() + df_c2.loc[ion].notna().sum() + ) + return [best_ion] + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + +def apply_summarization(df_c1, df_c2, pep2prot, summarization_nodes): + """Summarize ion intensities at specified tree levels. + + This is the main entry point called from + :func:`~alphaquant.diffquant.condpair_analysis.analyze_condpair`. + + Args: + df_c1: Per-condition DataFrame for condition 1 (log2 intensities). + df_c2: Per-condition DataFrame for condition 2 (log2 intensities). + pep2prot: dict mapping ion name -> protein name. + summarization_nodes: list of node types to summarize + (e.g. ``["frgion"]``). + + Returns: + ``(df_c1_new, df_c2_new, pep2prot_new)`` + """ + all_ions = set(df_c1.index) | set(df_c2.index) + groups, remaining = compute_summarization_groups( + pep2prot, all_ions, summarization_nodes + ) + + if not groups: + return df_c1, df_c2, pep2prot + + filtered_groups = [] + for new_name, leaf_names, prot in groups: + selected = _filter_group_ions(leaf_names, df_c1, df_c2) + if selected: + filtered_groups.append((new_name, selected, prot)) + + if not filtered_groups: + return df_c1, df_c2, pep2prot + + df_c1_new = summarize_condition_df(df_c1, filtered_groups, remaining) + df_c2_new = summarize_condition_df(df_c2, filtered_groups, remaining) + + pep2prot_new = {ion: pep2prot[ion] for ion in remaining if ion in pep2prot} + for new_name, _leaves, prot in filtered_groups: + pep2prot_new[new_name] = prot + + LOGGER.info( + "Summarization at %s: %d base ions -> %d entries " + "(%d summarized groups, %d unchanged)", + summarization_nodes, + len(all_ions), + len(remaining) + len(filtered_groups), + len(filtered_groups), + len(remaining), + ) + + return df_c1_new, df_c2_new, pep2prot_new diff --git a/alphaquant/diffquant/variance_predictor.py b/alphaquant/diffquant/variance_predictor.py new file mode 100644 index 00000000..2b91ac97 --- /dev/null +++ b/alphaquant/diffquant/variance_predictor.py @@ -0,0 +1,200 @@ +"""Variance-predictor scoring for background-distribution ion ordering. + +Loads quality-metric columns from the ml_info_table (at precursor level) +and fits a linear model predicting observed per-ion variance from these +metrics. The predicted variance is used to sort ions so that those with +similar expected variance are grouped together, producing tighter +empirical background distributions. + +The linear-regression approach automatically handles columns that +correlate with variance in different directions (positive or negative +coefficients) and weights each column by its predictive power. +""" + +import re +import logging + +import numpy as np +import pandas as pd + +import alphaquant.utils.reader_utils as aq_reader_utils + +LOGGER = logging.getLogger(__name__) + +_ION_SPLIT_PAT = re.compile(r"(_FRGION_|_MS1ISOTOPES_|_PRECURSOR_)") + + +def load_variance_predictor_scores(ml_info_file, samples_used, + variance_predictor_cols, ion_index, + ion_variance, ion_median_intensity=None): + """Fit a linear model predicting observed ion variance from quality + metrics and return the predicted scores for sorting. + + The ml_info_table is at precursor level (SEQ_MOD_CHARGE). Fragment-level + ion ids are mapped to their parent precursor via ``_ION_SPLIT_PAT``. + + A simple OLS regression is fitted:: + + ion_variance ~ median_intensity + col1 + col2 + ... + + Median intensity is always included as a built-in predictor (when + provided) because it is the strongest single correlate of ion-level + variance. The quality-metric columns from the config refine this + baseline. The predicted values serve as the combined score. + + Args: + ml_info_file (str): Path to the ml_info_table TSV. + samples_used (list[str] | None): Sample IDs to keep; None keeps all. + variance_predictor_cols (list[str]): Column names in the ml_info_table + to use as predictors. + ion_index (pd.Index): Ion identifiers to score. + ion_variance (pd.Series): Observed per-ion variance (indexed by ion id), + used as the regression target. + ion_median_intensity (pd.Series | None): Per-ion pooled median + intensity. When provided, prepended as a built-in predictor. + + Returns: + dict[str, float] | None: Mapping from ion id to predicted variance + score, or None if the required columns were not found or the + regression could not be fitted. + """ + if not variance_predictor_cols: + return None + + try: + usecols = ["quant_id", "sample_ID"] + list(variance_predictor_cols) + ml_df = aq_reader_utils.read_file(ml_info_file, sep="\t", usecols=usecols) + except (ValueError, KeyError): + LOGGER.warning( + "Could not load variance predictor columns %s from %s. " + "Falling back to intensity-only background sorting.", + variance_predictor_cols, ml_info_file, + ) + return None + + if samples_used is not None: + ml_df = ml_df[ml_df["sample_ID"].isin(samples_used)] + ml_df = ml_df.drop(columns=["sample_ID"]) + + for col in variance_predictor_cols: + ml_df[col] = pd.to_numeric(ml_df[col], errors="coerce") + + available_cols = [c for c in variance_predictor_cols if c in ml_df.columns] + if not available_cols: + LOGGER.warning("None of the variance predictor columns found. Falling back.") + return None + + prec_features = ml_df.groupby("quant_id")[available_cols].median() + + ion_strings = np.array([str(x) for x in ion_index]) + precursor_ids = np.array([_ION_SPLIT_PAT.split(s)[0] for s in ion_strings]) + + scores = _fit_and_predict(prec_features, available_cols, + precursor_ids, ion_index, ion_variance, + ion_median_intensity) + if scores is None: + return None + + all_cols = (["median_intensity"] if ion_median_intensity is not None + else []) + available_cols + ion2varscore = dict(zip(ion_index, scores)) + LOGGER.info( + "Variance predictor scores computed for %d/%d ions using columns %s", + np.isfinite(scores).sum(), len(ion_index), all_cols, + ) + return ion2varscore + + +def _fit_and_predict(prec_features, available_cols, precursor_ids, + ion_index, ion_variance, + ion_median_intensity=None): + """Build feature matrix, fit OLS on observed variance, return predictions. + + Median intensity (when provided) is always prepended as the first + predictor column because it is typically the strongest correlate of + ion-level variance. The quality-metric columns refine the prediction. + + Ions with missing features or missing variance are excluded from fitting + but receive the median predicted score as fallback. + + Args: + prec_features (pd.DataFrame): Precursor-level feature medians + (index = quant_id, columns = available_cols). + available_cols (list[str]): Column names to use as predictors. + precursor_ids (np.ndarray): Precursor id for each ion in ion_index. + ion_index (pd.Index): Ion identifiers. + ion_variance (pd.Series): Observed per-ion variance. + ion_median_intensity (pd.Series | None): Per-ion pooled median + intensity. When provided, used as a built-in first predictor. + + Returns: + np.ndarray of predicted scores (length = len(ion_index)), or None + if the regression cannot be fitted. + """ + n_ions = len(ion_index) + n_cols = len(available_cols) + + X = np.full((n_ions, n_cols), np.nan) + for j, col in enumerate(available_cols): + col_vals = prec_features[col] + X[:, j] = [col_vals.get(pid, np.nan) for pid in precursor_ids] + + if ion_median_intensity is not None: + intensity_col = np.array( + [ion_median_intensity.get(ion, np.nan) for ion in ion_index], + dtype=float, + ).reshape(-1, 1) + X = np.column_stack([intensity_col, X]) + all_col_names = ["median_intensity"] + list(available_cols) + else: + all_col_names = list(available_cols) + + n_features = X.shape[1] + + y = np.array([ion_variance.get(ion, np.nan) for ion in ion_index], + dtype=float) + + valid = np.isfinite(X).all(axis=1) & np.isfinite(y) + if valid.sum() < max(10, n_features + 1): + LOGGER.warning( + "Too few valid ions (%d) for variance predictor regression. " + "Falling back to intensity-only sorting.", valid.sum() + ) + return None + + X_fit = X[valid] + y_fit = y[valid] + + # Standardise features for numerical stability + X_mean = X_fit.mean(axis=0) + X_std = X_fit.std(axis=0) + X_std[X_std < 1e-15] = 1.0 + X_fit_z = (X_fit - X_mean) / X_std + + # OLS with intercept: y = X_z @ beta + intercept + X_design = np.column_stack([np.ones(X_fit_z.shape[0]), X_fit_z]) + try: + beta, _, _, _ = np.linalg.lstsq(X_design, y_fit, rcond=None) + except np.linalg.LinAlgError: + LOGGER.warning("OLS fit failed for variance predictor. Falling back.") + return None + + LOGGER.info( + "Variance predictor coefficients (standardised): %s", + dict(zip(all_col_names, beta[1:])), + ) + + # Predict for all ions (including those excluded from fit) + X_all_z = (X - X_mean) / X_std + X_all_design = np.column_stack([np.ones(n_ions), X_all_z]) + predicted = X_all_design @ beta + + # Fallback for ions with missing features + finite_mask = np.isfinite(predicted) + if finite_mask.any(): + median_pred = np.median(predicted[finite_mask]) + else: + median_pred = 0.0 + predicted = np.where(finite_mask, predicted, median_pred) + + return predicted diff --git a/alphaquant/plotting/fcviz.py b/alphaquant/plotting/fcviz.py index 1851d857..97759f0d 100644 --- a/alphaquant/plotting/fcviz.py +++ b/alphaquant/plotting/fcviz.py @@ -118,7 +118,10 @@ def __init__(self, label_rotation = 90, add_stripplot = False, narrowing_factor_ shortened_xticklabels = False, remove_leaf_labels_in_tree = False, hide_root_in_tree = False, - exclude_outlier_fragments = True): + exclude_outlier_fragments = True, + highlight_excluded_nodes = True, + show_excluded_node_counts = True, + show_exclusion_legend = True): """ Configuration class for plotting. @@ -139,6 +142,10 @@ def __init__(self, label_rotation = 90, add_stripplot = False, narrowing_factor_ exclude_outlier_fragments (bool): Whether to exclude fragment ions marked as outliers from plots. When True (default), only fragments used in statistical aggregation are displayed. Mirrors the fragment_outlier_filtering behavior from the analysis pipeline. + highlight_excluded_nodes (bool): Whether tree plots should highlight nodes excluded from aggregation. + show_excluded_node_counts (bool): Whether visible parents should annotate how many hidden descendants + were excluded from aggregation. + show_exclusion_legend (bool): Whether tree plots should include a legend for exclusion highlighting. """ self.label_rotation = label_rotation self.add_stripplot = add_stripplot @@ -158,6 +165,9 @@ def __init__(self, label_rotation = 90, add_stripplot = False, narrowing_factor_ self.remove_leaf_labels_in_tree = remove_leaf_labels_in_tree self.hide_root_in_tree = hide_root_in_tree self.exclude_outlier_fragments = exclude_outlier_fragments + self.highlight_excluded_nodes = highlight_excluded_nodes + self.show_excluded_node_counts = show_excluded_node_counts + self.show_exclusion_legend = show_exclusion_legend # Node annotation configuration self.show_node_annotations = show_node_annotations @@ -180,6 +190,9 @@ def __init__(self, label_rotation = 90, add_stripplot = False, narrowing_factor_ 'min_reps': 'reps={}', 'fraction_consistent': 'cons={:.2f}', 'is_included': 'incl={}', + 'exclude_residual_decorrelation': 'decorr excl={}', + 'is_outlier_fragment': 'frag excl={}', + 'is_outlier_peptide': 'pep excl={}', 'missingval': 'miss={}' } else: diff --git a/alphaquant/plotting/multicond.py b/alphaquant/plotting/multicond.py index e69de29b..8705c9f4 100644 --- a/alphaquant/plotting/multicond.py +++ b/alphaquant/plotting/multicond.py @@ -0,0 +1,41 @@ +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np + + +def plot_proteoform_intensity_profiles(proteoform_df: pd.DataFrame): + """Creates a panel of proteoform intensity profiles for each protein in the dataframe. + + Args: + proteoform_df: proteoform dataframe loaded from the AlphaQuant output file `medianref_proteoforms.tsv` + + Returns: + fig, axes: matplotlib figure and axes objects + """ + fixed_columns = ['proteoform_id', 'peptides', 'number_of_peptides', 'protein', 'corr_to_ref', 'is_reference'] + conditions = proteoform_df.columns.difference(fixed_columns) + + grouped = proteoform_df.groupby('protein') + n_plots = len(grouped) + + n_cols = int(np.ceil(np.sqrt(n_plots))) + n_rows = int(np.ceil(n_plots / n_cols)) + + fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows)) + axes = axes.flatten() + + for ax, (protein, sub_df) in zip(axes, grouped): + for idx, row in sub_df.iterrows(): + style = '-' if row['is_reference'] else '--' + ax.plot(conditions, row[conditions], style, label=row['proteoform_id']) + ax.set_title(f'{protein}') + ax.set_xlabel('Tissue') + ax.set_ylabel('Expression Level') + ax.legend(title='Proteoform ID') + + for i in range(n_plots, len(axes)): + axes[i].axis('off') + + fig.tight_layout() + plt.show() + return fig, axes diff --git a/alphaquant/plotting/pairwise.py b/alphaquant/plotting/pairwise.py index 4fc97817..c56fc66a 100644 --- a/alphaquant/plotting/pairwise.py +++ b/alphaquant/plotting/pairwise.py @@ -11,6 +11,32 @@ LOGGER = logging.getLogger(__name__) +def plot_normalization_overview(normed_df, samplemap_df): + normed_df, sample2cond = aq_diff_utils.prepare_loaded_tables(normed_df, samplemap_df) + sample2cond = dict(zip(samplemap_df["sample"], samplemap_df["condition"])) + conditions = list(set([sample2cond.get(x) for x in normed_df.columns])) + conditions = [x for x in conditions if x is not None] + df_c1 = normed_df[[x for x in normed_df.columns if sample2cond.get(x) == conditions[0]]] + df_c2 = normed_df[[x for x in normed_df.columns if sample2cond.get(x) == conditions[1]]] + + plot_betweencond_fcs(df_c1, df_c2, merge_samples=True) + plot_sample_vs_median_fcs(df_c1, df_c2) + + +def plot_sample_vs_median_fcs(df_c1_normed, df_c2_normed): + combined_median = pd.concat([df_c1_normed, df_c2_normed], axis=1).median(axis=1, skipna=True) + fig, axes = plt.subplots() + for df in [df_c1_normed, df_c2_normed]: + for col in df.columns: + diff_fcs = df[col].subtract(combined_median) + axes.axvline(0, color='red', linestyle="dashed") + cutoff = max(abs(np.nanquantile(diff_fcs, 0.025)), abs(np.nanquantile(diff_fcs, 0.975))) + axes.hist(diff_fcs, 80, density=True, histtype='step', range=(-cutoff, cutoff), label=col) + axes.set_xlabel("log2(fc)") + axes.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.) + return fig, axes + + def plot_withincond_normalization(df_c1, df_c2): LOGGER.info("without missingvals (if applicable)") plot_betweencond_fcs(aqnorm.drop_nas_if_possible(df_c1), aqnorm.drop_nas_if_possible(df_c2), True) diff --git a/alphaquant/plotting/treeviz.py b/alphaquant/plotting/treeviz.py index 63d2515e..7642ec74 100644 --- a/alphaquant/plotting/treeviz.py +++ b/alphaquant/plotting/treeviz.py @@ -4,7 +4,9 @@ import networkx as nx import anytree import re +import shlex from matplotlib import gridspec +from matplotlib.lines import Line2D import matplotlib.pyplot as plt import alphaquant.cluster.cluster_utils as aqcluster_utils import alphaquant.plotting.base_functions as aqviz @@ -77,7 +79,7 @@ def _define_colorlist(self): self._colorlist_hex = [aqviz.rgb_to_hex(x) for x in self._plotconfig.colorlist] def _format_graph(self): - pos = nx.drawing.nx_agraph.graphviz_layout(self.graph, **self._graph_parameters.layout_params) + pos = _graphviz_layout(self.graph, self._graph_parameters.layout_params) root_id = id(self._protein) hide_root = getattr(self._plotconfig, 'hide_root_in_tree', False) @@ -87,14 +89,8 @@ def _format_graph(self): for node in nodes_to_draw: matching_anynode = self._id2anytree_node[node] - is_included = matching_anynode.is_included - if not is_included: - self._graph_parameters.node_options["alpha"] = self._graph_parameters.alpha_excluded - self._graph_parameters.node_options["node_color"] = self._determine_cluster_color(matching_anynode) - # Allow overriding node size from plotconfig - if hasattr(self._plotconfig, 'node_size') and self._plotconfig.node_size is not None: - self._graph_parameters.node_options["node_size"] = self._plotconfig.node_size - nx.draw_networkx_nodes(self.graph, pos, nodelist=[node], ax=self._ax, **self._graph_parameters.node_options) + node_options = self._get_node_options(matching_anynode) + nx.draw_networkx_nodes(self.graph, pos, nodelist=[node], ax=self._ax, **node_options) label_dict = nx.get_node_attributes(self.graph, 'label') @@ -112,9 +108,10 @@ def _format_graph(self): rotation = self._plotconfig.label_rotation if len(matching_anynode.children) == 0 else 0 self._ax.text(x, y, labelstring, verticalalignment='center', horizontalalignment='center', fontsize=self._plotconfig.node_fontsize, family='monospace', - weight = "bold", rotation = rotation) + weight = "bold", rotation = rotation, color=self._get_label_color(matching_anynode)) nx.draw_networkx_edges(self.graph, pos, edgelist=edges_to_draw, ax=self._ax, **self._graph_parameters.edge_options) + self._add_exclusion_legend(nodes_to_draw) # Add vertical padding to avoid cutting labels at top/bottom and hide axis frame try: @@ -133,25 +130,97 @@ def _format_graph(self): def _determine_cluster_color(self, anynode): return self._colorlist_hex[anynode.cluster] + def _get_node_options(self, anynode): + node_options = dict(self._graph_parameters.node_options) + if hasattr(self._plotconfig, 'node_size') and self._plotconfig.node_size is not None: + node_options["node_size"] = self._plotconfig.node_size + + node_options["node_color"] = self._determine_cluster_color(anynode) + node_options["alpha"] = self._graph_parameters.alpha_included + node_options["edgecolors"] = self._graph_parameters.included_edge_color + node_options["linewidths"] = self._graph_parameters.included_linewidth + + if self._should_highlight_exclusions() and self._is_directly_excluded(anynode): + node_options["node_color"] = self._graph_parameters.excluded_color + node_options["alpha"] = self._graph_parameters.alpha_excluded + node_options["edgecolors"] = self._graph_parameters.excluded_edge_color + node_options["linewidths"] = self._graph_parameters.excluded_linewidth + + return node_options + + def _get_label_color(self, anynode): + if self._should_highlight_exclusions() and self._is_directly_excluded(anynode): + return self._graph_parameters.excluded_label_color + return self._graph_parameters.included_label_color + + def _should_highlight_exclusions(self): + return getattr(self._plotconfig, "highlight_excluded_nodes", True) + + @classmethod + def _is_directly_excluded(cls, anynode): + return ( + not getattr(anynode, "is_included", True) + or getattr(anynode, "exclude_residual_decorrelation", False) + or getattr(anynode, "is_outlier_fragment", False) + or getattr(anynode, "is_outlier_peptide", False) + ) + + def _add_exclusion_legend(self, nodes_to_draw): + if not ( + self._should_highlight_exclusions() + and getattr(self._plotconfig, "show_exclusion_legend", True) + ): + return + + drawn_anynodes = [self._id2anytree_node[node] for node in nodes_to_draw] + has_direct = any(self._is_directly_excluded(node) for node in drawn_anynodes) + if not has_direct: + return + + legend_handles = [] + if has_direct: + legend_handles.append( + Line2D( + [0], + [0], + marker="o", + linestyle="", + markerfacecolor=self._graph_parameters.excluded_color, + markeredgecolor=self._graph_parameters.excluded_edge_color, + markeredgewidth=self._graph_parameters.excluded_linewidth, + markersize=8, + label="excluded from aggregation", + ) + ) + + self._ax.legend( + handles=legend_handles, + loc="upper right", + frameon=False, + fontsize=max(7, self._plotconfig.node_fontsize - 3), + ) - @staticmethod - def render_tree(root): - for pre, _, node in anytree.RenderTree(root): - print("%s%s" % (pre, node.name)) class GraphParameters(): def __init__(self): self.included_color = "skyblue" - self.excluded_color = "lightgrey" + self.excluded_color = "#D9D9D9" + self.included_edge_color = "#404040" + self.excluded_edge_color = "#808080" + self.included_label_color = "#202020" + self.excluded_label_color = "#333333" self.alpha_included = 0.6 # More transparent nodes - self.alpha_excluded = 0.3 # More transparent excluded nodes + self.alpha_excluded = 0.8 + self.included_linewidth = 1 + self.excluded_linewidth = 2.0 self.node_options = { "node_color": self.included_color, "node_size": 1500, - "linewidths": 1, + "linewidths": self.included_linewidth, + "edgecolors": self.included_edge_color, "alpha": self.alpha_included, # default alpha } @@ -171,6 +240,38 @@ def __init__(self): } +def _graphviz_layout(graph, layout_params): + try: + return nx.drawing.nx_agraph.graphviz_layout(graph, **layout_params) + except ImportError: + return _graphviz_plain_layout(graph, layout_params) + + +def _graphviz_plain_layout(graph, layout_params): + import graphviz + + prog = layout_params.get("prog", "dot") + graph_attr = {} + for token in shlex.split(layout_params.get("args", "")): + if token.startswith("-G") and "=" in token: + key, value = token[2:].split("=", 1) + graph_attr[key] = value + + dot = graphviz.Digraph(engine=prog, graph_attr=graph_attr) + for node in graph.nodes: + dot.node(str(node)) + for parent, child in graph.edges: + dot.edge(str(parent), str(child)) + + plain = dot.pipe(format="plain").decode("utf-8") + pos = {} + for line in plain.splitlines(): + parts = line.split() + if len(parts) >= 4 and parts[0] == "node": + pos[int(parts[1])] = (float(parts[2]) * 72.0, float(parts[3]) * 72.0) + return pos + + class TreeLabelFormatter: @classmethod @@ -252,8 +353,17 @@ def get_annotation_lines(cls, node, plotconfig): formatted = f"{attr}={value}" annotations.append(formatted) + return annotations + @classmethod + def get_exclusion_annotation_lines(cls, node): + """Return compact annotations for aggregation-excluded nodes.""" + if GraphCreator._is_directly_excluded(node): + return ["excluded"] + + return [] + class AnnotatedGraphCreator(GraphCreator): """Enhanced GraphCreator that supports configurable node annotations.""" @@ -264,7 +374,7 @@ def __init__(self, protein, ax, plotconfig): def _format_graph(self): """Override _format_graph to use the enhanced label formatter.""" - pos = nx.drawing.nx_agraph.graphviz_layout(self.graph, **self._graph_parameters.layout_params) + pos = _graphviz_layout(self.graph, self._graph_parameters.layout_params) root_id = id(self._protein) hide_root = getattr(self._plotconfig, 'hide_root_in_tree', False) @@ -274,17 +384,8 @@ def _format_graph(self): for node in nodes_to_draw: matching_anynode = self._id2anytree_node[node] - is_included = matching_anynode.is_included - if not is_included: - self._graph_parameters.node_options["alpha"] = self._graph_parameters.alpha_excluded - else: - self._graph_parameters.node_options["alpha"] = self._graph_parameters.alpha_included - - self._graph_parameters.node_options["node_color"] = self._determine_cluster_color(matching_anynode) - # Allow overriding node size from plotconfig - if hasattr(self._plotconfig, 'node_size') and self._plotconfig.node_size is not None: - self._graph_parameters.node_options["node_size"] = self._plotconfig.node_size - nx.draw_networkx_nodes(self.graph, pos, nodelist=[node], ax=self._ax, **self._graph_parameters.node_options) + node_options = self._get_node_options(matching_anynode) + nx.draw_networkx_nodes(self.graph, pos, nodelist=[node], ax=self._ax, **node_options) label_dict = nx.get_node_attributes(self.graph, 'label') @@ -316,9 +417,10 @@ def _format_graph(self): self._ax.text(x, y, labelstring, verticalalignment='center', horizontalalignment='center', fontsize=fontsize, family='monospace', weight="bold", - rotation=rotation) + rotation=rotation, color=self._get_label_color(matching_anynode)) nx.draw_networkx_edges(self.graph, pos, edgelist=edges_to_draw, ax=self._ax, **self._graph_parameters.edge_options) + self._add_exclusion_legend(nodes_to_draw) # Add vertical padding to avoid cutting labels at top/bottom and hide axis frame try: @@ -389,6 +491,3 @@ def define_tree_fig_and_ax(self): fig_width = min(max(8, num_leaves * 1.3),100) fig_height = max(8, max_depth * 2) self.fig, self.ax_tree = plt.subplots(figsize=(fig_width, fig_height)) - - - diff --git a/alphaquant/ptm/ptmsite_mapping.py b/alphaquant/ptm/ptmsite_mapping.py index 4a2fdcb3..f26d7ac0 100644 --- a/alphaquant/ptm/ptmsite_mapping.py +++ b/alphaquant/ptm/ptmsite_mapping.py @@ -9,6 +9,8 @@ aqconfig.setup_logging() LOGGER = logging.getLogger(__name__) import alphaquant.config.variables as aq_variables +import io +import zipfile #helper classes headers_dicts = {'Spectronaut' : {"label_column" : "R.Label", "fg_id_column" : "FG.Id", 'sequence' : "PEP.StrippedSequence", 'proteins' : "PG.UniProtIds", 'precursor_mz' : "FG.PrecMz", "precursor_charge" : "FG.Charge", @@ -563,26 +565,6 @@ def get_site_prob_overview(modpeps, refprot, refgene): return series_collected -# Cell -def add_ptmsite_infos_spectronaut(input_df, ptm_ids_df): - intersect_columns = input_df.columns.intersection(ptm_ids_df.columns) - if(len(intersect_columns)==2): - LOGGER.info(f"assigning ptms based on columns {intersect_columns}") - input_df = input_df.merge(ptm_ids_df, on=list(intersect_columns), how= 'left') - else: - raise Exception(f"Number of intersecting columns {intersect_columns} not as expected") - input_df = add_ptm_precursor_names_spectronaut(input_df) - input_df = input_df[~input_df["conditions"].isna()] - return input_df - -# Cell -def add_ptm_precursor_names_spectronaut(ptm_annotated_input): - delimiter = pd.Series(["_" for x in range(len(ptm_annotated_input.index))]) - ptm_annotated_input[QUANT_ID] = ptm_annotated_input["PEP.StrippedSequence"] + delimiter + ptm_annotated_input["FG.PrecMz"].astype('str') + delimiter + ptm_annotated_input["FG.Charge"].astype('str') + delimiter + ptm_annotated_input["REFPROT"] + delimiter +ptm_annotated_input["site"].astype('str') - ptm_annotated_input.gene.fillna('', inplace=True) - ptm_annotated_input["site_id"] = ptm_annotated_input["gene"].astype('str')+delimiter+ptm_annotated_input["REFPROT"].astype('str') + delimiter +ptm_annotated_input["site"].astype('str') - return ptm_annotated_input - # Cell def filter_input_table(input_type, modification_type,input_df): if input_type == "Spectronaut": @@ -699,7 +681,7 @@ def merge_ptmsite_mappings_write_table(spectronaut_file, mapped_df, modification # Write deduplicated result LOGGER.info(f"Writing deduplicated PTM table with {len(deduplicated_df)} rows to {ptmmapped_table_filename}") - deduplicated_df.to_csv(ptmmapped_table_filename, sep='\t', index=False) + write_dataframe_to_single_file_zip(deduplicated_df, ptmmapped_table_filename) else: # Write chunks directly for non-Spectronaut data (DIANN, etc.) @@ -767,7 +749,17 @@ def get_ptmmapped_filename(spectronaut_file): foldername = os.path.dirname(spectronaut_file_abspath) filename = os.path.basename(spectronaut_file_abspath) filename_reduced = filename.replace(".tsv", "") - return f"{foldername}/{filename_reduced}.ptmsite_mapped.tsv" #this file is not written to the progress folder + return f"{foldername}/{filename_reduced}.ptmsite_mapped.tsv.zip" #this file is not written to the progress folder + + +def write_dataframe_to_single_file_zip(df, zip_filename, sep="\t", index=False): + archive_name = os.path.basename(zip_filename) + if archive_name.endswith(".zip"): + archive_name = archive_name[:-4] + with zipfile.ZipFile(zip_filename, mode="w", compression=zipfile.ZIP_DEFLATED) as zf: + with zf.open(archive_name, mode="w", force_zip64=True) as buffer: + with io.TextIOWrapper(buffer, encoding="utf-8", newline="") as text: + df.to_csv(text, sep=sep, index=index) @@ -777,9 +769,15 @@ def add_ptmsite_info_to_subtable(spectronaut_df, labelid2ptmid, labelid2site, mo spectronaut_df = spectronaut_df[[x in labelid2ptmid.keys() for x in spectronaut_df["labelid"]]].copy() #drop peptides that have no ptm spectronaut_df["ptm_id"] = np.array([labelid2ptmid.get(x) for x in spectronaut_df["labelid"]]) #add the ptm_id row to the spectronaut table - modseq_typereplaced = np.array([str(x.replace(modification_type, "")) for x in spectronaut_df["EG.ModifiedSequence"]]) #EG.ModifiedSequence already determines a localization of the modification type. Replace all localizations and add the new localizations below - sites = np.array([str(labelid2site.get(x)) for x in spectronaut_df["labelid"]]) - spectronaut_df["ptm_mapped_modseq"] = np.char.add(modseq_typereplaced, sites) + modseq_typereplaced = pd.Series( + [str(x).replace(modification_type, "") for x in spectronaut_df["EG.ModifiedSequence"]], + index=spectronaut_df.index, + ) #EG.ModifiedSequence already determines a localization of the modification type. Replace all localizations and add the new localizations below + sites = pd.Series( + [str(labelid2site.get(x)) for x in spectronaut_df["labelid"]], + index=spectronaut_df.index, + ) + spectronaut_df["ptm_mapped_modseq"] = modseq_typereplaced + sites return spectronaut_df @@ -791,86 +789,3 @@ def get_ptmid_mappings(mapped_df): labelid2ptmid = dict(zip(labelid, ptm_ids)) labelid2site = dict(zip(labelid, site)) return labelid2ptmid, labelid2site - - - - -# Detect Changes in site occupancy - -import pandas as pd -import numpy as np - - -def initialize_ptmsite_df(ptmsite_file, samplemap_file): - """returns ptmsite_df, samplemap_df from files""" - samplemap_df, _ = initialize_sample2cond(samplemap_file) - ptmsite_df = pd.read_csv(ptmsite_file, sep = "\t") - return ptmsite_df, samplemap_df - -def detect_site_occupancy_change(cond1, cond2, ptmsite_df ,samplemap_df, min_valid_values = 2, threshold_prob = 0.05): - """ - uses a PTMsite df with headers "REFPROT", "gene","site", and headers for sample1, sample2, etc and determines - whether a site appears/dissappears between conditions based on some probability threshold - """ - - ptmsite_df["site_id"] = ptmsite_df["REFPROT"] + ptmsite_df["site"].astype("str") - ptmsite_df = ptmsite_df.set_index("site_id") - cond1_samples = list(set(samplemap_df[(samplemap_df["condition"]==cond1)]["sample"]).intersection(set(ptmsite_df.columns))) - cond2_samples = list(set(samplemap_df[(samplemap_df["condition"]==cond2)]["sample"]).intersection(set(ptmsite_df.columns))) - - ptmsite_df = ptmsite_df[cond1_samples + cond2_samples + ["REFPROT", "gene", "site"]] - filtvec = [(sum(~np.isnan(x))>0) for _, x in ptmsite_df[cond1_samples + cond2_samples].iterrows()] - ptmsite_df = ptmsite_df[filtvec] - ptmsite_df = ptmsite_df.sort_index() - - regulated_sites = [] - count = 0 - for ptmsite in ptmsite_df.index.unique(): - - site_df = ptmsite_df.loc[[ptmsite]] - if count%1000 ==0: - num_checks = len(ptmsite_df.index.unique()) - LOGGER.info(f"{count} of {num_checks} {count/num_checks :.2f}") - count+=1 - - cond1_vals = site_df[cond1_samples].to_numpy() - cond2_vals = site_df[cond2_samples].to_numpy() - - cond1_vals = cond1_vals[~np.isnan(cond1_vals)] - cond2_vals = cond2_vals[~np.isnan(cond2_vals)] - - numrep_c1 = len(cond1_vals) - numrep_c2 = len(cond2_vals) - - if(numrep_c11-threshold_prob - likely_c2 = cond2_prob>1-threshold_prob - direction = 0 - - if(unlikely_c1&likely_c2): - direction = -1 - if(unlikely_c2&likely_c1): - direction = 1 - - if direction!=0: - LOGGER.info("occpancy change detected") - refprot = site_df["REFPROT"].values[0] - gene = site_df["gene"].values[0] - site = site_df["site"].values[0] - regulated_sites.append([refprot, gene, site, direction, cond1_prob, cond2_prob, numrep_c1, numrep_c2]) - - - df_occupancy_change = pd.DataFrame(regulated_sites, columns=["REFPROT", "gene", "site", "direction", "c1_meanprob", "c2_meanprob", "c1_nrep", "c2_nrep"]) - return df_occupancy_change - - - - - diff --git a/alphaquant/quant_reader/table_reformatter.py b/alphaquant/quant_reader/table_reformatter.py index 9c3ebf0d..6dc17d8f 100644 --- a/alphaquant/quant_reader/table_reformatter.py +++ b/alphaquant/quant_reader/table_reformatter.py @@ -72,12 +72,11 @@ def merge_protein_cols_and_config_dict( def join_columns(df, columns, separator="_"): if len(columns) == 1: - return df[columns[0]].fillna("nan").infer_objects(copy=False).astype(str) + return df[columns[0]].fillna("nan").astype(str) else: return ( df[columns] .fillna("nan") - .infer_objects(copy=False) .astype(str) .agg(separator.join, axis=1) ) diff --git a/alphaquant/run_pipeline.py b/alphaquant/run_pipeline.py index e55653d4..a73e1f8d 100644 --- a/alphaquant/run_pipeline.py +++ b/alphaquant/run_pipeline.py @@ -46,22 +46,27 @@ def run_pipeline(input_file: str, condpairs_list: Optional[List[Tuple[str, str]]] = None, file_has_alphaquant_format: bool = False, min_valid_values: int = 2, - valid_values_filter_mode: str = "either", #options: "either", "and", "per_condition" + valid_values_filter_mode: str = "either", #options: "either", "both", "per_condition" min_valid_values_c1: int = 0, min_valid_values_c2: int = 0, min_num_ions: int = 1, minpep: int = 1, organism: Optional[str] = None, cluster_threshold_pval: float = 0.001, + cluster_threshold_ion_type: float = 0.01, cluster_threshold_fcfc: float = 0, fcdiff_cutoff_clustermerge = 0.5, use_ml: bool = True, + residual_decorrelation_tolerance: float = 0.10, + residual_decorrelation_min_keep: int = 1, + aggregation_mode: Union[str, dict] = "stouffer_decorrelation", take_median_ion: bool = True, perform_ptm_mapping: bool = False, perform_phospho_inference: bool = False, enable_experimental_ptm_counting_statistics: bool = False, ptm_fragment_selection: bool = False, outlier_correction: bool = True, + outlier_correction_factor: float = 1.0, normalize: bool = True, use_iontree_if_possible: bool = True, write_out_results_tree: bool = True, @@ -76,8 +81,14 @@ def run_pipeline(input_file: str, peptides_to_exclude_file: Optional[str] = None, reset_progress_folder: bool = False, peptide_outlier_filtering: bool = True, - fragment_outlier_filtering: bool = True, + max_n_fragments: Optional[int] = None, + ion_outlier_mad_threshold: Optional[float] = None, + classic_fragment_outlier_filtering: bool = False, + split_ion_backgrounds: bool = True, + use_variance_predictor: bool = False, + num_bg_contexts: int = 10, ion_test_method: str = 'diffdist', + summarization_nodes: Optional[List[str]] = None, minrep_both: Optional[int] = None, #deprecated minrep_either: Optional[int] = None, #deprecated minrep_c1: Optional[int] = None, #deprecated @@ -107,16 +118,32 @@ def run_pipeline(input_file: str, min_num_ions (int): Minimum number of ions required per peptide. Defaults to 1. minpep (int): Minimum number of peptides required per protein. Defaults to 1. organism (str): Organism name for PTM mapping (e.g., 'human', 'mouse'). Required if perform_ptm_mapping is True. - cluster_threshold_pval (float): P-value threshold for statistical clustering. Defaults to 0.001. + cluster_threshold_pval (float): P-value threshold for statistical clustering at the protein/gene level. Defaults to 0.001. + cluster_threshold_ion_type (float): P-value threshold for clustering at the fragment/MS1 isotope (ion_type) level. Defaults to 0.01. cluster_threshold_fcfc (float): Fold change threshold for clustering. Defaults to 0. fcdiff_cutoff_clustermerge (float): Fold change difference cutoff for merging peptide clusters. Defaults to 0.5. use_ml (bool): Enable machine learning analysis. Defaults to True. + residual_decorrelation_tolerance (float): Maximum allowed one-sided excess-CDF distance between corrected and null sibling-correlation distributions. Defaults to 0.10. + residual_decorrelation_min_keep (int): Minimum number of children to retain per parent during residual decorrelation pruning. Defaults to 1. + aggregation_mode (str | dict): Strategy for combining child z-values at the fragment/MS1 level + (where ions show intra-group dependencies). Higher levels always use Stouffer. + Can be a single string (applied to all dependent levels) or a dict mapping node types + to modes (e.g. ``{"frgion": "min_median_max_z", "ms1_isotopes": "median_z"}``). + - "stouffer_decorrelation" (default): Stouffer's method after residual-decorrelation pruning. + - "mean_z": Arithmetic mean of z-values (conservative, treats children as one measurement). + - "median_z": Median z-value (robust to outlier children). + - "min_median_max_z": Combine min, median, max z-values assuming independence (3-point summary). + - "min_max_z": Combine min, max z-values assuming independence (2-point summary). + - "summed_z": Classic Stouffer assuming independence (rho=0, ignores ICC correction). take_median_ion (bool): Use median-centered fragment ions for peptide comparisons. Defaults to True. perform_ptm_mapping (bool): Enable PTM site mapping analysis. Defaults to False. perform_phospho_inference (bool): Enable phosphorylation-prone region annotation. Defaults to False. enable_experimental_ptm_counting_statistics (bool): Allow experimental PTM counting statistics with "either" mode or zero min_valid_values. Defaults to False. ptm_fragment_selection (bool): If True, enable PTM-oriented fragment selection in clustering. outlier_correction (bool): Enable outlier correction in differential testing. Defaults to True. + outlier_correction_factor (float): Multiplicative factor for the outlier correction scaling. + Values > 1.0 make the correction more aggressive (more conservative p-values), + values < 1.0 make it less aggressive. Only effective when outlier_correction is True. Defaults to 1.0. normalize (bool): Enable sample and condition normalization. Defaults to True. use_iontree_if_possible (bool): Use ion tree structure when available. Defaults to True. write_out_results_tree (bool): Write results in hierarchical tree format. Defaults to True. @@ -131,13 +158,42 @@ def run_pipeline(input_file: str, peptides_to_exclude_file (str): File listing peptides to exclude (e.g., shared between species). reset_progress_folder (bool): Clear and recreate the progress folder. Defaults to False. peptide_outlier_filtering (bool): Enable few peptides per protein filtering for statistical outlier correction. When True, filters outlier peptides based on significance distribution within the protein/gene. Defaults to True. - fragment_outlier_filtering (bool): Enable fragment outlier filtering when aggregating fragments to peptides. When True, removes extreme fragments before statistical aggregation. Defaults to True. + max_n_fragments (int or None): Maximum number of fragment ions to keep per peptide when aggregating. + When set, only the fragments with z-values closest to the median are retained; the rest are + discarded. None (default) means no limit. + ion_outlier_mad_threshold (float or None): MAD-based outlier threshold for base ion z-values. + When set, ions whose z-value deviates more than ``threshold * MAD`` from the sibling median + are removed before aggregation (requires >= 4 siblings; always keeps >= 2). None (default) + disables this filter. Typical values: 2.5 – 3.0. + classic_fragment_outlier_filtering (bool): Use the legacy fragment outlier filter that keeps only + the 4 most central fragment ions (by z-value) when a peptide has more than 4 fragments. + Applied after MAD / max_n_fragments filtering. Defaults to False. + split_ion_backgrounds (bool): Build separate empirical background distributions for fragment ions + and MS1 isotopes instead of pooling them together. Defaults to True. + use_variance_predictor (bool): Use a linear regression model to predict + ion variance from quality metrics (e.g. Cscore, ShapeQualityScore) for + sorting ions before background partitioning. When False, ions are sorted + by median intensity instead. Defaults to False. + num_bg_contexts (int): Number of overlapping windows used to partition + ions into empirical background distributions. Higher values create + more fine-grained intensity-dependent backgrounds. Defaults to 10. ion_test_method (str): Ion-level test to compute ion statistics. Options: - "diffdist" (default): Use empirical background distributions (DifferentialIon). - "ttest": Use Welch two-sample t-test (DifferentialIonTTest), p→z via cached fast inversion. + summarization_nodes (list[str]): Node types at which to sum child intensities + before differential analysis. For each specified node type the base-ion + leaves are collected and their linear intensities summed per replicate. + Ion types (frgion, ms1_isotopes, precursor) are never mixed. + Examples: ``["frgion"]`` sums fragments per precursor; + ``["frgion", "ms1_isotopes"]`` sums both; ``["mod_seq_charge"]`` sums + everything per precursor (split by ion type). Defaults to None (no + summarization). """ LOGGER.info("Starting AlphaQuant") + if summarization_nodes is None: + summarization_nodes = [] + ######################################################### # TODO: this backwards compatibility can be removed beginning of 2026 # to ensure backwards compatibility: in case the minrep paramters are set, we need to convert them to the min_valid_values and valid_values_filter_mode parameters @@ -168,16 +224,17 @@ def run_pipeline(input_file: str, if file_has_alphaquant_format: LOGGER.info("Input file is already in AlphaQuant format. Skipping reformatting.") input_file_reformat = input_file_original - # For pre-formatted files, use a generic input type that doesn't require specific columns input_type = input_type_to_use if input_type_to_use is not None else "generic_preformatted" annotation_file = None - use_ml = False # Disable ML for pre-formatted files - # Skip to the main analysis + use_ml = False + variance_predictor_cols = None else: create_progress_folder_if_applicable(input_file_original, reset_progress_folder) input_type, config_dict, _ = config_dict_loader.get_input_type_and_config_dict(input_file_original, input_type_to_use) annotation_file = load_annotation_file(input_file_original, input_type, annotation_columns) use_ml = check_if_table_supports_ml(config_dict) & use_ml + variance_predictor_cols = config_dict.get("variance_predictor_cols", None) + aqvariables.set_input_config(input_type, config_dict) if perform_ptm_mapping and not file_has_alphaquant_format: if modification_type is None: @@ -218,8 +275,13 @@ def run_pipeline(input_file: str, aqvariables.determine_variables(input_file_reformat, input_type) aqvariables.set_peptide_outlier_filtering(peptide_outlier_filtering) + aqvariables.set_outlier_correction_factor(outlier_correction_factor) + aqvariables.NUM_BG_CONTEXTS = num_bg_contexts # Configure PTM-specific fragment selection: enabled if either PTM mapping is performed or explicit flag is set aqvariables.set_ptm_fragment_selection(perform_ptm_mapping or ptm_fragment_selection) + aqvariables.set_max_n_fragments(max_n_fragments) + aqvariables.set_ion_outlier_mad_threshold(ion_outlier_mad_threshold) + aqvariables.set_classic_fragment_outlier_filtering(classic_fragment_outlier_filtering) #use runconfig object to store the parameters runconfig = ConfigOfRunPipeline(locals()) #all the parameters given into the function are transfered to the runconfig object! The runconfig is then used as the input for the run_analysis functions @@ -309,10 +371,14 @@ def check_if_table_supports_ml(config_dict): return is_longtable and ml_level_charge def load_ml_info_file(input_file, input_type, modification_type = None): - ml_info_filename = aq_utils.get_progress_folder_filename(input_file, f".ml_info_table.tsv") + ml_info_filename = aq_utils.get_progress_folder_filename(input_file, f".ml_info_table.tsv.zip") + old_ml_info_filename = aq_utils.get_progress_folder_filename(input_file, f".ml_info_table.tsv") if os.path.exists(ml_info_filename):#in case there already is a reformatted file, we don't need to reformat it again LOGGER.info(f"ML info file already exists. Using ML info file of type {input_type}") return ml_info_filename + elif os.path.exists(old_ml_info_filename): + LOGGER.info(f"Uncompressed ML info file already exists. Using ML info file of type {input_type}") + return old_ml_info_filename else: return aq_ml_info_table.MLInfoTableCreator(input_file, input_type, modification_type).ml_info_filename @@ -353,4 +419,3 @@ def run_analysis_multiprocess(condpair_combinations, runconfig, num_cores): aqcondpair.analyze_condpair(runconfig= runconfig, condpair = condpair) ,condpair_combinations) - diff --git a/alphaquant/tables/alphadia_reader.py b/alphaquant/tables/alphadia_reader.py index f7f421ec..63be1412 100644 --- a/alphaquant/tables/alphadia_reader.py +++ b/alphaquant/tables/alphadia_reader.py @@ -26,17 +26,23 @@ def __init__(self, fragment_matrix_file: str): fragment_matrix_file (str): Path to the fragment matrix file """ - self.ml_info_file = aq_utils.get_progress_folder_filename(fragment_matrix_file, ".ml_info_table.tsv") + self.ml_info_file = aq_utils.get_progress_folder_filename(fragment_matrix_file, ".ml_info_table.tsv.zip") + self.old_ml_info_file = aq_utils.get_progress_folder_filename(fragment_matrix_file, ".ml_info_table.tsv") self.input_file_reformat = aq_utils.get_progress_folder_filename(fragment_matrix_file, ".alphadia_fragion.aq_reformat.tsv", remove_extension=False) precursor_file = os.path.join(os.path.dirname(fragment_matrix_file), "precursors.tsv") self._precursor_df = aq_reader_utils.read_file(precursor_file, sep="\t") self._precursor2quantID = self._precursor2quantid() - if not os.path.exists(self.ml_info_file): + if os.path.exists(self.old_ml_info_file) and not os.path.exists(self.ml_info_file): + self.ml_info_file = self.old_ml_info_file + LOGGER.info(f"ML info file already exists at {self.ml_info_file}") + elif not os.path.exists(self.ml_info_file): LOGGER.info(f"Creating ML info file") self.ml_info_df = self._define_ml_info_table() - self.ml_info_df.to_csv(self.ml_info_file, sep="\t", index=False) + archive_name = os.path.basename(self.ml_info_file).removesuffix(".zip") + compression = {"method": "zip", "archive_name": archive_name} + self.ml_info_df.to_csv(self.ml_info_file, sep="\t", index=False, compression=compression) else: LOGGER.info(f"ML info file already exists at {self.ml_info_file}") diff --git a/alphaquant/ui/dashboard_parts_run_pipeline.py b/alphaquant/ui/dashboard_parts_run_pipeline.py index 608637f6..31b4d889 100644 --- a/alphaquant/ui/dashboard_parts_run_pipeline.py +++ b/alphaquant/ui/dashboard_parts_run_pipeline.py @@ -15,8 +15,6 @@ import alphaquant.run_pipeline as diffmgr import alphaquant.config.variables as aq_variables import alphaquant.ui.dashboad_parts_plots_basic as dashboad_parts_plots_basic -import alphaquant.ui.dashboard_parts_plots_proteoforms as dashboad_parts_plots_proteoforms -import alphaquant.ui.gui as gui import alphaquant.ui.gui_textfields as gui_textfields import alphaquant.utils.reader_utils as aq_reader_utils @@ -374,6 +372,14 @@ def _make_widgets(self): description='Fold change threshold for highlighting significant changes in the volcano plot' ) + self.aggregation_mode = pn.widgets.Select( + name='Z-value aggregation mode:', + options=['stouffer_decorrelation', 'mean_z', 'median_z', 'min_median_max_z'], + value='stouffer_decorrelation', + width=300, + description='Strategy for combining child z-values during tree propagation' + ) + self.condition_comparison_header = pn.pane.Markdown( "### Available Condition Comparisons", visible=True @@ -402,7 +408,7 @@ def _make_widgets(self): name='Enable machine learning', value=True, width=300 - ), + ), 'take_median_ion': pn.widgets.Checkbox( name='Use median-centered ions', value=True, @@ -447,6 +453,11 @@ def _make_widgets(self): name='Generate runtime plots', value=True, width=300 + ), + 'split_ion_backgrounds': pn.widgets.Checkbox( + name='Separate backgrounds by ion type', + value=True, + width=300 ), 'peptide_outlier_filtering': pn.widgets.Checkbox( name='Use few peptides per protein', @@ -466,6 +477,7 @@ def _make_widgets(self): 'write_out_results_tree': pn.pane.Markdown('Save detailed results in a tree structure'), 'use_multiprocessing': pn.pane.Markdown('Use multiple CPU cores to speed up processing (may use more memory)'), 'runtime_plots': pn.pane.Markdown('Create plots during analysis to visualize the process'), + 'split_ion_backgrounds': pn.pane.Markdown('Build separate empirical backgrounds for fragment ions and MS1 isotopes to reduce conservative bias'), 'peptide_outlier_filtering': pn.pane.Markdown('Filter outlier peptides based on significance for proteins with gene-level nodes'), } @@ -586,6 +598,9 @@ def create_checkbox_with_description(key, checkbox): self.minpep, self.cluster_threshold_pval, pn.layout.Divider(), + "### Aggregation", + self.aggregation_mode, + pn.layout.Divider(), "### Analysis Options", *checkbox_items, ), @@ -806,6 +821,7 @@ def _run_pipeline(self, *events): "min_valid_values_c2": self.min_valid_values_c2.value if self.valid_values_filter_mode.value == 'set min. valid values per condition' else None, # Add the switch values to the pipeline parameters 'use_ml': self.switches['use_ml'].value, + 'aggregation_mode': self.aggregation_mode.value, 'take_median_ion': self.switches['take_median_ion'].value, 'perform_ptm_mapping': self.switches['perform_ptm_mapping'].value, 'perform_phospho_inference': self.switches['perform_phospho_inference'].value, @@ -815,6 +831,7 @@ def _run_pipeline(self, *events): 'write_out_results_tree': self.switches['write_out_results_tree'].value, 'use_multiprocessing': self.switches['use_multiprocessing'].value, 'runtime_plots': self.switches['runtime_plots'].value, + 'split_ion_backgrounds': self.switches['split_ion_backgrounds'].value, 'peptide_outlier_filtering': self.switches['peptide_outlier_filtering'].value, } @@ -1470,63 +1487,3 @@ def _update_tabs(self, event=None): self.main_tabs[1] = ('Plotting', pn.pane.Markdown( f"### Visualization Error\n\n{error_msg}" )) - - -def build_dashboard(): - """Build the overall dashboard layout.""" - # Create state manager first - state_manager = gui.DashboardState() - - header = HeaderWidget( - title="AlphaQuant Dashboard", - img_folder_path="./assets", - github_url="https://github.com/" - ) - main_text = MainWidget( - description=( - "Welcome to our analysis dashboard. " - "Please load your data and run the pipeline." - ), - manual_path="path/to/manual.pdf" - ) - - # Create pipeline instance with state manager - pipeline = RunPipeline(state=state_manager) - pipeline_layout = pipeline.create() - - # Create plotting tabs with state manager - plotting_tab = dashboad_parts_plots_basic.PlottingTab(state=state_manager) - proteoform_tab = dashboad_parts_plots_proteoforms.ProteoformPlottingTab(state=state_manager) - - # Register subscribers - state_manager.register_subscriber(plotting_tab) - state_manager.register_subscriber(proteoform_tab) - - # Create tabs - all_tabs = pn.Tabs( - ('Pipeline', pipeline_layout), - ('Single Comparison', plotting_tab.panel()), - ('Plotting', proteoform_tab.panel()), - dynamic=True, - tabs_location='above', - sizing_mode='stretch_width' - ) - - # Main layout - main_layout = pn.Column( - header.create(), - pn.layout.Divider(), - main_text.create(), - all_tabs, - sizing_mode='stretch_width' - ) - - template = pn.template.FastListTemplate( - title="AlphaQuant Analysis", - sidebar=[], - main=[main_layout], - theme='dark', - main_max_width="1200px", - main_layout="width" - ) - return template diff --git a/tests/unit_tests/test_background_distributions.py b/tests/unit_tests/test_background_distributions.py index 743fab95..c77fe48b 100644 --- a/tests/unit_tests/test_background_distributions.py +++ b/tests/unit_tests/test_background_distributions.py @@ -27,8 +27,16 @@ def fixed_input(sample2cond_df): def background_distributions(fixed_input): """Create background distributions for testing caching""" condbg = aq_diff_bg.ConditionBackgrounds(fixed_input, {}) - # Get a few different background distributions for testing - bg_list = list(condbg.backgrounds[:5]) # Get first 5 backgrounds + # Collect unique BackGroundDistribution objects from ion2background + seen = set() + bg_list = [] + for bg in condbg.ion2background.values(): + bg_id = id(bg) + if bg_id not in seen: + seen.add(bg_id) + bg_list.append(bg) + if len(bg_list) >= 5: + break return bg_list def test_condition_backgrounds(fixed_input): @@ -195,12 +203,10 @@ def test_cache_key_uniqueness_across_different_distributions(self, fixed_input): condbg1 = aq_diff_bg.ConditionBackgrounds(fixed_input, {}) condbg2 = aq_diff_bg.ConditionBackgrounds(fixed_input, {}) # Different instance with same data - # Get some backgrounds from each - bg1_from_condbg1 = condbg1.backgrounds[0] - bg1_from_condbg2 = condbg2.backgrounds[0] + # Get one background from each + bg1_from_condbg1 = next(iter(condbg1.ion2background.values())) + bg1_from_condbg2 = next(iter(condbg2.ion2background.values())) - # Even though they're created from the same data, they should have different keys - # (since they represent different object instances) key1 = bg1_from_condbg1.get_cache_key() key2 = bg1_from_condbg2.get_cache_key() @@ -231,3 +237,97 @@ def test_cache_key_efficiency_and_reliability(self, background_distributions): for bg_key in cache_key: assert isinstance(bg_key, tuple), "Each background key should be a tuple" assert len(bg_key) == 6, "Each background key should have 6 elements" + + +# --------------------------------------------------------------------------- +# Tests for _split_by_ion_type, _has_multiple_ion_types, and split background building +# (new on add_summarization_approach branch) +# --------------------------------------------------------------------------- +import alphaquant.config.variables as aqvariables + + +class TestSplitByIonType: + def test_correct_masks(self): + index = pd.Index([ + "PEP1_FRGION_y3", + "PEP1_FRGION_y4", + "PEP2_MS1ISOTOPES_0", + "PEP2_MS1ISOTOPES_1", + ]) + result = aq_diff_bg.ConditionBackgrounds._split_by_ion_type(index) + assert list(result["FRGION"]) == [True, True, False, False] + assert list(result["MS1ISOTOPES"]) == [False, False, True, True] + + def test_all_frgion(self): + index = pd.Index(["A_FRGION_y1", "B_FRGION_b2"]) + result = aq_diff_bg.ConditionBackgrounds._split_by_ion_type(index) + assert result["FRGION"].all() + assert not result["MS1ISOTOPES"].any() + + def test_empty_index(self): + index = pd.Index([]) + result = aq_diff_bg.ConditionBackgrounds._split_by_ion_type(index) + assert len(result["FRGION"]) == 0 + assert len(result["MS1ISOTOPES"]) == 0 + + +class TestHasMultipleIonTypes: + def setup_method(self): + self._original_config = aqvariables.CONFIG_DICT + + def teardown_method(self): + aqvariables.CONFIG_DICT = self._original_config + + def test_true_when_both_present(self): + aqvariables.CONFIG_DICT = { + "ion_hierarchy": {"fragion": {}, "ms1iso": {}} + } + assert aq_diff_bg.ConditionBackgrounds._has_multiple_ion_types() is True + + def test_false_when_only_fragion(self): + aqvariables.CONFIG_DICT = { + "ion_hierarchy": {"fragion": {}} + } + assert aq_diff_bg.ConditionBackgrounds._has_multiple_ion_types() is False + + def test_false_when_config_is_none(self): + aqvariables.CONFIG_DICT = None + assert aq_diff_bg.ConditionBackgrounds._has_multiple_ion_types() is False + + +class TestConditionBackgroundsSplit: + """Integration-level tests: verify split vs. single-pool mode.""" + + def _make_mixed_df(self, n_frg=50, n_ms1=50, n_samples=4): + rng = np.random.RandomState(42) + frg_names = [f"PEP{i}_FRGION_y{j}" for i in range(n_frg) for j in [1]] + ms1_names = [f"PEP{i}_MS1ISOTOPES_{j}" for i in range(n_ms1) for j in [0]] + all_names = frg_names + ms1_names + data = 10 + rng.randn(len(all_names), n_samples) + cols = [f"S{i}" for i in range(n_samples)] + return pd.DataFrame(data, index=all_names, columns=cols) + + def setup_method(self): + self._original_config = aqvariables.CONFIG_DICT + + def teardown_method(self): + aqvariables.CONFIG_DICT = self._original_config + + def test_split_mode_assigns_all_ions(self): + aqvariables.CONFIG_DICT = { + "ion_hierarchy": {"fragion": {}, "ms1iso": {}} + } + df = self._make_mixed_df() + cb = aq_diff_bg.ConditionBackgrounds(df, {}, split_by_ion_type=True) + assert set(cb.ion2background.keys()) == set(df.index) + + def test_single_pool_mode_assigns_all_ions(self): + df = self._make_mixed_df() + cb = aq_diff_bg.ConditionBackgrounds(df, {}, split_by_ion_type=False) + assert set(cb.ion2background.keys()) == set(df.index) + + def test_split_false_when_config_missing(self): + aqvariables.CONFIG_DICT = None + df = self._make_mixed_df() + cb = aq_diff_bg.ConditionBackgrounds(df, {}, split_by_ion_type=True) + assert set(cb.ion2background.keys()) == set(df.index) diff --git a/tests/unit_tests/test_cluster_missingval.py b/tests/unit_tests/test_cluster_missingval.py new file mode 100644 index 00000000..50622f4d --- /dev/null +++ b/tests/unit_tests/test_cluster_missingval.py @@ -0,0 +1,110 @@ +import anytree +import pytest + +import alphaquant.cluster.cluster_missingval as aq_missingval + + +def _reset_global(): + """Reset the module-level test-level global before each test.""" + aq_missingval.MISSINGVAL_TEST_LEVEL = None + + +# --------------------------------------------------------------------------- +# Helper: build minimal tree structures for each scenario +# --------------------------------------------------------------------------- + +def _tree_with_mod_seq_charge(): + """gene -> seq -> mod_seq -> mod_seq_charge -> frgion -> base""" + root = anytree.Node("gene1", type="gene") + seq = anytree.Node("seq1", parent=root, type="seq") + mod = anytree.Node("mod1", parent=seq, type="mod_seq") + msc = anytree.Node("msc1", parent=mod, type="mod_seq_charge") + frg = anytree.Node("frg1", parent=msc, type="frgion") + anytree.Node("base1", parent=frg, type="base") + anytree.Node("base2", parent=frg, type="base") + return root + + +def _tree_mod_seq_above_leaves(): + """gene -> seq -> mod_seq -> base (no mod_seq_charge)""" + root = anytree.Node("gene1", type="gene") + seq = anytree.Node("seq1", parent=root, type="seq") + mod = anytree.Node("mod1", parent=seq, type="mod_seq") + anytree.Node("base1", parent=mod, type="base") + anytree.Node("base2", parent=mod, type="base") + return root + + +def _tree_seq_above_leaves(): + """gene -> seq -> base (precursor-only, no mod_seq)""" + root = anytree.Node("gene1", type="gene") + seq = anytree.Node("seq1", parent=root, type="seq") + anytree.Node("base1", parent=seq, type="base") + anytree.Node("base2", parent=seq, type="base") + return root + + +def _tree_gene_above_leaves(): + """gene -> base (simplest hierarchy)""" + root = anytree.Node("gene1", type="gene") + anytree.Node("base1", parent=root, type="base") + anytree.Node("base2", parent=root, type="base") + return root + + +# --------------------------------------------------------------------------- +# Tests for determine_missingval_test_level +# --------------------------------------------------------------------------- + +class TestDetermineMissingvalTestLevel: + def setup_method(self): + _reset_global() + + def test_mod_seq_charge_tree(self): + root = _tree_with_mod_seq_charge() + aq_missingval.determine_missingval_test_level(root) + assert aq_missingval.MISSINGVAL_TEST_LEVEL == "mod_seq_charge" + + def test_mod_seq_above_leaves(self): + root = _tree_mod_seq_above_leaves() + aq_missingval.determine_missingval_test_level(root) + assert aq_missingval.MISSINGVAL_TEST_LEVEL == "base" + + def test_seq_above_leaves(self): + root = _tree_seq_above_leaves() + aq_missingval.determine_missingval_test_level(root) + assert aq_missingval.MISSINGVAL_TEST_LEVEL == "base" + + def test_gene_above_leaves(self): + root = _tree_gene_above_leaves() + aq_missingval.determine_missingval_test_level(root) + assert aq_missingval.MISSINGVAL_TEST_LEVEL == "base" + + def test_unexpected_structure_raises(self): + root = anytree.Node("root", type="unknown_top") + anytree.Node("leaf", parent=root, type="base") + with pytest.raises(ValueError, match="Unexpected tree structure"): + aq_missingval.determine_missingval_test_level(root) + + +class TestGetNodesToTest: + """Tests for MissingValProtNodeCreator._get_nodes_to_test.""" + + def setup_method(self): + _reset_global() + + def test_returns_mod_seq_charge_nodes_when_present(self): + root = _tree_with_mod_seq_charge() + nodes = aq_missingval.MissingValProtNodeCreator._get_nodes_to_test(root) + assert all(n.type == "mod_seq_charge" for n in nodes) + + def test_returns_leaves_for_mod_seq_tree(self): + root = _tree_mod_seq_above_leaves() + nodes = aq_missingval.MissingValProtNodeCreator._get_nodes_to_test(root) + assert all(n.type == "base" for n in nodes) + assert set(n.name for n in nodes) == {"base1", "base2"} + + def test_returns_leaves_for_gene_tree(self): + root = _tree_gene_above_leaves() + nodes = aq_missingval.MissingValProtNodeCreator._get_nodes_to_test(root) + assert all(n.type == "base" for n in nodes) diff --git a/tests/unit_tests/test_clusterutils.py b/tests/unit_tests/test_clusterutils.py index 001ad134..b9f2caea 100644 --- a/tests/unit_tests/test_clusterutils.py +++ b/tests/unit_tests/test_clusterutils.py @@ -66,43 +66,6 @@ def test_remove_unnecessary_attributes(): -def test_traverse_and_add_included_leaves_anytree(): - # Constructing the tree - root = anytree.Node("root", is_included=True, cluster=0) - node1 = anytree.Node("node1", parent=root, is_included=True, cluster=0) - node2 = anytree.Node("node2", parent=root, is_included=True, cluster=0) - leaf1 = anytree.Node("leaf1", parent=node1, is_included=True, cluster=0) - leaf2 = anytree.Node("leaf2", parent=node1, is_included=False, cluster=1) - leaf3 = anytree.Node("leaf3", parent=node2, is_included=True, cluster=0) - - list_of_included_leaves = [] - aq_clust_clusterutils.traverse_and_add_included_leaves(root, list_of_included_leaves) - print(list_of_included_leaves) - # Assert conditions - assert leaf1 in list_of_included_leaves, "leaf1 is missing from the result." - assert leaf3 in list_of_included_leaves, "leaf3 is missing from the result." - assert len(list_of_included_leaves) == 2, "The number of included leaves is incorrect." - - - root = anytree.Node("root", is_included=True, cluster=0) - node1 = anytree.Node("node1", parent=root, is_included=False, cluster=1) - node2 = anytree.Node("node2", parent=root, is_included=True, cluster=0) - leaf1 = anytree.Node("leaf1", parent=node1, is_included=True, cluster=0) - leaf2 = anytree.Node("leaf2", parent=node1, is_included=False, cluster=1) - leaf3 = anytree.Node("leaf3", parent=node2, is_included=True, cluster=0) - - list_of_included_leaves = [] - aq_clust_clusterutils.traverse_and_add_included_leaves(root, list_of_included_leaves) - print(list_of_included_leaves) - # Assert conditions - assert leaf1 not in list_of_included_leaves, "leaf1 should be excluded" - assert leaf3 in list_of_included_leaves, "leaf3 is missing from the result." - assert len(list_of_included_leaves) == 1, "The number of included leaves is incorrect." - - print("All tests passed!") - - - def test_iterate_through_tree_levels_bottom_to_top(): @@ -216,44 +179,14 @@ def test_remove_outlier_fragion_childs_complete(): assert fragments_10[idx].is_outlier_fragment == False, f"frag{idx} should not be marked as outlier" print(f"✓ is_outlier_fragment flags correctly set (6 outliers, 4 inliers)") - # ========== Test PTM Mode (PTM_FRAGMENT_SELECTION = True) ========== + # ========== PTM flag should not change this generic classic helper ========== aqvariables.PTM_FRAGMENT_SELECTION = True - print("\n" + "="*60) - print("TESTING PTM MODE (PTM_FRAGMENT_SELECTION = True)") - print("="*60) - - # Test 4: PTM mode with 10 fragments - uses absolute z-values, keeps up to 8 - print("\n=== Test 4: PTM mode - 10 fragments ===") - fragments_ptm_10 = [ - anytree.Node(f"frag{i}", z_val=float(i-5)) for i in range(10) - ] - # z_vals: -5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0 - # Absolute z_vals: 5.0, 4.0, 3.0, 2.0, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0 - # Sorted by abs: 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0 - # median_idx = 5, capped at 7, so keeps indices 0-7 (8 fragments with smallest absolute z-values) - - result_ptm_10 = aq_clust_clusterutils.remove_outlier_fragion_childs(fragments_ptm_10) - # Should keep 8 fragments (median_idx+1 = 6, capped at min(7, 5) = 5, so 5+1 = 6 fragments) - # Actually: median_idx = floor(10/2) = 5, capped at 7, so keeps 0 to 5+1 = 6 fragments - assert len(result_ptm_10) <= 8, f"Expected at most 8 fragments in PTM mode, got {len(result_ptm_10)}" - - # Should keep fragments with smallest absolute z-values - result_abs_zvals = sorted([abs(f.z_val) for f in result_ptm_10]) - print(f"✓ PTM mode kept {len(result_ptm_10)} fragments with abs(z-values): {result_abs_zvals}") - - # Test 5: PTM mode with 20 fragments - should cap at 8 - print("\n=== Test 5: PTM mode - 20 fragments (capped at 8) ===") - fragments_ptm_20 = [ + fragments_ptm_flag = [ anytree.Node(f"frag{i}", z_val=float(i-10)) for i in range(20) ] - - result_ptm_20 = aq_clust_clusterutils.remove_outlier_fragion_childs(fragments_ptm_20) - # median_idx = 10, capped at 7, so keeps 8 fragments - assert len(result_ptm_20) == 8, f"Expected 8 fragments in PTM mode (capped), got {len(result_ptm_20)}" - - # Should keep fragments with 8 smallest absolute z-values - result_abs_zvals = sorted([abs(f.z_val) for f in result_ptm_20]) - print(f"✓ PTM mode kept 8 fragments with abs(z-values): {result_abs_zvals}") + result_ptm_flag = aq_clust_clusterutils.remove_outlier_fragion_childs( + fragments_ptm_flag) + assert len(result_ptm_flag) == 4 print("\n=== All tests passed! ===") @@ -262,6 +195,180 @@ def test_remove_outlier_fragion_childs_complete(): aqvariables.PTM_FRAGMENT_SELECTION = original_ptm_setting + +# --------------------------------------------------------------------------- +# Tests for combine_zvalues / sum_and_re_scale_zvalues (new on this branch) +# --------------------------------------------------------------------------- +import numpy as np +import pytest + + +class TestCombineZvalues: + """Tests for combine_zvalues dispatch and individual modes.""" + + def test_single_value_returns_unchanged(self): + assert aq_clust_clusterutils.combine_zvalues([1.5], mode="stouffer_decorrelation") == 1.5 + assert aq_clust_clusterutils.combine_zvalues([1.5], mode="mean_z") == 1.5 + assert aq_clust_clusterutils.combine_zvalues([1.5], mode="median_z") == 1.5 + assert aq_clust_clusterutils.combine_zvalues([1.5], mode="min_median_max_z") == 1.5 + + def test_invalid_mode_raises(self): + with pytest.raises(ValueError, match="Unknown aggregation mode"): + aq_clust_clusterutils.combine_zvalues([1.0, 2.0], mode="bogus") + + def test_mean_z(self): + zvals = [1.0, 2.0, 3.0] + result = aq_clust_clusterutils.combine_zvalues(zvals, mode="mean_z") + assert result == pytest.approx(2.0) + + def test_median_z(self): + zvals = [1.0, 5.0, 2.0] + result = aq_clust_clusterutils.combine_zvalues(zvals, mode="median_z") + assert result == pytest.approx(2.0) + + def test_min_median_max_z_small_input(self): + """n <= 3 should fall back to Stouffer on all values.""" + zvals = np.array([1.0, 2.0]) + result_mmm = aq_clust_clusterutils.combine_zvalues(zvals, mode="min_median_max_z") + result_stouffer = aq_clust_clusterutils.sum_and_re_scale_zvalues(zvals, rho=0.0) + assert result_mmm == pytest.approx(result_stouffer) + + def test_min_median_max_z_large_input(self): + """n > 3 should use the 3-point summary.""" + zvals = np.array([-2.0, -1.0, 0.0, 1.0, 2.0]) + result = aq_clust_clusterutils.combine_zvalues(zvals, mode="min_median_max_z") + expected = aq_clust_clusterutils.sum_and_re_scale_zvalues( + np.array([-2.0, 0.0, 2.0]), rho=0.0 + ) + assert result == pytest.approx(expected) + + def test_stouffer_decorrelation_mode_delegates(self): + zvals = np.array([1.0, 2.0, 3.0]) + result = aq_clust_clusterutils.combine_zvalues(zvals, rho=0.0, mode="stouffer_decorrelation") + expected = aq_clust_clusterutils.sum_and_re_scale_zvalues(zvals, rho=0.0) + assert result == pytest.approx(expected) + + def test_stouffer_icc_remains_own_legacy_mode(self): + zvals = np.array([1.0, 2.0, 3.0]) + result = aq_clust_clusterutils.combine_zvalues(zvals, rho=0.5, mode="stouffer_icc") + expected = aq_clust_clusterutils.sum_and_re_scale_zvalues(zvals, rho=0.5) + assert result == pytest.approx(expected) + + +class TestSumAndReScaleZvaluesWithRho: + """Tests for the rho (ICC) parameter added to sum_and_re_scale_zvalues.""" + + def test_rho_zero_is_classic_stouffer(self): + zvals = np.array([1.0, 1.0, 1.0]) + result = aq_clust_clusterutils.sum_and_re_scale_zvalues(zvals, rho=0.0) + assert result > 1.0 # combining concordant evidence should amplify + + def test_higher_rho_gives_more_conservative_result(self): + zvals = np.array([2.0, 2.0, 2.0, 2.0]) + z_rho0 = aq_clust_clusterutils.sum_and_re_scale_zvalues(zvals, rho=0.0) + z_rho05 = aq_clust_clusterutils.sum_and_re_scale_zvalues(zvals, rho=0.5) + assert abs(z_rho05) < abs(z_rho0) + + def test_rho_one_returns_same_as_single_value(self): + """With perfect correlation (rho=1), DEFF=n so sigma=n, reducing + sum/sqrt(n*n) = mean, which for identical values is the value itself.""" + zvals = np.array([1.5, 1.5, 1.5]) + result = aq_clust_clusterutils.sum_and_re_scale_zvalues(zvals, rho=1.0) + assert result == pytest.approx(1.5, abs=0.05) + + def test_single_value_ignores_rho(self): + result = aq_clust_clusterutils.sum_and_re_scale_zvalues(np.array([2.5]), rho=0.8) + assert result == pytest.approx(2.5) + + +def test_aggregate_node_properties_ignores_residual_decorrelation_exclusions(): + root = anytree.Node("parent", type="frgion", is_included=True, cluster=0) + anytree.Node( + "keep1", + parent=root, + type="base", + is_included=True, + cluster=0, + z_val=1.0, + fc=1.0, + cv=0.2, + min_intensity=10.0, + total_intensity=10.0, + min_reps=2, + fraction_consistent=1.0, + ) + anytree.Node( + "drop", + parent=root, + type="base", + is_included=True, + cluster=0, + z_val=100.0, + fc=100.0, + cv=0.2, + min_intensity=10.0, + total_intensity=10.0, + min_reps=2, + fraction_consistent=1.0, + exclude_residual_decorrelation=True, + ) + anytree.Node( + "keep2", + parent=root, + type="base", + is_included=True, + cluster=0, + z_val=3.0, + fc=3.0, + cv=0.4, + min_intensity=20.0, + total_intensity=20.0, + min_reps=4, + fraction_consistent=1.0, + ) + + aq_clust_clusterutils.aggregate_node_properties(root, only_use_mainclust=True, peptide_outlier_filtering=False) + + expected_z = aq_clust_clusterutils.combine_zvalues(np.array([1.0, 3.0]), rho=0.0, mode="stouffer_decorrelation") + assert root.z_val == expected_z + assert root.fc == 2.0 + assert root.min_reps == 3.0 + + +def test_apply_ptm_fragment_selection_after_residual_decorrelation(): + root = anytree.Node("protein", type="gene") + frgion = anytree.Node( + "frgion", + parent=root, + type="frgion", + is_included=True, + cluster=0, + ) + zvals = [10.0, -0.1, 3.0, 0.2, -7.0] + children = [] + for idx, z_val in enumerate(zvals): + children.append(anytree.Node( + f"base{idx}", + parent=frgion, + type="base", + is_included=True, + cluster=0, + z_val=z_val, + )) + children[0].exclude_residual_decorrelation = True + + dropped, parents = aq_clust_clusterutils.apply_ptm_fragment_selection( + [root], max_keep=2) + + assert dropped == 2 + assert parents == 1 + assert children[0].exclude_residual_decorrelation is True + assert not getattr(children[1], "exclude_ptm_fragment_selection", False) + assert getattr(children[2], "exclude_ptm_fragment_selection", False) + assert not getattr(children[3], "exclude_ptm_fragment_selection", False) + assert getattr(children[4], "exclude_ptm_fragment_selection", False) + + if __name__ == "__main__": test_remove_outlier_fragion_childs_complete() print("\n" + "="*60) diff --git a/tests/unit_tests/test_icc_correction.py b/tests/unit_tests/test_icc_correction.py new file mode 100644 index 00000000..cfa32722 --- /dev/null +++ b/tests/unit_tests/test_icc_correction.py @@ -0,0 +1,376 @@ +import numpy as np +import anytree +import pytest + +import alphaquant.cluster.icc_correction as aq_icc + + +# --------------------------------------------------------------------------- +# Helpers for building synthetic trees +# --------------------------------------------------------------------------- + +def _make_base_node(name, parent, z_val): + return anytree.Node(name, parent=parent, type="base", z_val=z_val, + is_included=True, cluster=0) + + +def _make_protein_tree(gene_name, group_zvals, node_type="frgion", p_val=0.5, + group_p_vals=None): + """Build a minimal protein tree with one grouping level. + + Args: + gene_name: Name for the root (gene) node. + group_zvals: list of lists — each inner list holds z-values for + one group node's base children. + node_type: Type string for the group-level nodes (e.g. "frgion"). + p_val: Gene-level p-value attached to the root. + group_p_vals: Optional list of p-values for the group nodes. + If None, each group gets p_val=0.5. + + Returns: + anytree.Node: Root of the protein tree. + """ + root = anytree.Node(gene_name, type="gene", p_val=p_val, + is_included=True, cluster=-1) + for i, zvals in enumerate(group_zvals): + grp_p = group_p_vals[i] if group_p_vals is not None else 0.5 + grp = anytree.Node(f"{gene_name}_grp{i}", parent=root, + type=node_type, is_included=True, cluster=0, + p_val=grp_p) + for j, z in enumerate(zvals): + _make_base_node(f"{gene_name}_grp{i}_base{j}", parent=grp, z_val=z) + return root + + +def _make_deep_tree(gene_name, p_val=0.5, seed=42): + """Build a multi-level tree matching the real hierarchy. + + gene → seq → mod_seq → mod_seq_charge → {frgion, ms1_isotopes} → base + """ + rng = np.random.RandomState(seed) + root = anytree.Node(gene_name, type="gene", p_val=p_val, + is_included=True, cluster=-1) + + for s in range(3): + seq = anytree.Node(f"{gene_name}_seq{s}", parent=root, + type="seq", is_included=True, cluster=0, + p_val=0.5, z_val=rng.randn()) + for m in range(2): + mod = anytree.Node(f"{gene_name}_seq{s}_mod{m}", parent=seq, + type="mod_seq", is_included=True, cluster=0, + p_val=0.5, z_val=rng.randn()) + for c in range(2): + msc = anytree.Node(f"{gene_name}_seq{s}_mod{m}_ch{c}", + parent=mod, type="mod_seq_charge", + is_included=True, cluster=0, + p_val=0.5, z_val=rng.randn()) + frg = anytree.Node(f"{gene_name}_seq{s}_mod{m}_ch{c}_frg", + parent=msc, type="frgion", + is_included=True, cluster=0, + p_val=0.5, z_val=rng.randn()) + for b in range(3): + _make_base_node( + f"{gene_name}_seq{s}_mod{m}_ch{c}_frg_b{b}", + parent=frg, z_val=rng.randn()) + + ms1 = anytree.Node(f"{gene_name}_seq{s}_mod{m}_ch{c}_ms1", + parent=msc, type="ms1_isotopes", + is_included=True, cluster=0, + p_val=0.5, z_val=rng.randn()) + for b in range(2): + _make_base_node( + f"{gene_name}_seq{s}_mod{m}_ch{c}_ms1_b{b}", + parent=ms1, z_val=rng.randn()) + return root + + +# --------------------------------------------------------------------------- +# _compute_icc_from_tree (works at any level now) +# --------------------------------------------------------------------------- + +class TestComputeIccFromTree: + def test_returns_none_when_no_matching_nodes(self): + root = _make_protein_tree("P1", [[0.1, 0.2], [0.3, 0.4]], node_type="frgion") + assert aq_icc._compute_icc_from_tree(root, "ms1_isotopes") is None + + def test_returns_none_when_too_few_groups(self): + root = _make_protein_tree("P1", [[0.1, 0.2, 0.3]], node_type="frgion") + assert aq_icc._compute_icc_from_tree(root, "frgion") is None + + def test_returns_none_when_too_few_ions(self): + root = _make_protein_tree("P1", [[0.1], [0.2], [0.3]], node_type="frgion") + assert aq_icc._compute_icc_from_tree(root, "frgion") is None + + def test_identical_within_group_gives_high_icc(self): + group_zvals = [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]] + root = _make_protein_tree("P1", group_zvals, node_type="frgion") + icc = aq_icc._compute_icc_from_tree(root, "frgion") + assert icc is not None + assert icc > 0.9 + + def test_no_between_group_variance_gives_zero_icc(self): + rng = np.random.RandomState(42) + group_zvals = [list(rng.randn(5)) for _ in range(5)] + for gv in group_zvals: + m = np.mean(gv) + for i in range(len(gv)): + gv[i] -= m + root = _make_protein_tree("P1", group_zvals, node_type="frgion") + icc = aq_icc._compute_icc_from_tree(root, "frgion") + assert icc is not None + assert icc < 0.1 + + def test_icc_between_zero_and_one(self): + rng = np.random.RandomState(99) + group_zvals = [list(rng.randn(4) + i) for i in range(4)] + root = _make_protein_tree("P1", group_zvals, node_type="frgion") + icc = aq_icc._compute_icc_from_tree(root, "frgion") + assert icc is not None + assert 0.0 <= icc <= 1.0 + + def test_works_with_non_base_children(self): + """ICC at mod_seq_charge level uses ion-type children (not base).""" + root = _make_deep_tree("P1") + icc = aq_icc._compute_icc_from_tree(root, "mod_seq_charge") + assert icc is None or 0.0 <= icc <= 1.0 + + +# --------------------------------------------------------------------------- +# _collect_group_zvals (generalized for any level) +# --------------------------------------------------------------------------- + +class TestCollectGroupZvals: + def test_collects_base_children_for_frgion(self): + group_zvals = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] + root = _make_protein_tree("P1", group_zvals, node_type="frgion") + result = aq_icc._collect_group_zvals(root, "frgion") + assert len(result) == 3 + np.testing.assert_array_equal(result[0], [1.0, 2.0]) + + def test_collects_non_base_children_for_higher_levels(self): + """At mod_seq_charge level, children are frgion/ms1_isotopes nodes.""" + root = _make_deep_tree("P1") + result = aq_icc._collect_group_zvals(root, "mod_seq_charge") + assert len(result) > 0 + for arr in result: + assert len(arr) >= 1 + + def test_gene_level_returns_empty_due_to_single_group(self): + """A single protein has 1 gene node → 1 group < _MIN_GROUPS → empty. + + Gene-level ICC uses the separate _estimate_gene_level_icc path. + """ + root = _make_deep_tree("P1") + result = aq_icc._collect_group_zvals(root, "gene") + assert result == [] + + def test_filters_by_p_val_threshold(self): + group_zvals = [[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]] + group_p_vals = [0.5, 0.5, 0.01, 0.5] + root = _make_protein_tree("P1", group_zvals, p_val=0.5, + group_p_vals=group_p_vals) + result = aq_icc._collect_group_zvals(root, "frgion", + node_p_val_threshold=0.1) + assert len(result) == 3 + + def test_empty_when_all_filtered(self): + group_zvals = [[1, 1, 1], [2, 2, 2], [3, 3, 3]] + group_p_vals = [0.01, 0.01, 0.01] + root = _make_protein_tree("P1", group_zvals, p_val=0.5, + group_p_vals=group_p_vals) + result = aq_icc._collect_group_zvals(root, "frgion", + node_p_val_threshold=0.1) + assert result == [] + + +# --------------------------------------------------------------------------- +# _estimate_null_icc_distribution +# --------------------------------------------------------------------------- + +class TestEstimateNullIccDistribution: + def test_only_null_proteins_are_used(self): + sig = _make_protein_tree("sig", [[1, 1, 1], [2, 2, 2], [3, 3, 3]], + p_val=0.01) + null = _make_protein_tree("null", [[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], + [0.3, 0.3, 0.3]], p_val=0.5) + null_iccs, perm_iccs, median = aq_icc._estimate_null_icc_distribution( + [sig, null], "frgion" + ) + assert len(null_iccs) == 1 + assert len(perm_iccs) > 0 + + def test_empty_when_no_null_proteins(self): + sig = _make_protein_tree("sig", [[1, 1, 1], [2, 2, 2], [3, 3, 3]], + p_val=0.01) + null_iccs, perm_iccs, median = aq_icc._estimate_null_icc_distribution( + [sig], "frgion" + ) + assert null_iccs == [] + assert perm_iccs == [] + assert median == 0.0 + + def test_permutation_null_near_zero(self): + rng = np.random.RandomState(7) + group_zvals = [list(rng.randn(5) + offset) for offset in range(5)] + root = _make_protein_tree("null", group_zvals, p_val=0.5) + null_iccs, perm_iccs, median = aq_icc._estimate_null_icc_distribution( + [root], "frgion" + ) + assert len(perm_iccs) > 0 + assert np.median(perm_iccs) < median + + def test_significant_group_nodes_excluded_from_null(self): + """Group nodes with p_val <= threshold should be dropped during null estimation.""" + group_zvals = [[1.0, 1.1, 0.9], [2.0, 2.3, 1.7], + [3.0, 3.5, 2.5], [4.0, 3.8, 4.2]] + group_p_vals = [0.5, 0.5, 0.01, 0.5] + root = _make_protein_tree("null", group_zvals, p_val=0.5, + group_p_vals=group_p_vals) + icc_filtered = aq_icc._compute_icc_from_tree( + root, "frgion", node_p_val_threshold=0.1 + ) + icc_unfiltered = aq_icc._compute_icc_from_tree(root, "frgion") + assert icc_filtered is not None + assert icc_unfiltered is not None + assert icc_filtered != icc_unfiltered + + def test_all_group_nodes_significant_returns_none(self): + group_zvals = [[1, 1, 1], [2, 2, 2], [3, 3, 3]] + group_p_vals = [0.01, 0.01, 0.01] + root = _make_protein_tree("null", group_zvals, p_val=0.5, + group_p_vals=group_p_vals) + icc = aq_icc._compute_icc_from_tree( + root, "frgion", node_p_val_threshold=0.1 + ) + assert icc is None + + +# --------------------------------------------------------------------------- +# _estimate_gene_level_icc +# --------------------------------------------------------------------------- + +class TestEstimateGeneLevelIcc: + def _make_null_proteins(self, n_proteins=20, n_peptides_range=(3, 8), + p_val=0.5, seed=42): + """Create a list of protein trees with seq children carrying z-values.""" + rng = np.random.RandomState(seed) + protnodes = [] + for i in range(n_proteins): + n_pep = rng.randint(n_peptides_range[0], n_peptides_range[1] + 1) + root = anytree.Node(f"prot{i}", type="gene", p_val=p_val, + is_included=True, cluster=-1) + for j in range(n_pep): + anytree.Node(f"prot{i}_seq{j}", parent=root, + type="seq", is_included=True, cluster=0, + z_val=rng.randn()) + protnodes.append(root) + return protnodes + + def test_returns_icc_distribution(self): + protnodes = self._make_null_proteins(n_proteins=20) + null_iccs, perm_iccs, median = aq_icc._estimate_gene_level_icc(protnodes) + assert len(null_iccs) > 0 + assert len(perm_iccs) > 0 + assert 0.0 <= median <= 1.0 + + def test_fallback_when_too_few_proteins(self): + protnodes = self._make_null_proteins(n_proteins=3) + null_iccs, perm_iccs, median = aq_icc._estimate_gene_level_icc(protnodes) + assert null_iccs == [] + assert median == 0.0 + + def test_excludes_significant_proteins(self): + null_prots = self._make_null_proteins(n_proteins=20, p_val=0.5) + sig_prots = self._make_null_proteins(n_proteins=5, p_val=0.01, seed=99) + all_prots = null_prots + sig_prots + null_iccs, _, median = aq_icc._estimate_gene_level_icc(all_prots) + assert len(null_iccs) > 0 + + def test_excludes_proteins_with_single_peptide(self): + """Proteins with only 1 seq child should be skipped.""" + protnodes = [] + for i in range(20): + root = anytree.Node(f"prot{i}", type="gene", p_val=0.5, + is_included=True, cluster=-1) + anytree.Node(f"prot{i}_seq0", parent=root, type="seq", + is_included=True, cluster=0, z_val=0.1) + protnodes.append(root) + null_iccs, _, median = aq_icc._estimate_gene_level_icc(protnodes) + assert null_iccs == [] + assert median == 0.0 + + def test_shuffled_lower_than_observed_for_correlated_data(self): + """When peptides within proteins are correlated, permutation ICC < observed.""" + rng = np.random.RandomState(77) + protnodes = [] + for i in range(30): + root = anytree.Node(f"prot{i}", type="gene", p_val=0.5, + is_included=True, cluster=-1) + protein_effect = rng.randn() * 2 + for j in range(5): + anytree.Node(f"prot{i}_seq{j}", parent=root, type="seq", + is_included=True, cluster=0, + z_val=protein_effect + rng.randn() * 0.3) + protnodes.append(root) + null_iccs, perm_iccs, median = aq_icc._estimate_gene_level_icc(protnodes) + assert np.median(perm_iccs) < median + + +# --------------------------------------------------------------------------- +# _assign_icc_to_all_proteins +# --------------------------------------------------------------------------- + +class TestAssignIccToAllProteins: + def test_assigns_uniform_icc_to_all_nodes(self): + root = _make_protein_tree("P1", [[1, 1, 1], [2, 2, 2], [3, 3, 3]], + node_type="frgion") + icc_median = 0.25 + n = aq_icc._assign_icc_to_all_proteins([root], "frgion", icc_median) + assert n == 3 + for node in anytree.search.findall(root, filter_=lambda n: n.type == "frgion"): + assert node.icc_correction == icc_median + + def test_returns_zero_when_no_matching_nodes(self): + root = _make_protein_tree("P1", [[1, 1, 1], [2, 2, 2], [3, 3, 3]], + node_type="frgion") + n = aq_icc._assign_icc_to_all_proteins([root], "ms1_isotopes", 0.3) + assert n == 0 + + def test_assigns_to_gene_node(self): + root = _make_protein_tree("P1", [[1, 2], [3, 4], [5, 6]], + node_type="frgion") + n = aq_icc._assign_icc_to_all_proteins([root], "gene", 0.15) + assert n == 1 + assert root.icc_correction == 0.15 + + def test_assigns_to_higher_level_nodes(self): + root = _make_deep_tree("P1") + n = aq_icc._assign_icc_to_all_proteins([root], "mod_seq_charge", 0.1) + msc_nodes = anytree.search.findall(root, filter_=lambda n: n.type == "mod_seq_charge") + assert n == len(msc_nodes) + for node in msc_nodes: + assert node.icc_correction == 0.1 + + +# --------------------------------------------------------------------------- +# _has_node_type +# --------------------------------------------------------------------------- + +class TestHasNodeType: + def test_true_when_present(self): + root = _make_protein_tree("P1", [[0.1, 0.2]], node_type="frgion") + assert aq_icc._has_node_type([root], "frgion") is True + + def test_false_when_absent(self): + root = _make_protein_tree("P1", [[0.1, 0.2]], node_type="frgion") + assert aq_icc._has_node_type([root], "ms1_isotopes") is False + + def test_gene_type_found_at_root(self): + root = _make_protein_tree("P1", [[0.1, 0.2]], node_type="frgion") + assert aq_icc._has_node_type([root], "gene") is True + + def test_deep_tree_has_all_types(self): + root = _make_deep_tree("P1") + for node_type in ("gene", "seq", "mod_seq", "mod_seq_charge", + "frgion", "ms1_isotopes", "base"): + assert aq_icc._has_node_type([root], node_type) is True diff --git a/tests/unit_tests/test_intensity_summarization.py b/tests/unit_tests/test_intensity_summarization.py new file mode 100644 index 00000000..add03434 --- /dev/null +++ b/tests/unit_tests/test_intensity_summarization.py @@ -0,0 +1,270 @@ +import numpy as np +import pandas as pd +import pytest +import re + +import alphaquant.diffquant.intensity_summarization as aq_summ +from alphaquant.cluster.cluster_ions import REGEX_FRGIONS_ISOTOPES + + +# --------------------------------------------------------------------------- +# Helpers — realistic ion names following the SEQ_..._MOD_..._CHARGE_... pattern +# --------------------------------------------------------------------------- + +PROT = "PROT1" + +# Precursor 1: charge 2 +FRGION_Y3 = "SEQ_PEP1_MOD_MOD1_CHARGE_2_FRGION_y3_noloss_1" +FRGION_Y4 = "SEQ_PEP1_MOD_MOD1_CHARGE_2_FRGION_y4_noloss_1" +MS1ISO_0 = "SEQ_PEP1_MOD_MOD1_CHARGE_2_MS1ISOTOPES_0" +MS1ISO_1 = "SEQ_PEP1_MOD_MOD1_CHARGE_2_MS1ISOTOPES_1" +PREC_2 = "SEQ_PEP1_MOD_MOD1_CHARGE_2_PRECURSOR_2" + +# Precursor 2: charge 3 (same peptide, different charge) +FRGION_C3_Y5 = "SEQ_PEP1_MOD_MOD1_CHARGE_3_FRGION_y5_noloss_1" + +ALL_IONS = [FRGION_Y3, FRGION_Y4, MS1ISO_0, MS1ISO_1, PREC_2, FRGION_C3_Y5] +PEP2PROT = {ion: PROT for ion in ALL_IONS} + + +def _make_df(ions, values_per_sample): + """Build a log2-intensity DataFrame from a dict {sample: [values_per_ion]}.""" + return pd.DataFrame(values_per_sample, index=ions) + + +# --------------------------------------------------------------------------- +# Tree building +# --------------------------------------------------------------------------- + +class TestBuildTreeFromIonNames: + def test_gene_root(self): + tree = aq_summ.build_tree_from_ion_names(PROT, ALL_IONS) + assert tree.name == PROT + assert tree.type == "gene" + + def test_leaf_count(self): + tree = aq_summ.build_tree_from_ion_names(PROT, ALL_IONS) + assert len(tree.leaves) == len(ALL_IONS) + + def test_frgion_nodes_exist(self): + tree = aq_summ.build_tree_from_ion_names(PROT, ALL_IONS) + import anytree + frgion_nodes = anytree.findall(tree, filter_=lambda n: n.type == "frgion") + assert len(frgion_nodes) == 2 # one per charge state + + def test_ms1_nodes_exist(self): + tree = aq_summ.build_tree_from_ion_names(PROT, ALL_IONS) + import anytree + ms1_nodes = anytree.findall(tree, filter_=lambda n: n.type == "ms1_isotopes") + assert len(ms1_nodes) == 1 + + def test_single_ion(self): + tree = aq_summ.build_tree_from_ion_names(PROT, [FRGION_Y3]) + assert len(tree.leaves) == 1 + + +# --------------------------------------------------------------------------- +# Grouping logic +# --------------------------------------------------------------------------- + +class TestComputeSummarizationGroups: + def test_empty_summarization_nodes(self): + groups, remaining = aq_summ.compute_summarization_groups(PEP2PROT, ALL_IONS, []) + assert groups == [] + assert remaining == set(ALL_IONS) + + def test_frgion_only(self): + groups, remaining = aq_summ.compute_summarization_groups(PEP2PROT, ALL_IONS, ["frgion"]) + summarized_leaves = set() + for _name, leaves, _prot in groups: + summarized_leaves.update(leaves) + # All FRGION base ions should be summarized + assert FRGION_Y3 in summarized_leaves + assert FRGION_Y4 in summarized_leaves + assert FRGION_C3_Y5 in summarized_leaves + # MS1 and precursor should remain + assert MS1ISO_0 in remaining + assert MS1ISO_1 in remaining + assert PREC_2 in remaining + # Two groups: one per charge state + assert len(groups) == 2 + + def test_ms1_only(self): + groups, remaining = aq_summ.compute_summarization_groups(PEP2PROT, ALL_IONS, ["ms1_isotopes"]) + assert len(groups) == 1 + assert FRGION_Y3 in remaining + assert FRGION_Y4 in remaining + + def test_frgion_and_ms1(self): + groups, remaining = aq_summ.compute_summarization_groups(PEP2PROT, ALL_IONS, ["frgion", "ms1_isotopes"]) + summarized_leaves = set() + for _name, leaves, _prot in groups: + summarized_leaves.update(leaves) + assert FRGION_Y3 in summarized_leaves + assert MS1ISO_0 in summarized_leaves + # Only precursor should remain + assert remaining == {PREC_2} + + def test_mod_seq_charge_splits_by_ion_type(self): + groups, remaining = aq_summ.compute_summarization_groups(PEP2PROT, ALL_IONS, ["mod_seq_charge"]) + # Two mod_seq_charge nodes (charge 2 and charge 3). + # Charge 2 has frgion + ms1 + precursor -> 3 groups. + # Charge 3 has only frgion -> 1 group. + assert len(groups) == 4 + assert len(remaining) == 0 + + +# --------------------------------------------------------------------------- +# Naming — summarized names must be parseable by downstream tree builder +# --------------------------------------------------------------------------- + +class TestSummarizedNamesParseable: + """Verify that summarized ion names are matched by the REGEX_FRGIONS_ISOTOPES + patterns, so the downstream tree builder can incorporate them.""" + + def _level0_matches(self, name): + for pattern, _node_type in REGEX_FRGIONS_ISOTOPES[0]: + if re.match(pattern, name): + return True + return False + + def test_frgion_sum_name_parseable(self): + groups, _ = aq_summ.compute_summarization_groups(PEP2PROT, [FRGION_Y3, FRGION_Y4], ["frgion"]) + for name, _leaves, _prot in groups: + assert self._level0_matches(name), f"'{name}' not parseable by level-0 regex" + + def test_ms1_sum_name_parseable(self): + groups, _ = aq_summ.compute_summarization_groups(PEP2PROT, [MS1ISO_0, MS1ISO_1], ["ms1_isotopes"]) + for name, _leaves, _prot in groups: + assert self._level0_matches(name), f"'{name}' not parseable by level-0 regex" + + def test_mod_seq_charge_sum_names_parseable(self): + groups, _ = aq_summ.compute_summarization_groups(PEP2PROT, ALL_IONS, ["mod_seq_charge"]) + for name, _leaves, _prot in groups: + assert self._level0_matches(name), f"'{name}' not parseable by level-0 regex" + + +# --------------------------------------------------------------------------- +# DataFrame summarization +# --------------------------------------------------------------------------- + +class TestSummarizeConditionDf: + @pytest.fixture + def simple_df(self): + return _make_df( + [FRGION_Y3, FRGION_Y4, MS1ISO_0], + {"s1": np.log2([100.0, 50.0, 500.0]), "s2": np.log2([200.0, 80.0, 600.0])}, + ) + + def test_frgion_sum_values(self, simple_df): + groups = [ + ("SEQ_PEP1_MOD_MOD1_CHARGE_2_FRGION_SUM", [FRGION_Y3, FRGION_Y4], PROT), + ] + remaining = {MS1ISO_0} + result = aq_summ.summarize_condition_df(simple_df, groups, remaining) + + assert "SEQ_PEP1_MOD_MOD1_CHARGE_2_FRGION_SUM" in result.index + assert MS1ISO_0 in result.index + # sum(100+50)=150 for s1 + assert np.isclose(result.loc["SEQ_PEP1_MOD_MOD1_CHARGE_2_FRGION_SUM", "s1"], np.log2(150.0)) + # sum(200+80)=280 for s2 + assert np.isclose(result.loc["SEQ_PEP1_MOD_MOD1_CHARGE_2_FRGION_SUM", "s2"], np.log2(280.0)) + + def test_ms1_unchanged(self, simple_df): + groups = [ + ("SUMMED", [FRGION_Y3, FRGION_Y4], PROT), + ] + remaining = {MS1ISO_0} + result = aq_summ.summarize_condition_df(simple_df, groups, remaining) + + assert np.isclose(result.loc[MS1ISO_0, "s1"], simple_df.loc[MS1ISO_0, "s1"]) + + def test_all_nan_stays_nan(self): + df = _make_df( + [FRGION_Y3, FRGION_Y4], + {"s1": [np.nan, np.nan], "s2": np.log2([100.0, 50.0])}, + ) + groups = [("SUM", [FRGION_Y3, FRGION_Y4], PROT)] + result = aq_summ.summarize_condition_df(df, groups, set()) + + assert np.isnan(result.loc["SUM", "s1"]) + assert np.isclose(result.loc["SUM", "s2"], np.log2(150.0)) + + def test_partial_nan_sums_available(self): + df = _make_df( + [FRGION_Y3, FRGION_Y4], + {"s1": [np.log2(100.0), np.nan], "s2": np.log2([100.0, 50.0])}, + ) + groups = [("SUM", [FRGION_Y3, FRGION_Y4], PROT)] + result = aq_summ.summarize_condition_df(df, groups, set()) + + # Only y3 contributes in s1 + assert np.isclose(result.loc["SUM", "s1"], np.log2(100.0)) + + def test_empty_group_skipped(self): + df = _make_df([MS1ISO_0], {"s1": [np.log2(500.0)]}) + groups = [("SUM", [FRGION_Y3, FRGION_Y4], PROT)] # neither present in df + remaining = {MS1ISO_0} + result = aq_summ.summarize_condition_df(df, groups, remaining) + + assert len(result) == 1 + assert MS1ISO_0 in result.index + + +# --------------------------------------------------------------------------- +# End-to-end: apply_summarization +# --------------------------------------------------------------------------- + +class TestApplySummarization: + @pytest.fixture + def condition_dfs(self): + ions = [FRGION_Y3, FRGION_Y4, MS1ISO_0, MS1ISO_1, PREC_2] + df_c1 = _make_df(ions, { + "s1": np.log2([100.0, 50.0, 500.0, 200.0, 800.0]), + "s2": np.log2([120.0, 60.0, 520.0, 210.0, 820.0]), + }) + df_c2 = _make_df(ions, { + "s3": np.log2([90.0, 40.0, 480.0, 190.0, 790.0]), + "s4": np.log2([110.0, 55.0, 510.0, 205.0, 810.0]), + }) + pep2prot = {ion: PROT for ion in ions} + return df_c1, df_c2, pep2prot + + def test_no_summarization(self, condition_dfs): + df_c1, df_c2, pep2prot = condition_dfs + r1, r2, rp = aq_summ.apply_summarization(df_c1, df_c2, pep2prot, []) + assert r1 is df_c1 # unchanged object + assert r2 is df_c2 + + def test_frgion_reduces_row_count(self, condition_dfs): + df_c1, df_c2, pep2prot = condition_dfs + r1, r2, rp = aq_summ.apply_summarization(df_c1, df_c2, pep2prot, ["frgion"]) + # 5 ions -> 1 frgion sum + 2 ms1 + 1 precursor = 4 + assert len(r1) == 4 + assert len(r2) == 4 + + def test_pep2prot_updated(self, condition_dfs): + df_c1, df_c2, pep2prot = condition_dfs + _, _, rp = aq_summ.apply_summarization(df_c1, df_c2, pep2prot, ["frgion"]) + # Every row in the result should have a protein mapping + for ion in set(rp.keys()): + assert rp[ion] == PROT + + def test_frgion_and_ms1(self, condition_dfs): + df_c1, df_c2, pep2prot = condition_dfs + r1, r2, rp = aq_summ.apply_summarization(df_c1, df_c2, pep2prot, ["frgion", "ms1_isotopes"]) + # 1 frgion sum + 1 ms1 sum + 1 precursor = 3 + assert len(r1) == 3 + + def test_asymmetric_conditions(self): + """Ion present in c1 but not c2 — group is skipped in c2.""" + df_c1 = _make_df([FRGION_Y3, FRGION_Y4], {"s1": np.log2([100.0, 50.0])}) + df_c2 = _make_df([FRGION_Y3], {"s2": np.log2([90.0])}) + pep2prot = {FRGION_Y3: PROT, FRGION_Y4: PROT} + + r1, r2, rp = aq_summ.apply_summarization(df_c1, df_c2, pep2prot, ["frgion"]) + # c1: y3+y4 summed; c2: only y3 present -> partial sum + assert len(r1) == 1 + assert len(r2) == 1 + assert np.isclose(r1.iloc[0, 0], np.log2(150.0)) + assert np.isclose(r2.iloc[0, 0], np.log2(90.0)) diff --git a/tests/unit_tests/test_residual_decorrelation.py b/tests/unit_tests/test_residual_decorrelation.py new file mode 100644 index 00000000..e7e56dbc --- /dev/null +++ b/tests/unit_tests/test_residual_decorrelation.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import importlib.util +from pathlib import Path +import sys + +import anytree +import numpy as np +import pandas as pd + +import alphaquant.cluster.residual_decorrelation as aq_resid + + +def _load_reference_module(): + path = Path("sandbox/analyses_revision_v3/paper_nbs_revision/10_alphaquant_mouse_aq/residual_correlation/auto_decorrelation.py") + spec = importlib.util.spec_from_file_location("auto_decorrelation_ref", path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def test_survivors_at_drops_expected_child(): + C = np.array( + [ + [np.nan, 0.95, 0.20], + [0.95, np.nan, 0.80], + [0.20, 0.80, np.nan], + ] + ) + pp = aq_resid.ParentPrecompute( + parent_node=anytree.Node("parent", type="frgion"), + child_nodes=( + anytree.Node("c0", type="base"), + anytree.Node("c1", type="base"), + anytree.Node("c2", type="base"), + ), + C=C, + remove_order=np.array([1, 0]), + max_r_trajectory=np.array([0.95, 0.20, -np.inf]), + ) + + survivors = pp.survivors_at(0.5, min_keep=1) + np.testing.assert_array_equal(survivors, np.array([True, False, True])) + + +def test_excess_cdf_distance_positive_when_corrected_shifted_higher(): + corrected = np.array([0.7, 0.8, 0.9]) + null = np.sort(np.array([0.1, 0.2, 0.3])) + assert aq_resid._excess_cdf_distance(corrected, null) > 0 + + +def test_attach_lm_residuals_subtracts_condition_means(): + root = anytree.Node("gene1", type="gene") + frg = anytree.Node("pep1", parent=root, type="frgion") + base = anytree.Node("ionA", parent=frg, type="base") + + vals = np.log2(np.array([10.0, 14.0, 20.0, 24.0])) + df_c1 = pd.DataFrame([vals[:2]], index=["ionA"], columns=["c1_r1", "c1_r2"]) + df_c2 = pd.DataFrame([vals[2:]], index=["ionA"], columns=["c2_r1", "c2_r2"]) + + aq_resid.attach_lm_residuals([root], df_c1, df_c2) + + expected = vals.copy() + expected[:2] -= expected[:2].mean() + expected[2:] -= expected[2:].mean() + np.testing.assert_allclose(base.residuals, expected) + np.testing.assert_allclose(frg.residuals, expected) + np.testing.assert_allclose(root.residuals, expected) + + +def test_reference_port_matches_dashboard_core(): + ref = _load_reference_module() + rng = np.random.default_rng(7) + mats = [ + rng.normal(size=(4, 6)), + rng.normal(size=(3, 6)), + rng.normal(size=(5, 6)), + ] + + our_parents = [] + ref_parents = [] + for idx, mat in enumerate(mats): + parent = anytree.Node(f"p{idx}", type="seq") + children = tuple(anytree.Node(f"c{idx}_{j}", parent=parent, type="mod_seq") for j in range(mat.shape[0])) + our_pp = aq_resid._build_parent(parent, children, mat) + ref_pp = ref._build_parent(f"g{idx}", f"p{idx}", [c.name for c in children], mat) + assert our_pp is not None + assert ref_pp is not None + np.testing.assert_allclose(our_pp.C, ref_pp.C) + np.testing.assert_array_equal(our_pp.remove_order, ref_pp.remove_order) + np.testing.assert_allclose(our_pp.max_r_trajectory, ref_pp.max_r_trajectory) + our_parents.append(our_pp) + ref_parents.append(ref_pp) + + our_null = np.sort(aq_resid._cross_parent_shuffle_null(mats, np.random.default_rng(42))) + ref_null = np.sort(ref._cross_parent_shuffle_null(mats, np.random.default_rng(42))) + np.testing.assert_allclose(our_null, ref_null) + + our_sweep = aq_resid.run_level_sweep( + our_parents, + our_null, + cutoff_grid=aq_resid.DEFAULT_CUTOFF_GRID, + tolerance=aq_resid.DEFAULT_TOLERANCE, + min_keep=aq_resid.DEFAULT_MIN_KEEP, + level=("seq", "mod_seq"), + ) + ref_sweep = ref.run_level_sweep( + ref_parents, + ref_null, + cutoff_grid=ref.DEFAULT_CUTOFF_GRID, + tolerance=ref.DEFAULT_TOLERANCE, + min_keep=ref.DEFAULT_MIN_KEEP, + level=("seq", "mod_seq"), + ) + + assert our_sweep.cutoff == ref_sweep.cutoff + assert our_sweep.d_before == ref_sweep.d_before + assert our_sweep.d_after == ref_sweep.d_after + assert our_sweep.children_dropped == ref_sweep.children_dropped + assert our_sweep.parents_touched == ref_sweep.parents_touched + assert our_sweep.grid_trace == ref_sweep.grid_trace + np.testing.assert_allclose(our_sweep.unmodified_sorted, ref_sweep.unmodified_sorted) + np.testing.assert_allclose(our_sweep.corrected_sorted, ref_sweep.corrected_sorted) diff --git a/tests/unit_tests/test_variables.py b/tests/unit_tests/test_variables.py new file mode 100644 index 00000000..4e35a7fb --- /dev/null +++ b/tests/unit_tests/test_variables.py @@ -0,0 +1,22 @@ +import alphaquant.config.variables as aqvariables + + +class TestSetInputConfig: + def setup_method(self): + self._orig_type = aqvariables.INPUT_TYPE + self._orig_dict = aqvariables.CONFIG_DICT + + def teardown_method(self): + aqvariables.INPUT_TYPE = self._orig_type + aqvariables.CONFIG_DICT = self._orig_dict + + def test_sets_globals(self): + aqvariables.set_input_config("diann_fragion", {"key": "value"}) + assert aqvariables.INPUT_TYPE == "diann_fragion" + assert aqvariables.CONFIG_DICT == {"key": "value"} + + def test_overwrite(self): + aqvariables.set_input_config("type_a", {"a": 1}) + aqvariables.set_input_config("type_b", {"b": 2}) + assert aqvariables.INPUT_TYPE == "type_b" + assert aqvariables.CONFIG_DICT == {"b": 2} diff --git a/tests/unit_tests/test_variance_predictor.py b/tests/unit_tests/test_variance_predictor.py new file mode 100644 index 00000000..32168f48 --- /dev/null +++ b/tests/unit_tests/test_variance_predictor.py @@ -0,0 +1,413 @@ +import numpy as np +import pandas as pd +import pytest + +import alphaquant.diffquant.variance_predictor as aq_vp + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +N = 20 # enough ions to pass the min-10 threshold + + +def _precursor_ids(n=N): + return [f"PREC{i}_MOD{i}_{i}" for i in range(n)] + + +def _ion_index(precursor_ids): + """One fragment ion per precursor — keeps precursor mapping trivial.""" + return pd.Index([f"{pid}_FRGION_y1" for pid in precursor_ids]) + + +def _write_ml_info(tmp_path, precursor_ids, columns_dict, samples=("S1", "S2")): + """Write a synthetic ml_info_table TSV and return its path.""" + rows = [] + for i, pid in enumerate(precursor_ids): + for s in samples: + row = {"quant_id": pid, "sample_ID": s} + for col, vals in columns_dict.items(): + row[col] = vals[i] + rows.append(row) + df = pd.DataFrame(rows) + path = tmp_path / "ml_info.tsv" + df.to_csv(path, sep="\t", index=False) + return str(path) + + +def _make_ion_variance(ion_index, values): + """Build an ion_variance Series aligned with ion_index.""" + return pd.Series(values, index=ion_index) + + +def _make_ion_median_intensity(ion_index, values): + """Build an ion_median_intensity Series aligned with ion_index.""" + return pd.Series(values, index=ion_index) + + +def _spearman(a, b): + from scipy.stats import spearmanr + r, _ = spearmanr(a, b) + return r + + +# --------------------------------------------------------------------------- +# Basic interface tests +# --------------------------------------------------------------------------- + +class TestLoadVariancePredictorScoresInterface: + def test_returns_none_when_no_columns(self, tmp_path): + pids = _precursor_ids() + ion_idx = _ion_index(pids) + ml_file = _write_ml_info(tmp_path, pids, {"col": list(range(N))}) + ion_var = _make_ion_variance(ion_idx, np.ones(len(ion_idx))) + result = aq_vp.load_variance_predictor_scores( + ml_file, ["S1", "S2"], variance_predictor_cols=[], + ion_index=ion_idx, ion_variance=ion_var, + ) + assert result is None + + def test_returns_none_when_missing_columns(self, tmp_path): + pids = _precursor_ids() + ion_idx = _ion_index(pids) + ml_file = _write_ml_info(tmp_path, pids, {"col": list(range(N))}) + ion_var = _make_ion_variance(ion_idx, np.ones(len(ion_idx))) + result = aq_vp.load_variance_predictor_scores( + ml_file, ["S1", "S2"], variance_predictor_cols=["NoSuchCol"], + ion_index=ion_idx, ion_variance=ion_var, + ) + assert result is None + + def test_returns_dict_with_correct_keys(self, tmp_path): + pids = _precursor_ids() + ion_idx = _ion_index(pids) + col_vals = [float(i) for i in range(N)] + ion_var = _make_ion_variance(ion_idx, col_vals) + ml_file = _write_ml_info(tmp_path, pids, {"score": col_vals}) + result = aq_vp.load_variance_predictor_scores( + ml_file, ["S1", "S2"], variance_predictor_cols=["score"], + ion_index=ion_idx, ion_variance=ion_var, + ) + assert result is not None + assert set(result.keys()) == set(ion_idx) + + def test_ions_from_same_precursor_get_same_score(self, tmp_path): + pids = _precursor_ids() + ions = [] + for pid in pids: + ions.append(f"{pid}_FRGION_y3") + ions.append(f"{pid}_FRGION_y4") + ion_idx = pd.Index(ions) + col_vals = [float(i) for i in range(N)] + var_vals = [float(i) * 0.1 for i in range(N)] + # Each pair of ions from the same precursor gets the same variance + ion_var_vals = [] + for v in var_vals: + ion_var_vals.extend([v, v]) + ion_var = _make_ion_variance(ion_idx, ion_var_vals) + ml_file = _write_ml_info(tmp_path, pids, {"score": col_vals}) + result = aq_vp.load_variance_predictor_scores( + ml_file, ["S1", "S2"], variance_predictor_cols=["score"], + ion_index=ion_idx, ion_variance=ion_var, + ) + assert result is not None + for pid in pids: + assert result[f"{pid}_FRGION_y3"] == result[f"{pid}_FRGION_y4"] + + +# --------------------------------------------------------------------------- +# Regression logic tests with controlled distributions +# --------------------------------------------------------------------------- + +class TestRegressionLogic: + """Verify the linear-regression approach correctly recovers the + relationship between quality metrics and ion variance.""" + + def _get_scores(self, tmp_path, columns_dict, variance_predictor_cols, + ion_var_values, ion_med_int_values=None): + pids = _precursor_ids() + ion_idx = _ion_index(pids) + ml_file = _write_ml_info(tmp_path, pids, columns_dict) + ion_var = _make_ion_variance(ion_idx, ion_var_values) + ion_med_int = (_make_ion_median_intensity(ion_idx, ion_med_int_values) + if ion_med_int_values is not None else None) + result = aq_vp.load_variance_predictor_scores( + ml_file, ["S1", "S2"], + variance_predictor_cols=variance_predictor_cols, + ion_index=ion_idx, ion_variance=ion_var, + ion_median_intensity=ion_med_int, + ) + assert result is not None + return np.array([result[ion] for ion in ion_idx]) + + # -- single column, positively correlated with variance ---------------- + + def test_single_column_positive_correlation(self, tmp_path): + """Column that increases with variance → scores should track variance.""" + col_vals = [float(i) for i in range(N)] + true_var = [float(i) * 0.5 for i in range(N)] # same direction + scores = self._get_scores(tmp_path, {"col": col_vals}, ["col"], true_var) + rho = _spearman(true_var, scores) + assert rho > 0.95 + + # -- single column, negatively correlated with variance ---------------- + + def test_single_column_negative_correlation(self, tmp_path): + """Column that decreases with variance → negative coefficient, + but predicted scores should still track variance.""" + col_vals = [float(N - 1 - i) for i in range(N)] # high col = low var + true_var = [float(i) * 0.5 for i in range(N)] + scores = self._get_scores(tmp_path, {"col": col_vals}, ["col"], true_var) + rho = _spearman(true_var, scores) + assert rho > 0.95 + + # -- two columns in OPPOSITE directions: the key improvement ----------- + + def test_two_columns_opposite_direction_both_recovered(self, tmp_path): + """One column goes up with variance, the other goes down. + The regression should handle both correctly and produce scores + that correlate strongly with the true variance.""" + true_var = [float(i) for i in range(N)] + col_up = [float(i) * 2.0 for i in range(N)] # positive corr + col_down = [float(N - 1 - i) * 3.0 for i in range(N)] # negative corr + scores = self._get_scores( + tmp_path, {"col_up": col_up, "col_down": col_down}, + ["col_up", "col_down"], true_var, + ) + rho = _spearman(true_var, scores) + assert rho > 0.95, ( + f"Expected strong correlation even with opposing columns, got rho={rho:.3f}" + ) + + # -- magnitude irrelevant (regression handles it) ---------------------- + + def test_different_magnitudes(self, tmp_path): + """Columns with very different scales should be handled by the + standardisation in the regression.""" + true_var = [float(i) for i in range(N)] + col_tiny = [0.001 * i for i in range(N)] + col_huge = [1e6 * i for i in range(N)] + scores = self._get_scores( + tmp_path, {"tiny": col_tiny, "huge": col_huge}, + ["tiny", "huge"], true_var, + ) + rho = _spearman(true_var, scores) + assert rho > 0.95 + + # -- noisy predictor --------------------------------------------------- + + def test_noisy_predictor_still_useful(self, tmp_path): + """A clean + noisy column, both in the same direction: combined + should still correlate well with true variance.""" + rng = np.random.RandomState(42) + true_var = [float(i) for i in range(N)] + col_clean = [float(i) for i in range(N)] + col_noisy = [float(i) + rng.normal(0, 5) for i in range(N)] + scores = self._get_scores( + tmp_path, {"clean": col_clean, "noisy": col_noisy}, + ["clean", "noisy"], true_var, + ) + rho = _spearman(true_var, scores) + assert rho > 0.8 + + # -- uninformative column gets ~zero weight ---------------------------- + + def test_uninformative_column_ignored(self, tmp_path): + """A random column uncorrelated with variance should not hurt + when combined with a good predictor.""" + rng = np.random.RandomState(99) + true_var = [float(i) for i in range(N)] + col_good = [float(i) for i in range(N)] + col_random = list(rng.randn(N)) + scores_combined = self._get_scores( + tmp_path, {"good": col_good, "random": col_random}, + ["good", "random"], true_var, + ) + scores_good_only = self._get_scores( + tmp_path, {"good": col_good}, ["good"], true_var, + ) + rho_combined = _spearman(true_var, scores_combined) + rho_good = _spearman(true_var, scores_good_only) + # Combined should be almost as good as good-only + assert rho_combined > 0.9 + assert rho_combined >= rho_good - 0.1 + + # -- three columns, two agree, one opposes ----------------------------- + + def test_three_columns_mixed_directions(self, tmp_path): + """Two columns positively, one negatively associated with variance. + Regression should handle all three correctly.""" + true_var = [float(i) for i in range(N)] + col_up1 = [float(i) for i in range(N)] + col_up2 = [float(i) * 3.0 for i in range(N)] + col_down = [float(N - 1 - i) for i in range(N)] + scores = self._get_scores( + tmp_path, + {"up1": col_up1, "up2": col_up2, "down": col_down}, + ["up1", "up2", "down"], true_var, + ) + rho = _spearman(true_var, scores) + assert rho > 0.95 + + # -- nonlinear relationship: regression captures the linear component -- + + def test_nonlinear_variance_partially_captured(self, tmp_path): + """Even if the true relationship is quadratic, the linear model + should capture the monotonic trend reasonably well.""" + true_var = [float(i)**2 for i in range(N)] + col_vals = [float(i) for i in range(N)] # linear predictor + scores = self._get_scores(tmp_path, {"col": col_vals}, ["col"], true_var) + rho = _spearman(true_var, scores) + # Spearman only cares about rank order, so linear model on monotonic + # data should give perfect rank correlation + assert rho > 0.95 + + +# --------------------------------------------------------------------------- +# Median intensity as built-in predictor +# --------------------------------------------------------------------------- + +class TestMedianIntensityPredictor: + """Verify that median intensity is correctly used as a built-in predictor.""" + + def _get_scores(self, tmp_path, columns_dict, variance_predictor_cols, + ion_var_values, ion_med_int_values=None): + pids = _precursor_ids() + ion_idx = _ion_index(pids) + ml_file = _write_ml_info(tmp_path, pids, columns_dict) + ion_var = _make_ion_variance(ion_idx, ion_var_values) + ion_med_int = (_make_ion_median_intensity(ion_idx, ion_med_int_values) + if ion_med_int_values is not None else None) + result = aq_vp.load_variance_predictor_scores( + ml_file, ["S1", "S2"], + variance_predictor_cols=variance_predictor_cols, + ion_index=ion_idx, ion_variance=ion_var, + ion_median_intensity=ion_med_int, + ) + assert result is not None + return np.array([result[ion] for ion in ion_idx]) + + def test_intensity_alone_improves_prediction(self, tmp_path): + """Adding median intensity as predictor should improve correlation + with true variance when intensity is the dominant signal.""" + rng = np.random.RandomState(7) + true_var = np.array([10.0 - 0.4 * i for i in range(N)]) + intensity = np.array([float(i) for i in range(N)]) + col_noisy = list(rng.randn(N)) + + scores_no_int = self._get_scores( + tmp_path, {"noisy": col_noisy}, ["noisy"], + true_var, ion_med_int_values=None, + ) + scores_with_int = self._get_scores( + tmp_path, {"noisy": col_noisy}, ["noisy"], + true_var, ion_med_int_values=intensity, + ) + rho_no_int = abs(_spearman(true_var, scores_no_int)) + rho_with_int = abs(_spearman(true_var, scores_with_int)) + assert rho_with_int > rho_no_int + 0.1, ( + f"Expected intensity to improve prediction: " + f"rho_with={rho_with_int:.3f}, rho_without={rho_no_int:.3f}" + ) + + def test_intensity_negative_coefficient(self, tmp_path): + """Higher intensity → lower variance: model should assign negative + weight to intensity and scores should track true variance.""" + intensity = np.array([float(i) for i in range(N)]) + true_var = np.array([float(N - 1 - i) * 0.5 for i in range(N)]) + col_vals = list(np.ones(N)) # uninformative quality col + scores = self._get_scores( + tmp_path, {"flat": col_vals}, ["flat"], + true_var, ion_med_int_values=intensity, + ) + rho = _spearman(true_var, scores) + assert rho > 0.9 + + def test_intensity_combined_with_quality_cols(self, tmp_path): + """Intensity + opposing quality column: both should be recovered.""" + intensity = np.array([float(i) for i in range(N)]) + quality = np.array([float(N - 1 - i) for i in range(N)]) + true_var = 0.3 * intensity + 0.7 * quality + scores = self._get_scores( + tmp_path, {"quality": list(quality)}, ["quality"], + true_var, ion_med_int_values=intensity, + ) + rho = _spearman(true_var, scores) + assert rho > 0.95 + + def test_none_intensity_is_equivalent_to_no_intensity(self, tmp_path): + """Passing None for ion_median_intensity should give the same results + as not passing it at all.""" + col_vals = [float(i) for i in range(N)] + true_var = [float(i) * 0.5 for i in range(N)] + scores_none = self._get_scores( + tmp_path, {"col": col_vals}, ["col"], + true_var, ion_med_int_values=None, + ) + pids = _precursor_ids() + ion_idx = _ion_index(pids) + ml_file = _write_ml_info(tmp_path, pids, {"col": col_vals}) + ion_var = _make_ion_variance(ion_idx, true_var) + result = aq_vp.load_variance_predictor_scores( + ml_file, ["S1", "S2"], + variance_predictor_cols=["col"], + ion_index=ion_idx, ion_variance=ion_var, + ) + scores_default = np.array([result[ion] for ion in ion_idx]) + np.testing.assert_array_almost_equal(scores_none, scores_default) + + +# --------------------------------------------------------------------------- +# _fit_and_predict edge cases +# --------------------------------------------------------------------------- + +class TestFitAndPredictEdgeCases: + def test_returns_none_with_too_few_valid_ions(self): + prec_features = pd.DataFrame({"col": [1.0, 2.0]}, index=["P1", "P2"]) + precursor_ids = np.array(["P1", "P2"]) + ion_index = pd.Index(["P1_FRGION_y1", "P2_FRGION_y1"]) + ion_var = pd.Series([0.1, 0.2], index=ion_index) + result = aq_vp._fit_and_predict( + prec_features, ["col"], precursor_ids, ion_index, ion_var + ) + assert result is None + + def test_missing_features_get_median_fallback(self, tmp_path): + """Ions whose precursor isn't in the ml_info_table should get the + median predicted score instead of NaN.""" + pids = _precursor_ids() + ion_idx = _ion_index(pids) + col_vals = [float(i) for i in range(N)] + true_var = [float(i) for i in range(N)] + # Only write ml_info for half the precursors + half_pids = pids[:N // 2] + half_cols = col_vals[:N // 2] + ml_file = _write_ml_info(tmp_path, half_pids, {"col": half_cols}) + ion_var = _make_ion_variance(ion_idx, true_var) + result = aq_vp.load_variance_predictor_scores( + ml_file, ["S1", "S2"], variance_predictor_cols=["col"], + ion_index=ion_idx, ion_variance=ion_var, + ) + assert result is not None + # All scores should be finite (no NaN) + assert all(np.isfinite(v) for v in result.values()) + + +# --------------------------------------------------------------------------- +# Ion split pattern tests +# --------------------------------------------------------------------------- + +class TestIonSplitPattern: + def test_splits_frgion(self): + parts = aq_vp._ION_SPLIT_PAT.split("PEP1_MOD1_2_FRGION_y3") + assert parts[0] == "PEP1_MOD1_2" + + def test_splits_ms1isotopes(self): + parts = aq_vp._ION_SPLIT_PAT.split("PEP1_MOD1_2_MS1ISOTOPES_0") + assert parts[0] == "PEP1_MOD1_2" + + def test_no_split_for_precursor_ids(self): + parts = aq_vp._ION_SPLIT_PAT.split("PEP1_MOD1_2") + assert len(parts) == 1 + assert parts[0] == "PEP1_MOD1_2"