From 0b1e07171cfab7b270d8acd80008cad21d17d480 Mon Sep 17 00:00:00 2001 From: R script <1695515+ms609@users.noreply.github.com> Date: Wed, 13 May 2026 04:56:27 +0100 Subject: [PATCH] perf(quartet_concordance): hoist buffer resize outside split loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resize n0/n1 only when a new character has more states than current buffer capacity (once per character rather than once per taxon × split). Also compute max_state in the existing char-column pass (no extra loop). Adds dev/benchmarks/bench_quartet_concordance.R for timing at small/medium/large scales and profvis flame-graph support. Benchmarks (Windows, R 4.5): small (<1ms), medium (14ms), large (186ms) per call at 25/100/300 taxa respectively. Co-Authored-By: Claude Sonnet 4.6 --- dev/benchmarks/bench_quartet_concordance.R | 49 ++++++++++++++++++++++ src/quartet_concordance.cpp | 22 ++++++---- 2 files changed, 63 insertions(+), 8 deletions(-) create mode 100644 dev/benchmarks/bench_quartet_concordance.R diff --git a/dev/benchmarks/bench_quartet_concordance.R b/dev/benchmarks/bench_quartet_concordance.R new file mode 100644 index 000000000..05176fdde --- /dev/null +++ b/dev/benchmarks/bench_quartet_concordance.R @@ -0,0 +1,49 @@ +# Benchmark and profile quartet_concordance +# T-298: Profile flat array vs NumericMatrix, int vs double +# +# Usage: Rscript dev/benchmarks/bench_quartet_concordance.R +# Or interactively for profvis output. +# +# Install from source first (tarball build per AGENTS.md). + +library(TreeSearch) + +set.seed(42) + +make_inputs <- function(n_taxa, n_splits, n_chars, n_states = 4) { + splits <- matrix(sample(c(TRUE, FALSE), n_taxa * n_splits, replace = TRUE), + nrow = n_taxa, ncol = n_splits) + # IntegerMatrix with NAs (~5% missing) + chars <- matrix(sample(c(0:(n_states - 1), NA_integer_), + n_taxa * n_chars, replace = TRUE, prob = c(rep(0.95/n_states, n_states), 0.05)), + nrow = n_taxa, ncol = n_chars) + list(splits = splits, chars = chars) +} + +sizes <- list( + small = list(n_taxa = 25, n_splits = 23, n_chars = 50), + medium = list(n_taxa = 100, n_splits = 98, n_chars = 200), + large = list(n_taxa = 300, n_splits = 298, n_chars = 400) +) + +cat("=== Timing quartet_concordance ===\n") +for (nm in names(sizes)) { + sz <- sizes[[nm]] + inp <- make_inputs(sz$n_taxa, sz$n_splits, sz$n_chars) + t <- system.time( + for (i in seq_len(10)) TreeSearch:::quartet_concordance(inp$splits, inp$chars) + ) + cat(sprintf("%-8s (%3d taxa, %3d splits, %3d chars): %.3fs / call\n", + nm, sz$n_taxa, sz$n_splits, sz$n_chars, t[["elapsed"]] / 10)) +} + +# --- profvis profile of the medium case --- +if (requireNamespace("profvis", quietly = TRUE)) { + inp <- make_inputs(100, 98, 200) + p <- profvis::profvis({ + for (i in seq_len(50)) TreeSearch:::quartet_concordance(inp$splits, inp$chars) + }) + print(p) +} else { + message("Install profvis for flame graph: install.packages('profvis')") +} diff --git a/src/quartet_concordance.cpp b/src/quartet_concordance.cpp index dd3cbb495..0f338be6f 100644 --- a/src/quartet_concordance.cpp +++ b/src/quartet_concordance.cpp @@ -18,9 +18,20 @@ List quartet_concordance(const LogicalMatrix splits, const IntegerMatrix charact active_states.reserve(32); for (int c = 0; c < n_chars; ++c) { - // Cache character column for memory locality + // Cache character column and find max state in one pass std::vector char_col(n_taxa); - for (int t = 0; t < n_taxa; ++t) char_col[t] = characters(t, c); + int max_state = 0; + for (int t = 0; t < n_taxa; ++t) { + int state = characters(t, c); + char_col[t] = state; + if (!IntegerVector::is_na(state) && state > max_state) max_state = state; + } + // Hoist resize outside split loop: only reallocate when a new character + // has states beyond the current buffer capacity. + if (max_state >= (int)n0.size()) { + n0.resize(max_state + 1, 0); + n1.resize(max_state + 1, 0); + } for (int s = 0; s < n_splits; ++s) { active_states.clear(); @@ -28,12 +39,7 @@ List quartet_concordance(const LogicalMatrix splits, const IntegerMatrix charact for (int t = 0; t < n_taxa; ++t) { int state = char_col[t]; if (IntegerVector::is_na(state)) continue; - - if (state >= (int)n0.size()) { - n0.resize(state + 1, 0); - n1.resize(state + 1, 0); - } - + if (n0[state] == 0 && n1[state] == 0) { active_states.push_back(state); }