diff --git a/bin/check_contamination.py b/bin/check_contamination.py index a004360a..0f5388bb 100755 --- a/bin/check_contamination.py +++ b/bin/check_contamination.py @@ -6,32 +6,47 @@ import click -import pandas as pd import matplotlib.pyplot as plt +import pandas as pd import seaborn as sns from read_utils import custom_na_values +from utils_filter import germline_mask, somatic_mask # Assuming somatic_variants and germline_variants are loaded as pandas DataFrames def compute_shared_variants(somatic_variants, germline_variants): - """ - # Example usage: - # shared_variants_matrix = compute_shared_variants(somatic_variants, germline_variants) + """Count mutations shared between each somatic sample and each germline sample. + + Parameters + ---------- + somatic_variants : pd.DataFrame + Variant table with at least ``SAMPLE_ID`` and ``MUT_ID`` columns, + providing the somatic side of the comparison. + germline_variants : pd.DataFrame + Variant table with at least ``SAMPLE_ID`` and ``MUT_ID`` columns, + providing the germline side of the comparison. + + Returns + ------- + pd.DataFrame + Integer matrix indexed by somatic ``SAMPLE_ID`` (rows) and germline + ``SAMPLE_ID`` (columns); each cell is the number of ``MUT_ID`` values + shared between that pair of samples. """ - unique_somatic_samples = sorted(somatic_variants['SAMPLE_ID'].unique()) - unique_germline_samples = sorted(germline_variants['SAMPLE_ID'].unique()) + unique_somatic_samples = sorted(somatic_variants["SAMPLE_ID"].unique()) + unique_germline_samples = sorted(germline_variants["SAMPLE_ID"].unique()) # Create a DataFrame to store counts (avoid .fillna downcasting warning) shared_counts = pd.DataFrame(0, index=unique_somatic_samples, columns=unique_germline_samples, dtype=int) # Iterate through each somatic sample for somatic_sample in unique_somatic_samples: - somatic_mutations = set(somatic_variants[somatic_variants['SAMPLE_ID'] == somatic_sample]['MUT_ID']) + somatic_mutations = set(somatic_variants[somatic_variants["SAMPLE_ID"] == somatic_sample]["MUT_ID"]) # Compare with germline mutations of all other samples for germline_sample in unique_germline_samples: - germline_mutations = set(germline_variants[germline_variants['SAMPLE_ID'] == germline_sample]['MUT_ID']) + germline_mutations = set(germline_variants[germline_variants["SAMPLE_ID"] == germline_sample]["MUT_ID"]) # Count shared mutations shared_counts.loc[somatic_sample, germline_sample] = len(somatic_mutations & germline_mutations) @@ -39,27 +54,41 @@ def compute_shared_variants(somatic_variants, germline_variants): return shared_counts - - -def contamination_detection_between_samples(maf_df, somatic_maf_df): - - # this is if we were to consider both unique and no-unique variants - vaf_threshold = 0.2 - germline_vars_all_samples = maf_df.loc[(maf_df["VAF"] > vaf_threshold) & (maf_df["vd_VAF"] > vaf_threshold) & (maf_df["VAF_AM"] > vaf_threshold), - ["SAMPLE_ID", "MUT_ID"]].drop_duplicates() +def contamination_detection_between_samples(maf_df, somatic_maf_df, somatic_vaf_boundary): + """Detect cross-sample contamination by comparing somatic and germline mutations. + + Builds somatic-vs-germline, all-vs-germline and germline-vs-germline shared-mutation + matrices, renders them as heatmaps, flags sample pairs where a large proportion of one + sample's germline variants appear as non-germline (somatic-looking) variants in another, + and writes the contaminated-sample tables and per-pair detail files to the current + working directory. + + Parameters + ---------- + maf_df : pd.DataFrame + Full mutation table for all samples (used to derive germline variants and the + all-variants set), with at least ``SAMPLE_ID``, ``MUT_ID``, ``VAF``, ``vd_VAF`` and + ``VAF_AM`` columns. + somatic_maf_df : pd.DataFrame + Filtered somatic mutation table, with at least ``SAMPLE_ID`` and ``MUT_ID`` columns. + somatic_vaf_boundary : float + VAF threshold passed to ``germline_mask`` to identify germline variants (a variant is + germline when all of ``VAF``/``vd_VAF``/``VAF_AM`` exceed it). + """ + # consider both unique and non-unique variants when collecting germline variants + germline_vars_all_samples = maf_df.loc[ + germline_mask(maf_df, somatic_vaf_boundary), ["SAMPLE_ID", "MUT_ID"] + ].drop_duplicates() print(germline_vars_all_samples["MUT_ID"].shape) print(len(germline_vars_all_samples["MUT_ID"].unique())) - somatic_variants = somatic_maf_df[["SAMPLE_ID", "MUT_ID"]] print(somatic_variants.shape) - all_variants = maf_df[["SAMPLE_ID", "MUT_ID"]] print(all_variants.shape) - ## Somatic vs Germline shared_variants_somatic2germline_matrix = compute_shared_variants(somatic_variants, germline_vars_all_samples) @@ -67,7 +96,9 @@ def contamination_detection_between_samples(maf_df, somatic_maf_df): plt.figure(figsize=(18, 15)) # Compute total number of germline mutations per sample - germline_counts = germline_vars_all_samples["SAMPLE_ID"].value_counts().reindex(shared_variants_somatic2germline_matrix.columns) + germline_counts = ( + germline_vars_all_samples["SAMPLE_ID"].value_counts().reindex(shared_variants_somatic2germline_matrix.columns) + ) # Create custom column labels with germline mutation counts col_labels = [f"(n={germline_counts[col]}) {col}" for col in shared_variants_somatic2germline_matrix.columns] @@ -76,76 +107,72 @@ def contamination_detection_between_samples(maf_df, somatic_maf_df): mask = shared_variants_somatic2germline_matrix > 30 annot = shared_variants_somatic2germline_matrix.where(mask) # convert selected values to nullable int then to string, replace missing with empty string - annot = annot.round(0).astype('Int64').astype(str).replace('', '').fillna('') + annot = annot.round(0).astype("Int64").astype(str).replace("", "").fillna("") sns.heatmap( shared_variants_somatic2germline_matrix, annot=annot, fmt="", cmap="Blues", - cbar_kws={'label': 'Shared Mutations'}, + cbar_kws={"label": "Shared Mutations"}, xticklabels=col_labels, yticklabels=shared_variants_somatic2germline_matrix.index, linewidths=0.5, - annot_kws={"color": "black", "fontsize": 10} + annot_kws={"color": "black", "fontsize": 10}, ) plt.xlabel("Germline Samples", fontsize=14) plt.ylabel("Somatic Samples", fontsize=14) plt.title("Somatic mutations that are germline in other samples", fontsize=16) - plt.savefig("somatic_vs_germline.pdf", bbox_inches = 'tight', dpi = 100) + plt.savefig("somatic_vs_germline.pdf", bbox_inches="tight", dpi=100) plt.show() - - - - ## All vs Germline shared_all_vs_germline_variants_matrix = compute_shared_variants(all_variants, germline_vars_all_samples) # Compute total number of germline mutations per sample - germline_counts = germline_vars_all_samples["SAMPLE_ID"].value_counts().reindex(shared_all_vs_germline_variants_matrix.columns) - - - normalized_shared_all_vs_germline_variants_matrix = shared_all_vs_germline_variants_matrix.divide(germline_counts, axis=1) - + germline_counts = ( + germline_vars_all_samples["SAMPLE_ID"].value_counts().reindex(shared_all_vs_germline_variants_matrix.columns) + ) + normalized_shared_all_vs_germline_variants_matrix = shared_all_vs_germline_variants_matrix.divide( + germline_counts, axis=1 + ) # Count shared mutations between somatic and germline samples plt.figure(figsize=(18, 15)) - # Create custom column labels with germline mutation counts - col_labels = [f"(n={germline_counts[col]}) {col}" for col in normalized_shared_all_vs_germline_variants_matrix.columns] - + col_labels = [ + f"(n={germline_counts[col]}) {col}" for col in normalized_shared_all_vs_germline_variants_matrix.columns + ] # Annotation: show rounded values only when 0.8 < x < 1 - cond = (normalized_shared_all_vs_germline_variants_matrix > 0.8) & (normalized_shared_all_vs_germline_variants_matrix < 1) + cond = (normalized_shared_all_vs_germline_variants_matrix > 0.8) & ( + normalized_shared_all_vs_germline_variants_matrix < 1 + ) annot = normalized_shared_all_vs_germline_variants_matrix.where(cond) - annot = annot.round(2).astype('string').fillna('') - - sns.heatmap(normalized_shared_all_vs_germline_variants_matrix, - annot=annot, - fmt="", - cmap="Blues", - cbar_kws={'label': 'Shared Mutations'}, - xticklabels=col_labels, yticklabels=normalized_shared_all_vs_germline_variants_matrix.index, - annot_kws={"color": "white", "fontsize": 10}, - linewidths=0.5) - - plt.xlabel("Germline Samples", fontsize = 14) - plt.ylabel("All mutations samples", fontsize = 14) - plt.title("All mutations that are germline in other samples", fontsize = 16) - plt.savefig("allmutations_vs_germline.pdf", bbox_inches = 'tight', dpi = 100) - plt.show() - - - - + annot = annot.round(2).astype("string").fillna("") + sns.heatmap( + normalized_shared_all_vs_germline_variants_matrix, + annot=annot, + fmt="", + cmap="Blues", + cbar_kws={"label": "Shared Mutations"}, + xticklabels=col_labels, + yticklabels=normalized_shared_all_vs_germline_variants_matrix.index, + annot_kws={"color": "white", "fontsize": 10}, + linewidths=0.5, + ) + plt.xlabel("Germline Samples", fontsize=14) + plt.ylabel("All mutations samples", fontsize=14) + plt.title("All mutations that are germline in other samples", fontsize=16) + plt.savefig("allmutations_vs_germline.pdf", bbox_inches="tight", dpi=100) + plt.show() ## Germline vs Germline @@ -154,70 +181,74 @@ def contamination_detection_between_samples(maf_df, somatic_maf_df): plt.figure(figsize=(18, 15)) # Compute total number of germline mutations per sample - germline_counts = germline_vars_all_samples["SAMPLE_ID"].value_counts().reindex(shared_germline_variants_matrix.columns) + germline_counts = ( + germline_vars_all_samples["SAMPLE_ID"].value_counts().reindex(shared_germline_variants_matrix.columns) + ) # Create custom column labels with germline mutation counts col_labels = [f"(n={germline_counts[col]}) {col}" for col in shared_germline_variants_matrix.columns] - # Annotation: follow original logic (keep values where < 0, else blank) mask = shared_germline_variants_matrix < 0 annot = shared_germline_variants_matrix.where(mask) - annot = annot.astype('string').fillna('') - - sns.heatmap(shared_germline_variants_matrix, - annot=annot, - fmt="", - cmap="Blues", - cbar_kws={'label': 'Shared Mutations'}, - xticklabels=col_labels, yticklabels=shared_germline_variants_matrix.index, - linewidths=0.5, - annot_kws={"fontsize": 8} - ) - - plt.xlabel("Germline Samples", fontsize = 14) - plt.ylabel("Germline Samples", fontsize = 14) - plt.title("Germline mutations that are germline in other samples", fontsize = 16) - plt.savefig("germline_vs_germline.pdf", bbox_inches = 'tight', dpi = 100) - plt.show() - - + annot = annot.astype("string").fillna("") + sns.heatmap( + shared_germline_variants_matrix, + annot=annot, + fmt="", + cmap="Blues", + cbar_kws={"label": "Shared Mutations"}, + xticklabels=col_labels, + yticklabels=shared_germline_variants_matrix.index, + linewidths=0.5, + annot_kws={"fontsize": 8}, + ) + plt.xlabel("Germline Samples", fontsize=14) + plt.ylabel("Germline Samples", fontsize=14) + plt.title("Germline mutations that are germline in other samples", fontsize=16) + plt.savefig("germline_vs_germline.pdf", bbox_inches="tight", dpi=100) + plt.show() # Compute total number of germline mutations per sample - germline_counts = germline_vars_all_samples["SAMPLE_ID"].value_counts().reindex(shared_germline_variants_matrix.columns) - - normalized_share_germline_vs_germline_variants_matrix = shared_germline_variants_matrix.divide(germline_counts, axis=1) + germline_counts = ( + germline_vars_all_samples["SAMPLE_ID"].value_counts().reindex(shared_germline_variants_matrix.columns) + ) + normalized_share_germline_vs_germline_variants_matrix = shared_germline_variants_matrix.divide( + germline_counts, axis=1 + ) plt.figure(figsize=(18, 15)) - # Create custom column labels with germline mutation counts - col_labels = [f"(n={germline_counts[col]}) {col}" for col in normalized_share_germline_vs_germline_variants_matrix.columns] + col_labels = [ + f"(n={germline_counts[col]}) {col}" for col in normalized_share_germline_vs_germline_variants_matrix.columns + ] - - cond = (normalized_share_germline_vs_germline_variants_matrix > 0.8) & (normalized_share_germline_vs_germline_variants_matrix < 1) + cond = (normalized_share_germline_vs_germline_variants_matrix > 0.8) & ( + normalized_share_germline_vs_germline_variants_matrix < 1 + ) annot = normalized_share_germline_vs_germline_variants_matrix.where(cond) - annot = annot.round(2).astype('string').fillna('') - - sns.heatmap(normalized_share_germline_vs_germline_variants_matrix, - annot=annot, - fmt="", - cmap="Blues", - cbar_kws={'label': 'Shared Mutations'}, - xticklabels=col_labels, yticklabels=normalized_share_germline_vs_germline_variants_matrix.index, - annot_kws={"color": "white", "fontsize": 10}, - linewidths=0.5) - - plt.xlabel("Germline Samples", fontsize = 14) - plt.ylabel("Germline samples", fontsize = 14) - plt.savefig("normalized.germline_vs_germline.pdf", bbox_inches = 'tight', dpi = 100) - plt.show() - + annot = annot.round(2).astype("string").fillna("") + sns.heatmap( + normalized_share_germline_vs_germline_variants_matrix, + annot=annot, + fmt="", + cmap="Blues", + cbar_kws={"label": "Shared Mutations"}, + xticklabels=col_labels, + yticklabels=normalized_share_germline_vs_germline_variants_matrix.index, + annot_kws={"color": "white", "fontsize": 10}, + linewidths=0.5, + ) + plt.xlabel("Germline Samples", fontsize=14) + plt.ylabel("Germline samples", fontsize=14) + plt.savefig("normalized.germline_vs_germline.pdf", bbox_inches="tight", dpi=100) + plt.show() ## Somatic vs Remaining Germline shared_somatic_to_non_shared_germline = shared_all_vs_germline_variants_matrix - shared_germline_variants_matrix @@ -226,102 +257,121 @@ def contamination_detection_between_samples(maf_df, somatic_maf_df): shared_somatic_to_non_shared_germline[shared_somatic_to_non_shared_germline < 5] = 0 # Compute total number of germline mutations per sample - germline_counts = germline_vars_all_samples["SAMPLE_ID"].value_counts().reindex(shared_somatic_to_non_shared_germline.columns) - - - total_germline_available_per_sample = (germline_counts - shared_germline_variants_matrix) - - shared_somatic_to_non_shared_germline_proportion = (shared_somatic_to_non_shared_germline / total_germline_available_per_sample).fillna(0) + germline_counts = ( + germline_vars_all_samples["SAMPLE_ID"].value_counts().reindex(shared_somatic_to_non_shared_germline.columns) + ) + total_germline_available_per_sample = germline_counts - shared_germline_variants_matrix + shared_somatic_to_non_shared_germline_proportion = ( + shared_somatic_to_non_shared_germline / total_germline_available_per_sample + ).fillna(0) plt.figure(figsize=(22, 18)) cond = shared_somatic_to_non_shared_germline_proportion > 0.45 annot = shared_somatic_to_non_shared_germline_proportion.where(cond) - annot = annot.round(2).astype('string').fillna('') - - sns.heatmap(shared_somatic_to_non_shared_germline_proportion, - annot=annot, - fmt="", - cmap="Blues", - cbar_kws={'label': 'Shared Mutations'}, - # xticklabels=col_labels, - yticklabels=shared_somatic_to_non_shared_germline_proportion.index, - annot_kws={"color": "black", "fontsize": 10}, - linewidths=0.5) - - plt.xlabel("Non-shared germline", fontsize = 14) - plt.ylabel("Somatic", fontsize = 14) - plt.title("Somatic mutations that are germline in other samples", fontsize = 16) - plt.savefig("contamination.somatic_vs_remaininggermline.pdf", bbox_inches = 'tight', dpi = 100) + annot = annot.round(2).astype("string").fillna("") + + sns.heatmap( + shared_somatic_to_non_shared_germline_proportion, + annot=annot, + fmt="", + cmap="Blues", + cbar_kws={"label": "Shared Mutations"}, + # xticklabels=col_labels, + yticklabels=shared_somatic_to_non_shared_germline_proportion.index, + annot_kws={"color": "black", "fontsize": 10}, + linewidths=0.5, + ) + + plt.xlabel("Non-shared germline", fontsize=14) + plt.ylabel("Somatic", fontsize=14) + plt.title("Somatic mutations that are germline in other samples", fontsize=16) + plt.savefig("contamination.somatic_vs_remaininggermline.pdf", bbox_inches="tight", dpi=100) plt.show() plt.close() - plt.figure(figsize=(22, 18)) cond = shared_somatic_to_non_shared_germline > 0 annot = shared_somatic_to_non_shared_germline.where(cond) # convert to nullable int then string, replace missing with empty string - annot = annot.round(0).astype('Int64').astype('string').replace('', '').fillna('') - - sns.heatmap(shared_somatic_to_non_shared_germline, - annot=annot, - fmt="", - cmap="Blues", - cbar_kws={'label': 'Shared Mutations'}, - # xticklabels=col_labels, - yticklabels=shared_somatic_to_non_shared_germline.index, - annot_kws={"color": "black", "fontsize": 10}, - linewidths=0.5) - - plt.xlabel("Non-shared germline", fontsize = 14) - plt.ylabel("Somatic", fontsize = 14) - plt.title("Somatic mutations that are germline in other samples (count)", fontsize = 16) - plt.savefig("contamination.somatic_vs_remaininggermline.numbers.pdf", bbox_inches = 'tight', dpi = 100) + annot = annot.round(0).astype("Int64").astype("string").replace("", "").fillna("") + + sns.heatmap( + shared_somatic_to_non_shared_germline, + annot=annot, + fmt="", + cmap="Blues", + cbar_kws={"label": "Shared Mutations"}, + # xticklabels=col_labels, + yticklabels=shared_somatic_to_non_shared_germline.index, + annot_kws={"color": "black", "fontsize": 10}, + linewidths=0.5, + ) + + plt.xlabel("Non-shared germline", fontsize=14) + plt.ylabel("Somatic", fontsize=14) + plt.title("Somatic mutations that are germline in other samples (count)", fontsize=16) + plt.savefig("contamination.somatic_vs_remaininggermline.numbers.pdf", bbox_inches="tight", dpi=100) plt.show() plt.close() - - max_prop_per_sample = shared_somatic_to_non_shared_germline_proportion.max(axis = 'columns') + max_prop_per_sample = shared_somatic_to_non_shared_germline_proportion.max(axis="columns") ## Exploration of contaminated samples receiver_source_pairs = [] - for sample, max_val in max_prop_per_sample[max_prop_per_sample>0.5].reset_index().values: - sample_vals = shared_somatic_to_non_shared_germline_proportion.loc[sample,:] - sample_vals_count = shared_somatic_to_non_shared_germline.loc[sample,:] + for sample, max_val in max_prop_per_sample[max_prop_per_sample > 0.5].reset_index().values: + sample_vals = shared_somatic_to_non_shared_germline_proportion.loc[sample, :] + sample_vals_count = shared_somatic_to_non_shared_germline.loc[sample, :] source_sampleids = sample_vals[sample_vals == max_val].index.values source_sampleid = source_sampleids[0] - receiver_source_pairs.append((sample, round(max_val,3), - list(zip([sample_vals_count[x].item() for x in source_sampleids], source_sampleids)))) + receiver_source_pairs.append( + ( + sample, + round(max_val, 3), + list(zip([sample_vals_count[x].item() for x in source_sampleids], source_sampleids)), + ) + ) - print(f'{sample} has {max_val:.2f} proportion of the germline variants of {source_sampleid} as with a VAF not corresponding to germline variants.') - print(f'Shared variants count: {sample_vals_count[source_sampleid]}') + print( + f"{sample} has {max_val:.2f} proportion of the germline variants of {source_sampleid} as with a VAF not corresponding to germline variants." + ) + print(f"Shared variants count: {sample_vals_count[source_sampleid]}") print() - - subseeeet = maf_df[["SAMPLE_ID", "MUT_ID", 'canonical_SYMBOL', "ALT_DEPTH", "DEPTH", "VAF", 'canonical_Consequence_broader', 'FILTER']] - p_dest = subseeeet[subseeeet["SAMPLE_ID"] == sample].drop("SAMPLE_ID", axis = 1) + subseeeet = maf_df[ + [ + "SAMPLE_ID", + "MUT_ID", + "canonical_SYMBOL", + "ALT_DEPTH", + "DEPTH", + "VAF", + "canonical_Consequence_broader", + "FILTER", + ] + ] + p_dest = subseeeet[subseeeet["SAMPLE_ID"] == sample].drop("SAMPLE_ID", axis=1) p_source_germ = germline_vars_all_samples[germline_vars_all_samples["SAMPLE_ID"] == source_sampleid] - p_source = subseeeet[(subseeeet["SAMPLE_ID"] == source_sampleid) - & (subseeeet["MUT_ID"].isin(p_source_germ["MUT_ID"].values)) - ].drop("SAMPLE_ID", axis = 1) - - merged_samples = p_dest.merge(p_source, - on = ["MUT_ID", 'canonical_SYMBOL', 'canonical_Consequence_broader'], - suffixes = ("_dest", "_source"), - how = 'right' - ) - - merged_samples.sort_values(by =["VAF_dest"], ascending=False - ).to_csv(f"{source_sampleid}.germline_variants_in.{sample}.tsv", - header = True, - sep = '\t', - index = False) - + p_source = subseeeet[ + (subseeeet["SAMPLE_ID"] == source_sampleid) & (subseeeet["MUT_ID"].isin(p_source_germ["MUT_ID"].values)) + ].drop("SAMPLE_ID", axis=1) + + merged_samples = p_dest.merge( + p_source, + on=["MUT_ID", "canonical_SYMBOL", "canonical_Consequence_broader"], + suffixes=("_dest", "_source"), + how="right", + ) + + merged_samples.sort_values(by=["VAF_dest"], ascending=False).to_csv( + f"{source_sampleid}.germline_variants_in.{sample}.tsv", header=True, sep="\t", index=False + ) + # plt.figure(figsize=(8, 6)) # plt.scatter(x = merged_samples["VAF_dest"].fillna(0), # y = merged_samples["VAF_source"].fillna(0), @@ -336,12 +386,10 @@ def contamination_detection_between_samples(maf_df, somatic_maf_df): # plt.show() # Store contamination results - contamination_detailed_df = pd.DataFrame(receiver_source_pairs, - columns=["SAMPLE_ID", "MAX_PROPORTION_GERMLINE_FROM_SOURCE", "SOURCE_SAMPLEID_COUNTS"]) - contamination_detailed_df.to_csv(f"contaminated_samples.detailed.tsv", - header = True, - sep = '\t', - index = False) + contamination_detailed_df = pd.DataFrame( + receiver_source_pairs, columns=["SAMPLE_ID", "MAX_PROPORTION_GERMLINE_FROM_SOURCE", "SOURCE_SAMPLEID_COUNTS"] + ) + contamination_detailed_df.to_csv("contaminated_samples.detailed.tsv", header=True, sep="\t", index=False) if contamination_detailed_df.empty: print("No contaminated samples detected.") @@ -351,38 +399,64 @@ def contamination_detection_between_samples(maf_df, somatic_maf_df): expanded_df.columns = ["SHARED_VARIANT_COUNT", "SOURCE_SAMPLEID"] contamination_detailed_df_long["SHARED_VARIANT_COUNT"] = expanded_df["SHARED_VARIANT_COUNT"].values contamination_detailed_df_long["SOURCE_SAMPLEID"] = expanded_df["SOURCE_SAMPLEID"].values - contamination_detailed_df_long = contamination_detailed_df_long.drop("SOURCE_SAMPLEID_COUNTS", axis = 1) - contamination_detailed_df_long.to_csv(f"contaminated_samples.detailed.long.tsv", - header = True, - sep = '\t', - index = False) - + contamination_detailed_df_long = contamination_detailed_df_long.drop("SOURCE_SAMPLEID_COUNTS", axis=1) + contamination_detailed_df_long.to_csv("contaminated_samples.detailed.long.tsv", header=True, sep="\t", index=False) def data_loading(maf_path, somatic_maf_path): + """Load the full and somatic MAF tables, keeping only covered SNVs. + + Parameters + ---------- + maf_path : str + Path to the full MAF file; rows flagged ``FILTER.not_covered`` are dropped and only + ``TYPE == "SNV"`` rows are kept. + somatic_maf_path : str + Path to the filtered somatic MAF file; only ``TYPE == "SNV"`` rows are kept. + + Returns + ------- + tuple + ``(maf_df, somatic_maf_df)`` — the filtered full mutation table and the filtered + somatic mutation table, both as ``pd.DataFrame``. + """ maf_df = pd.read_table(maf_path, na_values=custom_na_values) print(maf_df.shape) - maf_df = maf_df[~(maf_df["FILTER.not_covered"]) - & (maf_df["TYPE"] == 'SNV') - ].reset_index() + maf_df = maf_df[~(maf_df["FILTER.not_covered"]) & (maf_df["TYPE"] == "SNV")].reset_index() print(maf_df.shape) somatic_maf_df = pd.read_table(somatic_maf_path, na_values=custom_na_values) print(somatic_maf_df.shape) - somatic_maf_df = somatic_maf_df[(somatic_maf_df["TYPE"] == 'SNV')] + somatic_maf_df = somatic_maf_df[(somatic_maf_df["TYPE"] == "SNV")] print(somatic_maf_df.shape) return maf_df, somatic_maf_df def contamination_detection_in_snps(maf): + """Estimate per-sample contamination from the VAF distribution at known SNP positions. + + Restricts to gnomAD SNP positions, splits them into somatic-looking and germline-looking + sets by VAF, computes the per-sample proportion of SNP positions that look somatic, writes + the resulting table, and plots its distribution across samples. + + Parameters + ---------- + maf : pd.DataFrame + Full mutation table with at least ``SAMPLE_ID``, ``MUT_ID``, ``VAF``, ``vd_VAF``, ``VAF_AM``, and the boolean + ``FILTER.gnomAD_SNP`` column. + """ + snp_positions_maf = maf[maf["FILTER.gnomAD_SNP"]][["SAMPLE_ID", "MUT_ID", "VAF", "vd_VAF", "VAF_AM"]].reset_index( + drop=True + ) - snp_positions_maf = maf[maf["FILTER.gnomAD_SNP"]][ - ["SAMPLE_ID", "MUT_ID", "VAF"] - ].reset_index(drop = True) - # being very restrictive in the VAF to count the occurrences of potentially contaminated mutations - somatic_snp_positions_maf = snp_positions_maf[snp_positions_maf["VAF"] < 0.05].reset_index(drop = True) - germline_snp_positions_maf = snp_positions_maf[snp_positions_maf["VAF"] >= 0.05].reset_index(drop = True) + contamination_vaf_threshold = 0.05 + somatic_snp_positions_maf = snp_positions_maf.loc[ + somatic_mask(snp_positions_maf, contamination_vaf_threshold) + ].reset_index(drop=True) + germline_snp_positions_maf = snp_positions_maf.loc[ + germline_mask(snp_positions_maf, contamination_vaf_threshold) + ].reset_index(drop=True) unique_SNP_positions = snp_positions_maf["MUT_ID"].unique() number_unique_SNP_positions = len(unique_SNP_positions) @@ -391,23 +465,31 @@ def contamination_detection_in_snps(maf): for sample in snp_positions_maf["SAMPLE_ID"].unique(): germline_count = len(germline_snp_positions_maf[germline_snp_positions_maf["SAMPLE_ID"] == sample]) somatic_count = len(somatic_snp_positions_maf[somatic_snp_positions_maf["SAMPLE_ID"] == sample]) - remaining_germline = number_unique_SNP_positions-germline_count - sample_SNP_mutation_freq.append([sample, - germline_count, - remaining_germline, - somatic_count, - somatic_count / remaining_germline if remaining_germline > 0 else 1 - ]) + remaining_germline = number_unique_SNP_positions - germline_count + sample_SNP_mutation_freq.append( + [ + sample, + germline_count, + remaining_germline, + somatic_count, + somatic_count / remaining_germline if remaining_germline > 0 else 1, + ] + ) sample_SNP_mutation_freq_df = pd.DataFrame(sample_SNP_mutation_freq) - sample_SNP_mutation_freq_df.columns = ["SAMPLE_ID", "germline_count", "remaining_germline", "somatic_count", "prop_somatic_SNPs"] + sample_SNP_mutation_freq_df.columns = [ + "SAMPLE_ID", + "germline_count", + "remaining_germline", + "somatic_count", + "prop_somatic_SNPs", + ] # identify outliers in the "prop_somatic_SNPs" column - sample_SNP_mutation_freq_df = sample_SNP_mutation_freq_df.sort_values(by = "prop_somatic_SNPs", ascending = False) - sample_SNP_mutation_freq_df.to_csv("sample_SNP_mutation_freq.tsv", header = True, sep = '\t', index = False) + sample_SNP_mutation_freq_df = sample_SNP_mutation_freq_df.sort_values(by="prop_somatic_SNPs", ascending=False) + sample_SNP_mutation_freq_df.to_csv("sample_SNP_mutation_freq.tsv", header=True, sep="\t", index=False) plt.figure(figsize=(6, 3)) - sns.violinplot(data=sample_SNP_mutation_freq_df, x="prop_somatic_SNPs", - fill= False, color="lightgray", inner=None) + sns.violinplot(data=sample_SNP_mutation_freq_df, x="prop_somatic_SNPs", fill=False, color="lightgray", inner=None) sns.swarmplot(data=sample_SNP_mutation_freq_df, x="prop_somatic_SNPs", color="black", size=3) plt.title("Proportion of all SNPs across samples\ndetected as somatic") @@ -418,23 +500,41 @@ def contamination_detection_in_snps(maf): @click.command() -@click.option('--maf_path', type=click.Path(exists=True), required=True, help='Path to the MAF file.') -@click.option('--somatic_maf', type=click.Path(exists=True), required=True, help='Path to the filtered somatic mutations file.') -def main(maf_path, somatic_maf): - """ - CLI entry point for assessing contamination between samples using germline and somatic mutations. +@click.option("--maf_path", type=click.Path(exists=True), required=True, help="Path to the MAF file.") +@click.option( + "--somatic_maf", type=click.Path(exists=True), required=True, help="Path to the filtered somatic mutations file." +) +@click.option( + "--somatic-vaf-boundary", + type=float, + default=0.3, + show_default=True, + help="VAF boundary for somatic variants; a variant with all of VAF/vd_VAF/VAF_AM above it is germline.", +) +def main(maf_path, somatic_maf, somatic_vaf_boundary): + """Assess cross-sample contamination using germline and somatic mutations. + + Loads the input tables and runs both the between-samples and the SNP-based contamination + analyses, writing their tables and plots to the current working directory. + + Parameters + ---------- + maf_path : str + Path to the full MAF file. + somatic_maf : str + Path to the filtered somatic mutations file. + somatic_vaf_boundary : float, optional + VAF boundary separating somatic from germline variants. Default is 0.3. """ - + maf_df, somatic_maf_df = data_loading(maf_path, somatic_maf) print("Running contamination analysis between samples") - contamination_detection_between_samples(maf_df, somatic_maf_df) + contamination_detection_between_samples(maf_df, somatic_maf_df, somatic_vaf_boundary) print("Running general contamination analysis") contamination_detection_in_snps(maf_df) - -if __name__ == '__main__': - +if __name__ == "__main__": main() diff --git a/bin/filter_cohort.py b/bin/filter_cohort.py index 755eee1d..c6dca9a9 100755 --- a/bin/filter_cohort.py +++ b/bin/filter_cohort.py @@ -41,13 +41,12 @@ """ import logging -from pathlib import Path import click import pandas as pd -from utils import add_filter from read_utils import custom_na_values -from utils_filter import expand_filter_column +from utils import add_filter +from utils_filter import expand_filter_column, germline_mask, somatic_mask # Logging logging.basicConfig( @@ -55,9 +54,10 @@ ) LOG = logging.getLogger("filter_cohort") -def flag_repetitive_variants(maf_df: pd.DataFrame, - repetitive_variant_threshold: int, - somatic_vaf_boundary: float) -> pd.DataFrame: + +def flag_repetitive_variants( + maf_df: pd.DataFrame, repetitive_variant_threshold: int, somatic_vaf_boundary: float +) -> pd.DataFrame: """ Flags filter column for repetitive variants from the MAF dataframe. A variant is considered repetitive if it appears in at least ``repetitive_variant_threshold`` samples. Additionally, variants that consistently appear at the same position in reads @@ -87,13 +87,17 @@ def flag_repetitive_variants(maf_df: pd.DataFrame, # Work with already filtered df + somatic only to explore potential artifacts # take only variant and sample info from the df - maf_df_f_somatic = maf_df.loc[maf_df["VAF"] <= somatic_vaf_boundary][["MUT_ID","SAMPLE_ID", "PMEAN", "PSTD"]].reset_index(drop = True) + maf_df_f_somatic = maf_df.loc[somatic_mask(maf_df, somatic_vaf_boundary)][ + ["MUT_ID", "SAMPLE_ID", "PMEAN", "PSTD"] + ].reset_index(drop=True) # Group by 'MUT_ID' and count occurrences maf_df_f_somatic_pivot = maf_df_f_somatic.groupby("MUT_ID").size().reset_index(name="count") # Store repetitive variants - repetitive_variants = maf_df_f_somatic_pivot[maf_df_f_somatic_pivot["count"] >= repetitive_variant_threshold]["MUT_ID"] + repetitive_variants = maf_df_f_somatic_pivot[maf_df_f_somatic_pivot["count"] >= repetitive_variant_threshold][ + "MUT_ID" + ] LOG.info("%s repetitive_variants", len(repetitive_variants)) # Flag repetitive variants in the original dataframe @@ -104,10 +108,10 @@ def flag_repetitive_variants(maf_df: pd.DataFrame, maf_df = maf_df.drop("repetitive_variant", axis=1) # Use the position in read information to filter repetitive variants with a fixed position (likely artifacts) - maf_df_f_somatic_pos_info = maf_df_f_somatic[~(maf_df_f_somatic["PMEAN"].isna()) & - (maf_df_f_somatic["PMEAN"] != -1) & - (maf_df_f_somatic["PSTD"] == 0)] - + maf_df_f_somatic_pos_info = maf_df_f_somatic[ + ~(maf_df_f_somatic["PMEAN"].isna()) & (maf_df_f_somatic["PMEAN"] != -1) & (maf_df_f_somatic["PSTD"] == 0) + ] + # Check if there are any repetitive variants with a fixed position if maf_df_f_somatic_pos_info.shape[0] == 0: LOG.info("No repetitive variants with fixed position found.") @@ -127,19 +131,20 @@ def flag_repetitive_variants(maf_df: pd.DataFrame, # Flag these variants in the maf dataframe maf_df["repetitive_mapping_variant"] = maf_df["MUT_ID"].isin(variants_with_rep_position) LOG.info("%s muts flagged as repetitive_mapping_variant", maf_df["repetitive_mapping_variant"].sum()) - - maf_df["FILTER"] = maf_df[["FILTER","repetitive_mapping_variant"]].apply(lambda x: add_filter(x["FILTER"], x["repetitive_mapping_variant"], "repetitive_mapping_variant"), - axis = 1 - ) - maf_df = maf_df.drop("repetitive_mapping_variant", axis = 1) + + maf_df["FILTER"] = maf_df[["FILTER", "repetitive_mapping_variant"]].apply( + lambda x: add_filter(x["FILTER"], x["repetitive_mapping_variant"], "repetitive_mapping_variant"), axis=1 + ) + maf_df = maf_df.drop("repetitive_mapping_variant", axis=1) return maf_df -def flag_cohort_n_rich(maf_df: pd.DataFrame, - n_rich_cohort_proportion: float, - somatic_vaf_boundary: float) -> pd.DataFrame: + +def flag_cohort_n_rich( + maf_df: pd.DataFrame, n_rich_cohort_proportion: float, somatic_vaf_boundary: float +) -> pd.DataFrame: """ - Flags FILTER column for cohort_n_rich variants from the MAF dataframe. + Flags FILTER column for cohort_n_rich variants from the MAF dataframe. Parameters ---------- @@ -161,62 +166,59 @@ def flag_cohort_n_rich(maf_df: pd.DataFrame, if max_samples < 2: LOG.warning("Not enough samples to identify cohort_n_rich mutations!") return maf_df - + number_of_samples = max(2, (max_samples * n_rich_cohort_proportion) // 1) LOG.info(f"Flagging mutations that are n_rich in at least: {number_of_samples} samples as cohort_n_rich") # Work with already filtered df to explore potential artifacts # take only variant and sample info from the df. - maf_df_f = maf_df[["MUT_ID", "SAMPLE_ID", "VAF_Ns", "FILTER"]].reset_index(drop = True) + maf_df_f = maf_df[["MUT_ID", "SAMPLE_ID", "VAF_Ns", "FILTER"]].reset_index(drop=True) # Aggregate n_rich variants n_rich_vars_df = ( maf_df_f[maf_df_f["FILTER"].str.contains("n_rich")] .groupby("MUT_ID") - .agg( - N_rich_frequency=('SAMPLE_ID', 'count'), - VAF_Ns_threshold=('VAF_Ns', 'min') - ) - ) - + .agg(N_rich_frequency=("SAMPLE_ID", "count"), VAF_Ns_threshold=("VAF_Ns", "min")) + ) + # Flag variants that are n_rich in at least number_of_samples samples -> cohort_n_rich - n_rich_vars = set(n_rich_vars_df[n_rich_vars_df['N_rich_frequency'] >= number_of_samples].index) + n_rich_vars = set(n_rich_vars_df[n_rich_vars_df["N_rich_frequency"] >= number_of_samples].index) maf_df["cohort_n_rich"] = maf_df["MUT_ID"].isin(n_rich_vars) LOG.info("%s muts flagged as cohort_n_rich", maf_df["cohort_n_rich"].sum()) - maf_df["FILTER"] = maf_df[["FILTER","cohort_n_rich"]].apply(lambda x: add_filter(x["FILTER"], x["cohort_n_rich"], "cohort_n_rich"), - axis = 1 - ) - + maf_df["FILTER"] = maf_df[["FILTER", "cohort_n_rich"]].apply( + lambda x: add_filter(x["FILTER"], x["cohort_n_rich"], "cohort_n_rich"), axis=1 + ) + # Flag variants that are n_rich in at least 1 sample -> cohort_n_rich_uni - n_rich_vars_uni = set(n_rich_vars_df[n_rich_vars_df['N_rich_frequency'] > 0].index) + n_rich_vars_uni = set(n_rich_vars_df[n_rich_vars_df["N_rich_frequency"] > 0].index) maf_df["cohort_n_rich_uni"] = maf_df["MUT_ID"].isin(n_rich_vars_uni) LOG.info("%s muts flagged as cohort_n_rich_uni", maf_df["cohort_n_rich_uni"].sum()) - maf_df["FILTER"] = maf_df[["FILTER","cohort_n_rich_uni"]].apply(lambda x: add_filter(x["FILTER"], x["cohort_n_rich_uni"], "cohort_n_rich_uni"), - axis = 1 - ) - + maf_df["FILTER"] = maf_df[["FILTER", "cohort_n_rich_uni"]].apply( + lambda x: add_filter(x["FILTER"], x["cohort_n_rich_uni"], "cohort_n_rich_uni"), axis=1 + ) + # Flag variants that exceed the VAF_Ns threshold -> cohort_n_rich_threshold - maf_df = maf_df.merge(n_rich_vars_df, on = 'MUT_ID', how = 'left') - maf_df['N_rich_frequency'] = maf_df['N_rich_frequency'].fillna(0) - maf_df['VAF_Ns_threshold'] = maf_df['VAF_Ns_threshold'].fillna(1.1) + maf_df = maf_df.merge(n_rich_vars_df, on="MUT_ID", how="left") + maf_df["N_rich_frequency"] = maf_df["N_rich_frequency"].fillna(0) + maf_df["VAF_Ns_threshold"] = maf_df["VAF_Ns_threshold"].fillna(1.1) - maf_df["cohort_n_rich_threshold"] = maf_df["VAF_Ns"] >= maf_df['VAF_Ns_threshold'] + maf_df["cohort_n_rich_threshold"] = maf_df["VAF_Ns"] >= maf_df["VAF_Ns_threshold"] LOG.info("%s muts flagged as cohort_n_rich_threshold", maf_df["cohort_n_rich_threshold"].sum()) - maf_df["FILTER"] = maf_df[["FILTER","cohort_n_rich_threshold"]].apply(lambda x: add_filter(x["FILTER"], x["cohort_n_rich_threshold"], "cohort_n_rich_threshold"), - axis = 1 - ) + maf_df["FILTER"] = maf_df[["FILTER", "cohort_n_rich_threshold"]].apply( + lambda x: add_filter(x["FILTER"], x["cohort_n_rich_threshold"], "cohort_n_rich_threshold"), axis=1 + ) # Drop temporary columns - maf_df = maf_df.drop(["cohort_n_rich", "cohort_n_rich_uni", "cohort_n_rich_threshold"], axis = 1) - + maf_df = maf_df.drop(["cohort_n_rich", "cohort_n_rich_uni", "cohort_n_rich_threshold"], axis=1) + return maf_df -def flag_other_samples_snp(maf_df, - somatic_vaf_boundary: float) -> pd.DataFrame: + +def flag_other_samples_snp(maf_df, somatic_vaf_boundary: float) -> pd.DataFrame: """ Filters out SNPs from other samples from the MAF dataframe @@ -234,28 +236,27 @@ def flag_other_samples_snp(maf_df, """ LOG.info("Flagging SNPs from other samples...") # Get all germline variants from all samples, consider both unique and non-unique variants - germline_vars_all_samples = maf_df.loc[(maf_df["VAF"] > somatic_vaf_boundary) & - (maf_df["VAF_AM"] > somatic_vaf_boundary) & - (maf_df["vd_VAF"] > somatic_vaf_boundary), - "MUT_ID"].unique() - + germline_vars_all_samples = maf_df.loc[germline_mask(maf_df, somatic_vaf_boundary), "MUT_ID"].unique() + LOG.info(f"Using all germline variants of all samples, total: {len(germline_vars_all_samples)} variants.") # Identify variants that are germline in other samples but somatic in the current sample maf_df["other_sample_SNP"] = False - maf_df.loc[(maf_df["MUT_ID"].isin(germline_vars_all_samples)) & - (maf_df["VAF"] <= somatic_vaf_boundary), "other_sample_SNP"] = True - LOG.info("%s muts flagged as other_sample_SNP", maf_df['other_sample_SNP'].sum()) + maf_df.loc[ + (maf_df["MUT_ID"].isin(germline_vars_all_samples)) & somatic_mask(maf_df, somatic_vaf_boundary), + "other_sample_SNP", + ] = True + LOG.info("%s muts flagged as other_sample_SNP", maf_df["other_sample_SNP"].sum()) # Flag variants that are germline in other samples but somatic in the current sample - maf_df["FILTER"] = maf_df[["FILTER","other_sample_SNP"]].apply( - lambda x: add_filter(x["FILTER"], x["other_sample_SNP"], "other_sample_SNP"), - axis = 1 - ) - maf_df = maf_df.drop("other_sample_SNP", axis = 1) + maf_df["FILTER"] = maf_df[["FILTER", "other_sample_SNP"]].apply( + lambda x: add_filter(x["FILTER"], x["other_sample_SNP"], "other_sample_SNP"), axis=1 + ) + maf_df = maf_df.drop("other_sample_SNP", axis=1) return maf_df + def flag_gnomad_snp(maf_df: pd.DataFrame) -> pd.DataFrame: """ Flags gnomAD SNPs in the MAF dataframe @@ -274,19 +275,20 @@ def flag_gnomad_snp(maf_df: pd.DataFrame) -> pd.DataFrame: # Flag gnomAD SNPs if "gnomAD_SNP" in maf_df.columns: - maf_df["gnomAD_SNP"] = maf_df["gnomAD_SNP"].replace({"True": True, "False": False, '-' : False}).fillna(False).astype(bool) + maf_df["gnomAD_SNP"] = ( + maf_df["gnomAD_SNP"].replace({"True": True, "False": False, "-": False}).fillna(False).astype(bool) + ) LOG.info("Out of %d positions, %d are gnomAD SNPs", maf_df["gnomAD_SNP"].shape[0], maf_df["gnomAD_SNP"].sum()) - - maf_df["FILTER"] = maf_df[["FILTER","gnomAD_SNP"]].apply( - lambda x: add_filter(x["FILTER"], x["gnomAD_SNP"], "gnomAD_SNP"), - axis = 1 - ) - maf_df = maf_df.drop("gnomAD_SNP", axis = 1) + + maf_df["FILTER"] = maf_df[["FILTER", "gnomAD_SNP"]].apply( + lambda x: add_filter(x["FILTER"], x["gnomAD_SNP"], "gnomAD_SNP"), axis=1 + ) + maf_df = maf_df.drop("gnomAD_SNP", axis=1) return maf_df -def flag_vaf_ns_threshold(maf_df: pd.DataFrame, vaf_ns_threshold: float) -> pd.DataFrame: +def flag_vaf_ns_threshold(maf_df: pd.DataFrame, vaf_ns_threshold: float) -> pd.DataFrame: """ Flag variants that have a proportion of Ns higher than vaf_ns_threshold @@ -307,14 +309,14 @@ def flag_vaf_ns_threshold(maf_df: pd.DataFrame, vaf_ns_threshold: float) -> pd.D maf_df["high_n_vaf"] = maf_df[["VAF_Ns", "VAF_Ns_AM"]].ge(vaf_ns_threshold).any(axis=1) LOG.info("%s muts flagged as high_n_vaf", maf_df["high_n_vaf"].sum()) - maf_df["FILTER"] = maf_df[["FILTER","high_n_vaf"]].apply( - lambda x: add_filter(x["FILTER"], x["high_n_vaf"], "high_n_vaf"), - axis = 1 - ) - maf_df = maf_df.drop("high_n_vaf", axis = 1) + maf_df["FILTER"] = maf_df[["FILTER", "high_n_vaf"]].apply( + lambda x: add_filter(x["FILTER"], x["high_n_vaf"], "high_n_vaf"), axis=1 + ) + maf_df = maf_df.drop("high_n_vaf", axis=1) return maf_df + def flag_distorted_expanded(maf_df: pd.DataFrame) -> pd.DataFrame: """ If there is a column named VAF_distorted_expanded_sq, add a filter flag for variants with distorted VAF distribution. @@ -331,19 +333,22 @@ def flag_distorted_expanded(maf_df: pd.DataFrame) -> pd.DataFrame: """ LOG.info("Flagging variants with distorted VAF distribution...") - if 'VAF_distorted_expanded_sq' in maf_df.columns: - maf_df["FILTER"] = maf_df[["FILTER","VAF_distorted_expanded_sq"]].apply( - lambda x: add_filter(x["FILTER"], x["VAF_distorted_expanded_sq"], "VAF_distorted_expanded_sq"), - axis = 1 - ) + if "VAF_distorted_expanded_sq" in maf_df.columns: + maf_df["FILTER"] = maf_df[["FILTER", "VAF_distorted_expanded_sq"]].apply( + lambda x: add_filter(x["FILTER"], x["VAF_distorted_expanded_sq"], "VAF_distorted_expanded_sq"), axis=1 + ) return maf_df -def flag_maf(maf_df: pd.DataFrame, sample_name: str, - repetitive_variant_threshold: int, - somatic_vaf_boundary: float, - n_rich_cohort_proportion: float, - vaf_ns_threshold: float) -> None: + +def flag_maf( + maf_df: pd.DataFrame, + sample_name: str, + repetitive_variant_threshold: int, + somatic_vaf_boundary: float, + n_rich_cohort_proportion: float, + vaf_ns_threshold: float, +) -> None: """ Script to process a MAF (Mutation Annotation Format) file. It filters out repetitive variants, cohort_n_rich variants, and SNPs from other samples. @@ -386,36 +391,44 @@ def flag_maf(maf_df: pd.DataFrame, sample_name: str, maf_df = expand_filter_column(maf_df) ## Save final filtered MAF - maf_df.to_csv(f"{sample_name}.cohort.filtered.tsv.gz", - sep = "\t", - header = True, - index = False) - + maf_df.to_csv(f"{sample_name}.cohort.filtered.tsv.gz", sep="\t", header=True, index=False) + LOG.info("Cohort flagging complete!") + @click.command() -@click.option('--maf-df-file', required=True, type=click.Path(exists=True), help='Input gzipped MAF file (TSV)') -@click.option('--sample-name', required=True, type=str, help='Sample name for output file') -@click.option('--repetitive-variant-threshold', required=True, type=int, help='Threshold for repetitive variants') -@click.option('--somatic-vaf-boundary', required=True, type=float, help='VAF boundary for somatic variants') -@click.option('--n-rich-cohort-proportion', required=True, type=float, help='Proportion for n-rich cohort filtering') -@click.option('--vaf-ns-threshold', required=False, type=float, default=0.1, help='VAF of Ns threshold for filtering variants') -def main(maf_df_file: str, sample_name: str, repetitive_variant_threshold: int, - somatic_vaf_boundary: float, n_rich_cohort_proportion: float, vaf_ns_threshold: float): +@click.option("--maf-df-file", required=True, type=click.Path(exists=True), help="Input gzipped MAF file (TSV)") +@click.option("--sample-name", required=True, type=str, help="Sample name for output file") +@click.option("--repetitive-variant-threshold", required=True, type=int, help="Threshold for repetitive variants") +@click.option("--somatic-vaf-boundary", required=True, type=float, help="VAF boundary for somatic variants") +@click.option("--n-rich-cohort-proportion", required=True, type=float, help="Proportion for n-rich cohort filtering") +@click.option( + "--vaf-ns-threshold", required=False, type=float, default=0.1, help="VAF of Ns threshold for filtering variants" +) +def main( + maf_df_file: str, + sample_name: str, + repetitive_variant_threshold: int, + somatic_vaf_boundary: float, + n_rich_cohort_proportion: float, + vaf_ns_threshold: float, +): """ CLI wrapper for flag_maf function. """ # Load MAF dataframe - maf_df = pd.read_csv(maf_df_file, compression='gzip', header=0, sep='\t', na_values=custom_na_values) + maf_df = pd.read_csv(maf_df_file, compression="gzip", header=0, sep="\t", na_values=custom_na_values) LOG.debug(f"{maf_df_file}") # Flag MAF file - flag_maf(maf_df, + flag_maf( + maf_df, sample_name, repetitive_variant_threshold, - somatic_vaf_boundary, + somatic_vaf_boundary, n_rich_cohort_proportion, - vaf_ns_threshold) - + vaf_ns_threshold, + ) + -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/bin/test/test_utils_filter.py b/bin/test/test_utils_filter.py new file mode 100644 index 00000000..cfb630b6 --- /dev/null +++ b/bin/test/test_utils_filter.py @@ -0,0 +1,554 @@ +#!/usr/bin/env python3 +""" +Unit tests for utils_filter.py. + +Covers: + - somatic_mask / germline_mask (each and combined) + - filter_maf (all criterion branches) + - load_filter_criteria (parsing, combining, prefix stripping) + - expand_filter_column (boolean column creation + required columns) + - extract_flagged_regions_bed (empty and non-empty BED output) +""" + +import os +import sys +import tempfile +import unittest +from pathlib import Path + +import pandas as pd + +# Add the bin directory to the path to import sibling modules +sys.path.insert(0, str(Path(__file__).parent.parent)) +from utils_filter import ( + expand_filter_column, + extract_flagged_regions_bed, + filter_maf, + germline_mask, + load_filter_criteria, + somatic_mask, +) + +THRESHOLD = 0.3 + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_df(rows: list[tuple[float, float, float]]) -> pd.DataFrame: + """Build a minimal MAF DataFrame with VAF, vd_VAF, and VAF_AM columns.""" + vafs, vd_vafs, vaf_ams = zip(*rows) + return pd.DataFrame({"VAF": list(vafs), "vd_VAF": list(vd_vafs), "VAF_AM": list(vaf_ams)}) + + +def _make_maf(rows: list[dict]) -> pd.DataFrame: + """Build a MAF DataFrame from a list of row dicts, preserving column order.""" + return pd.DataFrame(rows) + + +# --------------------------------------------------------------------------- +# somatic_mask +# --------------------------------------------------------------------------- + + +class TestSomaticMask(unittest.TestCase): + """Tests for somatic_mask(maf_df, threshold).""" + + def test_all_below_threshold_is_somatic(self): + """All three VAF columns strictly below threshold → somatic True.""" + df = _make_df([(0.1, 0.2, 0.05)]) + result = somatic_mask(df, THRESHOLD) + self.assertListEqual(result.tolist(), [True]) + + def test_all_above_threshold_is_not_somatic(self): + """All three VAF columns strictly above threshold → somatic False.""" + df = _make_df([(0.5, 0.6, 0.4)]) + result = somatic_mask(df, THRESHOLD) + self.assertListEqual(result.tolist(), [False]) + + def test_boundary_equality_is_somatic(self): + """All three VAF columns exactly equal to threshold → somatic True (≤ is inclusive).""" + df = _make_df([(THRESHOLD, THRESHOLD, THRESHOLD)]) + result = somatic_mask(df, THRESHOLD) + self.assertListEqual(result.tolist(), [True]) + + def test_asymmetric_is_not_somatic(self): + """Mixed columns (some ≤ threshold, some > threshold) → somatic False.""" + df = _make_df([(0.1, 0.1, 0.5)]) + result = somatic_mask(df, THRESHOLD) + self.assertListEqual(result.tolist(), [False]) + + def test_multiple_rows(self): + """Full four-case matrix in a single DataFrame.""" + df = _make_df( + [ + (0.1, 0.2, 0.05), # all-below → True + (0.5, 0.6, 0.4), # all-above → False + (THRESHOLD, THRESHOLD, THRESHOLD), # boundary → True + (0.1, 0.1, 0.5), # asymmetric → False + ] + ) + result = somatic_mask(df, THRESHOLD) + self.assertListEqual(result.tolist(), [True, False, True, False]) + + +# --------------------------------------------------------------------------- +# germline_mask +# --------------------------------------------------------------------------- + + +class TestGermlineMask(unittest.TestCase): + """Tests for germline_mask(maf_df, threshold).""" + + def test_all_below_threshold_is_not_germline(self): + """All three VAF columns strictly below threshold → germline False.""" + df = _make_df([(0.1, 0.2, 0.05)]) + result = germline_mask(df, THRESHOLD) + self.assertListEqual(result.tolist(), [False]) + + def test_all_above_threshold_is_germline(self): + """All three VAF columns strictly above threshold → germline True.""" + df = _make_df([(0.5, 0.6, 0.4)]) + result = germline_mask(df, THRESHOLD) + self.assertListEqual(result.tolist(), [True]) + + def test_boundary_equality_is_not_germline(self): + """All three VAF columns exactly equal to threshold → germline False (> is exclusive).""" + df = _make_df([(THRESHOLD, THRESHOLD, THRESHOLD)]) + result = germline_mask(df, THRESHOLD) + self.assertListEqual(result.tolist(), [False]) + + def test_asymmetric_is_not_germline(self): + """Mixed columns (some ≤ threshold, some > threshold) → germline False.""" + df = _make_df([(0.1, 0.1, 0.5)]) + result = germline_mask(df, THRESHOLD) + self.assertListEqual(result.tolist(), [False]) + + def test_multiple_rows(self): + """Full four-case matrix in a single DataFrame.""" + df = _make_df( + [ + (0.1, 0.2, 0.05), # all-below → False + (0.5, 0.6, 0.4), # all-above → True + (THRESHOLD, THRESHOLD, THRESHOLD), # boundary → False + (0.1, 0.1, 0.5), # asymmetric → False + ] + ) + result = germline_mask(df, THRESHOLD) + self.assertListEqual(result.tolist(), [False, True, False, False]) + + +# --------------------------------------------------------------------------- +# somatic + germline masks are never simultaneously True +# --------------------------------------------------------------------------- + + +class TestMasksAreNotComplements(unittest.TestCase): + """Verify that somatic and germline masks are never simultaneously True.""" + + def test_no_row_is_true_in_both_masks(self): + """No row should satisfy both somatic and germline conditions at once.""" + df = _make_df( + [ + (0.1, 0.2, 0.05), # all-below + (0.5, 0.6, 0.4), # all-above + (THRESHOLD, THRESHOLD, THRESHOLD), # boundary + (0.1, 0.1, 0.5), # asymmetric + ] + ) + somatic = somatic_mask(df, THRESHOLD) + germline = germline_mask(df, THRESHOLD) + both_true = (somatic & germline).tolist() + self.assertListEqual(both_true, [False, False, False, False]) + + def test_asymmetric_row_is_false_in_both_masks(self): + """An asymmetric row must be False in somatic AND False in germline (the 'neither' case).""" + df = _make_df([(0.1, 0.1, 0.5)]) + self.assertFalse(somatic_mask(df, THRESHOLD).iloc[0]) + self.assertFalse(germline_mask(df, THRESHOLD).iloc[0]) + + +# --------------------------------------------------------------------------- +# filter_maf +# --------------------------------------------------------------------------- + + +class TestFilterMaf(unittest.TestCase): + """Tests for filter_maf(maf_df, filter_criteria).""" + + def _base_maf(self) -> pd.DataFrame: + """Return a small MAF DataFrame exercising all criterion branches.""" + return _make_maf( + [ + {"MUT_ID": "M1", "VAF": 0.1, "DEPTH": 50, "FILTER": "PASS", "TYPE": "SNV", "FILTER.not_covered": False}, + { + "MUT_ID": "M2", + "VAF": 0.4, + "DEPTH": 30, + "FILTER": "n_rich;NM20", + "TYPE": "SNV", + "FILTER.not_covered": False, + }, + { + "MUT_ID": "M3", + "VAF": 0.2, + "DEPTH": 60, + "FILTER": "low_mappability", + "TYPE": "INDEL", + "FILTER.not_covered": True, + }, + { + "MUT_ID": "M4", + "VAF": 0.05, + "DEPTH": 80, + "FILTER": "PASS", + "TYPE": "SNV", + "FILTER.not_covered": False, + }, + ] + ) + + # --- numeric operator branch (len(operator)==2) --- + + def test_numeric_le_filters_correctly(self): + """('VAF', 'le 0.3') keeps only rows with VAF ≤ 0.3.""" + df = self._base_maf() + result = filter_maf(df, [("VAF", "le 0.3")]) + self.assertListEqual(sorted(result["MUT_ID"].tolist()), ["M1", "M3", "M4"]) + + def test_numeric_ge_filters_correctly(self): + """('DEPTH', 'ge 50') keeps only rows with DEPTH ≥ 50.""" + df = self._base_maf() + result = filter_maf(df, [("DEPTH", "ge 50")]) + self.assertListEqual(sorted(result["MUT_ID"].tolist()), ["M1", "M3", "M4"]) + + def test_numeric_lt_filters_correctly(self): + """('VAF', 'lt 0.2') keeps only rows with VAF < 0.2.""" + df = self._base_maf() + result = filter_maf(df, [("VAF", "lt 0.2")]) + self.assertListEqual(sorted(result["MUT_ID"].tolist()), ["M1", "M4"]) + + def test_multiple_numeric_criteria_are_anded(self): + """Combining two numeric criteria narrows the result set.""" + df = self._base_maf() + result = filter_maf(df, [("VAF", "le 0.3"), ("DEPTH", "ge 50")]) + self.assertListEqual(sorted(result["MUT_ID"].tolist()), ["M1", "M3", "M4"]) + + # --- notcontains / contains branch --- + + def test_notcontains_excludes_matching_filter_token(self): + """('FILTER', 'notcontains n_rich') removes rows whose FILTER cell contains 'n_rich'.""" + df = self._base_maf() + result = filter_maf(df, [("FILTER", "notcontains n_rich")]) + # M2 has 'n_rich' in its FILTER → removed + self.assertNotIn("M2", result["MUT_ID"].tolist()) + self.assertIn("M1", result["MUT_ID"].tolist()) + + def test_notcontains_respects_semicolon_split(self): + """notcontains splits on ';' so a token that is a prefix of another is handled correctly.""" + df = _make_maf( + [ + {"MUT_ID": "A", "FILTER": "NM20;PASS"}, + {"MUT_ID": "B", "FILTER": "NM200;PASS"}, # 'NM20' is NOT a token here + {"MUT_ID": "C", "FILTER": "PASS"}, + ] + ) + result = filter_maf(df, [("FILTER", "notcontains NM20")]) + self.assertNotIn("A", result["MUT_ID"].tolist()) + self.assertIn("B", result["MUT_ID"].tolist()) + self.assertIn("C", result["MUT_ID"].tolist()) + + def test_contains_keeps_only_matching_filter_token(self): + """('FILTER', 'contains n_rich') keeps only rows whose FILTER cell contains 'n_rich'.""" + df = self._base_maf() + result = filter_maf(df, [("FILTER", "contains n_rich")]) + self.assertListEqual(result["MUT_ID"].tolist(), ["M2"]) + + # --- boolean column branch --- + + def test_boolean_criterion_true_selects_matching_rows(self): + """('FILTER.not_covered', True) keeps only rows where FILTER.not_covered is True.""" + df = self._base_maf() + result = filter_maf(df, [("FILTER.not_covered", True)]) + self.assertListEqual(result["MUT_ID"].tolist(), ["M3"]) + + def test_boolean_criterion_false_selects_matching_rows(self): + """('FILTER.not_covered', False) keeps only rows where FILTER.not_covered is False.""" + df = self._base_maf() + result = filter_maf(df, [("FILTER.not_covered", False)]) + self.assertListEqual(sorted(result["MUT_ID"].tolist()), ["M1", "M2", "M4"]) + + # --- plain-value (equality) branch --- + + def test_plain_value_equality_match(self): + """('TYPE', 'SNV') keeps only rows where TYPE == 'SNV'.""" + df = self._base_maf() + result = filter_maf(df, [("TYPE", "SNV")]) + self.assertListEqual(sorted(result["MUT_ID"].tolist()), ["M1", "M2", "M4"]) + + def test_plain_value_no_match_returns_empty(self): + """('TYPE', 'NONEXISTENT') returns an empty DataFrame.""" + df = self._base_maf() + result = filter_maf(df, [("TYPE", "NONEXISTENT")]) + self.assertEqual(len(result), 0) + + # --- no criteria leaves DataFrame unchanged --- + + def test_empty_criteria_returns_all_rows(self): + """An empty criteria list leaves the DataFrame unchanged.""" + df = self._base_maf() + result = filter_maf(df, []) + self.assertEqual(len(result), len(df)) + + +# --------------------------------------------------------------------------- +# load_filter_criteria +# --------------------------------------------------------------------------- + + +class TestLoadFilterCriteria(unittest.TestCase): + """Tests for load_filter_criteria(filters, somatic_filters).""" + + def test_extracts_notcontains_entries_from_filters(self): + """Items starting with 'notcontains ' in filters are returned with prefix stripped.""" + result = load_filter_criteria("notcontains n_rich,notcontains NM20", "") + self.assertListEqual(sorted(result), ["NM20", "n_rich"]) + + def test_extracts_notcontains_entries_from_somatic_filters(self): + """Items from somatic_filters starting with 'notcontains ' are included.""" + result = load_filter_criteria("", "notcontains low_mappability") + self.assertListEqual(result, ["low_mappability"]) + + def test_combines_both_lists(self): + """Entries from both arguments are merged before filtering.""" + result = load_filter_criteria("notcontains n_rich", "notcontains NM20") + self.assertListEqual(sorted(result), ["NM20", "n_rich"]) + + def test_non_notcontains_entries_are_excluded(self): + """Entries that do not start with 'notcontains ' are silently dropped.""" + result = load_filter_criteria("notcontains n_rich,VAF le 0.3,PASS", "") + self.assertListEqual(result, ["n_rich"]) + + def test_empty_strings_return_empty_list(self): + """Both arguments being empty strings yields an empty list.""" + result = load_filter_criteria("", "") + self.assertListEqual(result, []) + + def test_whitespace_is_trimmed_around_items(self): + """Leading/trailing whitespace around comma-separated items is stripped.""" + result = load_filter_criteria(" notcontains n_rich , notcontains NM20 ", "") + self.assertListEqual(sorted(result), ["NM20", "n_rich"]) + + def test_duplicate_entries_are_preserved(self): + """Duplicates across both arguments are preserved (no deduplication contract).""" + result = load_filter_criteria("notcontains n_rich", "notcontains n_rich") + self.assertEqual(result.count("n_rich"), 2) + + +# --------------------------------------------------------------------------- +# expand_filter_column +# --------------------------------------------------------------------------- + + +class TestExpandFilterColumn(unittest.TestCase): + """Tests for expand_filter_column(maf_df).""" + + def _make_filter_df(self, filter_values: list[str]) -> pd.DataFrame: + """Build a minimal MAF DataFrame with only a FILTER column.""" + return pd.DataFrame({"FILTER": filter_values}) + + def test_creates_boolean_column_for_each_token(self): + """Each unique ';'-delimited token gets its own FILTER. boolean column.""" + df = self._make_filter_df(["n_rich;NM20", "PASS", "NM20"]) + result = expand_filter_column(df) + self.assertIn("FILTER.n_rich", result.columns) + self.assertIn("FILTER.NM20", result.columns) + self.assertIn("FILTER.PASS", result.columns) + + def test_boolean_values_are_correct(self): + """True only where the token is present in that row's FILTER value.""" + df = self._make_filter_df(["n_rich;NM20", "PASS", "NM20"]) + result = expand_filter_column(df) + # Row 0: n_rich and NM20 present + self.assertTrue(result.loc[0, "FILTER.n_rich"]) + self.assertTrue(result.loc[0, "FILTER.NM20"]) + self.assertFalse(result.loc[0, "FILTER.PASS"]) + # Row 1: only PASS present + self.assertFalse(result.loc[1, "FILTER.n_rich"]) + self.assertTrue(result.loc[1, "FILTER.PASS"]) + # Row 2: only NM20 present + self.assertFalse(result.loc[2, "FILTER.n_rich"]) + self.assertTrue(result.loc[2, "FILTER.NM20"]) + + def test_required_columns_always_exist(self): + """FILTER.not_covered and FILTER.not_in_exons are always created even if absent in data.""" + df = self._make_filter_df(["PASS", "PASS"]) + result = expand_filter_column(df) + self.assertIn("FILTER.not_covered", result.columns) + self.assertIn("FILTER.not_in_exons", result.columns) + + def test_required_columns_are_false_when_token_absent(self): + """Required columns are all False when neither token appears in the data.""" + df = self._make_filter_df(["PASS", "n_rich"]) + result = expand_filter_column(df) + self.assertFalse(result["FILTER.not_covered"].any()) + self.assertFalse(result["FILTER.not_in_exons"].any()) + + def test_single_token_per_row(self): + """A FILTER column with no semicolons creates one boolean column per distinct value.""" + df = self._make_filter_df(["alpha", "beta", "alpha"]) + result = expand_filter_column(df) + self.assertIn("FILTER.alpha", result.columns) + self.assertIn("FILTER.beta", result.columns) + self.assertListEqual(result["FILTER.alpha"].tolist(), [True, False, True]) + self.assertListEqual(result["FILTER.beta"].tolist(), [False, True, False]) + + def test_all_required_columns_true_when_token_present(self): + """FILTER.not_covered is True exactly for the rows that contain 'not_covered'.""" + df = self._make_filter_df(["not_covered;n_rich", "PASS", "not_covered"]) + result = expand_filter_column(df) + self.assertListEqual(result["FILTER.not_covered"].tolist(), [True, False, True]) + + +# --------------------------------------------------------------------------- +# extract_flagged_regions_bed +# --------------------------------------------------------------------------- + + +class TestExtractFlaggedRegionsBed(unittest.TestCase): + """Tests for extract_flagged_regions_bed(maf_df, name, filters, specification).""" + + def setUp(self): + """Switch into a fresh temporary directory for each test; restore on teardown.""" + self._tmpdir = tempfile.mkdtemp() + self._orig_dir = os.getcwd() + os.chdir(self._tmpdir) + + def tearDown(self): + """Restore original working directory.""" + os.chdir(self._orig_dir) + + def _make_expanded_maf(self, rows: list[dict]) -> pd.DataFrame: + """Build a MAF with CHROM/POS columns then run expand_filter_column.""" + df = pd.DataFrame(rows) + return expand_filter_column(df) + + # --- empty case --- + + def test_empty_case_creates_empty_bed_file(self): + """No flagged rows → an empty .bed file is touched and the function returns None.""" + df = self._make_expanded_maf( + [ + {"CHROM": "chr1", "POS": 100, "FILTER": "PASS"}, + {"CHROM": "chr1", "POS": 200, "FILTER": "PASS"}, + ] + ) + result = extract_flagged_regions_bed(df, "sample1", ["n_rich"]) + self.assertIsNone(result) + bed_path = Path("sample1.flagged-pos.bed") + self.assertTrue(bed_path.exists()) + self.assertEqual(bed_path.stat().st_size, 0) + + def test_empty_case_with_specification_uses_correct_filename(self): + """specification parameter is included in the BED file name for the empty case.""" + df = self._make_expanded_maf([{"CHROM": "chr1", "POS": 100, "FILTER": "PASS"}]) + extract_flagged_regions_bed(df, "sample1", ["n_rich"], specification="cohort-") + self.assertTrue(Path("sample1.cohort-flagged-pos.bed").exists()) + + def test_empty_case_no_matching_filter_columns(self): + """When filter names have no corresponding FILTER.* columns, result is empty BED.""" + df = self._make_expanded_maf([{"CHROM": "chr1", "POS": 100, "FILTER": "PASS"}]) + result = extract_flagged_regions_bed(df, "sampleX", ["nonexistent_filter"]) + self.assertIsNone(result) + self.assertTrue(Path("sampleX.flagged-pos.bed").exists()) + + # --- non-empty case --- + + def test_nonempty_case_writes_bed_with_correct_columns(self): + """BED file has four tab-separated columns: CHROM, START, END, FILTERS.""" + df = self._make_expanded_maf( + [ + {"CHROM": "chr1", "POS": 500, "FILTER": "n_rich"}, + {"CHROM": "chr2", "POS": 1000, "FILTER": "PASS"}, + ] + ) + extract_flagged_regions_bed(df, "sample2", ["n_rich"]) + bed_path = Path("sample2.flagged-pos.bed") + self.assertTrue(bed_path.exists()) + bed = pd.read_csv(bed_path, sep="\t", header=None, names=["CHROM", "START", "END", "FILTERS"]) + self.assertEqual(len(bed), 1) + row = bed.iloc[0] + self.assertEqual(row["CHROM"], "chr1") + self.assertEqual(row["START"], 500) + self.assertEqual(row["END"], 500) + self.assertIn("FILTER.n_rich", row["FILTERS"]) + + def test_nonempty_case_multiple_filters_joined_with_comma(self): + """When a position has two active filter flags, FILTERS column is comma-joined.""" + df = self._make_expanded_maf( + [ + {"CHROM": "chr1", "POS": 300, "FILTER": "n_rich;NM20"}, + ] + ) + extract_flagged_regions_bed(df, "sample3", ["n_rich", "NM20"]) + bed = pd.read_csv( + Path("sample3.flagged-pos.bed"), sep="\t", header=None, names=["CHROM", "START", "END", "FILTERS"] + ) + self.assertEqual(len(bed), 1) + filters_value = bed.iloc[0]["FILTERS"] + # Both filter column names should appear, joined by comma + self.assertIn("FILTER.n_rich", filters_value) + self.assertIn("FILTER.NM20", filters_value) + self.assertIn(",", filters_value) + + def test_nonempty_case_multiple_rows_all_written(self): + """Multiple flagged positions each produce a row in the BED file.""" + df = self._make_expanded_maf( + [ + {"CHROM": "chr1", "POS": 100, "FILTER": "n_rich"}, + {"CHROM": "chr1", "POS": 200, "FILTER": "NM20"}, + {"CHROM": "chr2", "POS": 50, "FILTER": "PASS"}, + ] + ) + extract_flagged_regions_bed(df, "sample4", ["n_rich", "NM20"]) + bed = pd.read_csv( + Path("sample4.flagged-pos.bed"), sep="\t", header=None, names=["CHROM", "START", "END", "FILTERS"] + ) + self.assertEqual(len(bed), 2) + self.assertSetEqual(set(bed["START"].tolist()), {100, 200}) + + def test_nonempty_case_returns_none(self): + """Function has no explicit return in the non-empty path, so returns None.""" + df = self._make_expanded_maf([{"CHROM": "chr1", "POS": 100, "FILTER": "n_rich"}]) + result = extract_flagged_regions_bed(df, "sample5", ["n_rich"]) + self.assertIsNone(result) + + def test_nonempty_case_with_specification_uses_correct_filename(self): + """specification parameter is included in the BED file name for the non-empty case.""" + df = self._make_expanded_maf([{"CHROM": "chr1", "POS": 100, "FILTER": "n_rich"}]) + extract_flagged_regions_bed(df, "sample6", ["n_rich"], specification="cohort-") + self.assertTrue(Path("sample6.cohort-flagged-pos.bed").exists()) + + def test_nonempty_case_bed_is_sorted_by_chrom_and_pos(self): + """BED rows are sorted by CHROM then POS (ascending).""" + df = self._make_expanded_maf( + [ + {"CHROM": "chr2", "POS": 800, "FILTER": "n_rich"}, + {"CHROM": "chr1", "POS": 999, "FILTER": "n_rich"}, + {"CHROM": "chr1", "POS": 100, "FILTER": "n_rich"}, + ] + ) + extract_flagged_regions_bed(df, "sample7", ["n_rich"]) + bed = pd.read_csv( + Path("sample7.flagged-pos.bed"), sep="\t", header=None, names=["CHROM", "START", "END", "FILTERS"] + ) + self.assertEqual(len(bed), 3) + self.assertEqual(bed.iloc[0]["CHROM"], "chr1") + self.assertEqual(bed.iloc[0]["START"], 100) + self.assertEqual(bed.iloc[1]["START"], 999) + self.assertEqual(bed.iloc[2]["CHROM"], "chr2") + + +if __name__ == "__main__": + unittest.main() diff --git a/bin/utils_filter.py b/bin/utils_filter.py index bfd201be..3209ae36 100644 --- a/bin/utils_filter.py +++ b/bin/utils_filter.py @@ -1,15 +1,18 @@ #!/usr/bin/env python import logging -import pandas as pd from pathlib import Path + +import pandas as pd + """ Utility functions for extracting filters from a MAF DataFrame. """ LOG = logging.getLogger(__name__) + def filter_maf(maf_df, filter_criteria): - ''' + """ Filter a MAF dataframe with filtering information coming from a list of tuples. This can be either a dictionary transformed to list with the .items() method or by directly creating a list of tuples. [('VAF', 'le 0.3'), ('VAF_AM', 'le 0.3'), ('vd_VAF', 'le 0.3'), @@ -17,76 +20,127 @@ def filter_maf(maf_df, filter_criteria): ('FILTER', 'notcontains cohort_n_rich_uni'), ('FILTER', 'notcontains NM20'), ('FILTER', 'notcontains no_pileup_support'), ('FILTER', 'notcontains other_sample_SNP'), ('FILTER', 'notcontains low_mappability')] - ''' + """ # Define mappings for operators used in criteria operators = { - 'eq': lambda x, y: x == y, - 'ne': lambda x, y: x != y, - 'lt': lambda x, y: x < y, - 'le': lambda x, y: x <= y, - 'gt': lambda x, y: x > y, - 'ge': lambda x, y: x >= y, - 'not': lambda x, y: x != y, - 'notcontains': lambda x, y: x.apply(lambda z : y not in z.split(";")), # (~maf_df["FILTER"].str.contains("not_in_panel")) - 'contains': lambda x, y: x.apply(lambda z : y in z.split(";")) + "eq": lambda x, y: x == y, + "ne": lambda x, y: x != y, + "lt": lambda x, y: x < y, + "le": lambda x, y: x <= y, + "gt": lambda x, y: x > y, + "ge": lambda x, y: x >= y, + "not": lambda x, y: x != y, + "notcontains": lambda x, y: x.apply( + lambda z: y not in z.split(";") + ), # (~maf_df["FILTER"].str.contains("not_in_panel")) + "contains": lambda x, y: x.apply(lambda z: y in z.split(";")), } # Apply filters based on criteria from the JSON file for col, criterion in filter_criteria: - if isinstance(criterion, bool): pref_len = maf_df.shape[0] maf_df = maf_df[maf_df[col] == criterion] - print(f"Applying {col}:{criterion} filter implied going from {pref_len} mutations to {maf_df.shape[0]} mutations.") + print( + f"Applying {col}:{criterion} filter implied going from {pref_len} mutations to {maf_df.shape[0]} mutations." + ) - elif ' ' in criterion: + elif " " in criterion: operator, value = criterion.split(maxsplit=1) if len(operator) == 2 and operator in operators: # 'VAF' : 'le 0.35' pref_len = maf_df.shape[0] maf_df = maf_df[operators[operator](maf_df[col], float(value))] - print(f"Applying {col}:{criterion} filter implied going from {pref_len} mutations to {maf_df.shape[0]} mutations.") + print( + f"Applying {col}:{criterion} filter implied going from {pref_len} mutations to {maf_df.shape[0]} mutations." + ) elif operator in operators: # 'FILTER' : 'notcontains n_rich', pref_len = maf_df.shape[0] maf_df = maf_df[operators[operator](maf_df[col], value)] - print(f"Applying {col}:{criterion} filter implied going from {pref_len} mutations to {maf_df.shape[0]} mutations.") + print( + f"Applying {col}:{criterion} filter implied going from {pref_len} mutations to {maf_df.shape[0]} mutations." + ) else: print(f"We have no filtering criteria defined for {col}:{criterion} filter.") - else: # 'TYPE' : 'SNV' pref_len = maf_df.shape[0] maf_df = maf_df[maf_df[col] == criterion] - print(f"Applying {col}:{criterion} filter implied going from {pref_len} mutations to {maf_df.shape[0]} mutations.") + print( + f"Applying {col}:{criterion} filter implied going from {pref_len} mutations to {maf_df.shape[0]} mutations." + ) return maf_df + +def somatic_mask(maf_df: pd.DataFrame, threshold: float) -> pd.Series: + """ + Return a boolean mask identifying somatic variants. + + Parameters + ---------- + maf_df : pd.DataFrame + MAF dataframe containing at least the columns ``VAF``, ``vd_VAF``, and + ``VAF_AM``. + threshold : float + Upper bound (inclusive) on VAF for a variant to be called somatic. + + Returns + ------- + pd.Series + Boolean series with the same index as *maf_df*; ``True`` where the + variant is somatic. + """ + return (maf_df["VAF"] <= threshold) & (maf_df["vd_VAF"] <= threshold) & (maf_df["VAF_AM"] <= threshold) + + +def germline_mask(maf_df: pd.DataFrame, threshold: float) -> pd.Series: + """ + Return a boolean mask identifying germline variants. + + Parameters + ---------- + maf_df : pd.DataFrame + MAF dataframe containing at least the columns ``VAF``, ``vd_VAF``, and + ``VAF_AM``. + threshold : float + Lower bound (exclusive) on VAF for a variant to be called germline. + + Returns + ------- + pd.Series + Boolean series with the same index as *maf_df*; ``True`` where the + variant is germline. + """ + return (maf_df["VAF"] > threshold) & (maf_df["vd_VAF"] > threshold) & (maf_df["VAF_AM"] > threshold) + + def load_filter_criteria(filters: str, somatic_filters: str) -> list[str]: """ Parse filter criteria from comma-separated strings. - + Parameters ---------- filters : str Comma-separated list of filter criteria somatic_filters : str Comma-separated list of somatic filter criteria - + Returns ------- list[str] List of filter names to apply """ # Parse comma-separated strings into lists - filter_list = [f.strip() for f in filters.split(',') if f.strip()] - somatic_filter_list = [f.strip() for f in somatic_filters.split(',') if f.strip()] - + filter_list = [f.strip() for f in filters.split(",") if f.strip()] + somatic_filter_list = [f.strip() for f in somatic_filters.split(",") if f.strip()] + # Combine both lists all_filters = filter_list + somatic_filter_list @@ -95,21 +149,22 @@ def load_filter_criteria(filters: str, somatic_filters: str) -> list[str]: LOG.info(f"Loaded {len(result)} filter criteria: {result}") return result + def expand_filter_column(maf_df: pd.DataFrame) -> pd.DataFrame: """ Expands the FILTER column by creating new columns for each unique filter. Each new column indicates if the corresponding filter is present (True/False). """ # Split FILTER column once per row and convert to set for O(1) lookup - filter_sets = maf_df["FILTER"].str.split(";").apply(lambda x: set(x) if x != [''] else set()) - + filter_sets = maf_df["FILTER"].str.split(";").apply(lambda x: set(x) if x != [""] else set()) + # Get all unique filter values (excluding empty strings) all_filters = set( - filter_val - for filter_val in maf_df["FILTER"].str.split(";").explode().unique() - if filter_val and filter_val != '' + filter_val + for filter_val in maf_df["FILTER"].str.split(";").explode().unique() + if filter_val and filter_val != "" ) - + # Ensure "not_covered" and "not_in_exons" exist required_filters = {"not_covered", "not_in_exons"} all_filters.update(required_filters) @@ -120,7 +175,10 @@ def expand_filter_column(maf_df: pd.DataFrame) -> pd.DataFrame: return maf_df -def extract_flagged_regions_bed(maf_df: pd.DataFrame, name: str, FILTERS: list[str], specification: str = "") -> pd.DataFrame | None: + +def extract_flagged_regions_bed( + maf_df: pd.DataFrame, name: str, filters: list[str], specification: str = "" +) -> pd.DataFrame | None: """ Returns a BED file with the regions discarded, including the list of filters applied to each mutation. Creates a properly formatted BED file with 0-based coordinates and half-open intervals. @@ -131,7 +189,7 @@ def extract_flagged_regions_bed(maf_df: pd.DataFrame, name: str, FILTERS: list[s Input MAF dataframe with filter columns. POS column should contain 1-based coordinates. name : str Sample name to be used in the output BED file name. - FILTERS : list[str] + filters : list[str] List of filter criteria to check for in the MAF dataframe. specification : str, optional Additional string to include in the output BED file name (e.g., "cohort-"), by default "". @@ -143,7 +201,7 @@ def extract_flagged_regions_bed(maf_df: pd.DataFrame, name: str, FILTERS: list[s Output coordinates are 0-based with half-open intervals [start, end). """ # List of filter columns you want to check for - filter_columns = [f"FILTER.{f}" for f in FILTERS if f"FILTER.{f}" in maf_df.columns] + filter_columns = [f"FILTER.{f}" for f in filters if f"FILTER.{f}" in maf_df.columns] maf_df_filters = maf_df[maf_df[filter_columns].any(axis=1)] if filter_columns else pd.DataFrame() @@ -157,24 +215,20 @@ def extract_flagged_regions_bed(maf_df: pd.DataFrame, name: str, FILTERS: list[s bed_df = maf_df_filters[["CHROM", "POS"] + filter_columns] # Transform to long format - _bed_melt = (pd.melt(bed_df, - id_vars=["CHROM", "POS"], - value_vars=filter_columns, - var_name="FILTERS") - .query("value == True") - ) + _bed_melt = pd.melt(bed_df, id_vars=["CHROM", "POS"], value_vars=filter_columns, var_name="FILTERS").query( + "value == True" + ) LOG.info("Mutations flagged: %s", _bed_melt.shape[0]) # Aggregate filters per position bed_annotated = ( - _bed_melt - .drop_duplicates() - .sort_values(by=["CHROM", "POS"]) - .groupby(["CHROM","POS"])["FILTERS"] - .agg(','.join) - .reset_index() - .rename(columns={"POS": "START"}) + _bed_melt.drop_duplicates() + .sort_values(by=["CHROM", "POS"]) + .groupby(["CHROM", "POS"])["FILTERS"] + .agg(",".join) + .reset_index() + .rename(columns={"POS": "START"}) ) # The idea is to filter depth files at these positions, so make END = START (1-based) @@ -183,6 +237,8 @@ def extract_flagged_regions_bed(maf_df: pd.DataFrame, name: str, FILTERS: list[s LOG.info("Unique regions flagged: %s", bed_annotated.shape[0]) # Write BED file without header or index - (bed_annotated[["CHROM", "START", "END", "FILTERS"]] - .to_csv(f"{name}.{specification}flagged-pos.bed", sep="\t", header=False, index=False) - ) \ No newline at end of file + ( + bed_annotated[["CHROM", "START", "END", "FILTERS"]].to_csv( + f"{name}.{specification}flagged-pos.bed", sep="\t", header=False, index=False + ) + ) diff --git a/conf/modules.config b/conf/modules.config index 9d3dbe0b..ea499691 100644 --- a/conf/modules.config +++ b/conf/modules.config @@ -357,6 +357,10 @@ process { ext.prop_samples_nrich = params.prop_samples_nrich } + withName: CONTAMINATION { + ext.germline_threshold = params.germline_threshold + } + withName: "TABLE2GROUP" { ext.unique_identifier = params.features_unique_identifier ext.feature_groups = params.features_groups_list diff --git a/modules/local/contamination/main.nf b/modules/local/contamination/main.nf index 7185b5d4..84dc1de0 100644 --- a/modules/local/contamination/main.nf +++ b/modules/local/contamination/main.nf @@ -16,10 +16,12 @@ process COMPUTE_CONTAMINATION { path "versions.yml" , topic: versions script: + def somatic_vaf_boundary = task.ext.germline_threshold ? "--somatic-vaf-boundary ${task.ext.germline_threshold}" : "" """ check_contamination.py \\ --maf_path ${maf} \\ - --somatic_maf ${somatic_maf} + --somatic_maf ${somatic_maf} \\ + ${somatic_vaf_boundary} cat <<-END_VERSIONS > versions.yml "${task.process}":