From 04f4e25f74725ac284f7fef2e43e5c15d5720be2 Mon Sep 17 00:00:00 2001 From: Heng Ge Date: Wed, 20 May 2026 00:49:35 -0700 Subject: [PATCH 01/10] feat(mem_wal): LsmFtsSearchPlanner with Local mode + lance-index candidate APIs Adds FTS support to LsmScanner spanning base table, flushed memtable generations, and active/frozen in-memory memtables. Local scoring mode is wired end-to-end (each source uses its own BM25 stats; coordinator unions per-source plans, per-partition top-K sort + sort-preserving merge). LocalWithGlobalRescore returns a clear NotSupported error until the rescore exec lands in a follow-up. Lance-index extensions used by Rescore mode (when wired) are in place: - InvertedIndex::bm25_candidate_search returns top-K' candidates with raw (doc_len, term_freqs[input_order]) and local_score, parallel to bm25_search. - FtsMemIndex::search_candidates returns the same per-doc stats from the in-memory tail + frozen partitions. - FtsMemIndex::bm25_stats_for_terms exports segment-level (N, sumdl, df_t) so the in-memory index can feed a global scorer. Also aligns `_score` nullability on FtsIndexExec with the on-disk FTS schema so Local mode can UNION active + base/flushed without a schema mismatch. --- rust/lance-index/src/scalar/inverted/index.rs | 352 ++++++++++ rust/lance/src/dataset/mem_wal/index.rs | 2 +- rust/lance/src/dataset/mem_wal/index/fts.rs | 285 ++++++++ .../mem_wal/memtable/scanner/exec/fts.rs | 7 +- rust/lance/src/dataset/mem_wal/scanner.rs | 5 + .../src/dataset/mem_wal/scanner/builder.rs | 27 + .../src/dataset/mem_wal/scanner/fts_search.rs | 637 ++++++++++++++++++ 7 files changed, 1313 insertions(+), 2 deletions(-) create mode 100644 rust/lance/src/dataset/mem_wal/scanner/fts_search.rs diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index d08dacd26e7..13a375cb288 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -212,6 +212,46 @@ impl PartitionCandidates { } } +/// A single hit from [`InvertedIndex::bm25_candidate_search`] decorated with +/// the BM25 sufficient statistics needed to recompute the score against a +/// different (e.g. globally-aggregated) corpus. +/// +/// `term_freqs[i]` is the frequency of the caller-supplied query token at +/// input position `i`. A `0` means the term did not appear in the doc. +#[derive(Debug, Clone)] +pub struct InvertedIndexCandidate { + pub row_id: u64, + pub doc_length: u32, + pub term_freqs: Vec, + /// Score this segment assigned under the scorer that drove pruning. + pub local_score: f32, +} + +/// Heap entry used internally by [`InvertedIndex::bm25_candidate_search`] to +/// keep a bounded top-K' across partitions. Ordering is by score so a +/// min-heap (`Reverse`) prunes the lowest score on overflow. +struct ScoredCandidate { + score: crate::vector::graph::OrderedFloat, + inner: InvertedIndexCandidate, +} + +impl PartialEq for ScoredCandidate { + fn eq(&self, other: &Self) -> bool { + self.score == other.score + } +} +impl Eq for ScoredCandidate {} +impl PartialOrd for ScoredCandidate { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for ScoredCandidate { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.score.cmp(&other.score) + } +} + #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Default)] pub enum TokenSetFormat { Arrow, @@ -677,6 +717,171 @@ impl InvertedIndex { .unzip()) } + /// Per-candidate variant of [`bm25_search`] that returns the raw BM25 + /// sufficient statistics each hit was scored from, instead of just + /// `(row_id, score)`. + /// + /// Callers building a multi-segment query plan (`Local` mode + `Global` + /// rescore) need to combine per-segment top-K' results into a single + /// globally-comparable ranking. With only `(row_id, score)` they cannot + /// recompute scores against globally-aggregated stats; with + /// `(row_id, doc_length, term_freqs)` they can. + /// + /// Pruning and ranking inside this segment still use `base_scorer` (or + /// the local IndexBM25Scorer when `None`) — same selection policy as + /// [`bm25_search`], so the returned set is the same as that method would + /// return for the same `params.limit`, just decorated with sufficient + /// stats. `term_freqs` is sized to `tokens.len()` and indexed by the + /// caller's input token order; a `0` means the term did not appear in + /// the doc (e.g., because the query is OR and only some terms matched). + #[instrument(level = "debug", skip_all)] + pub async fn bm25_candidate_search( + &self, + tokens: Arc, + params: Arc, + operator: Operator, + prefilter: Arc, + metrics: Arc, + base_scorer: Option<&MemBM25Scorer>, + ) -> Result> { + let local_scorer; + let scorer: &dyn Scorer = if let Some(base_scorer) = base_scorer { + base_scorer + } else { + local_scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref())); + &local_scorer + }; + + let limit = params.limit.unwrap_or(usize::MAX); + if limit == 0 { + return Ok(Vec::new()); + } + let mask = prefilter.mask(); + let num_query_terms = tokens.len(); + + // Min-heap of size `limit` keyed by local_score so we can prune the + // smallest score in O(log limit) when a new candidate arrives. + let mut heap: BinaryHeap> = BinaryHeap::new(); + + let parts = self + .partitions + .iter() + .map(|part| { + let part = part.clone(); + let tokens = tokens.clone(); + let params = params.clone(); + let mask = mask.clone(); + let metrics = metrics.clone(); + async move { + let postings = part + .load_posting_lists(tokens.as_ref(), params.as_ref(), metrics.as_ref()) + .await?; + if postings.is_empty() { + return Result::Ok(PartitionCandidates::empty()); + } + let max_position = postings + .iter() + .map(|posting| posting.term_index() as usize) + .max() + .unwrap_or_default(); + let mut tokens_by_position = vec![String::new(); max_position + 1]; + for posting in &postings { + let idx = posting.term_index() as usize; + tokens_by_position[idx] = posting.token().to_owned(); + } + let params = params.clone(); + let mask = mask.clone(); + let metrics = metrics.clone(); + spawn_cpu(move || { + let candidates = part.bm25_search( + params.as_ref(), + operator, + mask, + postings, + metrics.as_ref(), + )?; + Ok(PartitionCandidates { + tokens_by_position, + candidates, + }) + }) + .await + } + }) + .collect::>(); + let mut parts = stream::iter(parts).buffer_unordered(get_num_compute_intensive_cpus()); + let mut idf_cache: HashMap = HashMap::new(); + while let Some(res) = parts.try_next().await? { + if res.candidates.is_empty() { + continue; + } + // Map this partition's `term_index` → the caller's query-token + // position. `tokens_by_position` is partition-local; reorder so + // returned `term_freqs[input_idx]` matches the caller's input. + let term_index_to_input_idx: Vec> = res + .tokens_by_position + .iter() + .map(|tok| tokens.token_index(tok)) + .collect(); + let mut idf_by_position = Vec::with_capacity(res.tokens_by_position.len()); + for token in &res.tokens_by_position { + let idf_weight = match idf_cache.get(token) { + Some(weight) => *weight, + None => { + let weight = scorer.query_weight(token); + idf_cache.insert(token.clone(), weight); + weight + } + }; + idf_by_position.push(idf_weight); + } + for DocCandidate { + row_id, + freqs, + doc_length, + } in res.candidates + { + let mut term_freqs = vec![0u32; num_query_terms]; + let mut score = 0.0; + for (term_index, freq) in freqs.into_iter() { + let pos = term_index as usize; + debug_assert!(pos < idf_by_position.len()); + score += idf_by_position[pos] * scorer.doc_weight(freq, doc_length); + if let Some(input_idx) = term_index_to_input_idx[pos] { + term_freqs[input_idx] = freq; + } + } + let candidate = InvertedIndexCandidate { + row_id, + doc_length, + term_freqs, + local_score: score, + }; + if heap.len() < limit { + heap.push(Reverse(ScoredCandidate { + score: crate::vector::graph::OrderedFloat(score), + inner: candidate, + })); + } else if heap.peek().unwrap().0.score.0 < score { + heap.pop(); + heap.push(Reverse(ScoredCandidate { + score: crate::vector::graph::OrderedFloat(score), + inner: candidate, + })); + } + } + } + // `BinaryHeap>` is a min-heap on score, so + // `into_sorted_vec` returns ascending by `Reverse(score)`, which is + // descending by `score` — exactly the "best first" order callers + // expect, matching `bm25_search`'s convention. + Ok(heap + .into_sorted_vec() + .into_iter() + .map(|Reverse(s)| s.inner) + .collect()) + } + async fn load_legacy_index( store: Arc, frag_reuse_index: Option>, @@ -5455,6 +5660,153 @@ mod tests { } } + #[tokio::test] + async fn test_bm25_candidate_search_matches_bm25_search() { + // Same 2-partition fixture as test_bm25_search_uses_global_idf, but + // we also call bm25_candidate_search and verify: + // - the returned row_id set matches bm25_search + // - doc_length and term_freqs are correct against the inputs + // - local_score equals bm25_search's score for the same row + let tmpdir = TempObjDir::default(); + let store = Arc::new(LanceIndexStore::new( + ObjectStore::local().into(), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + let mut builder0 = InnerBuilder::new(0, false, TokenSetFormat::default()); + builder0.tokens.add("alpha".to_owned()); + builder0.tokens.add("beta".to_owned()); + builder0.posting_lists.push(PostingListBuilder::new(false)); + builder0.posting_lists.push(PostingListBuilder::new(false)); + // doc 0 ("alpha"): freq 2, dl 5 → only doc with non-1 freq, used to + // prove term_freqs really comes from the posting. + builder0.posting_lists[0].add(0, PositionRecorder::Count(2)); + builder0.posting_lists[1].add(1, PositionRecorder::Count(1)); + builder0.posting_lists[1].add(2, PositionRecorder::Count(1)); + builder0.docs.append(100, 5); + builder0.docs.append(101, 1); + builder0.docs.append(102, 1); + builder0.write(store.as_ref()).await.unwrap(); + + let mut builder1 = InnerBuilder::new(1, false, TokenSetFormat::default()); + builder1.tokens.add("alpha".to_owned()); + builder1.posting_lists.push(PostingListBuilder::new(false)); + builder1.posting_lists[0].add(0, PositionRecorder::Count(1)); + builder1.docs.append(200, 1); + builder1.write(store.as_ref()).await.unwrap(); + + let metadata = std::collections::HashMap::from_iter(vec![ + ( + "partitions".to_owned(), + serde_json::to_string(&vec![0u64, 1u64]).unwrap(), + ), + ( + "params".to_owned(), + serde_json::to_string(&InvertedIndexParams::default()).unwrap(), + ), + ( + TOKEN_SET_FORMAT_KEY.to_owned(), + TokenSetFormat::default().to_string(), + ), + ]); + let mut writer = store + .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty())) + .await + .unwrap(); + writer.finish_with_metadata(metadata).await.unwrap(); + + let cache = Arc::new(LanceCache::with_capacity(4096)); + let index = InvertedIndex::load(store.clone(), None, cache.as_ref()) + .await + .unwrap(); + + // Query 2 terms; "missing" must not crash and must yield freq=0. + let tokens = Arc::new(Tokens::new( + vec!["alpha".to_string(), "missing".to_string()], + DocType::Text, + )); + let params = Arc::new(FtsSearchParams::new().with_limit(Some(10))); + let prefilter = Arc::new(NoFilter); + let metrics = Arc::new(NoOpMetricsCollector); + + let (row_ids, scores) = index + .bm25_search( + tokens.clone(), + params.clone(), + Operator::Or, + prefilter.clone(), + metrics.clone(), + None, + ) + .await + .unwrap(); + + let candidates = index + .bm25_candidate_search( + tokens.clone(), + params, + Operator::Or, + prefilter, + metrics, + None, + ) + .await + .unwrap(); + + // Same hit set. + let bm25_rows: std::collections::HashSet = row_ids.iter().copied().collect(); + let cand_rows: std::collections::HashSet = + candidates.iter().map(|c| c.row_id).collect(); + assert_eq!(bm25_rows, cand_rows); + + // Candidates are sorted by local_score DESC. + for w in candidates.windows(2) { + assert!( + w[0].local_score >= w[1].local_score, + "candidates not sorted: {:?} then {:?}", + w[0].local_score, + w[1].local_score + ); + } + + // term_freqs is in input-token order: ["alpha", "missing"]. + for c in &candidates { + assert_eq!(c.term_freqs.len(), 2); + // No row has "missing", so column 1 must be 0. + assert_eq!(c.term_freqs[1], 0); + } + let by_row: std::collections::HashMap = + candidates.iter().map(|c| (c.row_id, c)).collect(); + + // Doc 100 had freq=2 ("alpha" added twice) and dl=5. + let c100 = by_row.get(&100).expect("row 100"); + assert_eq!(c100.term_freqs[0], 2); + assert_eq!(c100.doc_length, 5); + + // Doc 200 had freq=1 and dl=1. + let c200 = by_row.get(&200).expect("row 200"); + assert_eq!(c200.term_freqs[0], 1); + assert_eq!(c200.doc_length, 1); + + // local_score must match bm25_search's score for the same row. + let bm25_by_row: std::collections::HashMap = row_ids + .iter() + .copied() + .zip(scores.iter().copied()) + .collect(); + for c in &candidates { + let want = bm25_by_row[&c.row_id]; + assert!( + (c.local_score - want).abs() < 1e-6, + "row {} local_score={} vs bm25_search {}", + c.row_id, + c.local_score, + want, + ); + } + } + #[tokio::test] async fn test_phrase_query_reads_legacy_per_doc_positions() { let tmpdir = TempObjDir::default(); diff --git a/rust/lance/src/dataset/mem_wal/index.rs b/rust/lance/src/dataset/mem_wal/index.rs index 6a796b8b0c4..e7057c2242d 100644 --- a/rust/lance/src/dataset/mem_wal/index.rs +++ b/rust/lance/src/dataset/mem_wal/index.rs @@ -41,7 +41,7 @@ pub type RowPosition = u64; // Re-export public types used externally pub use btree::{BTreeIndexConfig, BTreeMemIndex}; -pub use fts::{FtsIndexConfig, FtsMemIndex, FtsQueryExpr, SearchOptions}; +pub use fts::{FtsCandidate, FtsIndexConfig, FtsMemIndex, FtsQueryExpr, SearchOptions}; pub use hnsw::{HnswIndexConfig, HnswMemIndex}; // ============================================================================ diff --git a/rust/lance/src/dataset/mem_wal/index/fts.rs b/rust/lance/src/dataset/mem_wal/index/fts.rs index 535b9c53fbd..adc571c60fd 100644 --- a/rust/lance/src/dataset/mem_wal/index/fts.rs +++ b/rust/lance/src/dataset/mem_wal/index/fts.rs @@ -78,6 +78,26 @@ pub struct FtsEntry { pub score: f32, } +/// Per-candidate sufficient statistics for global-stats rescoring. +/// +/// Carries everything a downstream coordinator needs to recompute BM25 with a +/// globally-aggregated `MemBM25Scorer`: the doc length and the per-query-term +/// frequencies. `local_score` is the score this index already computed under +/// its own corpus statistics — useful for diagnostics and for the `Local` +/// scoring mode that doesn't rescore. +#[derive(Debug, Clone)] +pub struct FtsCandidate { + /// Row position in MemTable. + pub row_position: RowPosition, + /// Token count for this document (`|d|`). + pub doc_len: u32, + /// One frequency per query term, in the order terms were passed to + /// `search_candidates`. A zero means the term is absent from the doc. + pub term_freqs: Vec, + /// BM25 score under this index's local corpus statistics. + pub local_score: f32, +} + /// Full-text search query expression for composable queries. #[derive(Debug, Clone)] pub enum FtsQueryExpr { @@ -1389,6 +1409,153 @@ impl FtsMemIndex { out } + // ------------------------------------------------------------------ + // Stats export and rescore-friendly candidate search + // ------------------------------------------------------------------ + + /// Segment summary for the supplied terms. + /// + /// Returns `(total_tokens, num_docs, df_per_term)` where each `df_per_term[i]` + /// is the number of visible documents containing `terms[i]`. The element + /// order matches `terms`, so callers can feed this directly into + /// `build_global_bm25_scorer`-style aggregation without re-keying. + /// + /// All inputs are tokenized exactly as the caller passes them — no extra + /// search-time tokenization is applied. Pass already-tokenized strings if + /// the goal is to compare against `bm25_search`'s scorer. + pub fn bm25_stats_for_terms(&self, terms: &[String]) -> (u64, usize, Vec) { + let st = self.state.load_full(); + let tail_snap = st.tail.snapshot(); + let mut total_tokens = tail_snap.cumulative_total_tokens; + let mut num_docs = tail_snap.cumulative_doc_count as usize; + for p in st.partitions.iter() { + total_tokens += p.total_tokens(); + num_docs += p.doc_count(); + } + let dfs: Vec = terms + .iter() + .map(|t| { + let mut df = tail_token_df(&st.tail.terms, t, tail_snap.visible_count); + for p in st.partitions.iter() { + df += p.token_df(t); + } + df + }) + .collect(); + (total_tokens, num_docs, dfs) + } + + /// Search and return the top-`k_prime` candidates with raw BM25 + /// sufficient statistics for downstream global rescoring. + /// + /// `query_terms` must already be tokenized — this method does not invoke + /// the index tokenizer. Tokens not present in the index contribute zero + /// frequencies and zero score (matching the OR semantics of + /// `search_match`). + /// + /// Documents are ranked by local BM25 (this index's own corpus stats); + /// the per-candidate `local_score` is reported alongside `doc_len` and + /// `term_freqs` so a coordinator can recompute with global stats without + /// touching the index again. + pub fn search_candidates(&self, query_terms: &[String], k_prime: usize) -> Vec { + if query_terms.is_empty() || k_prime == 0 { + return Vec::new(); + } + let st = self.state.load_full(); + let tail_snap = st.tail.snapshot(); + let scorer = build_scorer(&st, &tail_snap, query_terms); + if scorer.num_docs() == 0 { + return Vec::new(); + } + + // Aggregate (doc_len, tfs_per_query_term) per matching row. + let num_terms = query_terms.len(); + let mut per_doc: HashMap)> = HashMap::new(); + + // Tail contribution. + for (ti, token) in query_terms.iter().enumerate() { + let Some(entry) = st.tail.terms.get(token.as_str()) else { + continue; + }; + let slice = entry.value().load_full(); + for chunk in &slice.chunks { + if chunk.batch_position >= tail_snap.visible_count { + continue; + } + let Some(meta) = tail_snap.batch_for(chunk.batch_position) else { + continue; + }; + for (i, &row_position) in chunk.row_positions.iter().enumerate() { + let dl = meta.dl(row_position).unwrap_or(1); + let freq = chunk.frequencies[i]; + let slot = per_doc + .entry(row_position) + .or_insert_with(|| (dl, vec![0u32; num_terms])); + // dl can come in from a later term but for a tail row + // it's the same value either way; keep the first write. + slot.0 = dl; + slot.1[ti] = freq; + } + } + } + + // Frozen partitions contribution. + for partition in st.partitions.iter() { + for (ti, token) in query_terms.iter().enumerate() { + let Some(term_id) = partition.term_id(token) else { + continue; + }; + let mut cursor = PostingCursor::new(partition, term_id); + while let Some(local_doc) = cursor.cursor_doc() { + let row_pos = partition.docs.row_id(local_doc); + let dl = partition.docs.num_tokens(local_doc); + let freq = cursor.freq(); + let slot = per_doc + .entry(row_pos) + .or_insert_with(|| (dl, vec![0u32; num_terms])); + slot.0 = dl; + slot.1[ti] = freq; + cursor.advance(); + } + } + } + + // Score each row with local stats and keep the top-K'. + let mut candidates: Vec = per_doc + .into_iter() + .map(|(row_position, (doc_len, term_freqs))| { + let mut score = 0f32; + for (ti, &tf) in term_freqs.iter().enumerate() { + if tf > 0 { + score += + scorer.query_weight(&query_terms[ti]) * scorer.doc_weight(tf, doc_len); + } + } + FtsCandidate { + row_position, + doc_len, + term_freqs, + local_score: score, + } + }) + .collect(); + + if candidates.len() > k_prime { + candidates.select_nth_unstable_by(k_prime, |a, b| { + b.local_score + .partial_cmp(&a.local_score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + candidates.truncate(k_prime); + } + candidates.sort_by(|a, b| { + b.local_score + .partial_cmp(&a.local_score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + candidates + } + // ------------------------------------------------------------------ // Flush to Lance inverted index format // ------------------------------------------------------------------ @@ -2767,6 +2934,124 @@ mod tests { .unwrap() } + #[test] + fn test_bm25_stats_for_terms_tail_only() { + let schema = create_test_schema(); + let index = FtsMemIndex::new(1, "description".to_string()); + let batch = create_test_batch(&schema); + index.insert(&batch, 0).unwrap(); + + let (total_tokens, num_docs, df) = index.bm25_stats_for_terms(&[ + "hello".to_string(), + "world".to_string(), + "missing".to_string(), + ]); + assert_eq!(num_docs, 3); + // "hello world" (2) + "goodbye world" (2) + "hello again" (2) = 6 tokens + assert_eq!(total_tokens, 6); + assert_eq!(df, vec![2, 2, 0]); + } + + #[test] + fn test_bm25_stats_for_terms_with_frozen_partition() { + let schema = create_test_schema(); + // Freeze after every batch so the partition path is exercised. + let index = FtsMemIndex::new(1, "description".to_string()).with_freeze_threshold_rows(1); + let batch = create_test_batch(&schema); + index.insert(&batch, 0).unwrap(); + + let (total_tokens, num_docs, df) = + index.bm25_stats_for_terms(&["hello".to_string(), "world".to_string()]); + // Same corpus, just lives in frozen partitions now. Stats must match + // the tail-only result exactly so global aggregation is freeze-invariant. + assert_eq!(num_docs, 3); + assert_eq!(total_tokens, 6); + assert_eq!(df, vec![2, 2]); + } + + #[test] + fn test_search_candidates_returns_doc_len_and_tfs() { + let schema = create_test_schema(); + let index = FtsMemIndex::new(1, "description".to_string()); + let batch = create_test_batch(&schema); + index.insert(&batch, 0).unwrap(); + + // Query terms: "hello" matches rows 0,2; "world" matches rows 0,1. + // Row 0 contains both → tfs = [1, 1]; row 1 → [0, 1]; row 2 → [1, 0]. + let candidates = index.search_candidates(&["hello".to_string(), "world".to_string()], 10); + assert_eq!(candidates.len(), 3); + + let by_pos: HashMap = + candidates.iter().map(|c| (c.row_position, c)).collect(); + + let c0 = by_pos.get(&0).expect("row 0 hit"); + assert_eq!(c0.doc_len, 2); + assert_eq!(c0.term_freqs, vec![1, 1]); + assert!(c0.local_score > 0.0); + + let c1 = by_pos.get(&1).expect("row 1 hit"); + assert_eq!(c1.doc_len, 2); + assert_eq!(c1.term_freqs, vec![0, 1]); + + let c2 = by_pos.get(&2).expect("row 2 hit"); + assert_eq!(c2.doc_len, 2); + assert_eq!(c2.term_freqs, vec![1, 0]); + + // Output must be sorted by local_score DESC. + for w in candidates.windows(2) { + assert!(w[0].local_score >= w[1].local_score); + } + } + + #[test] + fn test_search_candidates_k_prime_truncation() { + let schema = create_test_schema(); + let index = FtsMemIndex::new(1, "description".to_string()); + let batch = create_test_batch(&schema); + index.insert(&batch, 0).unwrap(); + + let top1 = index.search_candidates(&["world".to_string()], 1); + assert_eq!(top1.len(), 1); + + let top0 = index.search_candidates(&["world".to_string()], 0); + assert!( + top0.is_empty(), + "k_prime=0 must return empty without panicking" + ); + + let none = index.search_candidates(&[], 10); + assert!(none.is_empty(), "empty query must return empty"); + } + + #[test] + fn test_search_candidates_consistent_across_freeze() { + // search_candidates must return identical stats whether the data is + // in the tail or already frozen into a partition — this is the + // invariant the LSM rescore path depends on. + let schema = create_test_schema(); + + let tail_only = FtsMemIndex::new(1, "description".to_string()); + tail_only.insert(&create_test_batch(&schema), 0).unwrap(); + + let frozen = FtsMemIndex::new(1, "description".to_string()).with_freeze_threshold_rows(1); + frozen.insert(&create_test_batch(&schema), 0).unwrap(); + + let terms = vec!["hello".to_string(), "world".to_string()]; + let a = tail_only.search_candidates(&terms, 10); + let b = frozen.search_candidates(&terms, 10); + assert_eq!(a.len(), b.len()); + + let a_map: HashMap)> = a + .into_iter() + .map(|c| (c.row_position, (c.doc_len, c.term_freqs))) + .collect(); + let b_map: HashMap)> = b + .into_iter() + .map(|c| (c.row_position, (c.doc_len, c.term_freqs))) + .collect(); + assert_eq!(a_map, b_map); + } + #[test] fn test_fts_index_insert_and_search() { let schema = create_test_schema(); diff --git a/rust/lance/src/dataset/mem_wal/memtable/scanner/exec/fts.rs b/rust/lance/src/dataset/mem_wal/memtable/scanner/exec/fts.rs index 595572919f8..ddcc31a00d9 100644 --- a/rust/lance/src/dataset/mem_wal/memtable/scanner/exec/fts.rs +++ b/rust/lance/src/dataset/mem_wal/memtable/scanner/exec/fts.rs @@ -106,7 +106,12 @@ impl FtsIndexExec { .iter() .map(|f| f.as_ref().clone()) .collect(); - fields.push(Field::new(SCORE_COLUMN, DataType::Float32, false)); + // `_score` is nullable here to stay schema-compatible with + // `lance_index::scalar::inverted::FTS_SCHEMA` (the schema base/flushed + // FTS exec nodes emit). The LSM `full_text_search` planner UNIONs the + // active arm with base/flushed arms; UnionExec requires schema equality + // including nullability. The actual emitted column is always populated. + fields.push(Field::new(SCORE_COLUMN, DataType::Float32, true)); if with_row_id { fields.push(Field::new(lance_core::ROW_ID, DataType::UInt64, true)); } diff --git a/rust/lance/src/dataset/mem_wal/scanner.rs b/rust/lance/src/dataset/mem_wal/scanner.rs index ec179653096..10ec08d7353 100644 --- a/rust/lance/src/dataset/mem_wal/scanner.rs +++ b/rust/lance/src/dataset/mem_wal/scanner.rs @@ -36,6 +36,7 @@ mod collector; mod data_source; pub mod exec; mod flushed_cache; +mod fts_search; mod planner; mod point_lookup; mod projection; @@ -47,6 +48,10 @@ pub use collector::{ }; pub use data_source::{FlushedGeneration, LsmDataSource, LsmGeneration, ShardSnapshot}; pub use flushed_cache::FlushedMemTableCache; +pub use fts_search::{ + DEFAULT_RESCORE_FACTOR, FtsScoringMode, LsmFtsSearchPlanner, MIN_RESCORE_CANDIDATES, + SCORE_COLUMN, +}; pub use point_lookup::LsmPointLookupPlanner; pub use projection::DISTANCE_COLUMN; pub use vector_search::LsmVectorSearchPlanner; diff --git a/rust/lance/src/dataset/mem_wal/scanner/builder.rs b/rust/lance/src/dataset/mem_wal/scanner/builder.rs index 570a3e0cfc9..7109a84fc68 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/builder.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/builder.rs @@ -300,6 +300,33 @@ impl LsmScanner { .await } + /// Build an FTS execution plan spanning base + flushed + active sources. + /// + /// Routes through [`super::LsmFtsSearchPlanner`]. Output schema is + /// `user_projection ∪ pk_columns + _score`. Per-source `_score` is + /// ranked DESC across the union and capped at `k`. See the planner + /// docs for scoring-mode semantics. + /// + /// `column` must be FTS-indexed on every source. The current + /// scoring modes are [`super::FtsScoringMode::Local`] (wired) and + /// `LocalWithGlobalRescore` (placeholder until the rescore exec + /// lands in PR follow-up). + pub async fn full_text_search( + &self, + column: &str, + query: lance_index::scalar::FullTextSearchQuery, + k: usize, + mode: super::FtsScoringMode, + ) -> Result> { + let collector = self.build_collector(); + let base_schema = self.schema(); + let planner = + super::LsmFtsSearchPlanner::new(collector, self.pk_columns.clone(), base_schema); + planner + .plan_search(column, query, k, self.projection.as_deref(), mode) + .await + } + /// Execute the scan and return a stream of record batches. pub async fn try_into_stream(&self) -> Result { let plan = self.create_plan().await?; diff --git a/rust/lance/src/dataset/mem_wal/scanner/fts_search.rs b/rust/lance/src/dataset/mem_wal/scanner/fts_search.rs new file mode 100644 index 00000000000..14f1f30478b --- /dev/null +++ b/rust/lance/src/dataset/mem_wal/scanner/fts_search.rs @@ -0,0 +1,637 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Full-text search planner for LSM scanner. +//! +//! Builds an execution plan that scores an FTS query across the base +//! table, flushed memtable generations, and active/frozen-undrained +//! in-memory memtables, returning rows ordered by BM25 `_score` DESC. +//! +//! # Scoring modes +//! +//! - [`FtsScoringMode::Local`] — each source uses its own corpus +//! statistics to score. Cross-source `_score` values are only +//! approximately comparable, but the plan is single-pass and never +//! coordinates stats across sources. +//! - [`FtsScoringMode::LocalWithGlobalRescore`] — each source returns +//! top-K' candidates with the raw BM25 sufficient statistics +//! (`doc_len`, per-term frequencies), and a coordinator rescores +//! them with globally-aggregated stats. NOT YET IMPLEMENTED at the +//! planner level — returns a descriptive error today; will land +//! alongside the rescore-aware per-source exec nodes. +//! +//! Staleness: per-source results are returned as-is. The same primary +//! key may appear from multiple sources if it was updated across +//! generations; the caller is responsible for dedup if they need it. +//! This is the user-chosen behavior captured in `DESIGN.md §3`. + +use std::sync::Arc; + +use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; +use datafusion::physical_expr::expressions::Column; +use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion::physical_plan::union::UnionExec; +use lance_core::{Error, Result, is_system_column}; +use lance_index::scalar::FullTextSearchQuery; +use tracing::instrument; + +use super::collector::LsmDataSourceCollector; +use super::data_source::LsmDataSource; +use super::projection::project_to_canonical; +use crate::dataset::mem_wal::memtable::scanner::MemTableScanner; + +/// `_score` column name in FTS results — kept aligned with +/// `lance_index::scalar::inverted::SCORE_COL` so this module doesn't +/// require an import for one string constant. +pub const SCORE_COLUMN: &str = "_score"; + +/// Default candidate multiplier for `LocalWithGlobalRescore`. +/// +/// Picked to match wjones127's draft on [discussion +/// #6789](https://github.com/lance-format/lance/discussions/6789): K' = +/// `rescore_factor * k`, floored at `max(k, 100)`. Subject to the +/// benchmark in `BENCH.md` — if the recall@k curve flattens earlier +/// we'll lower this. +pub const DEFAULT_RESCORE_FACTOR: u32 = 10; + +/// Floor for K' to keep rescore reasonable when `k` is tiny +/// (e.g., `k = 1` shouldn't collapse to one candidate per source). +pub const MIN_RESCORE_CANDIDATES: usize = 100; + +/// How per-source BM25 contributes to the final `_score`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FtsScoringMode { + /// Each source scores with its own corpus stats. + /// + /// Cheapest mode: a single round-trip and no coordinator state. + /// `_score` values are NOT strictly comparable across sources, + /// but ranking within each source is correct and the union is + /// merged by `_score` DESC. + Local, + /// Each source returns top-K' candidates with raw + /// `(doc_len, term_freqs)`; a coordinator rescores them with + /// globally-aggregated BM25 statistics. K' = `rescore_factor * k` + /// floored at [`MIN_RESCORE_CANDIDATES`]. + LocalWithGlobalRescore { rescore_factor: u32 }, +} + +impl FtsScoringMode { + /// Convenience constructor for `LocalWithGlobalRescore` with the + /// project default rescore factor. + pub fn local_with_global_rescore_default() -> Self { + Self::LocalWithGlobalRescore { + rescore_factor: DEFAULT_RESCORE_FACTOR, + } + } + + /// Effective K' for a user-supplied `k` (floored at the minimum). + pub fn rescore_k_prime(&self, k: usize) -> usize { + match self { + Self::Local => k, + Self::LocalWithGlobalRescore { rescore_factor } => (*rescore_factor as usize) + .saturating_mul(k.max(1)) + .max(MIN_RESCORE_CANDIDATES) + .max(k), + } + } +} + +/// Plans FTS queries over LSM data. +pub struct LsmFtsSearchPlanner { + collector: LsmDataSourceCollector, + pk_columns: Vec, + base_schema: SchemaRef, +} + +impl LsmFtsSearchPlanner { + /// Create a new planner. + pub fn new( + collector: LsmDataSourceCollector, + pk_columns: Vec, + base_schema: SchemaRef, + ) -> Self { + Self { + collector, + pk_columns, + base_schema, + } + } + + /// Build the FTS execution plan. + /// + /// # Arguments + /// + /// * `column` — text column to search; must have an FTS index on + /// the base dataset, every flushed memtable dataset, and every + /// active/frozen `IndexStore`. + /// * `query` — the FTS query (match / phrase / boolean / fuzzy). + /// * `k` — global top-k to return. + /// * `projection` — user columns to project. PK columns are + /// auto-included. `_score` is always appended. + /// * `mode` — see [`FtsScoringMode`]. + #[instrument( + name = "lsm_fts_search", + level = "info", + skip_all, + fields(column = %column, k, mode = ?mode) + )] + pub async fn plan_search( + &self, + column: &str, + query: FullTextSearchQuery, + k: usize, + projection: Option<&[String]>, + mode: FtsScoringMode, + ) -> Result> { + match mode { + FtsScoringMode::Local => self.plan_local(column, query, k, projection).await, + FtsScoringMode::LocalWithGlobalRescore { .. } => Err(Error::not_supported(format!( + "LocalWithGlobalRescore FTS planner not yet implemented; tracked under \ + ~/ai/analysis/lance/FTSRead/lsm-fts-search-with-global-rescore/PLAN.md \ + (T3 phase 2). Use FtsScoringMode::Local for now (k={k}, column={column})." + ))), + } + } + + async fn plan_local( + &self, + column: &str, + query: FullTextSearchQuery, + k: usize, + projection: Option<&[String]>, + ) -> Result> { + let sources = self.collector.collect()?; + let target_schema = self.canonical_fts_schema(projection); + + if sources.is_empty() { + return self.empty_plan(&target_schema); + } + + let mut per_source_plans: Vec> = Vec::with_capacity(sources.len()); + for source in &sources { + let plan = self + .build_source_local(source, column, &query, k, projection) + .await?; + let normalized = project_to_canonical(plan, &target_schema)?; + per_source_plans.push(normalized); + } + + // Single source: skip Union and the merge. + let merged: Arc = if per_source_plans.len() == 1 { + per_source_plans.into_iter().next().unwrap() + } else { + #[allow(deprecated)] + let union: Arc = Arc::new(UnionExec::new(per_source_plans)); + union + }; + + let score_idx = merged.schema().index_of(SCORE_COLUMN).map_err(|_| { + Error::internal(format!( + "{SCORE_COLUMN} missing from canonical FTS schema after merge" + )) + })?; + + let sort_expr = vec![PhysicalSortExpr { + expr: Arc::new(Column::new(SCORE_COLUMN, score_idx)), + options: SortOptions { + descending: true, + nulls_first: false, + }, + }]; + let lex_ordering = LexOrdering::new(sort_expr).ok_or_else(|| { + Error::internal("Failed to build LexOrdering for FTS _score sort".to_string()) + })?; + + // Per-partition sort with `fetch=k` so each upstream partition + // can early-terminate at k; the preserving merge then does a + // K-way heap merge also capped at k. Same pattern as + // LsmVectorSearchPlanner. + let per_partition_sorted: Arc = Arc::new( + SortExec::new(lex_ordering.clone(), merged) + .with_preserve_partitioning(true) + .with_fetch(Some(k)), + ); + let merged_sorted: Arc = Arc::new( + SortPreservingMergeExec::new(lex_ordering, per_partition_sorted).with_fetch(Some(k)), + ); + + Ok(merged_sorted) + } + + async fn build_source_local( + &self, + source: &LsmDataSource, + column: &str, + query: &FullTextSearchQuery, + k: usize, + projection: Option<&[String]>, + ) -> Result> { + match source { + LsmDataSource::BaseTable { dataset } => { + let mut scanner = dataset.scan(); + let cols = self.fts_scanner_projection(projection); + scanner.project(&cols.iter().map(|s| s.as_str()).collect::>())?; + let bound_query = query + .clone() + .with_column(column.to_string())? + .limit(Some(k as i64)); + scanner.full_text_search(bound_query)?; + scanner.create_plan().await + } + LsmDataSource::FlushedMemTable { path, .. } => { + let dataset = crate::dataset::DatasetBuilder::from_uri(path) + .load() + .await?; + let mut scanner = dataset.scan(); + let cols = self.fts_scanner_projection(projection); + scanner.project(&cols.iter().map(|s| s.as_str()).collect::>())?; + let bound_query = query + .clone() + .with_column(column.to_string())? + .limit(Some(k as i64)); + scanner.full_text_search(bound_query)?; + scanner.create_plan().await + } + LsmDataSource::ActiveMemTable { + batch_store, + index_store, + schema, + .. + } => { + let mut scanner = + MemTableScanner::new(batch_store.clone(), index_store.clone(), schema.clone()); + let cols = self.fts_scanner_projection(projection); + scanner.project(&cols.iter().map(|s| s.as_str()).collect::>()); + // `MemTableScanner::full_text_search` takes a raw match + // string; richer query shapes (phrase/boolean/fuzzy) + // can be plumbed via FtsQuery directly using the + // private setter once we need them. + let match_str = match &query.query { + lance_index::scalar::inverted::query::FtsQuery::Match(m) => m.terms.clone(), + other => { + return Err(Error::not_supported(format!( + "Active memtable FTS via LsmFtsSearchPlanner currently only \ + supports MatchQuery, got: {other:?}" + ))); + } + }; + let _ = scanner.full_text_search(column, &match_str); + // Active arm doesn't take a top-K hint via the builder + // today; per-partition Sort+fetch above bounds the + // emitted rows. + let _ = k; + scanner.create_plan().await + } + } + } + + /// Columns to pass to the underlying scanner: user projection + /// minus system / `_score`, with PK columns appended. + fn fts_scanner_projection(&self, user_projection: Option<&[String]>) -> Vec { + let mut cols: Vec = if let Some(p) = user_projection { + p.iter() + .filter(|c| !is_system_column(c) && c.as_str() != SCORE_COLUMN) + .cloned() + .collect() + } else { + self.base_schema + .fields() + .iter() + .map(|f| f.name().clone()) + .collect() + }; + for pk in &self.pk_columns { + if !cols.contains(pk) { + cols.push(pk.clone()); + } + } + cols + } + + /// Canonical FTS output: user-projected cols + PK + `_score`. + fn canonical_fts_schema(&self, user_projection: Option<&[String]>) -> SchemaRef { + let mut ordered: Vec = if let Some(p) = user_projection { + p.to_vec() + } else { + self.base_schema + .fields() + .iter() + .map(|f| f.name().clone()) + .collect() + }; + for pk in &self.pk_columns { + if !ordered.contains(pk) { + ordered.push(pk.clone()); + } + } + if !ordered.iter().any(|c| c == SCORE_COLUMN) { + ordered.push(SCORE_COLUMN.to_string()); + } + let fields: Vec> = ordered + .iter() + .filter_map(|name| { + if name == SCORE_COLUMN { + Some(Arc::new(Field::new(SCORE_COLUMN, DataType::Float32, true))) + } else if is_system_column(name) { + Some(Arc::new(Field::new(name.clone(), DataType::UInt64, true))) + } else { + self.base_schema + .field_with_name(name) + .ok() + .map(|f| Arc::new(f.clone())) + } + }) + .collect(); + Arc::new(Schema::new(fields)) + } + + fn empty_plan(&self, schema: &SchemaRef) -> Result> { + use datafusion::physical_plan::empty::EmptyExec; + Ok(Arc::new(EmptyExec::new(schema.clone()))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dataset::mem_wal::scanner::collector::{InMemoryMemTableRef, InMemoryMemTables}; + use crate::dataset::mem_wal::write::{BatchStore, IndexStore}; + use crate::dataset::{Dataset, WriteParams}; + use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray}; + use arrow_schema::{DataType, Field, Schema as ArrowSchema}; + use futures::TryStreamExt; + use std::collections::HashMap; + + fn fts_schema() -> Arc { + let mut id_meta = HashMap::new(); + id_meta.insert( + "lance-schema:unenforced-primary-key".to_string(), + "true".to_string(), + ); + let id_field = Field::new("id", DataType::Int32, false).with_metadata(id_meta); + Arc::new(ArrowSchema::new(vec![ + id_field, + Field::new("text", DataType::Utf8, true), + ])) + } + + fn make_batch(schema: &ArrowSchema, ids: &[i32], texts: &[&str]) -> RecordBatch { + RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(ids.to_vec())), + Arc::new(StringArray::from(texts.to_vec())), + ], + ) + .unwrap() + } + + async fn write_dataset(uri: &str, batches: Vec) -> Dataset { + let schema = batches[0].schema(); + let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); + Dataset::write(reader, uri, Some(WriteParams::default())) + .await + .unwrap() + } + + #[test] + fn rescore_k_prime_respects_floor_and_factor() { + let mode = FtsScoringMode::LocalWithGlobalRescore { rescore_factor: 10 }; + // factor * k, floored at MIN_RESCORE_CANDIDATES + assert_eq!(mode.rescore_k_prime(10), 100); + assert_eq!(mode.rescore_k_prime(20), 200); + // tiny k → floor kicks in + assert_eq!(mode.rescore_k_prime(1), MIN_RESCORE_CANDIDATES); + // Local mode passes k through + assert_eq!(FtsScoringMode::Local.rescore_k_prime(50), 50); + } + + #[tokio::test] + async fn rescore_mode_returns_clear_not_implemented_error() { + let schema = fts_schema(); + let tmp = tempfile::tempdir().unwrap(); + let base_uri = format!("{}/base", tmp.path().to_str().unwrap()); + write_dataset(&base_uri, vec![make_batch(&schema, &[1], &["hello"])]).await; + let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![]); + let planner = LsmFtsSearchPlanner::new(collector, vec!["id".to_string()], schema); + + let err = planner + .plan_search( + "text", + FullTextSearchQuery::new("hello".to_string()), + 10, + None, + FtsScoringMode::local_with_global_rescore_default(), + ) + .await + .expect_err("rescore mode must error until phase 2 lands"); + let msg = format!("{err}"); + assert!( + msg.contains("LocalWithGlobalRescore"), + "error must name the mode the user asked for: {msg}" + ); + } + + #[tokio::test] + async fn local_mode_unions_base_and_active_with_consistent_score_schema() { + // Regression for the `_score` nullability mismatch between + // FtsIndexExec (active arm) and FTS_SCHEMA (base/flushed). The + // active-only test below would not catch this — UnionExec rejects + // schema-inequality, so we need at least one base + one active + // source to exercise that code path. + use crate::index::DatasetIndexExt; + use lance_index::IndexType; + use lance_index::scalar::inverted::tokenizer::InvertedIndexParams; + + let schema = fts_schema(); + let tmp = tempfile::tempdir().unwrap(); + + // Base Lance dataset with FTS index on the `text` column. + let base_uri = format!("{}/base", tmp.path().to_str().unwrap()); + let mut base_ds = write_dataset( + &base_uri, + vec![make_batch( + &schema, + &[1, 2], + &["lance rocks", "unrelated text"], + )], + ) + .await; + base_ds + .create_index( + &["text"], + IndexType::Inverted, + Some("text_fts".to_string()), + &InvertedIndexParams::default(), + false, + ) + .await + .unwrap(); + let base_ds = Arc::new(Dataset::open(&base_uri).await.unwrap()); + + // Active memtable with its own FTS index, containing a matching row. + let batch_store = Arc::new(BatchStore::with_capacity(16)); + let mut indexes = IndexStore::new(); + indexes.add_fts("text_fts".to_string(), 1, "text".to_string()); + let active_batch = make_batch( + &schema, + &[3, 4], + &["lance memwal goes fast", "completely unrelated"], + ); + batch_store.append(active_batch.clone()).unwrap(); + indexes + .insert_with_batch_position(&active_batch, 0, Some(0)) + .unwrap(); + let indexes = Arc::new(indexes); + + let collector = LsmDataSourceCollector::new(base_ds, vec![]).with_in_memory_memtables( + uuid::Uuid::new_v4(), + InMemoryMemTables { + active: InMemoryMemTableRef { + batch_store, + index_store: indexes, + schema: schema.clone(), + generation: 1, + }, + frozen: vec![], + }, + ); + + let planner = LsmFtsSearchPlanner::new(collector, vec!["id".to_string()], schema); + let plan = planner + .plan_search( + "text", + FullTextSearchQuery::new("lance".to_string()), + 10, + None, + FtsScoringMode::Local, + ) + .await + .expect("planner should produce a base+active union plan"); + + let ctx = datafusion::prelude::SessionContext::new(); + let stream = plan.execute(0, ctx.task_ctx()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + let total: usize = batches.iter().map(|b| b.num_rows()).sum(); + // Both base id=1 ("lance rocks") and active id=3 ("lance memwal ...") + // should match. id=2 / id=4 do not contain "lance". + assert!( + total >= 2, + "expected at least the 2 'lance' rows from base+active, got {total}" + ); + + // Both sources must agree on _score nullability — verifies the fix. + let out = batches[0].schema(); + let score_field = out + .field_with_name(SCORE_COLUMN) + .expect("_score column missing from output"); + assert!( + score_field.is_nullable(), + "_score must be nullable to stay union-compatible across base+active" + ); + + // Sanity: ids contain at least one base hit (id=1) and one active hit (id=3). + let mut ids: Vec = Vec::new(); + for b in &batches { + let col = b + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..b.num_rows() { + ids.push(col.value(i)); + } + } + assert!(ids.contains(&1), "missing base hit id=1; got ids={ids:?}"); + assert!(ids.contains(&3), "missing active hit id=3; got ids={ids:?}"); + } + + #[tokio::test] + async fn local_mode_active_memtable_only_returns_score_sorted_hits() { + let schema = fts_schema(); + let batch_store = Arc::new(BatchStore::with_capacity(16)); + let mut indexes = IndexStore::new(); + // text column has field_id 1 in fts_schema() + indexes.add_fts("text_fts".to_string(), 1, "text".to_string()); + let batch = make_batch( + &schema, + &[1, 2, 3, 4], + &[ + "lance is a columnar data format", + "memwal handles streaming writes", + "lance memwal lance lance", + "completely unrelated", + ], + ); + batch_store.append(batch.clone()).unwrap(); + indexes + .insert_with_batch_position(&batch, 0, Some(0)) + .unwrap(); + let indexes = Arc::new(indexes); + + let tmp = tempfile::tempdir().unwrap(); + let base_uri = format!("{}/base", tmp.path().to_str().unwrap()); + let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![]) + .with_in_memory_memtables( + uuid::Uuid::new_v4(), + InMemoryMemTables { + active: InMemoryMemTableRef { + batch_store, + index_store: indexes, + schema: schema.clone(), + generation: 1, + }, + frozen: vec![], + }, + ); + + let planner = LsmFtsSearchPlanner::new(collector, vec!["id".to_string()], schema); + let plan = planner + .plan_search( + "text", + FullTextSearchQuery::new("lance".to_string()), + 10, + None, + FtsScoringMode::Local, + ) + .await + .expect("local mode planner should produce a plan"); + + // Plan executes and emits _score-sorted rows. + let ctx = datafusion::prelude::SessionContext::new(); + let stream = plan.execute(0, ctx.task_ctx()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + let total: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert!( + total >= 2, + "expected at least the 2 'lance' rows, got {total}" + ); + + // Schema must include _score and the PK id. + let out = batches[0].schema(); + assert!(out.field_with_name(SCORE_COLUMN).is_ok()); + assert!(out.field_with_name("id").is_ok()); + + // _score must be non-ascending across the result. + let mut prev_score: Option = None; + for batch in &batches { + let score = batch + .column_by_name(SCORE_COLUMN) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + let s = score.value(i); + if let Some(p) = prev_score { + assert!(p >= s, "scores not sorted DESC: {p} then {s}"); + } + prev_score = Some(s); + } + } + } +} From 55fcf32af7108489d2c435e68b5bb2c61b2205f3 Mon Sep 17 00:00:00 2001 From: Heng Ge Date: Wed, 20 May 2026 01:05:37 -0700 Subject: [PATCH 02/10] feat(mem_wal): wire LocalWithGlobalRescore mode in LsmFtsSearchPlanner MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements wjones127's rescore proposal (discussion #6789) on the single-node MemWAL path. The planner synchronously: 1. Tokenizes the query against the first source's FTS tokenizer. 2. Resolves a `SourceHandle` for each LSM source — active memtable keeps its `FtsMemIndex` reference; Lance sources open the column's InvertedIndex once and reuse it. 3. Gathers `(N_i, sumdl_i, df_t_i)` from every source via `bm25_stats_for_terms` and folds them into one global MemBM25Scorer. 4. Runs each source's candidate search with LOCAL pruning (no base_scorer) at `K' = max(rescore_factor * k, MIN_CANDIDATES)`. 5. Rescores every candidate with the global scorer and picks the global top-k. 6. Materializes user columns per source (BatchStore for active, `take_rows` for Lance) and stitches them back into the rescored order via a transient `__lsm_fts_order` column. 7. Returns the pre-materialized batch as a MemorySourceConfig exec. Two new end-to-end tests cover (a) active-only rescore picks the highest-tf doc first and (b) base+active rescore produces identical scores for symmetric hits under the global stats. --- .../src/dataset/mem_wal/scanner/fts_search.rs | 797 +++++++++++++++++- 1 file changed, 778 insertions(+), 19 deletions(-) diff --git a/rust/lance/src/dataset/mem_wal/scanner/fts_search.rs b/rust/lance/src/dataset/mem_wal/scanner/fts_search.rs index 14f1f30478b..4b2ad3fe5fd 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/fts_search.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/fts_search.rs @@ -15,19 +15,24 @@ //! coordinates stats across sources. //! - [`FtsScoringMode::LocalWithGlobalRescore`] — each source returns //! top-K' candidates with the raw BM25 sufficient statistics -//! (`doc_len`, per-term frequencies), and a coordinator rescores -//! them with globally-aggregated stats. NOT YET IMPLEMENTED at the -//! planner level — returns a descriptive error today; will land -//! alongside the rescore-aware per-source exec nodes. +//! (`doc_len`, per-term frequencies); the planner aggregates +//! per-source `(N, sumdl, df_t)` into one global `MemBM25Scorer`, +//! rescores every candidate with the global stats, and returns the +//! pre-materialized top-k as a [`MemorySourceConfig`] exec. //! //! Staleness: per-source results are returned as-is. The same primary //! key may appear from multiple sources if it was updated across //! generations; the caller is responsible for dedup if they need it. //! This is the user-chosen behavior captured in `DESIGN.md §3`. +use std::collections::HashMap; use std::sync::Arc; +use arrow_array::{Array, Float32Array, RecordBatch, UInt32Array, UInt64Array}; use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; +use arrow_select::concat::concat_batches; +use arrow_select::take::take; +use datafusion::datasource::memory::MemorySourceConfig; use datafusion::physical_expr::expressions::Column; use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; use datafusion::physical_plan::ExecutionPlan; @@ -35,13 +40,23 @@ use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::union::UnionExec; use lance_core::{Error, Result, is_system_column}; +use lance_index::metrics::NoOpMetricsCollector; +use lance_index::prefilter::NoFilter; use lance_index::scalar::FullTextSearchQuery; +use lance_index::scalar::inverted::document_tokenizer::DocType; +use lance_index::scalar::inverted::query::{ + FtsQuery as IndexFtsQuery, FtsSearchParams, Operator, Tokens, collect_query_tokens, +}; +use lance_index::scalar::inverted::{InvertedIndex, InvertedIndexCandidate, MemBM25Scorer, Scorer}; use tracing::instrument; use super::collector::LsmDataSourceCollector; use super::data_source::LsmDataSource; use super::projection::project_to_canonical; +use crate::Dataset; +use crate::dataset::mem_wal::index::FtsCandidate; use crate::dataset::mem_wal::memtable::scanner::MemTableScanner; +use crate::dataset::mem_wal::write::{BatchStore, IndexStore}; /// `_score` column name in FTS results — kept aligned with /// `lance_index::scalar::inverted::SCORE_COL` so this module doesn't @@ -148,14 +163,312 @@ impl LsmFtsSearchPlanner { ) -> Result> { match mode { FtsScoringMode::Local => self.plan_local(column, query, k, projection).await, - FtsScoringMode::LocalWithGlobalRescore { .. } => Err(Error::not_supported(format!( - "LocalWithGlobalRescore FTS planner not yet implemented; tracked under \ - ~/ai/analysis/lance/FTSRead/lsm-fts-search-with-global-rescore/PLAN.md \ - (T3 phase 2). Use FtsScoringMode::Local for now (k={k}, column={column})." - ))), + FtsScoringMode::LocalWithGlobalRescore { rescore_factor } => { + self.plan_rescore(column, query, k, projection, rescore_factor) + .await + } } } + /// Single-node implementation of wjones127's `LocalWithGlobalRescore` + /// mode (discussion #6789). Orchestrates synchronously: + /// + /// 1. Tokenize the query against the first available source's + /// tokenizer (we assume all sources share the same FTS params). + /// 2. Open the InvertedIndex for each Lance source and gather + /// `(N_i, sumdl_i, df_t_i)` from every source. + /// 3. Aggregate into a single global `MemBM25Scorer`. + /// 4. Run each source's candidate search with LOCAL stats (so each + /// segment uses its own WAND pruning thresholds). + /// 5. Rescore the union of candidates with the global scorer. + /// 6. Take the top-k by rescored `_score`. + /// 7. Materialize user columns (active arm reads BatchStore; Lance + /// arms `take_rows`), assemble the output RecordBatch, and + /// return it as a `MemorySourceConfig` exec. + /// + /// The output is pre-materialized rather than streaming because + /// rescore needs every candidate from every source in scope before + /// it can pick the global top-k — a buffered exec would be the same + /// shape under the hood. For the bench-relevant single-node case + /// this is a clear win on simplicity at no correctness cost. + async fn plan_rescore( + &self, + column: &str, + query: FullTextSearchQuery, + k: usize, + projection: Option<&[String]>, + rescore_factor: u32, + ) -> Result> { + let sources = self.collector.collect()?; + let target_schema = self.canonical_fts_schema(projection); + if sources.is_empty() || k == 0 { + return self.empty_plan(&target_schema); + } + + // Step 1: pull a tokenizer + tokenize the query text. + let match_text = extract_match_text(&query)?; + let mut tokenizer = self.resolve_tokenizer(&sources, column).await?; + let tokens_obj = collect_query_tokens(&match_text, &mut tokenizer); + let token_strs: Vec = (0..tokens_obj.len()) + .map(|i| tokens_obj.get_token(i).to_owned()) + .collect(); + if token_strs.is_empty() { + return self.empty_plan(&target_schema); + } + + let k_prime = FtsScoringMode::LocalWithGlobalRescore { rescore_factor }.rescore_k_prime(k); + + // Step 2: resolve each source to a `SourceHandle` and gather its stats. + let mut handles: Vec = Vec::with_capacity(sources.len()); + let mut total_tokens: u64 = 0; + let mut num_docs: usize = 0; + let mut df_map: HashMap = + token_strs.iter().map(|t| (t.clone(), 0usize)).collect(); + for source in &sources { + let handle = self.resolve_handle(source, column).await?; + let (tt, nd, df_vec) = handle.stats_for_terms(&token_strs)?; + total_tokens += tt; + num_docs += nd; + for (t, c) in token_strs.iter().zip(df_vec.into_iter()) { + *df_map.get_mut(t).expect("df entry seeded above") += c; + } + handles.push(handle); + } + if num_docs == 0 { + return self.empty_plan(&target_schema); + } + let global_scorer = MemBM25Scorer::new(total_tokens, num_docs, df_map); + + // Step 3: per-source candidate search with LOCAL pruning (no base_scorer). + let tokens_arc = Arc::new(Tokens::new(token_strs.clone(), DocType::Text)); + let params = Arc::new(FtsSearchParams::new().with_limit(Some(k_prime))); + let mut rescored: Vec = Vec::new(); + for (source_idx, handle) in handles.iter().enumerate() { + let candidates = handle + .candidate_search(&tokens_arc, ¶ms, &token_strs) + .await?; + for c in candidates { + // Step 4: rescore with global scorer in-place. + let score = bm25_score(&global_scorer, &token_strs, &c.term_freqs, c.doc_len); + rescored.push(RescoredCandidate { + source_idx, + row_id: c.row_id, + score, + }); + } + } + + // Step 5: pick global top-k. + rescored.sort_by(|a, b| { + b.score + .partial_cmp(&a.score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + rescored.truncate(k); + if rescored.is_empty() { + return self.empty_plan(&target_schema); + } + + // Step 6: materialize user columns per source, in the order + // determined by `rescored`. Each candidate carries the source it + // came from so we know which BatchStore / Dataset to take from. + let final_batch = self + .materialize_rescored(&handles, &rescored, projection, &target_schema) + .await?; + + // Step 7: wrap pre-computed batch in MemorySourceConfig. + let exec = + MemorySourceConfig::try_new_exec(&[vec![final_batch]], target_schema.clone(), None) + .map_err(|e| Error::internal(format!("MemorySourceConfig failed: {e}")))?; + Ok(exec) + } + + /// Acquire a tokenizer compatible with every source's FTS index. + /// + /// We assume FTS-indexed sources in an LSM hierarchy share their + /// `InvertedIndexParams` (otherwise their indexes wouldn't be + /// merge-compatible). Pulls the tokenizer from the first source + /// that has one; any later mismatch is the caller's bug. + async fn resolve_tokenizer( + &self, + sources: &[LsmDataSource], + column: &str, + ) -> Result> + { + for source in sources { + match source { + LsmDataSource::ActiveMemTable { index_store, .. } => { + if let Some(idx) = index_store.get_fts_by_column(column) { + return idx.params().build(); + } + } + LsmDataSource::BaseTable { dataset } => { + if let Some(idx) = open_inverted_index(dataset, column).await? { + return Ok(idx.tokenizer()); + } + } + LsmDataSource::FlushedMemTable { path, .. } => { + let dataset = crate::dataset::DatasetBuilder::from_uri(path) + .load() + .await?; + if let Some(idx) = open_inverted_index(&dataset, column).await? { + return Ok(idx.tokenizer()); + } + } + } + } + Err(Error::invalid_input(format!( + "No source carries an FTS index on column '{column}'; \ + cannot tokenize the query for LocalWithGlobalRescore mode." + ))) + } + + async fn resolve_handle(&self, source: &LsmDataSource, column: &str) -> Result { + match source { + LsmDataSource::ActiveMemTable { + batch_store, + index_store, + schema, + .. + } => { + let _ = index_store.get_fts_by_column(column).ok_or_else(|| { + Error::invalid_input(format!( + "Active memtable is missing an FTS index on column '{column}'" + )) + })?; + Ok(SourceHandle::Active { + batch_store: batch_store.clone(), + index_store: index_store.clone(), + schema: schema.clone(), + column: column.to_string(), + }) + } + LsmDataSource::BaseTable { dataset } => { + let index = open_inverted_index(dataset, column).await?.ok_or_else(|| { + Error::invalid_input(format!( + "Base table is missing an FTS index on column '{column}'" + )) + })?; + Ok(SourceHandle::Lance { + dataset: dataset.clone(), + index, + }) + } + LsmDataSource::FlushedMemTable { path, .. } => { + let dataset = crate::dataset::DatasetBuilder::from_uri(path) + .load() + .await?; + let index = open_inverted_index(&dataset, column).await?.ok_or_else(|| { + Error::invalid_input(format!( + "Flushed memtable at {path} is missing an FTS index on column '{column}'" + )) + })?; + Ok(SourceHandle::Lance { + dataset: Arc::new(dataset), + index, + }) + } + } + } + + /// Materialize the rescored top-k into a single RecordBatch with + /// the canonical FTS schema. Groups by source so we can issue one + /// take per Lance source, then rebuilds the original row order. + async fn materialize_rescored( + &self, + handles: &[SourceHandle], + rescored: &[RescoredCandidate], + projection: Option<&[String]>, + target_schema: &SchemaRef, + ) -> Result { + let cols = self.fts_scanner_projection(projection); + // Group candidate positions by source. + let mut by_source: HashMap> = HashMap::new(); + for (i, r) in rescored.iter().enumerate() { + by_source + .entry(r.source_idx) + .or_default() + .push((i, r.row_id, r.score)); + } + + // For each source, materialize its rows into a partial batch + // that includes a synthetic `_order` column (the index in + // `rescored`) so we can re-sort at the end. + let mut partials: Vec = Vec::new(); + for (source_idx, mut entries) in by_source.into_iter() { + // Stable to preserve relative order; not strictly needed + // because we re-sort by `_order` below, but cheaper to keep + // related row ids adjacent for take. + entries.sort_by_key(|(_, rid, _)| *rid); + let row_ids: Vec = entries.iter().map(|(_, rid, _)| *rid).collect(); + let scores: Vec = entries.iter().map(|(_, _, s)| *s).collect(); + let orders: Vec = entries.iter().map(|(i, _, _)| *i as u32).collect(); + + let materialized = handles[source_idx] + .materialize_rows(&row_ids, &cols) + .await?; + let mut columns: Vec> = materialized.columns().to_vec(); + let mut fields: Vec> = + materialized.schema().fields().iter().cloned().collect(); + columns.push(Arc::new(Float32Array::from(scores))); + fields.push(Arc::new(Field::new(SCORE_COLUMN, DataType::Float32, true))); + columns.push(Arc::new(UInt32Array::from(orders))); + fields.push(Arc::new(Field::new( + "__lsm_fts_order", + DataType::UInt32, + false, + ))); + + let schema = Arc::new(Schema::new( + fields.iter().map(|f| (**f).clone()).collect::>(), + )); + partials.push(RecordBatch::try_new(schema, columns)?); + } + + if partials.is_empty() { + // All sources returned 0 candidates after rescore. + return Ok(RecordBatch::new_empty(target_schema.clone())); + } + + // Concat across sources. + let stitch_schema = partials[0].schema(); + let stitched = concat_batches(&stitch_schema, &partials)?; + + // Sort by `__lsm_fts_order` ASC so the output reflects the + // top-k order from `rescored`. + let order_col = stitched + .column_by_name("__lsm_fts_order") + .expect("__lsm_fts_order present after materialize") + .as_any() + .downcast_ref::() + .expect("__lsm_fts_order is UInt32"); + let mut indices_with_order: Vec<(usize, u32)> = (0..order_col.len()) + .map(|i| (i, order_col.value(i))) + .collect(); + indices_with_order.sort_by_key(|(_, o)| *o); + let take_idx: UInt32Array = indices_with_order.iter().map(|(i, _)| *i as u32).collect(); + + // Drop `__lsm_fts_order` from the final output by projecting on + // the canonical schema's column names. + let final_cols: Vec> = target_schema + .fields() + .iter() + .map(|f| { + let src = stitched.column_by_name(f.name()).ok_or_else(|| { + Error::internal(format!( + "rescore materialization missing column '{}'", + f.name() + )) + })?; + let taken = take(src.as_ref(), &take_idx, None).map_err(|e| { + Error::internal(format!("take failed on column '{}': {e}", f.name())) + })?; + Ok::<_, Error>(taken) + }) + .collect::>()?; + Ok(RecordBatch::try_new(target_schema.clone(), final_cols)?) + } + async fn plan_local( &self, column: &str, @@ -354,6 +667,294 @@ impl LsmFtsSearchPlanner { } } +/// One rescored hit threaded through the rescore orchestrator. +#[derive(Debug)] +struct RescoredCandidate { + /// Index into the `handles` slice — tells materialization which + /// source the `row_id` is relative to. + source_idx: usize, + /// Source-local row identifier. + /// + /// * Active arm: BatchStore row position. + /// * Lance arm: Lance row id. + row_id: u64, + /// Score under the globally-aggregated BM25 statistics. + score: f32, +} + +/// Pre-resolved handle for a single LSM source. Created once per +/// rescore plan so we don't reopen `Dataset` / `InvertedIndex` twice +/// (once for stats, once for candidates). +enum SourceHandle { + Active { + batch_store: Arc, + index_store: Arc, + schema: SchemaRef, + column: String, + }, + Lance { + dataset: Arc, + index: Arc, + }, +} + +impl SourceHandle { + fn stats_for_terms(&self, terms: &[String]) -> Result<(u64, usize, Vec)> { + match self { + Self::Active { + index_store, + column, + .. + } => { + let idx = index_store + .get_fts_by_column(column) + .expect("active handle invariant: FTS index present"); + Ok(idx.bm25_stats_for_terms(terms)) + } + Self::Lance { index, .. } => Ok(index.bm25_stats_for_terms(terms)), + } + } + + async fn candidate_search( + &self, + tokens: &Arc, + params: &Arc, + token_strs: &[String], + ) -> Result> { + match self { + Self::Active { + index_store, + column, + .. + } => { + let idx = index_store + .get_fts_by_column(column) + .expect("active handle invariant: FTS index present"); + let k_prime = params.limit.unwrap_or(usize::MAX); + let candidates = idx.search_candidates(token_strs, k_prime); + Ok(candidates + .into_iter() + .map(UnifiedCandidate::from_fts_candidate) + .collect()) + } + Self::Lance { index, .. } => { + let prefilter = Arc::new(NoFilter); + let metrics = Arc::new(NoOpMetricsCollector); + let raw = index + .bm25_candidate_search( + tokens.clone(), + params.clone(), + Operator::Or, + prefilter, + metrics, + None, + ) + .await?; + Ok(raw + .into_iter() + .map(UnifiedCandidate::from_inverted_candidate) + .collect()) + } + } + } + + async fn materialize_rows(&self, row_ids: &[u64], cols: &[String]) -> Result { + match self { + Self::Active { + batch_store, + schema, + .. + } => active_materialize(batch_store, schema, row_ids, cols), + Self::Lance { dataset, .. } => { + // Project the dataset's Lance schema down to the requested + // columns by name. Unknown names are dropped (`take_rows` + // would otherwise error on schema construction). + let names: Vec<&str> = cols + .iter() + .filter(|n| dataset.schema().field(n).is_some()) + .map(|n| n.as_str()) + .collect(); + let projection = dataset.schema().project(&names)?; + Ok(dataset.take_rows(row_ids, Arc::new(projection)).await?) + } + } + } +} + +/// Common shape for one candidate, regardless of where it came from. +struct UnifiedCandidate { + row_id: u64, + doc_len: u32, + term_freqs: Vec, +} + +impl UnifiedCandidate { + fn from_fts_candidate(c: FtsCandidate) -> Self { + Self { + row_id: c.row_position, + doc_len: c.doc_len, + term_freqs: c.term_freqs, + } + } + + fn from_inverted_candidate(c: InvertedIndexCandidate) -> Self { + Self { + row_id: c.row_id, + doc_len: c.doc_length, + term_freqs: c.term_freqs, + } + } +} + +/// BM25 score from a scorer + raw per-doc stats. +fn bm25_score(scorer: &MemBM25Scorer, tokens: &[String], freqs: &[u32], doc_len: u32) -> f32 { + let mut score = 0f32; + for (ti, tok) in tokens.iter().enumerate() { + let f = freqs[ti]; + if f > 0 { + score += scorer.query_weight(tok) * scorer.doc_weight(f, doc_len); + } + } + score +} + +/// Pull the raw text out of a `FullTextSearchQuery` for tokenization. +/// +/// Today we only handle `MatchQuery`; other shapes return a clear +/// `not_supported` error mirroring the Local-mode active-arm +/// restriction. Lifting this requires plumbing structured query shapes +/// through the rescore path, tracked in `PLAN.md`. +fn extract_match_text(query: &FullTextSearchQuery) -> Result { + match &query.query { + IndexFtsQuery::Match(m) => Ok(m.terms.clone()), + other => Err(Error::not_supported(format!( + "LocalWithGlobalRescore currently supports only MatchQuery; got: {other:?}" + ))), + } +} + +/// Open the column's inverted index from a Lance dataset, or `None` +/// if no FTS index exists for the column. +async fn open_inverted_index( + dataset: &Dataset, + column: &str, +) -> Result>> { + use crate::index::{DatasetIndexExt, DatasetIndexInternalExt}; + + // Resolve the column's field id so we can match against `IndexMetadata.fields`. + let field_id = match dataset.schema().field(column) { + Some(f) => f.id, + None => return Ok(None), + }; + + let indices = dataset.load_indices().await?; + for meta in indices.iter() { + if !meta.fields.contains(&field_id) { + continue; + } + // Mixed index types on the same field would be unusual but + // possible; the canonical filter is the downcast below. + let uuid = meta.uuid.to_string(); + let Ok(opened) = dataset + .open_generic_index(column, &uuid, &lance_index::metrics::NoOpMetricsCollector) + .await + else { + continue; + }; + if let Some(inv) = opened.as_any().downcast_ref::() { + return Ok(Some(Arc::new(inv.clone()))); + } + } + Ok(None) +} + +/// Materialize user-projected columns from the active memtable's +/// BatchStore for a sequence of BatchStore-row-position row ids. +fn active_materialize( + batch_store: &Arc, + schema: &SchemaRef, + row_ids: &[u64], + cols: &[String], +) -> Result { + // Pre-compute (start, end] ranges per batch so we can binary-search + // a row position to its batch. + struct BatchRange { + start: u64, + end: u64, + batch_id: usize, + } + let mut ranges: Vec = Vec::new(); + let mut cur: u64 = 0; + for (batch_id, stored) in batch_store.iter().enumerate() { + ranges.push(BatchRange { + start: cur, + end: cur + stored.num_rows as u64, + batch_id, + }); + cur += stored.num_rows as u64; + } + let find = |row_pos: u64| -> Option<&BatchRange> { + let idx = ranges.partition_point(|r| r.end <= row_pos); + ranges + .get(idx) + .filter(|r| row_pos >= r.start && row_pos < r.end) + }; + + // For each row id, append the relevant column slice to a vector. + let col_indices: Vec = cols + .iter() + .map(|name| { + schema.index_of(name).map_err(|_| { + Error::internal(format!( + "active materialize: column '{name}' missing from BatchStore schema" + )) + }) + }) + .collect::>()?; + let mut per_col: Vec>> = vec![Vec::new(); col_indices.len()]; + for &row_pos in row_ids { + let br = find(row_pos).ok_or_else(|| { + Error::internal(format!( + "active materialize: row position {row_pos} out of range" + )) + })?; + let stored = batch_store.get(br.batch_id).ok_or_else(|| { + Error::internal(format!( + "active materialize: batch {} missing from store", + br.batch_id + )) + })?; + let local = (row_pos - br.start) as u32; + let take_idx = UInt64Array::from(vec![local as u64]); + for (slot, &src_idx) in per_col.iter_mut().zip(col_indices.iter()) { + let col = stored.data.column(src_idx); + let taken = take(col.as_ref(), &take_idx, None)?; + slot.push(taken); + } + } + let mut fields: Vec = Vec::with_capacity(col_indices.len()); + let mut columns: Vec> = Vec::with_capacity(col_indices.len()); + for (name, slot) in cols.iter().zip(per_col.into_iter()) { + let src_field = schema.field_with_name(name).map_err(|e| { + Error::internal(format!( + "active materialize: field '{name}' lookup failed: {e}" + )) + })?; + fields.push(src_field.clone()); + if slot.is_empty() { + // No rows to take — build an empty array of the right type. + let empty = arrow_array::new_empty_array(src_field.data_type()); + columns.push(empty); + } else { + let refs: Vec<&dyn Array> = slot.iter().map(|a| a.as_ref()).collect(); + let concatenated = arrow_select::concat::concat(&refs)?; + columns.push(concatenated); + } + } + let out_schema = Arc::new(Schema::new(fields)); + Ok(RecordBatch::try_new(out_schema, columns)?) +} + #[cfg(test)] mod tests { use super::*; @@ -410,29 +1011,187 @@ mod tests { } #[tokio::test] - async fn rescore_mode_returns_clear_not_implemented_error() { + async fn rescore_mode_unions_base_and_active_with_global_scores() { + // End-to-end smoke for LocalWithGlobalRescore: a base + active + // shape where the "lance" term appears in both. Score + // recomputation under the global scorer must yield identical + // scores for the two hits because both have freq=1 and dl=2 — + // the global stats are corpus-wide so they see the same + // (idf, avgdl) for both rows. + use crate::index::DatasetIndexExt; + use lance_index::IndexType; + use lance_index::scalar::inverted::tokenizer::InvertedIndexParams; + let schema = fts_schema(); let tmp = tempfile::tempdir().unwrap(); + + // Base Lance dataset with FTS index. let base_uri = format!("{}/base", tmp.path().to_str().unwrap()); - write_dataset(&base_uri, vec![make_batch(&schema, &[1], &["hello"])]).await; - let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![]); - let planner = LsmFtsSearchPlanner::new(collector, vec!["id".to_string()], schema); + let mut base_ds = write_dataset( + &base_uri, + vec![make_batch( + &schema, + &[1, 2], + &["lance fast", "unrelated text"], + )], + ) + .await; + base_ds + .create_index( + &["text"], + IndexType::Inverted, + Some("text_fts".to_string()), + &InvertedIndexParams::default(), + false, + ) + .await + .unwrap(); + let base_ds = Arc::new(Dataset::open(&base_uri).await.unwrap()); + + // Active memtable with FTS index over a different row. + let batch_store = Arc::new(BatchStore::with_capacity(16)); + let mut indexes = IndexStore::new(); + indexes.add_fts("text_fts".to_string(), 1, "text".to_string()); + let active_batch = make_batch(&schema, &[3, 4], &["lance quick", "completely unrelated"]); + batch_store.append(active_batch.clone()).unwrap(); + indexes + .insert_with_batch_position(&active_batch, 0, Some(0)) + .unwrap(); + let indexes = Arc::new(indexes); + + let collector = LsmDataSourceCollector::new(base_ds, vec![]).with_in_memory_memtables( + uuid::Uuid::new_v4(), + InMemoryMemTables { + active: InMemoryMemTableRef { + batch_store, + index_store: indexes, + schema: schema.clone(), + generation: 1, + }, + frozen: vec![], + }, + ); - let err = planner + let planner = LsmFtsSearchPlanner::new(collector, vec!["id".to_string()], schema); + let plan = planner .plan_search( "text", - FullTextSearchQuery::new("hello".to_string()), + FullTextSearchQuery::new("lance".to_string()), 10, None, FtsScoringMode::local_with_global_rescore_default(), ) .await - .expect_err("rescore mode must error until phase 2 lands"); - let msg = format!("{err}"); + .expect("rescore planner should produce a base+active plan"); + + let ctx = datafusion::prelude::SessionContext::new(); + let stream = plan.execute(0, ctx.task_ctx()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + let total: usize = batches.iter().map(|b| b.num_rows()).sum(); + // Both base id=1 and active id=3 contain "lance" → 2 hits. + assert_eq!(total, 2, "expected exactly the 2 'lance' hits"); + + let out = batches[0].schema(); + assert!(out.field_with_name(SCORE_COLUMN).is_ok()); + assert!(out.field_with_name("id").is_ok()); + + // Collect (id, score) pairs. + let mut hits: Vec<(i32, f32)> = Vec::new(); + for b in &batches { + let ids = b + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let scores = b + .column_by_name(SCORE_COLUMN) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..b.num_rows() { + hits.push((ids.value(i), scores.value(i))); + } + } + // Both hits must be present. + let by_id: std::collections::HashMap = hits.iter().copied().collect(); + let s1 = *by_id.get(&1).expect("base hit id=1 missing"); + let s3 = *by_id.get(&3).expect("active hit id=3 missing"); + // Global stats see N=4 docs, df("lance")=2. Both id=1 and id=3 + // have freq=1 and doc_len=2 → identical BM25 under global stats. assert!( - msg.contains("LocalWithGlobalRescore"), - "error must name the mode the user asked for: {msg}" + (s1 - s3).abs() < 1e-5, + "global rescore should give identical scores for symmetric hits; got s1={s1}, s3={s3}" + ); + // Sort: scores descending. + for w in hits.windows(2) { + assert!(w[0].1 >= w[1].1); + } + } + + #[tokio::test] + async fn rescore_mode_active_only_runs_end_to_end() { + // Cheaper regression that doesn't need a base Lance dataset: + // just an active memtable. Validates the active-only candidate + // path + rescore math. + let schema = fts_schema(); + let batch_store = Arc::new(BatchStore::with_capacity(16)); + let mut indexes = IndexStore::new(); + indexes.add_fts("text_fts".to_string(), 1, "text".to_string()); + let batch = make_batch( + &schema, + &[1, 2, 3], + &["lance lance lance", "lance once", "no match here"], ); + batch_store.append(batch.clone()).unwrap(); + indexes + .insert_with_batch_position(&batch, 0, Some(0)) + .unwrap(); + let indexes = Arc::new(indexes); + + let tmp = tempfile::tempdir().unwrap(); + let base_uri = format!("{}/base", tmp.path().to_str().unwrap()); + let collector = LsmDataSourceCollector::without_base_table(base_uri, vec![]) + .with_in_memory_memtables( + uuid::Uuid::new_v4(), + InMemoryMemTables { + active: InMemoryMemTableRef { + batch_store, + index_store: indexes, + schema: schema.clone(), + generation: 1, + }, + frozen: vec![], + }, + ); + + let planner = LsmFtsSearchPlanner::new(collector, vec!["id".to_string()], schema); + let plan = planner + .plan_search( + "text", + FullTextSearchQuery::new("lance".to_string()), + 10, + None, + FtsScoringMode::local_with_global_rescore_default(), + ) + .await + .expect("rescore planner should produce a plan"); + let ctx = datafusion::prelude::SessionContext::new(); + let stream = plan.execute(0, ctx.task_ctx()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + let total: usize = batches.iter().map(|b| b.num_rows()).sum(); + // id=1 (3 occurrences) and id=2 (1 occurrence) match; id=3 doesn't. + assert_eq!(total, 2); + // id=1 should outrank id=2 because tf is higher and doc length is similar. + let first_id = batches[0] + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + assert_eq!(first_id, 1, "highest-tf doc should rank first"); } #[tokio::test] From e0f5f3f2964941bbbd80da8a0b27426b8cf3bdcc Mon Sep 17 00:00:00 2001 From: Heng Ge Date: Wed, 20 May 2026 01:11:20 -0700 Subject: [PATCH 03/10] =?UTF-8?q?bench(mem=5Fwal):=20add=20lsm=5Ffts=5Fmod?= =?UTF-8?q?es=20=E2=80=94=20Local=20vs=20Rescore=20on=20FineWeb?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New bench `benches/mem_wal/fts/lsm_fts_modes.rs`, sibling of `mem_wal_fineweb_fts.rs`, sharing the same FineWeb loader shape and cache-dir convention. For a configurable LSM shape (balanced / memwal_skewed / growing_lsm) the bench: 1. Loads a HuggingFace FineWeb slice into a base Lance dataset plus several flushed-generation datasets plus one active in-memory memtable, each with its own FTS index. 2. Picks `num_queries` representative single-term queries from the 80–99 percentile band of corpus DF. 3. Runs both FtsScoringMode::Local and FtsScoringMode::LocalWithGlobalRescore through LsmScanner and records per-query latency. 4. Builds a single-merged-index baseline (the same FineWeb rows in one Lance dataset) and runs the same queries against it. 5. Reports mean / p50 / p95 / p99 latency per mode plus top-K Jaccard and `_score` Pearson for both LSM modes against each other and against the baseline. Output is JSON-pretty-printed to stdout and (optionally) written to `--output`. Bench is registered in Cargo.toml next to its sibling. --- rust/lance/Cargo.toml | 5 + .../benches/mem_wal/fts/lsm_fts_modes.rs | 939 ++++++++++++++++++ 2 files changed, 944 insertions(+) create mode 100644 rust/lance/benches/mem_wal/fts/lsm_fts_modes.rs diff --git a/rust/lance/Cargo.toml b/rust/lance/Cargo.toml index 2e16ef90979..816a1872685 100644 --- a/rust/lance/Cargo.toml +++ b/rust/lance/Cargo.toml @@ -258,6 +258,11 @@ name = "mem_wal_fts_bench" path = "benches/mem_wal/fts/mem_wal_fts_bench.rs" harness = false +[[bench]] +name = "lsm_fts_modes" +path = "benches/mem_wal/fts/lsm_fts_modes.rs" +harness = false + [[bench]] name = "manifest_commit" harness = false diff --git a/rust/lance/benches/mem_wal/fts/lsm_fts_modes.rs b/rust/lance/benches/mem_wal/fts/lsm_fts_modes.rs new file mode 100644 index 00000000000..5a688264a8a --- /dev/null +++ b/rust/lance/benches/mem_wal/fts/lsm_fts_modes.rs @@ -0,0 +1,939 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Benchmark comparing `LsmScanner::full_text_search` scoring modes on +//! the LSM hierarchy with multiple flushed generations on a real +//! FineWeb corpus. +//! +//! Sibling of `mem_wal_fineweb_fts.rs`. Shares its FineWeb loader +//! shape (HF `sample/10BT` parquet shards, `--cache-dir` to amortize +//! downloads). +//! +//! Per shape × scoring mode the bench reports: +//! +//! * Wall-clock per query and aggregate latency percentiles. +//! * Top-K Jaccard vs a single-merged-index ground truth (the same +//! FineWeb rows loaded into a single Lance dataset with one FTS +//! index, queried via `scanner.full_text_search`). +//! * Pearson correlation of `_score` between LSM mode and ground +//! truth on the intersection. +//! +//! Example: +//! +//! ```bash +//! cargo bench -p lance --bench lsm_fts_modes -- \ +//! --shape memwal_skewed --k 100 --num-queries 100 \ +//! --rescore-factor 10 \ +//! --cache-dir /tmp/fineweb-cache --output result.json +//! ``` + +#![recursion_limit = "256"] +#![allow(clippy::print_stdout, clippy::print_stderr)] + +use std::collections::{HashMap, HashSet}; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use arrow_array::{Array, Int64Array, RecordBatch, RecordBatchIterator, StringArray}; +use arrow_schema::{DataType, Field, Schema as ArrowSchema}; +use futures::TryStreamExt; +use lance::dataset::mem_wal::scanner::{ + FtsScoringMode, InMemoryMemTableRef, InMemoryMemTables, LsmScanner, +}; +use lance::dataset::mem_wal::write::{BatchStore, IndexStore}; +use lance::dataset::{Dataset, WriteParams}; +use lance::index::DatasetIndexExt; +use lance_core::Result; +use lance_index::IndexType; +use lance_index::scalar::FullTextSearchQuery; +use lance_index::scalar::inverted::tokenizer::InvertedIndexParams; +use lance_tokenizer::TokenStream; +use parquet::arrow::async_reader::ParquetRecordBatchStreamBuilder; +use serde_json::json; +use uuid::Uuid; + +const TEXT_COL: &str = "text"; +const FTS_INDEX_NAME: &str = "text_fts"; +const HF_API_LISTING: &str = + "https://huggingface.co/api/datasets/HuggingFaceFW/fineweb/tree/main/sample/10BT"; +const HF_FILE_BASE: &str = "https://huggingface.co/datasets/HuggingFaceFW/fineweb/resolve/main/"; + +// ---------------------------------------------------------------------- +// Shape +// ---------------------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Shape { + /// 4 equal-sized flushed gens + 1 equal-sized active. Cross-source + /// stats are similar so Local should already be close to Rescore. + Balanced, + /// 1 huge base + 4 tiny flushed gens + 1 tiny active. The case + /// where local-stats BM25 is most distorted vs a merged index. + MemwalSkewed, + /// Heterogeneous flushed sizes (1k+5k+25k+100k) + 25k active. + GrowingLsm, +} + +impl Shape { + fn parse(value: &str) -> std::result::Result { + match value { + "balanced" => Ok(Self::Balanced), + "memwal_skewed" => Ok(Self::MemwalSkewed), + "growing_lsm" => Ok(Self::GrowingLsm), + other => Err(format!( + "unknown shape '{other}', expected balanced|memwal_skewed|growing_lsm" + )), + } + } + + fn as_str(self) -> &'static str { + match self { + Self::Balanced => "balanced", + Self::MemwalSkewed => "memwal_skewed", + Self::GrowingLsm => "growing_lsm", + } + } + + /// (base_rows, vec_of_flushed_gen_rows, active_rows). `base_rows = None` + /// means no base table (fresh-tier-only). Designed so total + /// `corpus_rows` is roughly comparable across shapes for fair Jaccard + /// vs a single-merged baseline. + fn slicing(self) -> (Option, Vec, usize) { + match self { + Self::Balanced => (Some(100_000), vec![25_000; 4], 25_000), + Self::MemwalSkewed => (Some(1_000_000), vec![5_000; 4], 5_000), + Self::GrowingLsm => (Some(100_000), vec![1_000, 5_000, 25_000, 100_000], 25_000), + } + } + + fn total_rows(self) -> usize { + let (base, gens, active) = self.slicing(); + base.unwrap_or(0) + gens.iter().sum::() + active + } +} + +// ---------------------------------------------------------------------- +// Args +// ---------------------------------------------------------------------- + +#[derive(Debug, Clone)] +struct Args { + shape: Shape, + k: usize, + num_queries: usize, + rescore_factor: u32, + cache_dir: PathBuf, + work_dir: Option, + output: Option, + skip_baseline: bool, + tokio_threads: usize, + /// If `Some`, cap rows used per shape — useful for smoke testing + /// without downloading hundreds of MB. + max_corpus_rows: Option, +} + +impl Default for Args { + fn default() -> Self { + let threads = std::thread::available_parallelism().map_or(1, usize::from); + Self { + shape: Shape::Balanced, + k: 100, + num_queries: 100, + rescore_factor: 10, + cache_dir: std::env::temp_dir().join("mem_wal_fineweb_fts_cache"), + work_dir: None, + output: None, + skip_baseline: false, + tokio_threads: threads, + max_corpus_rows: None, + } + } +} + +fn parse(flag: &str, value: &str) -> Result +where + T: std::str::FromStr, + T::Err: std::fmt::Display, +{ + value + .parse::() + .map_err(|e| lance_core::Error::io(format!("flag {flag}: {e}"))) +} + +fn parse_args() -> Result { + let mut args = Args::default(); + let raw: Vec = std::env::args().skip(1).collect(); + let mut iter = raw.iter(); + while let Some(flag) = iter.next() { + match flag.as_str() { + "--shape" => { + args.shape = Shape::parse( + iter.next() + .ok_or_else(|| lance_core::Error::io("--shape needs value"))?, + ) + .map_err(lance_core::Error::io)? + } + "--k" => { + args.k = parse( + "--k", + iter.next() + .ok_or_else(|| lance_core::Error::io("--k needs value"))?, + )? + } + "--num-queries" => { + args.num_queries = parse( + "--num-queries", + iter.next() + .ok_or_else(|| lance_core::Error::io("--num-queries needs value"))?, + )? + } + "--rescore-factor" => { + args.rescore_factor = parse( + "--rescore-factor", + iter.next() + .ok_or_else(|| lance_core::Error::io("--rescore-factor needs value"))?, + )? + } + "--cache-dir" => { + args.cache_dir = PathBuf::from( + iter.next() + .ok_or_else(|| lance_core::Error::io("--cache-dir needs value"))?, + ) + } + "--work-dir" => { + args.work_dir = + Some(PathBuf::from(iter.next().ok_or_else(|| { + lance_core::Error::io("--work-dir needs value") + })?)) + } + "--output" => { + args.output = + Some(PathBuf::from(iter.next().ok_or_else(|| { + lance_core::Error::io("--output needs value") + })?)) + } + "--skip-baseline" => args.skip_baseline = true, + "--tokio-threads" => { + args.tokio_threads = parse( + "--tokio-threads", + iter.next() + .ok_or_else(|| lance_core::Error::io("--tokio-threads needs value"))?, + )? + } + "--max-corpus-rows" => { + args.max_corpus_rows = Some(parse( + "--max-corpus-rows", + iter.next() + .ok_or_else(|| lance_core::Error::io("--max-corpus-rows needs value"))?, + )?) + } + // criterion-style noise we want to ignore so `cargo bench` + // can hand us nothing extra without erroring. + "--bench" | "--test" => {} + other => { + eprintln!("unknown flag: {other}"); + return Err(lance_core::Error::io(format!("unknown flag {other}"))); + } + } + } + Ok(args) +} + +// ---------------------------------------------------------------------- +// FineWeb loading (mirrors mem_wal_fineweb_fts.rs) +// ---------------------------------------------------------------------- + +#[derive(serde::Deserialize)] +struct HfTreeEntry { + #[serde(rename = "type")] + kind: String, + path: String, +} + +async fn list_shard_paths() -> Result> { + let entries: Vec = reqwest::get(HF_API_LISTING) + .await + .map_err(|e| lance_core::Error::io(format!("listing HTTP: {e}")))? + .json() + .await + .map_err(|e| lance_core::Error::io(format!("listing JSON: {e}")))?; + let mut shards: Vec = entries + .into_iter() + .filter(|e| e.kind == "file" && e.path.ends_with(".parquet")) + .map(|e| e.path) + .collect(); + shards.sort(); + Ok(shards) +} + +async fn download_shard(rel_path: &str, dest: &std::path::Path) -> Result<()> { + if dest.exists() { + return Ok(()); + } + let url = format!("{HF_FILE_BASE}{rel_path}"); + let tmp = dest.with_extension("part"); + for attempt in 1..=5u32 { + println!("downloading {rel_path} (attempt {attempt}/5) ..."); + let result: Result = async { + let resp = reqwest::get(&url) + .await + .map_err(|e| lance_core::Error::io(format!("download HTTP: {e}")))?; + if !resp.status().is_success() { + return Err(lance_core::Error::io(format!( + "download {url} -> status {}", + resp.status() + ))); + } + resp.bytes() + .await + .map_err(|e| lance_core::Error::io(format!("read body: {e}"))) + } + .await; + match result { + Ok(bytes) => { + std::fs::write(&tmp, &bytes) + .map_err(|e| lance_core::Error::io(format!("write: {e}")))?; + std::fs::rename(&tmp, dest) + .map_err(|e| lance_core::Error::io(format!("rename: {e}")))?; + println!( + " wrote {:.1} MB to {}", + bytes.len() as f64 / 1024.0 / 1024.0, + dest.display() + ); + return Ok(()); + } + Err(e) if attempt < 5 => { + eprintln!(" attempt {attempt} failed: {e}; retrying"); + tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await; + } + Err(e) => return Err(e), + } + } + unreachable!() +} + +async fn read_shard_text( + path: &std::path::Path, + out: &mut Vec, + max_rows: usize, +) -> Result { + let file = tokio::fs::File::open(path) + .await + .map_err(|e| lance_core::Error::io(format!("open parquet: {e}")))?; + let builder = ParquetRecordBatchStreamBuilder::new(file) + .await + .map_err(|e| lance_core::Error::io(format!("parquet builder: {e}")))?; + let mut stream = builder + .build() + .map_err(|e| lance_core::Error::io(format!("parquet stream: {e}")))?; + let mut taken = 0usize; + while taken < max_rows { + let Some(rb) = stream + .try_next() + .await + .map_err(|e| lance_core::Error::io(format!("parquet read: {e}")))? + else { + break; + }; + let col = rb + .column_by_name(TEXT_COL) + .ok_or_else(|| lance_core::Error::io("text column missing".to_string()))?; + let strs = col + .as_any() + .downcast_ref::() + .ok_or_else(|| lance_core::Error::io("text not StringArray".to_string()))?; + for i in 0..strs.len() { + if taken >= max_rows { + break; + } + if strs.is_null(i) { + continue; + } + out.push(strs.value(i).to_string()); + taken += 1; + } + } + Ok(taken) +} + +async fn load_corpus(needed_rows: usize, cache_dir: &std::path::Path) -> Result> { + std::fs::create_dir_all(cache_dir) + .map_err(|e| lance_core::Error::io(format!("mkdir cache: {e}")))?; + let shards = list_shard_paths().await?; + println!("fineweb sample/10BT: {} shards", shards.len()); + let mut buf: Vec = Vec::with_capacity(needed_rows); + for rel in &shards { + if buf.len() >= needed_rows { + break; + } + let name = rel.rsplit('/').next().unwrap_or(rel); + let local = cache_dir.join(name); + download_shard(rel, &local).await?; + let want = needed_rows - buf.len(); + let got = read_shard_text(&local, &mut buf, want).await?; + println!(" shard {name} -> {got} rows (cumulative {})", buf.len()); + } + if buf.len() < needed_rows { + return Err(lance_core::Error::io(format!( + "fineweb yielded only {} rows, need {needed_rows}", + buf.len() + ))); + } + Ok(buf) +} + +// ---------------------------------------------------------------------- +// Dataset shaping +// ---------------------------------------------------------------------- + +fn make_schema() -> Arc { + let mut id_meta = HashMap::new(); + id_meta.insert( + "lance-schema:unenforced-primary-key".to_string(), + "true".to_string(), + ); + let id_field = Field::new("id", DataType::Int64, false).with_metadata(id_meta); + Arc::new(ArrowSchema::new(vec![ + id_field, + Field::new(TEXT_COL, DataType::Utf8, true), + ])) +} + +fn slice_to_batch(schema: Arc, start_id: i64, texts: &[String]) -> RecordBatch { + let ids = Int64Array::from_iter_values(start_id..start_id + texts.len() as i64); + let text = StringArray::from_iter_values(texts.iter().map(String::as_str)); + RecordBatch::try_new(schema, vec![Arc::new(ids), Arc::new(text)]).unwrap() +} + +async fn write_lance(uri: &str, batches: Vec) -> Result { + let schema = batches[0].schema(); + let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); + Dataset::write(reader, uri, Some(WriteParams::default())).await +} + +async fn create_fts_index(ds: &mut Dataset) -> Result<()> { + ds.create_index( + &[TEXT_COL], + IndexType::Inverted, + Some(FTS_INDEX_NAME.to_string()), + &InvertedIndexParams::default(), + false, + ) + .await?; + Ok(()) +} + +/// Build the LSM shape: writes one Lance dataset per flushed gen (with +/// FTS index), and constructs an active in-memory memtable from the +/// active slice. Returns the collector inputs ready for `LsmScanner`. +async fn build_lsm_shape( + shape: Shape, + corpus: &[String], + work_dir: &std::path::Path, +) -> Result<( + Option>, + Vec, + Uuid, + InMemoryMemTables, +)> { + let schema = make_schema(); + let (base_rows, gen_rows, active_rows) = shape.slicing(); + let total = base_rows.unwrap_or(0) + gen_rows.iter().sum::() + active_rows; + assert!( + corpus.len() >= total, + "shape needs {total} rows, corpus has {}", + corpus.len() + ); + + let mut cursor: usize = 0; + let mut id_cursor: i64 = 0; + let shard_id = Uuid::new_v4(); + + // Base. + let base = if let Some(n) = base_rows { + let uri = format!("{}/base", work_dir.display()); + let mut ds = write_lance( + &uri, + vec![slice_to_batch( + schema.clone(), + id_cursor, + &corpus[cursor..cursor + n], + )], + ) + .await?; + create_fts_index(&mut ds).await?; + let ds = Arc::new(Dataset::open(&uri).await?); + cursor += n; + id_cursor += n as i64; + Some(ds) + } else { + None + }; + + // Flushed generations. + let mut shard_snapshot = lance::dataset::mem_wal::scanner::ShardSnapshot::new(shard_id) + .with_current_generation((gen_rows.len() as u64).max(1)); + let base_uri = base + .as_ref() + .map(|d| d.uri().to_string()) + .unwrap_or_else(|| format!("{}/base", work_dir.display())); + for (i, &n) in gen_rows.iter().enumerate() { + let gen_num = (i + 1) as u64; + let rel = format!("gen_{gen_num}"); + let uri = format!("{base_uri}/_mem_wal/{shard_id}/{rel}"); + let mut ds = write_lance( + &uri, + vec![slice_to_batch( + schema.clone(), + id_cursor, + &corpus[cursor..cursor + n], + )], + ) + .await?; + create_fts_index(&mut ds).await?; + cursor += n; + id_cursor += n as i64; + shard_snapshot = shard_snapshot.with_flushed_generation(gen_num, rel); + } + + // Active memtable. + let batch_store = Arc::new(BatchStore::with_capacity(active_rows.max(16))); + let mut indexes = IndexStore::new(); + indexes.add_fts(FTS_INDEX_NAME.to_string(), 1, TEXT_COL.to_string()); + let active_batch = slice_to_batch( + schema.clone(), + id_cursor, + &corpus[cursor..cursor + active_rows], + ); + batch_store.append(active_batch.clone()).unwrap(); + indexes + .insert_with_batch_position(&active_batch, 0, Some(0)) + .unwrap(); + let indexes = Arc::new(indexes); + + let in_memory = InMemoryMemTables { + active: InMemoryMemTableRef { + batch_store, + index_store: indexes, + schema, + generation: (gen_rows.len() as u64) + 1, + }, + frozen: vec![], + }; + + Ok((base, vec![shard_snapshot], shard_id, in_memory)) +} + +/// Build a single Lance dataset containing the full corpus + one FTS +/// index. Used as the ground-truth reference for Jaccard / score Pearson. +async fn build_baseline(corpus: &[String], work_dir: &std::path::Path) -> Result> { + let schema = make_schema(); + let uri = format!("{}/baseline_merged", work_dir.display()); + let batch = slice_to_batch(schema, 0, corpus); + let mut ds = write_lance(&uri, vec![batch]).await?; + create_fts_index(&mut ds).await?; + Ok(Arc::new(Dataset::open(&uri).await?)) +} + +// ---------------------------------------------------------------------- +// Query selection +// ---------------------------------------------------------------------- + +/// Pick `n` representative single-term queries from the corpus. +/// +/// Tokenizes a sample of the corpus with the default English tokenizer, +/// counts term frequencies, and returns terms in the "long tail" — not +/// the absolute most frequent (those match nearly every doc and don't +/// produce interesting BM25 rankings) and not the rarest (those match +/// nothing and produce empty top-K). Window roughly between the 80th +/// and 99th percentile of df. +fn pick_queries(corpus: &[String], n: usize) -> Vec { + let sample_n = corpus.len().min(2_000); + let mut tokenizer = InvertedIndexParams::default().build().expect("tokenizer"); + let mut df: HashMap = HashMap::new(); + for text in corpus.iter().take(sample_n) { + let mut stream = tokenizer.token_stream_for_doc(text); + let mut seen: HashSet = HashSet::new(); + while let Some(tok) = stream.next() { + if seen.insert(tok.text.clone()) { + *df.entry(tok.text.clone()).or_insert(0) += 1; + } + } + } + let mut all: Vec<(String, usize)> = df.into_iter().collect(); + all.sort_by_key(|(_, c)| *c); + // Pull from the 80th–99th percentile window. + let lo = (all.len() as f64 * 0.80) as usize; + let hi = (all.len() as f64 * 0.99) as usize; + let window: &[(String, usize)] = &all[lo.min(all.len())..hi.min(all.len())]; + if window.is_empty() { + return Vec::new(); + } + let stride = (window.len() / n.max(1)).max(1); + let mut out: Vec = Vec::with_capacity(n); + for (i, (term, _)) in window.iter().enumerate() { + if out.len() >= n { + break; + } + if i % stride == 0 { + out.push(term.clone()); + } + } + out +} + +// ---------------------------------------------------------------------- +// Mode runner +// ---------------------------------------------------------------------- + +#[derive(Debug)] +struct ModeRun { + /// Per-query top-k row id sets. + top_ids: Vec>, + /// Per-query (id → score) maps, for Pearson on the intersection. + scored: Vec>, + latencies_us: Vec, +} + +async fn run_mode( + scanner: &LsmScanner, + mode: FtsScoringMode, + queries: &[String], + k: usize, +) -> Result { + let mut top_ids = Vec::with_capacity(queries.len()); + let mut scored = Vec::with_capacity(queries.len()); + let mut latencies_us = Vec::with_capacity(queries.len()); + for q in queries { + let t = Instant::now(); + let plan = scanner + .full_text_search(TEXT_COL, FullTextSearchQuery::new(q.clone()), k, mode) + .await?; + let ctx = datafusion::prelude::SessionContext::new(); + let stream = plan + .execute(0, ctx.task_ctx()) + .map_err(|e| lance_core::Error::io(format!("plan execute for query '{q}': {e}")))?; + let batches: Vec = stream + .try_collect() + .await + .map_err(|e| lance_core::Error::io(format!("collect for query '{q}': {e}")))?; + latencies_us.push(t.elapsed().as_micros() as u64); + + let mut ids: HashSet = HashSet::new(); + let mut score_map: HashMap = HashMap::new(); + for b in &batches { + let id_col = b + .column_by_name("id") + .expect("id col") + .as_any() + .downcast_ref::() + .expect("id Int64"); + let score_col = b + .column_by_name("_score") + .expect("_score col") + .as_any() + .downcast_ref::() + .expect("_score Float32"); + for i in 0..b.num_rows() { + let id = id_col.value(i); + ids.insert(id); + score_map.insert(id, score_col.value(i)); + } + } + top_ids.push(ids); + scored.push(score_map); + } + Ok(ModeRun { + top_ids, + scored, + latencies_us, + }) +} + +/// Run the merged-index baseline (single Lance dataset). We reuse the +/// same `scanner.full_text_search` API on the merged dataset so the +/// score scale is comparable to the LSM result (`scanner` already uses +/// `build_global_bm25_scorer` for its internal multi-partition case). +async fn run_baseline(baseline: &Dataset, queries: &[String], k: usize) -> Result { + let mut top_ids = Vec::with_capacity(queries.len()); + let mut scored = Vec::with_capacity(queries.len()); + let mut latencies_us = Vec::with_capacity(queries.len()); + for q in queries { + let t = Instant::now(); + let mut scanner = baseline.scan(); + scanner.project(&["id", TEXT_COL])?; + scanner.full_text_search( + FullTextSearchQuery::new(q.clone()) + .with_column(TEXT_COL.to_string())? + .limit(Some(k as i64)), + )?; + let batches: Vec = scanner.try_into_stream().await?.try_collect().await?; + latencies_us.push(t.elapsed().as_micros() as u64); + + let mut ids: HashSet = HashSet::new(); + let mut score_map: HashMap = HashMap::new(); + for b in &batches { + let id_col = b + .column_by_name("id") + .expect("id col") + .as_any() + .downcast_ref::() + .expect("id Int64"); + let score_col = b + .column_by_name("_score") + .expect("_score col") + .as_any() + .downcast_ref::() + .expect("_score Float32"); + for i in 0..b.num_rows() { + let id = id_col.value(i); + ids.insert(id); + score_map.insert(id, score_col.value(i)); + } + } + top_ids.push(ids); + scored.push(score_map); + } + Ok(ModeRun { + top_ids, + scored, + latencies_us, + }) +} + +// ---------------------------------------------------------------------- +// Metrics +// ---------------------------------------------------------------------- + +fn percentile(values: &[f64], pct: f64) -> f64 { + if values.is_empty() { + return 0.0; + } + let mut sorted = values.to_vec(); + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let rank = (pct / 100.0) * (sorted.len() - 1) as f64; + let lo = rank.floor() as usize; + let hi = rank.ceil() as usize; + if lo == hi { + sorted[lo] + } else { + let frac = rank - lo as f64; + sorted[lo] * (1.0 - frac) + sorted[hi] * frac + } +} + +fn mean_jaccard(a: &[HashSet], b: &[HashSet]) -> f64 { + let pairs: Vec = a + .iter() + .zip(b.iter()) + .filter_map(|(x, y)| { + if x.is_empty() && y.is_empty() { + None + } else { + let inter = x.intersection(y).count() as f64; + let union = x.union(y).count() as f64; + Some(inter / union) + } + }) + .collect(); + if pairs.is_empty() { + 0.0 + } else { + pairs.iter().sum::() / pairs.len() as f64 + } +} + +/// Pearson correlation of scores on the intersection. Averaged over +/// queries that have at least 2 overlapping ids (Pearson is undefined +/// for fewer). +fn mean_pearson(a: &[HashMap], b: &[HashMap]) -> f64 { + let pairs: Vec = a + .iter() + .zip(b.iter()) + .filter_map(|(x, y)| { + let common: Vec = x.keys().filter(|k| y.contains_key(k)).copied().collect(); + if common.len() < 2 { + return None; + } + let xs: Vec = common.iter().map(|i| x[i] as f64).collect(); + let ys: Vec = common.iter().map(|i| y[i] as f64).collect(); + let mx = xs.iter().sum::() / xs.len() as f64; + let my = ys.iter().sum::() / ys.len() as f64; + let num: f64 = xs + .iter() + .zip(ys.iter()) + .map(|(a, b)| (a - mx) * (b - my)) + .sum(); + let dx: f64 = xs.iter().map(|a| (a - mx).powi(2)).sum::().sqrt(); + let dy: f64 = ys.iter().map(|b| (b - my).powi(2)).sum::().sqrt(); + if dx == 0.0 || dy == 0.0 { + None + } else { + Some(num / (dx * dy)) + } + }) + .collect(); + if pairs.is_empty() { + 0.0 + } else { + pairs.iter().sum::() / pairs.len() as f64 + } +} + +// ---------------------------------------------------------------------- +// Run +// ---------------------------------------------------------------------- + +async fn run(args: Args) -> Result<()> { + let needed = args + .max_corpus_rows + .unwrap_or_else(|| args.shape.total_rows()); + println!( + "shape={} needed_rows={} k={} num_queries={} rescore_factor={}", + args.shape.as_str(), + needed, + args.k, + args.num_queries, + args.rescore_factor + ); + + let corpus = load_corpus(needed, &args.cache_dir).await?; + let queries = pick_queries(&corpus, args.num_queries); + println!("picked {} query terms", queries.len()); + + let work_dir = if let Some(d) = &args.work_dir { + std::fs::create_dir_all(d).map_err(|e| lance_core::Error::io(format!("mkdir: {e}")))?; + d.clone() + } else { + tempfile::tempdir() + .map_err(|e| lance_core::Error::io(format!("tempdir: {e}")))? + .keep() + }; + + let (base, shard_snapshots, shard_id, in_memory) = + build_lsm_shape(args.shape, &corpus, &work_dir).await?; + + let pk_columns = vec!["id".to_string()]; + let scanner = if let Some(b) = base.clone() { + LsmScanner::new(b, shard_snapshots.clone(), pk_columns.clone()) + } else { + let schema = make_schema(); + LsmScanner::without_base_table( + schema, + format!("{}/base", work_dir.display()), + shard_snapshots.clone(), + pk_columns.clone(), + ) + } + .with_in_memory_memtables(shard_id, in_memory); + + println!("running Local mode ..."); + let local = run_mode(&scanner, FtsScoringMode::Local, &queries, args.k).await?; + println!("running LocalWithGlobalRescore mode ..."); + let rescore = run_mode( + &scanner, + FtsScoringMode::LocalWithGlobalRescore { + rescore_factor: args.rescore_factor, + }, + &queries, + args.k, + ) + .await?; + + let baseline_run = if args.skip_baseline { + None + } else { + println!("building merged-index baseline ..."); + let baseline = build_baseline(&corpus, &work_dir).await?; + Some(run_baseline(&baseline, &queries, args.k).await?) + }; + + // Aggregate metrics. + let lat = |run: &ModeRun| -> (f64, f64, f64, f64) { + let mut v: Vec = run + .latencies_us + .iter() + .map(|x| *x as f64 / 1000.0) + .collect(); + v.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let mean = v.iter().sum::() / v.len() as f64; + ( + mean, + percentile(&v, 50.0), + percentile(&v, 95.0), + percentile(&v, 99.0), + ) + }; + let (local_mean, local_p50, local_p95, local_p99) = lat(&local); + let (rescore_mean, rescore_p50, rescore_p95, rescore_p99) = lat(&rescore); + + let jaccard_local_rescore = mean_jaccard(&local.top_ids, &rescore.top_ids); + let pearson_local_rescore = mean_pearson(&local.scored, &rescore.scored); + let (jaccard_local_baseline, pearson_local_baseline) = if let Some(b) = &baseline_run { + ( + mean_jaccard(&local.top_ids, &b.top_ids), + mean_pearson(&local.scored, &b.scored), + ) + } else { + (f64::NAN, f64::NAN) + }; + let (jaccard_rescore_baseline, pearson_rescore_baseline) = if let Some(b) = &baseline_run { + ( + mean_jaccard(&rescore.top_ids, &b.top_ids), + mean_pearson(&rescore.scored, &b.scored), + ) + } else { + (f64::NAN, f64::NAN) + }; + + let summary = json!({ + "shape": args.shape.as_str(), + "k": args.k, + "num_queries": queries.len(), + "rescore_factor": args.rescore_factor, + "local": { + "mean_ms": local_mean, + "p50_ms": local_p50, + "p95_ms": local_p95, + "p99_ms": local_p99, + }, + "rescore": { + "mean_ms": rescore_mean, + "p50_ms": rescore_p50, + "p95_ms": rescore_p95, + "p99_ms": rescore_p99, + }, + "jaccard": { + "local_vs_rescore": jaccard_local_rescore, + "local_vs_baseline": jaccard_local_baseline, + "rescore_vs_baseline": jaccard_rescore_baseline, + }, + "pearson_score": { + "local_vs_rescore": pearson_local_rescore, + "local_vs_baseline": pearson_local_baseline, + "rescore_vs_baseline": pearson_rescore_baseline, + }, + }); + + println!( + "\n=== Result ===\n{}", + serde_json::to_string_pretty(&summary).unwrap() + ); + if let Some(path) = &args.output { + std::fs::write(path, serde_json::to_string_pretty(&summary).unwrap()) + .map_err(|e| lance_core::Error::io(format!("write output: {e}")))?; + println!("\nwrote {}", path.display()); + } + Ok(()) +} + +fn main() -> Result<()> { + let args = parse_args()?; + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .worker_threads(args.tokio_threads) + .build() + .map_err(|e| lance_core::Error::io(format!("tokio: {e}")))?; + rt.block_on(run(args)) +} From 757d8393982c6d2f34c51430ea64364f98666723 Mon Sep 17 00:00:00 2001 From: Heng Ge Date: Wed, 20 May 2026 22:14:59 -0700 Subject: [PATCH 04/10] bench(mem_wal): add mem_wal_fts_read_bench + storage sweep driver New CLI bench modeled on the vector / point-lookup read benches: `--phase prepare|search`, `--uri` with local/cloud detection, real FineWeb text payload, ShardWriter ingestion of flushed generations plus an active memtable, and the same JSON output contract. The scoring mode is the panel: each search invocation times the query set under both FtsScoringMode::Local and LocalWithGlobalRescore and reports per-mode p50/p95/p99/mean/qps plus the top-k Jaccard between the two modes. `run_fts_read_sweep.sh` drives the panel across local NVMe and an s3:// prefix for a configurable base-size / top-k matrix, mirrors each result.json to S3, and prints a summary table. Both registered in Cargo.toml next to the existing FineWeb FTS benches. --- rust/lance/Cargo.toml | 5 + .../mem_wal/fts/mem_wal_fts_read_bench.rs | 786 ++++++++++++++++++ .../benches/mem_wal/fts/run_fts_read_sweep.sh | 150 ++++ 3 files changed, 941 insertions(+) create mode 100644 rust/lance/benches/mem_wal/fts/mem_wal_fts_read_bench.rs create mode 100755 rust/lance/benches/mem_wal/fts/run_fts_read_sweep.sh diff --git a/rust/lance/Cargo.toml b/rust/lance/Cargo.toml index 816a1872685..2db049944aa 100644 --- a/rust/lance/Cargo.toml +++ b/rust/lance/Cargo.toml @@ -263,6 +263,11 @@ name = "lsm_fts_modes" path = "benches/mem_wal/fts/lsm_fts_modes.rs" harness = false +[[bench]] +name = "mem_wal_fts_read_bench" +path = "benches/mem_wal/fts/mem_wal_fts_read_bench.rs" +harness = false + [[bench]] name = "manifest_commit" harness = false diff --git a/rust/lance/benches/mem_wal/fts/mem_wal_fts_read_bench.rs b/rust/lance/benches/mem_wal/fts/mem_wal_fts_read_bench.rs new file mode 100644 index 00000000000..1e4fc6459b5 --- /dev/null +++ b/rust/lance/benches/mem_wal/fts/mem_wal_fts_read_bench.rs @@ -0,0 +1,786 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Standalone CLI benchmark for FTS read across LSM levels. +//! +//! Sibling of `mem_wal_vector_bench.rs` / `mem_wal_point_lookup_bench.rs`: +//! same `--phase prepare|search` shape, same `ShardWriter`-based ingestion +//! of flushed generations + an active memtable, same `--uri` cloud/local +//! detection, and the same JSON output contract. The payload is real +//! HuggingFace FineWeb `text` and the query path is +//! [`LsmFtsSearchPlanner`] over the base table + flushed generations + +//! active memtable. +//! +//! The "panel" for FTS is the scoring mode: each invocation runs the same +//! query set under both [`FtsScoringMode::Local`] and +//! [`FtsScoringMode::LocalWithGlobalRescore`] and reports per-mode latency +//! plus the top-k Jaccard between the two (how much rescoring moves the +//! ranking). +//! +//! Two phases, selected with `--phase`: +//! +//! --phase prepare Load FineWeb text, write the base dataset, create an +//! inverted (FTS) index, and initialize MemWAL with the +//! index maintained. +//! --phase search Ingest rows across LSM levels via ShardWriter, then run +//! the FTS query panel under both scoring modes. +//! +//! Example: +//! +//! ```bash +//! cargo bench -p lance --bench mem_wal_fts_read_bench -- \ +//! --phase prepare --uri /tmp/fts_read_bench \ +//! --base-rows 1000000 --cache-dir /tmp/fineweb-cache +//! +//! cargo bench -p lance --bench mem_wal_fts_read_bench -- \ +//! --phase search --uri /tmp/fts_read_bench \ +//! --base-rows 1000000 --max-memtable-rows 100000 \ +//! --queries 200 --k 10 --rescore-factor 10 \ +//! --cache-dir /tmp/fineweb-cache --output result.json +//! ``` + +#![recursion_limit = "256"] +#![allow(clippy::print_stdout, clippy::print_stderr)] + +use std::collections::{HashMap, HashSet}; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use arrow_array::{Array, Int64Array, RecordBatch, RecordBatchIterator, StringArray}; +use arrow_schema::{DataType, Field, Schema as ArrowSchema}; +use datafusion::prelude::SessionContext; +use futures::TryStreamExt; +use lance::dataset::mem_wal::scanner::{ + FtsScoringMode, LsmDataSourceCollector, LsmFtsSearchPlanner, ShardSnapshot, +}; +use lance::dataset::mem_wal::{DatasetMemWalExt, ShardWriterConfig}; +use lance::dataset::{Dataset, WriteParams}; +use lance::index::DatasetIndexExt; +use lance_core::Result; +use lance_index::IndexType; +use lance_index::scalar::FullTextSearchQuery; +use lance_index::scalar::inverted::tokenizer::InvertedIndexParams; +use lance_tokenizer::TokenStream; +use parquet::arrow::async_reader::ParquetRecordBatchStreamBuilder; +use serde_json::json; +use uuid::Uuid; + +const TEXT_COL: &str = "text"; +const FTS_INDEX_NAME: &str = "text_fts"; +const HF_API_LISTING: &str = + "https://huggingface.co/api/datasets/HuggingFaceFW/fineweb/tree/main/sample/10BT"; +const HF_FILE_BASE: &str = "https://huggingface.co/datasets/HuggingFaceFW/fineweb/resolve/main/"; + +// ---------------------------------------------------------------------- +// Phase / Args +// ---------------------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Phase { + Prepare, + Search, +} + +impl Phase { + fn parse(value: &str) -> std::result::Result { + match value { + "prepare" => Ok(Self::Prepare), + "search" => Ok(Self::Search), + _ => Err(format!("unknown phase '{value}', expected prepare|search")), + } + } + + fn as_str(self) -> &'static str { + match self { + Self::Prepare => "prepare", + Self::Search => "search", + } + } +} + +#[derive(Debug, Clone)] +struct Args { + phase: Phase, + uri: String, + base_rows: usize, + max_memtable_rows: usize, + flushed_generations: usize, + batch_rows: usize, + queries: usize, + k: usize, + rescore_factor: u32, + cache_dir: PathBuf, + output: Option, +} + +impl Default for Args { + fn default() -> Self { + Self { + phase: Phase::Search, + uri: String::new(), + base_rows: 1_000_000, + max_memtable_rows: 100_000, + flushed_generations: 2, + batch_rows: 1_000, + queries: 200, + k: 10, + rescore_factor: 10, + cache_dir: std::env::temp_dir().join("mem_wal_fineweb_fts_cache"), + output: None, + } + } +} + +fn parse_val(flag: &str, value: &str) -> Result +where + T: std::str::FromStr, + T::Err: std::fmt::Display, +{ + value + .parse() + .map_err(|e| lance_core::Error::invalid_input(format!("invalid {flag}: {value} ({e})"))) +} + +fn parse_args() -> Result { + let mut args = Args::default(); + let mut iter = std::env::args().skip(1); + let mut has_phase = false; + let mut has_uri = false; + while let Some(flag) = iter.next() { + if flag == "--bench" { + continue; + } + let value = iter + .next() + .ok_or_else(|| lance_core::Error::invalid_input(format!("missing value for {flag}")))?; + match flag.as_str() { + "--phase" => { + args.phase = Phase::parse(&value).map_err(lance_core::Error::invalid_input)?; + has_phase = true; + } + "--uri" => { + args.uri = value; + has_uri = true; + } + "--base-rows" => args.base_rows = parse_val(&flag, &value)?, + "--max-memtable-rows" => args.max_memtable_rows = parse_val(&flag, &value)?, + "--flushed-generations" => args.flushed_generations = parse_val(&flag, &value)?, + "--batch-rows" => args.batch_rows = parse_val(&flag, &value)?, + "--queries" => args.queries = parse_val(&flag, &value)?, + "--k" => args.k = parse_val(&flag, &value)?, + "--rescore-factor" => args.rescore_factor = parse_val(&flag, &value)?, + "--cache-dir" => args.cache_dir = PathBuf::from(value), + "--output" => args.output = Some(PathBuf::from(value)), + _ => { + return Err(lance_core::Error::invalid_input(format!( + "unknown argument: {flag}" + ))); + } + } + } + if !has_phase { + return Err(lance_core::Error::invalid_input( + "--phase is required (prepare|search)", + )); + } + if !has_uri { + return Err(lance_core::Error::invalid_input("--uri is required")); + } + if args.batch_rows == 0 || args.base_rows == 0 || args.max_memtable_rows == 0 { + return Err(lance_core::Error::invalid_input( + "base-rows, max-memtable-rows, batch-rows must be > 0", + )); + } + Ok(args) +} + +fn is_cloud_uri(uri: &str) -> bool { + uri.starts_with("s3://") || uri.starts_with("gs://") || uri.starts_with("az://") +} + +// ---------------------------------------------------------------------- +// FineWeb loading (mirrors mem_wal_fineweb_fts.rs) +// ---------------------------------------------------------------------- + +#[derive(serde::Deserialize)] +struct HfTreeEntry { + #[serde(rename = "type")] + kind: String, + path: String, +} + +async fn list_shard_paths() -> Result> { + let entries: Vec = reqwest::get(HF_API_LISTING) + .await + .map_err(|e| lance_core::Error::io(format!("listing HTTP: {e}")))? + .json() + .await + .map_err(|e| lance_core::Error::io(format!("listing JSON: {e}")))?; + let mut shards: Vec = entries + .into_iter() + .filter(|e| e.kind == "file" && e.path.ends_with(".parquet")) + .map(|e| e.path) + .collect(); + shards.sort(); + Ok(shards) +} + +async fn download_shard(rel_path: &str, dest: &std::path::Path) -> Result<()> { + if dest.exists() { + return Ok(()); + } + let url = format!("{HF_FILE_BASE}{rel_path}"); + let tmp = dest.with_extension("part"); + for attempt in 1..=5u32 { + println!("downloading {rel_path} (attempt {attempt}/5) ..."); + let result: Result = async { + let resp = reqwest::get(&url) + .await + .map_err(|e| lance_core::Error::io(format!("download HTTP: {e}")))?; + if !resp.status().is_success() { + return Err(lance_core::Error::io(format!( + "download {url} -> status {}", + resp.status() + ))); + } + resp.bytes() + .await + .map_err(|e| lance_core::Error::io(format!("read body: {e}"))) + } + .await; + match result { + Ok(bytes) => { + std::fs::write(&tmp, &bytes) + .map_err(|e| lance_core::Error::io(format!("write: {e}")))?; + std::fs::rename(&tmp, dest) + .map_err(|e| lance_core::Error::io(format!("rename: {e}")))?; + return Ok(()); + } + Err(e) if attempt < 5 => { + eprintln!(" attempt {attempt} failed: {e}; retrying"); + tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await; + } + Err(e) => return Err(e), + } + } + unreachable!() +} + +async fn read_shard_text( + path: &std::path::Path, + out: &mut Vec, + max_rows: usize, +) -> Result { + let file = tokio::fs::File::open(path) + .await + .map_err(|e| lance_core::Error::io(format!("open parquet: {e}")))?; + let builder = ParquetRecordBatchStreamBuilder::new(file) + .await + .map_err(|e| lance_core::Error::io(format!("parquet builder: {e}")))?; + let mut stream = builder + .build() + .map_err(|e| lance_core::Error::io(format!("parquet stream: {e}")))?; + let mut taken = 0usize; + while taken < max_rows { + let Some(rb) = stream + .try_next() + .await + .map_err(|e| lance_core::Error::io(format!("parquet read: {e}")))? + else { + break; + }; + let col = rb + .column_by_name(TEXT_COL) + .ok_or_else(|| lance_core::Error::io("text column missing".to_string()))?; + let strs = col + .as_any() + .downcast_ref::() + .ok_or_else(|| lance_core::Error::io("text not StringArray".to_string()))?; + for i in 0..strs.len() { + if taken >= max_rows { + break; + } + if strs.is_null(i) { + continue; + } + out.push(strs.value(i).to_string()); + taken += 1; + } + } + Ok(taken) +} + +async fn load_corpus(needed_rows: usize, cache_dir: &std::path::Path) -> Result> { + std::fs::create_dir_all(cache_dir) + .map_err(|e| lance_core::Error::io(format!("mkdir cache: {e}")))?; + let shards = list_shard_paths().await?; + println!("fineweb sample/10BT: {} shards", shards.len()); + let mut buf: Vec = Vec::with_capacity(needed_rows); + for rel in &shards { + if buf.len() >= needed_rows { + break; + } + let name = rel.rsplit('/').next().unwrap_or(rel); + let local = cache_dir.join(name); + download_shard(rel, &local).await?; + let want = needed_rows - buf.len(); + let got = read_shard_text(&local, &mut buf, want).await?; + println!(" shard {name} -> {got} rows (cumulative {})", buf.len()); + } + if buf.len() < needed_rows { + return Err(lance_core::Error::io(format!( + "fineweb yielded only {} rows, need {needed_rows}", + buf.len() + ))); + } + Ok(buf) +} + +// ---------------------------------------------------------------------- +// Schema / batch helpers +// ---------------------------------------------------------------------- + +fn make_schema() -> Arc { + let mut id_meta = HashMap::new(); + id_meta.insert( + "lance-schema:unenforced-primary-key".to_string(), + "true".to_string(), + ); + Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int64, false).with_metadata(id_meta), + Field::new(TEXT_COL, DataType::Utf8, true), + ])) +} + +fn make_batch(schema: Arc, start_id: i64, texts: &[String]) -> RecordBatch { + let ids = Int64Array::from_iter_values(start_id..start_id + texts.len() as i64); + let text = StringArray::from_iter_values(texts.iter().map(String::as_str)); + RecordBatch::try_new(schema, vec![Arc::new(ids), Arc::new(text)]).unwrap() +} + +// ---------------------------------------------------------------------- +// Latency stats +// ---------------------------------------------------------------------- + +fn percentile(sorted: &[f64], pct: f64) -> f64 { + if sorted.is_empty() { + return f64::NAN; + } + let idx = ((pct / 100.0) * (sorted.len().saturating_sub(1)) as f64).round() as usize; + sorted[idx.min(sorted.len() - 1)] +} + +struct LatencyStats { + p50_us: u64, + p95_us: u64, + p99_us: u64, + mean_us: f64, + qps: f64, +} + +fn compute_stats(mut latencies_us: Vec) -> LatencyStats { + latencies_us.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + let mean = latencies_us.iter().sum::() / latencies_us.len().max(1) as f64; + let total_s = latencies_us.iter().sum::() / 1_000_000.0; + let qps = if total_s > 0.0 { + latencies_us.len() as f64 / total_s + } else { + 0.0 + }; + LatencyStats { + p50_us: percentile(&latencies_us, 50.0) as u64, + p95_us: percentile(&latencies_us, 95.0) as u64, + p99_us: percentile(&latencies_us, 99.0) as u64, + mean_us: mean, + qps, + } +} + +// ---------------------------------------------------------------------- +// Query set: mid-frequency single terms from the corpus +// ---------------------------------------------------------------------- + +const STOPWORDS: &[&str] = &[ + "the", "a", "an", "and", "or", "of", "to", "in", "on", "for", "with", "as", "by", "is", "was", + "are", "were", "be", "been", "being", "this", "that", "these", "those", "it", "its", "but", + "not", "no", "if", "then", "than", "so", "do", "does", "did", "have", "has", "had", "will", + "would", "should", "could", "can", "may", "might", "must", "i", "you", "he", "she", "we", + "they", "them", "his", "her", "their", "our", "us", "me", "my", "your", "him", "at", "from", +]; + +fn build_query_terms(sample: &[String], n: usize) -> Vec { + let mut tokenizer = InvertedIndexParams::default() + .build() + .expect("default tokenizer builds"); + let mut freq: HashMap = HashMap::new(); + for t in sample.iter().take(50_000) { + let mut stream = tokenizer.token_stream_for_doc(t); + while let Some(tok) = stream.next() { + if tok.text.len() < 3 || tok.text.len() > 24 || STOPWORDS.contains(&tok.text.as_str()) { + continue; + } + *freq.entry(tok.text.clone()).or_default() += 1; + } + } + let mut by_freq: Vec<(String, u64)> = freq.into_iter().collect(); + by_freq.sort_by(|a, b| b.1.cmp(&a.1).then(a.0.cmp(&b.0))); + // Skip the most-frequent tokens (near-ties in BM25), keep mid-frequency. + let skip = (by_freq.len() / 4).min(300); + by_freq + .into_iter() + .skip(skip) + .map(|(t, _)| t) + .take(n) + .collect() +} + +// ---------------------------------------------------------------------- +// Prepare phase +// ---------------------------------------------------------------------- + +async fn run_prepare(args: &Args) -> Result<()> { + let start = Instant::now(); + let corpus = load_corpus(args.base_rows, &args.cache_dir).await?; + let schema = make_schema(); + + let total_batches = corpus.len().div_ceil(args.batch_rows); + let mut batches = Vec::with_capacity(total_batches); + let mut lo = 0usize; + while lo < corpus.len() { + let hi = (lo + args.batch_rows).min(corpus.len()); + batches.push(Ok(make_batch(schema.clone(), lo as i64, &corpus[lo..hi]))); + lo = hi; + } + let reader = RecordBatchIterator::new(batches.into_iter(), schema.clone()); + let write_start = Instant::now(); + let mut dataset = Dataset::write(reader, &args.uri, Some(WriteParams::default())).await?; + println!( + "wrote {} base rows in {:.1}s", + args.base_rows, + write_start.elapsed().as_secs_f64() + ); + + let index_start = Instant::now(); + dataset + .create_index( + &[TEXT_COL], + IndexType::Inverted, + Some(FTS_INDEX_NAME.to_string()), + &InvertedIndexParams::default(), + true, + ) + .await?; + println!( + "created FTS index in {:.1}s", + index_start.elapsed().as_secs_f64() + ); + + dataset + .initialize_mem_wal() + .maintained_indexes([FTS_INDEX_NAME]) + .execute() + .await?; + println!( + "prepare complete in {:.1}s: uri={}", + start.elapsed().as_secs_f64(), + args.uri + ); + Ok(()) +} + +// ---------------------------------------------------------------------- +// Search phase +// ---------------------------------------------------------------------- + +/// Result row ids per query for one mode + the latency distribution. +struct ModeRun { + top_ids: Vec>, + latencies_us: Vec, +} + +async fn run_mode( + planner: &LsmFtsSearchPlanner, + mode: FtsScoringMode, + queries: &[String], + k: usize, +) -> Result { + let ctx = SessionContext::new(); + let mut top_ids = Vec::with_capacity(queries.len()); + let mut latencies_us = Vec::with_capacity(queries.len()); + for q in queries { + let t0 = Instant::now(); + let plan = planner + .plan_search(TEXT_COL, FullTextSearchQuery::new(q.clone()), k, None, mode) + .await?; + let stream = plan.execute(0, ctx.task_ctx())?; + let batches: Vec = stream.try_collect().await?; + latencies_us.push(t0.elapsed().as_micros() as f64); + + let mut ids: HashSet = HashSet::new(); + for b in &batches { + if let Some(col) = b + .column_by_name("id") + .and_then(|c| c.as_any().downcast_ref::()) + { + for i in 0..col.len() { + ids.insert(col.value(i)); + } + } + } + top_ids.push(ids); + } + Ok(ModeRun { + top_ids, + latencies_us, + }) +} + +fn mean_jaccard(a: &[HashSet], b: &[HashSet]) -> f64 { + let pairs: Vec = a + .iter() + .zip(b.iter()) + .filter_map(|(x, y)| { + if x.is_empty() && y.is_empty() { + None + } else { + let inter = x.intersection(y).count() as f64; + let union = x.union(y).count() as f64; + Some(inter / union) + } + }) + .collect(); + if pairs.is_empty() { + 0.0 + } else { + pairs.iter().sum::() / pairs.len() as f64 + } +} + +async fn run_search(args: &Args) -> Result { + let dataset = Arc::new(Dataset::open(&args.uri).await?); + let arrow_schema: Arc = Arc::new(ArrowSchema::from(dataset.schema())); + let schema = make_schema(); + + // Memtable text is drawn from FineWeb; row *ids* are assigned past the + // base slice (via `id_base`) so they don't collide with base-table ids, + // but the *text content* can reuse FineWeb rows freely — BM25 latency + // doesn't depend on content novelty. Load just enough rows once, + // covering both the memtable payload and the query-term sample, instead + // of re-reading the whole base corpus from parquet. + let active_rows = args.max_memtable_rows / 2; + let total_memtable_rows = args.flushed_generations * args.max_memtable_rows + active_rows; + let sample_rows = args.base_rows.min(50_000); + let load_rows = total_memtable_rows.max(sample_rows); + println!("loading {load_rows} FineWeb rows for memtable payload + query sample ..."); + let mt_corpus = load_corpus(load_rows, &args.cache_dir).await?; + let mt_text = &mt_corpus[..total_memtable_rows]; + + let shard_id = Uuid::new_v4(); + let row_bytes = 2048; // rough FineWeb text row size + let config = ShardWriterConfig { + shard_id, + shard_spec_id: 0, + durable_write: false, + sync_indexed_write: false, + max_memtable_size: args.max_memtable_rows * row_bytes * 2, + max_memtable_rows: args.max_memtable_rows, + max_unflushed_memtable_bytes: args.max_memtable_rows * row_bytes * 6, + max_wal_flush_interval: Some(Duration::from_secs(60)), + ..ShardWriterConfig::default() + }; + let writer = dataset.mem_wal_writer(shard_id, config).await?; + + let flush_wait = if is_cloud_uri(&args.uri) { + Duration::from_secs(5) + } else { + Duration::from_millis(500) + }; + + // Ingest flushed generations + 1 active (50% full). + let mut gen_sizes: Vec = (0..args.flushed_generations) + .map(|_| args.max_memtable_rows) + .collect(); + gen_sizes.push(active_rows); + + let id_base = args.base_rows as i64; + let mut cursor = 0usize; + let ingest_start = Instant::now(); + for (gen_idx, &gen_rows) in gen_sizes.iter().enumerate() { + let mut written = 0usize; + while written < gen_rows { + let chunk = args.batch_rows.min(gen_rows - written); + let start = id_base + (cursor) as i64; + let slice = &mt_text[cursor..cursor + chunk]; + let batch = make_batch(schema.clone(), start, slice); + writer.put(vec![batch]).await?; + cursor += chunk; + written += chunk; + } + let is_flushed = gen_idx < args.flushed_generations; + println!( + " gen {}: wrote {} rows ({})", + gen_idx + 1, + gen_rows, + if is_flushed { "flushed" } else { "active" } + ); + if is_flushed { + tokio::time::sleep(flush_wait).await; + } + } + println!( + "ingested {} memtable rows in {:.1}s", + cursor, + ingest_start.elapsed().as_secs_f64() + ); + + let manifest = writer.manifest().await?; + let in_memory_refs = writer.in_memory_memtable_refs().await?; + let mut shard_snapshot = ShardSnapshot::new(shard_id); + if let Some(ref m) = manifest { + shard_snapshot = shard_snapshot.with_current_generation(m.current_generation); + for fg in &m.flushed_generations { + shard_snapshot = shard_snapshot.with_flushed_generation(fg.generation, fg.path.clone()); + } + } + let num_flushed = manifest + .as_ref() + .map(|m| m.flushed_generations.len()) + .unwrap_or(0); + println!("manifest: {num_flushed} flushed generations"); + + let collector = LsmDataSourceCollector::new(dataset.clone(), vec![shard_snapshot]) + .with_in_memory_memtables(shard_id, in_memory_refs); + let pk_columns = vec!["id".to_string()]; + let planner = LsmFtsSearchPlanner::new(collector, pk_columns, arrow_schema); + + // Query set: mid-frequency terms from a sample of the loaded corpus. + println!("building query set ..."); + let sample_end = mt_corpus + .len() + .min(sample_rows.max(total_memtable_rows.min(50_000))); + let queries = build_query_terms(&mt_corpus[..sample_end], args.queries); + println!("picked {} query terms", queries.len()); + + // Run both modes. + println!( + "running Local mode ({} queries, k={}) ...", + queries.len(), + args.k + ); + let local = run_mode(&planner, FtsScoringMode::Local, &queries, args.k).await?; + println!( + "running LocalWithGlobalRescore mode (factor={}) ...", + args.rescore_factor + ); + let rescore = run_mode( + &planner, + FtsScoringMode::LocalWithGlobalRescore { + rescore_factor: args.rescore_factor, + }, + &queries, + args.k, + ) + .await?; + + // Keep writer alive so the active memtable stays reachable. + std::mem::forget(writer); + + let local_stats = compute_stats(local.latencies_us.clone()); + let rescore_stats = compute_stats(rescore.latencies_us.clone()); + let jaccard = mean_jaccard(&local.top_ids, &rescore.top_ids); + + println!( + "local: p50={}us p95={}us p99={}us mean={:.0}us qps={:.0}", + local_stats.p50_us, + local_stats.p95_us, + local_stats.p99_us, + local_stats.mean_us, + local_stats.qps + ); + println!( + "rescore: p50={}us p95={}us p99={}us mean={:.0}us qps={:.0}", + rescore_stats.p50_us, + rescore_stats.p95_us, + rescore_stats.p99_us, + rescore_stats.mean_us, + rescore_stats.qps + ); + println!("top-{} jaccard local-vs-rescore = {:.4}", args.k, jaccard); + + Ok(json!({ + "bench": "mem_wal_fts_read", + "phase": "search", + "uri_kind": if is_cloud_uri(&args.uri) { "cloud" } else { "local" }, + "base_rows": args.base_rows, + "max_memtable_rows": args.max_memtable_rows, + "flushed_generations": num_flushed, + "active_rows": active_rows, + "k": args.k, + "rescore_factor": args.rescore_factor, + "queries": queries.len(), + "jaccard_local_vs_rescore": jaccard, + "local": { + "p50_us": local_stats.p50_us, + "p95_us": local_stats.p95_us, + "p99_us": local_stats.p99_us, + "mean_us": local_stats.mean_us as u64, + "qps": local_stats.qps as u64, + }, + "rescore": { + "p50_us": rescore_stats.p50_us, + "p95_us": rescore_stats.p95_us, + "p99_us": rescore_stats.p99_us, + "mean_us": rescore_stats.mean_us as u64, + "qps": rescore_stats.qps as u64, + }, + })) +} + +// ---------------------------------------------------------------------- +// Entrypoint +// ---------------------------------------------------------------------- + +async fn run(args: Args) -> Result<()> { + println!( + "bench=mem_wal_fts_read phase={} uri={} base_rows={} max_memtable_rows={} flushed_generations={} queries={} k={} rescore_factor={}", + args.phase.as_str(), + args.uri, + args.base_rows, + args.max_memtable_rows, + args.flushed_generations, + args.queries, + args.k, + args.rescore_factor, + ); + + match args.phase { + Phase::Prepare => run_prepare(&args).await?, + Phase::Search => { + let result = run_search(&args).await?; + let text = serde_json::to_string_pretty(&result) + .map_err(|e| lance_core::Error::io(format!("serialize: {e}")))?; + println!("{text}"); + if let Some(path) = &args.output { + if let Some(parent) = path.parent() + && !parent.as_os_str().is_empty() + { + std::fs::create_dir_all(parent).ok(); + } + std::fs::write(path, text.as_bytes()) + .map_err(|e| lance_core::Error::io(format!("write {}: {e}", path.display())))?; + } + } + } + println!("=== DONE ==="); + Ok(()) +} + +fn main() -> Result<()> { + let args = parse_args()?; + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .map_err(|e| lance_core::Error::io(format!("build runtime: {e}")))?; + runtime.block_on(run(args)) +} diff --git a/rust/lance/benches/mem_wal/fts/run_fts_read_sweep.sh b/rust/lance/benches/mem_wal/fts/run_fts_read_sweep.sh new file mode 100755 index 00000000000..83203573818 --- /dev/null +++ b/rust/lance/benches/mem_wal/fts/run_fts_read_sweep.sh @@ -0,0 +1,150 @@ +#!/usr/bin/env bash +# Driver for the LSM FTS read benchmark panel across storage backends. +# +# Sweeps the `mem_wal_fts_read_bench` over: +# - storage backend : local NVMe path and an s3:// prefix +# - base table size : configurable list (default 100k, 1M) +# - top-k : configurable list (default 10, 100) +# +# For each (backend, base_rows) the bench's `prepare` phase runs once to +# write the base dataset + FTS index + MemWAL; then for each k the `search` +# phase ingests flushed generations + an active memtable through ShardWriter +# and times the FTS query panel under both Local and LocalWithGlobalRescore +# scoring modes. +# +# Each config runs under a `timeout` watchdog so a hang costs one window. +# +# Usage: +# rust/lance/benches/mem_wal/fts/run_fts_read_sweep.sh [run_id] +# +# Env: +# NVME_PREFIX local scratch dir on fast disk (default /lsm-fts-nvme) +# S3_PREFIX s3:// dataset prefix (default s3://jack-devland-build/lsm-fts) +# CACHE_DIR FineWeb shard download cache (default /lance-fineweb-cache) +# BASE_ROWS_LIST space-separated base sizes (default "100000 1000000") +# K_LIST space-separated top-k values (default "10 100") +# MAX_MEMTABLE_ROWS active/flushed memtable cap (default 100000) +# FLUSHED_GENERATIONS number of flushed gens (default 2) +# QUERIES queries per mode (default 200) +# RESCORE_FACTOR K' multiplier for rescore (default 10) +# BACKENDS space-separated subset of "nvme s3" (default both) +# CONFIG_TIMEOUT per-config seconds (default 5400) + +set -uo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd -P)" +REPO_ROOT="$(git -C "$SCRIPT_DIR" rev-parse --show-toplevel)" +cd "$REPO_ROOT" + +RUN_ID="${1:-$(date -u +%Y%m%dT%H%M%SZ)}" +NVME_PREFIX="${NVME_PREFIX:-${TMPDIR:-/tmp}/lsm-fts-nvme}" +S3_PREFIX="${S3_PREFIX:-s3://jack-devland-build/lsm-fts}" +CACHE_DIR="${CACHE_DIR:-${TMPDIR:-/tmp}/lance-fineweb-cache}" +BASE_ROWS_LIST="${BASE_ROWS_LIST:-100000 1000000}" +K_LIST="${K_LIST:-10 100}" +MAX_MEMTABLE_ROWS="${MAX_MEMTABLE_ROWS:-100000}" +FLUSHED_GENERATIONS="${FLUSHED_GENERATIONS:-2}" +QUERIES="${QUERIES:-200}" +RESCORE_FACTOR="${RESCORE_FACTOR:-10}" +BACKENDS="${BACKENDS:-nvme s3}" +CONFIG_TIMEOUT="${CONFIG_TIMEOUT:-5400}" + +LOCAL_DIR="$REPO_ROOT/target/lsm-fts-read-results/${RUN_ID}" +mkdir -p "$LOCAL_DIR" "$CACHE_DIR" "$NVME_PREFIX" + +BENCH=mem_wal_fts_read_bench +BIN="$(find target/release/deps -maxdepth 1 -type f -perm -111 -name "${BENCH}-*" ! -name '*.d' -printf '%T@ %p\n' 2>/dev/null | sort -nr | head -1 | cut -d' ' -f2-)" +if [ -z "$BIN" ]; then + echo "building bench binary..." + cargo bench -p lance --bench "$BENCH" --no-run + BIN="$(find target/release/deps -maxdepth 1 -type f -perm -111 -name "${BENCH}-*" ! -name '*.d' -printf '%T@ %p\n' 2>/dev/null | sort -nr | head -1 | cut -d' ' -f2-)" +fi +echo "bench binary: $BIN" +echo "run id: $RUN_ID" +echo "backends: $BACKENDS" +echo "" + +backend_prefix() { + case "$1" in + nvme) echo "$NVME_PREFIX/$RUN_ID" ;; + s3) echo "$S3_PREFIX/$RUN_ID" ;; + *) echo "ERROR unknown backend $1" >&2; exit 1 ;; + esac +} + +run_phase() { + local name="$1"; shift + local log="$LOCAL_DIR/${name}.log" + echo ">>> $name" + timeout "$CONFIG_TIMEOUT" "$BIN" --bench "$@" > "$log" 2>&1 + local rc=$? + if [ "$rc" -eq 124 ]; then + echo " !!! TIMED OUT after ${CONFIG_TIMEOUT}s (see $log)" + return 1 + elif [ "$rc" -ne 0 ]; then + echo " !!! failed rc=$rc (see $log)" + return 1 + fi + echo " ok" + return 0 +} + +for backend in $BACKENDS; do + prefix="$(backend_prefix "$backend")" + for base_rows in $BASE_ROWS_LIST; do + case "$base_rows" in + 1000000) btag=1M ;; + 100000) btag=100k ;; + *) btag="${base_rows}" ;; + esac + uri="$prefix/base_${btag}" + + # prepare once per (backend, base_rows) + run_phase "prepare_${backend}_${btag}" \ + --phase prepare --uri "$uri" \ + --base-rows "$base_rows" --batch-rows 1000 \ + --cache-dir "$CACHE_DIR" || continue + + # search for each k + for k in $K_LIST; do + name="search_${backend}_${btag}_k${k}" + out="$LOCAL_DIR/${name}.json" + if [ -f "$out" ]; then + echo ">>> $name (already done, skipping)" + continue + fi + run_phase "$name" \ + --phase search --uri "$uri" \ + --base-rows "$base_rows" \ + --max-memtable-rows "$MAX_MEMTABLE_ROWS" \ + --flushed-generations "$FLUSHED_GENERATIONS" \ + --batch-rows 1000 \ + --queries "$QUERIES" --k "$k" \ + --rescore-factor "$RESCORE_FACTOR" \ + --cache-dir "$CACHE_DIR" \ + --output "$out" + # mirror result to s3 for durability regardless of backend + [ -f "$out" ] && aws s3 cp "$out" "$S3_PREFIX/$RUN_ID/results/${name}.json" >/dev/null 2>&1 + done + done +done + +echo "" +echo "=== summary ===" +python3 - "$LOCAL_DIR" <<'PY' +import glob, json, os, sys +d = sys.argv[1] +hdr = f"{'config':32s} {'local_p50':>10} {'rescore_p50':>12} {'local_qps':>10} {'rescore_qps':>12} {'jaccard':>8}" +print(hdr) +for p in sorted(glob.glob(os.path.join(d, "*.json"))): + try: + r = json.load(open(p)) + except Exception as e: + print(f" bad {p}: {e}"); continue + name = os.path.basename(p)[:-5] + lo = r.get("local", {}); re_ = r.get("rescore", {}) + print(f"{name:32s} {lo.get('p50_us','-'):>10} {re_.get('p50_us','-'):>12} " + f"{lo.get('qps','-'):>10} {re_.get('qps','-'):>12} {r.get('jaccard_local_vs_rescore',0):>8.4f}") +PY +echo "" +echo "results: $LOCAL_DIR + $S3_PREFIX/$RUN_ID/results/" From e6409447322bd1a6f9273a52d69abb2f2a57b99e Mon Sep 17 00:00:00 2001 From: Heng Ge Date: Thu, 21 May 2026 08:56:31 -0700 Subject: [PATCH 05/10] fix(bench): make mem_wal_fts_read flush generations deterministically The memtable flush trigger is byte/batch-count based, not row-count, so the FineWeb text payload (variable row size) never reliably flushed at `max_memtable_rows`. Set `max_memtable_batches` to one generation's worth of batches so the batch store fills exactly at each generation boundary, and drain pending flushes via `wait_for_flush_drain()` before snapshotting the manifest so all flushed generations are visible to the planner. --- .../mem_wal/fts/mem_wal_fts_read_bench.rs | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/rust/lance/benches/mem_wal/fts/mem_wal_fts_read_bench.rs b/rust/lance/benches/mem_wal/fts/mem_wal_fts_read_bench.rs index 1e4fc6459b5..f6ea0d9e7dd 100644 --- a/rust/lance/benches/mem_wal/fts/mem_wal_fts_read_bench.rs +++ b/rust/lance/benches/mem_wal/fts/mem_wal_fts_read_bench.rs @@ -578,14 +578,24 @@ async fn run_search(args: &Args) -> Result { let shard_id = Uuid::new_v4(); let row_bytes = 2048; // rough FineWeb text row size + // The memtable flush trigger is `estimated_size >= max_memtable_size || + // batch_store_full`. FineWeb text rows vary in size, so a byte threshold + // is an unreliable way to flush exactly one generation per + // `max_memtable_rows`. Instead make the *batch-count* cap the trigger: + // set `max_memtable_batches` to one generation's worth of batches so the + // store fills (and flushes) precisely at each generation boundary, + // independent of text length. Keep `max_memtable_size` high so it never + // pre-empts the batch-count trigger. + let batches_per_gen = (args.max_memtable_rows / args.batch_rows).max(1); let config = ShardWriterConfig { shard_id, shard_spec_id: 0, durable_write: false, sync_indexed_write: false, - max_memtable_size: args.max_memtable_rows * row_bytes * 2, + max_memtable_size: args.max_memtable_rows * row_bytes * 100, max_memtable_rows: args.max_memtable_rows, - max_unflushed_memtable_bytes: args.max_memtable_rows * row_bytes * 6, + max_memtable_batches: batches_per_gen, + max_unflushed_memtable_bytes: args.max_memtable_rows * row_bytes * 20, max_wal_flush_interval: Some(Duration::from_secs(60)), ..ShardWriterConfig::default() }; @@ -628,6 +638,10 @@ async fn run_search(args: &Args) -> Result { tokio::time::sleep(flush_wait).await; } } + // Wait for any triggered (sealed) memtable flushes to commit to the + // manifest before we snapshot it — otherwise the flushed generations + // race the read and may not all be visible yet. + writer.wait_for_flush_drain().await?; println!( "ingested {} memtable rows in {:.1}s", cursor, From b157b81773f22947124d37d135c3038cf2cca305 Mon Sep 17 00:00:00 2001 From: Heng Ge Date: Thu, 21 May 2026 09:03:50 -0700 Subject: [PATCH 06/10] fix(mem_wal): resolve flushed-gen FTS index via load_scalar_index criteria The rescore path's open_inverted_index used a manual field-id scan of load_indices() that failed to match the maintained FTS index on flushed generation datasets, so rescore errored with "missing an FTS index" even though Local mode (which uses scanner.full_text_search) resolved it fine. Switch to the same load_scalar_index(for_column().supports_fts()) criteria lookup the base-table FTS exec path uses. --- .../src/dataset/mem_wal/scanner/fts_search.rs | 44 +++++++++---------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/rust/lance/src/dataset/mem_wal/scanner/fts_search.rs b/rust/lance/src/dataset/mem_wal/scanner/fts_search.rs index 4b2ad3fe5fd..ee8b575cdf1 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/fts_search.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/fts_search.rs @@ -835,37 +835,33 @@ fn extract_match_text(query: &FullTextSearchQuery) -> Result { /// Open the column's inverted index from a Lance dataset, or `None` /// if no FTS index exists for the column. +/// +/// Uses the same criteria-based lookup as the base-table FTS exec path +/// (`load_scalar_index(... .for_column().supports_fts())`) rather than a +/// manual field-id scan, so flushed-generation datasets resolve their +/// maintained FTS index identically to how `scanner.full_text_search` +/// resolves it. async fn open_inverted_index( dataset: &Dataset, column: &str, ) -> Result>> { use crate::index::{DatasetIndexExt, DatasetIndexInternalExt}; + use lance_index::IndexCriteria; - // Resolve the column's field id so we can match against `IndexMetadata.fields`. - let field_id = match dataset.schema().field(column) { - Some(f) => f.id, - None => return Ok(None), + let Some(meta) = dataset + .load_scalar_index(IndexCriteria::default().for_column(column).supports_fts()) + .await? + else { + return Ok(None); }; - - let indices = dataset.load_indices().await?; - for meta in indices.iter() { - if !meta.fields.contains(&field_id) { - continue; - } - // Mixed index types on the same field would be unusual but - // possible; the canonical filter is the downcast below. - let uuid = meta.uuid.to_string(); - let Ok(opened) = dataset - .open_generic_index(column, &uuid, &lance_index::metrics::NoOpMetricsCollector) - .await - else { - continue; - }; - if let Some(inv) = opened.as_any().downcast_ref::() { - return Ok(Some(Arc::new(inv.clone()))); - } - } - Ok(None) + let uuid = meta.uuid.to_string(); + let opened = dataset + .open_generic_index(column, &uuid, &lance_index::metrics::NoOpMetricsCollector) + .await?; + Ok(opened + .as_any() + .downcast_ref::() + .map(|inv| Arc::new(inv.clone()))) } /// Materialize user-projected columns from the active memtable's From 9ce5054f03a3c679750cd82b90af05c2214a503b Mon Sep 17 00:00:00 2001 From: Heng Ge Date: Thu, 21 May 2026 09:19:59 -0700 Subject: [PATCH 07/10] feat(mem_wal): flat candidate path for index-less flushed gens in rescore MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Flushed memtable generations are written without an on-disk FTS index (the maintained index lives only in the active/frozen memtable), so the rescore path can't assume every Lance source has an InvertedIndex. Add a flat fallback: when open_inverted_index returns None, scan + tokenize the source's text column once to compute corpus stats (total_tokens, num_docs, df) and per-doc query-term frequencies, then score candidates with local stats for top-K' selection — mirroring the flat fallback Local mode gets from scanner.full_text_search. resolve_tokenizer is replaced by resolve_params so the same FTS params drive both query tokenization and the flat scans. Regression test covers indexed base + index-less flushed gen + indexed active in one rescore query. --- .../src/dataset/mem_wal/scanner/fts_search.rs | 395 ++++++++++++++++-- 1 file changed, 358 insertions(+), 37 deletions(-) diff --git a/rust/lance/src/dataset/mem_wal/scanner/fts_search.rs b/rust/lance/src/dataset/mem_wal/scanner/fts_search.rs index ee8b575cdf1..5f54809b3ea 100644 --- a/rust/lance/src/dataset/mem_wal/scanner/fts_search.rs +++ b/rust/lance/src/dataset/mem_wal/scanner/fts_search.rs @@ -28,7 +28,7 @@ use std::collections::HashMap; use std::sync::Arc; -use arrow_array::{Array, Float32Array, RecordBatch, UInt32Array, UInt64Array}; +use arrow_array::{Array, Float32Array, RecordBatch, StringArray, UInt32Array, UInt64Array}; use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; use arrow_select::concat::concat_batches; use arrow_select::take::take; @@ -47,6 +47,7 @@ use lance_index::scalar::inverted::document_tokenizer::DocType; use lance_index::scalar::inverted::query::{ FtsQuery as IndexFtsQuery, FtsSearchParams, Operator, Tokens, collect_query_tokens, }; +use lance_index::scalar::inverted::tokenizer::InvertedIndexParams; use lance_index::scalar::inverted::{InvertedIndex, InvertedIndexCandidate, MemBM25Scorer, Scorer}; use tracing::instrument; @@ -205,9 +206,10 @@ impl LsmFtsSearchPlanner { return self.empty_plan(&target_schema); } - // Step 1: pull a tokenizer + tokenize the query text. + // Step 1: resolve shared FTS params + tokenize the query text. let match_text = extract_match_text(&query)?; - let mut tokenizer = self.resolve_tokenizer(&sources, column).await?; + let params = self.resolve_params(&sources, column).await?; + let mut tokenizer = params.build()?; let tokens_obj = collect_query_tokens(&match_text, &mut tokenizer); let token_strs: Vec = (0..tokens_obj.len()) .map(|i| tokens_obj.get_token(i).to_owned()) @@ -225,7 +227,9 @@ impl LsmFtsSearchPlanner { let mut df_map: HashMap = token_strs.iter().map(|t| (t.clone(), 0usize)).collect(); for source in &sources { - let handle = self.resolve_handle(source, column).await?; + let handle = self + .resolve_handle(source, column, ¶ms, &token_strs) + .await?; let (tt, nd, df_vec) = handle.stats_for_terms(&token_strs)?; total_tokens += tt; num_docs += nd; @@ -283,28 +287,30 @@ impl LsmFtsSearchPlanner { Ok(exec) } - /// Acquire a tokenizer compatible with every source's FTS index. + /// Resolve the `InvertedIndexParams` shared by the LSM sources. /// /// We assume FTS-indexed sources in an LSM hierarchy share their - /// `InvertedIndexParams` (otherwise their indexes wouldn't be - /// merge-compatible). Pulls the tokenizer from the first source - /// that has one; any later mismatch is the caller's bug. - async fn resolve_tokenizer( + /// params (otherwise their indexes wouldn't be merge-compatible). + /// Pulls from the first source that carries an index (active + /// `FtsMemIndex` or a Lance dataset with an on-disk inverted index); + /// these params are then used both to tokenize the query and to + /// build tokenizers for flat-scanning index-less flushed + /// generations. + async fn resolve_params( &self, sources: &[LsmDataSource], column: &str, - ) -> Result> - { + ) -> Result { for source in sources { match source { LsmDataSource::ActiveMemTable { index_store, .. } => { if let Some(idx) = index_store.get_fts_by_column(column) { - return idx.params().build(); + return Ok(idx.params().clone()); } } LsmDataSource::BaseTable { dataset } => { if let Some(idx) = open_inverted_index(dataset, column).await? { - return Ok(idx.tokenizer()); + return Ok(idx.params().clone()); } } LsmDataSource::FlushedMemTable { path, .. } => { @@ -312,7 +318,7 @@ impl LsmFtsSearchPlanner { .load() .await?; if let Some(idx) = open_inverted_index(&dataset, column).await? { - return Ok(idx.tokenizer()); + return Ok(idx.params().clone()); } } } @@ -323,7 +329,22 @@ impl LsmFtsSearchPlanner { ))) } - async fn resolve_handle(&self, source: &LsmDataSource, column: &str) -> Result { + /// Resolve a source to a `SourceHandle`. + /// + /// Lance sources (base + flushed generations) may or may not carry + /// an on-disk inverted index. Flushed memtable generations are + /// written without one (the maintained FTS index lives only in the + /// active/frozen memtable), so for those we fall back to a flat + /// scan-and-tokenize candidate path keyed off `params` + the query + /// `tokens` — mirroring the flat fallback `scanner.full_text_search` + /// uses for Local mode. + async fn resolve_handle( + &self, + source: &LsmDataSource, + column: &str, + params: &InvertedIndexParams, + tokens: &[String], + ) -> Result { match source { LsmDataSource::ActiveMemTable { batch_store, @@ -344,33 +365,38 @@ impl LsmFtsSearchPlanner { }) } LsmDataSource::BaseTable { dataset } => { - let index = open_inverted_index(dataset, column).await?.ok_or_else(|| { - Error::invalid_input(format!( - "Base table is missing an FTS index on column '{column}'" - )) - })?; - Ok(SourceHandle::Lance { - dataset: dataset.clone(), - index, - }) + Self::lance_handle(dataset.clone(), column, params, tokens).await } LsmDataSource::FlushedMemTable { path, .. } => { - let dataset = crate::dataset::DatasetBuilder::from_uri(path) - .load() - .await?; - let index = open_inverted_index(&dataset, column).await?.ok_or_else(|| { - Error::invalid_input(format!( - "Flushed memtable at {path} is missing an FTS index on column '{column}'" - )) - })?; - Ok(SourceHandle::Lance { - dataset: Arc::new(dataset), - index, - }) + let dataset = Arc::new( + crate::dataset::DatasetBuilder::from_uri(path) + .load() + .await?, + ); + Self::lance_handle(dataset, column, params, tokens).await } } } + /// Build a Lance `SourceHandle`: indexed when an on-disk inverted + /// index exists, otherwise a flat (scan + tokenize) handle. + async fn lance_handle( + dataset: Arc, + column: &str, + params: &InvertedIndexParams, + tokens: &[String], + ) -> Result { + if let Some(index) = open_inverted_index(&dataset, column).await? { + Ok(SourceHandle::Lance { dataset, index }) + } else { + let flat = FlatData::compute(&dataset, column, params, tokens).await?; + Ok(SourceHandle::LanceFlat { + dataset, + flat: Arc::new(flat), + }) + } + } + /// Materialize the rescored top-k into a single RecordBatch with /// the canonical FTS schema. Groups by source so we can issue one /// take per Lance source, then rebuilds the original row order. @@ -696,6 +722,13 @@ enum SourceHandle { dataset: Arc, index: Arc, }, + /// A Lance dataset with no on-disk FTS index (e.g., a flushed + /// memtable generation). Candidates + stats come from a one-shot + /// flat scan-and-tokenize computed at handle resolution. + LanceFlat { + dataset: Arc, + flat: Arc, + }, } impl SourceHandle { @@ -712,6 +745,10 @@ impl SourceHandle { Ok(idx.bm25_stats_for_terms(terms)) } Self::Lance { index, .. } => Ok(index.bm25_stats_for_terms(terms)), + Self::LanceFlat { flat, .. } => { + debug_assert_eq!(flat.df.len(), terms.len()); + Ok((flat.total_tokens, flat.num_docs, flat.df.clone())) + } } } @@ -755,6 +792,10 @@ impl SourceHandle { .map(UnifiedCandidate::from_inverted_candidate) .collect()) } + Self::LanceFlat { flat, .. } => { + let k_prime = params.limit.unwrap_or(usize::MAX); + Ok(flat.top_candidates(token_strs, k_prime)) + } } } @@ -765,7 +806,9 @@ impl SourceHandle { schema, .. } => active_materialize(batch_store, schema, row_ids, cols), - Self::Lance { dataset, .. } => { + // Both Lance variants materialize the same way: `row_id` is a + // Lance `_rowid`, so `take_rows` fetches the user columns. + Self::Lance { dataset, .. } | Self::LanceFlat { dataset, .. } => { // Project the dataset's Lance schema down to the requested // columns by name. Unknown names are dropped (`take_rows` // would otherwise error on schema construction). @@ -781,6 +824,169 @@ impl SourceHandle { } } +/// Flat (index-less) FTS state for one Lance source, computed by a +/// single scan-and-tokenize pass over the dataset's text column. +/// +/// Holds corpus-wide stats (`total_tokens`, `num_docs`, per-term `df`) +/// plus the candidate docs that contain at least one query term, each +/// carrying its `_rowid`, `doc_len`, and per-query-term frequencies. +struct FlatData { + /// Per query term (input order): number of docs containing it. + df: Vec, + total_tokens: u64, + num_docs: usize, + /// Candidate docs (those matching >=1 term): `_rowid`, `doc_len`, + /// and `term_freqs` in query-token order. + cand_row_ids: Vec, + cand_doc_lens: Vec, + cand_tfs: Vec>, +} + +impl FlatData { + /// Scan `dataset`'s `column`, tokenizing each doc with `params`, to + /// build corpus stats + per-doc query-term frequencies. + async fn compute( + dataset: &Dataset, + column: &str, + params: &InvertedIndexParams, + tokens: &[String], + ) -> Result { + use futures::TryStreamExt; + + let term_to_idx: HashMap<&str, usize> = tokens + .iter() + .enumerate() + .map(|(i, t)| (t.as_str(), i)) + .collect(); + + let mut scanner = dataset.scan(); + scanner.project(&[column])?; + scanner.with_row_id(); + let mut stream = scanner.try_into_stream().await?; + + let mut tokenizer = params.build()?; + let mut df = vec![0usize; tokens.len()]; + let mut total_tokens: u64 = 0; + let mut num_docs: usize = 0; + let mut cand_row_ids: Vec = Vec::new(); + let mut cand_doc_lens: Vec = Vec::new(); + let mut cand_tfs: Vec> = Vec::new(); + + while let Some(batch) = stream.try_next().await? { + let rowid_col = batch + .column_by_name(lance_core::ROW_ID) + .ok_or_else(|| Error::internal("flat scan missing _rowid".to_string()))? + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::internal("_rowid not UInt64".to_string()))?; + let text_col = batch + .column_by_name(column) + .ok_or_else(|| Error::internal(format!("flat scan missing '{column}'")))?; + let texts = string_values(text_col.as_ref())?; + + for (i, text) in texts.iter().enumerate() { + num_docs += 1; + let mut tfs = vec![0u32; tokens.len()]; + let mut doc_len: u32 = 0; + if let Some(text) = text { + let mut stream = tokenizer.token_stream_for_doc(text); + while let Some(tok) = stream.next() { + doc_len += 1; + if let Some(&ti) = term_to_idx.get(tok.text.as_str()) { + tfs[ti] += 1; + } + } + } + total_tokens += doc_len as u64; + if tfs.iter().any(|&f| f > 0) { + for (ti, &f) in tfs.iter().enumerate() { + if f > 0 { + df[ti] += 1; + } + } + cand_row_ids.push(rowid_col.value(i)); + cand_doc_lens.push(doc_len); + cand_tfs.push(tfs); + } + } + } + + Ok(Self { + df, + total_tokens, + num_docs, + cand_row_ids, + cand_doc_lens, + cand_tfs, + }) + } + + /// Local-stats top-`k_prime` candidates. Scores each candidate with a + /// scorer built from this source's own stats (matching the indexed + /// path's local-pruning semantics) and keeps the best `k_prime`. + fn top_candidates(&self, tokens: &[String], k_prime: usize) -> Vec { + if k_prime == 0 || self.cand_row_ids.is_empty() { + return Vec::new(); + } + let token_docs: HashMap = tokens + .iter() + .cloned() + .zip(self.df.iter().copied()) + .collect(); + let scorer = MemBM25Scorer::new(self.total_tokens, self.num_docs, token_docs); + + let mut scored: Vec<(usize, f32)> = (0..self.cand_row_ids.len()) + .map(|i| { + let s = bm25_score(&scorer, tokens, &self.cand_tfs[i], self.cand_doc_lens[i]); + (i, s) + }) + .collect(); + if scored.len() > k_prime { + scored.select_nth_unstable_by(k_prime, |a, b| { + b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal) + }); + scored.truncate(k_prime); + } + scored + .into_iter() + .map(|(i, _)| UnifiedCandidate { + row_id: self.cand_row_ids[i], + doc_len: self.cand_doc_lens[i], + term_freqs: self.cand_tfs[i].clone(), + }) + .collect() + } +} + +/// Extract optional UTF-8 strings from a Utf8 / LargeUtf8 / Utf8View array. +fn string_values(array: &dyn Array) -> Result>> { + use arrow_array::{LargeStringArray, StringViewArray}; + use arrow_schema::DataType; + match array.data_type() { + DataType::Utf8 => { + let a = array.as_any().downcast_ref::().unwrap(); + Ok((0..a.len()) + .map(|i| (!a.is_null(i)).then(|| a.value(i))) + .collect()) + } + DataType::LargeUtf8 => { + let a = array.as_any().downcast_ref::().unwrap(); + Ok((0..a.len()) + .map(|i| (!a.is_null(i)).then(|| a.value(i))) + .collect()) + } + DataType::Utf8View => { + let a = array.as_any().downcast_ref::().unwrap(); + Ok((0..a.len()) + .map(|i| (!a.is_null(i)).then(|| a.value(i))) + .collect()) + } + other => Err(Error::invalid_input(format!( + "flat FTS scan: column must be Utf8/LargeUtf8/Utf8View, got {other:?}" + ))), + } +} + /// Common shape for one candidate, regardless of where it came from. struct UnifiedCandidate { row_id: u64, @@ -1126,6 +1332,121 @@ mod tests { } } + #[tokio::test] + async fn rescore_mode_handles_indexless_flushed_generation() { + // Regression: flushed memtable generations are written WITHOUT an + // on-disk FTS index (only the active/frozen memtable carries the + // maintained index). The rescore path must flat-scan such sources + // for candidates + stats rather than erroring. Layout: indexed + // base + an index-less flushed gen + indexed active, all sharing + // the "lance" term. + use crate::index::DatasetIndexExt; + use lance_index::IndexType; + use lance_index::scalar::inverted::tokenizer::InvertedIndexParams; + use uuid::Uuid; + + let schema = fts_schema(); + let tmp = tempfile::tempdir().unwrap(); + + // Indexed base. + let base_uri = format!("{}/base", tmp.path().to_str().unwrap()); + let mut base_ds = write_dataset( + &base_uri, + vec![make_batch(&schema, &[1, 2], &["lance base", "noise"])], + ) + .await; + base_ds + .create_index( + &["text"], + IndexType::Inverted, + Some("text_fts".to_string()), + &InvertedIndexParams::default(), + false, + ) + .await + .unwrap(); + let base_ds = Arc::new(Dataset::open(&base_uri).await.unwrap()); + + // Index-less flushed generation at the collector's resolved path: + // {base_uri}/_mem_wal/{shard}/gen_1. NO create_index call. + let shard_id = Uuid::new_v4(); + let gen1_uri = format!("{base_uri}/_mem_wal/{shard_id}/gen_1"); + write_dataset( + &gen1_uri, + vec![make_batch( + &schema, + &[3, 4], + &["lance flushed gen", "unrelated"], + )], + ) + .await; + + // Indexed active memtable. + let batch_store = Arc::new(BatchStore::with_capacity(16)); + let mut indexes = IndexStore::new(); + indexes.add_fts("text_fts".to_string(), 1, "text".to_string()); + let active_batch = make_batch(&schema, &[5, 6], &["lance active", "nothing"]); + batch_store.append(active_batch.clone()).unwrap(); + indexes + .insert_with_batch_position(&active_batch, 0, Some(0)) + .unwrap(); + let indexes = Arc::new(indexes); + + let shard_snapshot = crate::dataset::mem_wal::scanner::ShardSnapshot::new(shard_id) + .with_current_generation(2) + .with_flushed_generation(1, "gen_1".to_string()); + + let collector = LsmDataSourceCollector::new(base_ds, vec![shard_snapshot]) + .with_in_memory_memtables( + shard_id, + InMemoryMemTables { + active: InMemoryMemTableRef { + batch_store, + index_store: indexes, + schema: schema.clone(), + generation: 2, + }, + frozen: vec![], + }, + ); + + let planner = LsmFtsSearchPlanner::new(collector, vec!["id".to_string()], schema); + let plan = planner + .plan_search( + "text", + FullTextSearchQuery::new("lance".to_string()), + 10, + None, + FtsScoringMode::local_with_global_rescore_default(), + ) + .await + .expect("rescore must handle an index-less flushed generation via flat scan"); + + let ctx = datafusion::prelude::SessionContext::new(); + let stream = plan.execute(0, ctx.task_ctx()).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + + let mut ids: Vec = Vec::new(); + for b in &batches { + let col = b + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..b.num_rows() { + ids.push(col.value(i)); + } + } + // One "lance" hit from each tier: base id=1, flushed id=3, active id=5. + assert!(ids.contains(&1), "missing base hit; got {ids:?}"); + assert!( + ids.contains(&3), + "missing index-less flushed-gen hit (flat path broken); got {ids:?}" + ); + assert!(ids.contains(&5), "missing active hit; got {ids:?}"); + } + #[tokio::test] async fn rescore_mode_active_only_runs_end_to_end() { // Cheaper regression that doesn't need a base Lance dataset: From e0e9461f28fda5c8ac338e51a18377025164b575 Mon Sep 17 00:00:00 2001 From: Heng Ge Date: Thu, 21 May 2026 09:30:57 -0700 Subject: [PATCH 08/10] bench(mem_wal): index flushed generations in fts read bench MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Flushed memtable generations are written without an FTS index, so both scoring modes would flat-scan them per query — an O(rows*queries) artifact that swamps the scoring-mode signal the bench measures. Build an inverted index on each flushed generation after flush (modeling the realistic post-flush multi-segment FTS state) so Local and Rescore both use the fast indexed path. The index-less flat path remains covered by unit tests. --- .../mem_wal/fts/mem_wal_fts_read_bench.rs | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/rust/lance/benches/mem_wal/fts/mem_wal_fts_read_bench.rs b/rust/lance/benches/mem_wal/fts/mem_wal_fts_read_bench.rs index f6ea0d9e7dd..6b353f09d24 100644 --- a/rust/lance/benches/mem_wal/fts/mem_wal_fts_read_bench.rs +++ b/rust/lance/benches/mem_wal/fts/mem_wal_fts_read_bench.rs @@ -663,6 +663,37 @@ async fn run_search(args: &Args) -> Result { .unwrap_or(0); println!("manifest: {num_flushed} flushed generations"); + // Build an FTS index on each flushed generation. Flushed memtables are + // written without one (the maintained index lives only in the active + // memtable), so without this both scoring modes would flat-scan the + // generations per query — an O(rows·queries) artifact that swamps the + // scoring-mode signal this bench measures. Indexing the generations + // models the realistic post-flush multi-segment FTS state (discussion + // #6789) and lets both Local and Rescore use the fast indexed path. + // (The index-less flat path itself is covered by unit tests.) + if let Some(ref m) = manifest { + let idx_start = Instant::now(); + for fg in &m.flushed_generations { + let gen_uri = format!("{}/_mem_wal/{}/{}", args.uri, shard_id, fg.path); + let mut gen_ds = Dataset::open(&gen_uri).await?; + gen_ds + .create_index( + &[TEXT_COL], + IndexType::Inverted, + Some(FTS_INDEX_NAME.to_string()), + &InvertedIndexParams::default(), + true, + ) + .await?; + } + if num_flushed > 0 { + println!( + "indexed {num_flushed} flushed generations in {:.1}s", + idx_start.elapsed().as_secs_f64() + ); + } + } + let collector = LsmDataSourceCollector::new(dataset.clone(), vec![shard_snapshot]) .with_in_memory_memtables(shard_id, in_memory_refs); let pk_columns = vec!["id".to_string()]; From 209f1d30ec3d9ce9c857ce19ffacbb8f69827679 Mon Sep 17 00:00:00 2001 From: Heng Ge Date: Thu, 21 May 2026 17:21:36 -0700 Subject: [PATCH 09/10] reconcile(bench): drop manual flushed-gen indexing (built on flush in #6901) lance #6901 makes the memtable flush handler build the shard's maintained secondary indexes on each flushed generation, so the FTS index now exists on every flushed gen without the bench creating it. Remove the manual create_index loop; both scoring modes still use the fast indexed path, and the rescore planner's flat fallback remains for the no-maintained-index case (covered by unit tests). --- .../mem_wal/fts/mem_wal_fts_read_bench.rs | 37 ++++--------------- 1 file changed, 7 insertions(+), 30 deletions(-) diff --git a/rust/lance/benches/mem_wal/fts/mem_wal_fts_read_bench.rs b/rust/lance/benches/mem_wal/fts/mem_wal_fts_read_bench.rs index 6b353f09d24..3a0f3e9e26d 100644 --- a/rust/lance/benches/mem_wal/fts/mem_wal_fts_read_bench.rs +++ b/rust/lance/benches/mem_wal/fts/mem_wal_fts_read_bench.rs @@ -663,36 +663,13 @@ async fn run_search(args: &Args) -> Result { .unwrap_or(0); println!("manifest: {num_flushed} flushed generations"); - // Build an FTS index on each flushed generation. Flushed memtables are - // written without one (the maintained index lives only in the active - // memtable), so without this both scoring modes would flat-scan the - // generations per query — an O(rows·queries) artifact that swamps the - // scoring-mode signal this bench measures. Indexing the generations - // models the realistic post-flush multi-segment FTS state (discussion - // #6789) and lets both Local and Rescore use the fast indexed path. - // (The index-less flat path itself is covered by unit tests.) - if let Some(ref m) = manifest { - let idx_start = Instant::now(); - for fg in &m.flushed_generations { - let gen_uri = format!("{}/_mem_wal/{}/{}", args.uri, shard_id, fg.path); - let mut gen_ds = Dataset::open(&gen_uri).await?; - gen_ds - .create_index( - &[TEXT_COL], - IndexType::Inverted, - Some(FTS_INDEX_NAME.to_string()), - &InvertedIndexParams::default(), - true, - ) - .await?; - } - if num_flushed > 0 { - println!( - "indexed {num_flushed} flushed generations in {:.1}s", - idx_start.elapsed().as_secs_f64() - ); - } - } + // Flushed generations carry the same maintained secondary indexes as + // the active memtable: the flush handler builds them during flush + // (lance #6901), so each generation already has the FTS index and + // both scoring modes use the fast indexed path. No manual indexing + // step is needed here. (The index-less flat fallback in the rescore + // planner is still exercised by unit tests for the no-maintained-index + // case.) let collector = LsmDataSourceCollector::new(dataset.clone(), vec![shard_snapshot]) .with_in_memory_memtables(shard_id, in_memory_refs); From 7f7c766864cc497b122853be486493275a18728a Mon Sep 17 00:00:00 2001 From: Heng Ge Date: Thu, 21 May 2026 18:13:25 -0700 Subject: [PATCH 10/10] reconcile(bench): remove lsm_fts_modes in favor of mem_wal_fts_read_bench The #6882 refactor replaced ad-hoc benches with standalone CLI+JSON benchmarks driven through the real ShardWriter ingestion path. mem_wal_fts_read_bench follows that template (and is what the EC2 sweep ran); lsm_fts_modes was an off-template synthetic-shape bench that built datasets manually. Drop it and its Cargo.toml entry to keep one template-aligned FTS read benchmark. --- rust/lance/Cargo.toml | 5 - .../benches/mem_wal/fts/lsm_fts_modes.rs | 939 ------------------ 2 files changed, 944 deletions(-) delete mode 100644 rust/lance/benches/mem_wal/fts/lsm_fts_modes.rs diff --git a/rust/lance/Cargo.toml b/rust/lance/Cargo.toml index 2db049944aa..330df20ce40 100644 --- a/rust/lance/Cargo.toml +++ b/rust/lance/Cargo.toml @@ -258,11 +258,6 @@ name = "mem_wal_fts_bench" path = "benches/mem_wal/fts/mem_wal_fts_bench.rs" harness = false -[[bench]] -name = "lsm_fts_modes" -path = "benches/mem_wal/fts/lsm_fts_modes.rs" -harness = false - [[bench]] name = "mem_wal_fts_read_bench" path = "benches/mem_wal/fts/mem_wal_fts_read_bench.rs" diff --git a/rust/lance/benches/mem_wal/fts/lsm_fts_modes.rs b/rust/lance/benches/mem_wal/fts/lsm_fts_modes.rs deleted file mode 100644 index 5a688264a8a..00000000000 --- a/rust/lance/benches/mem_wal/fts/lsm_fts_modes.rs +++ /dev/null @@ -1,939 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright The Lance Authors - -//! Benchmark comparing `LsmScanner::full_text_search` scoring modes on -//! the LSM hierarchy with multiple flushed generations on a real -//! FineWeb corpus. -//! -//! Sibling of `mem_wal_fineweb_fts.rs`. Shares its FineWeb loader -//! shape (HF `sample/10BT` parquet shards, `--cache-dir` to amortize -//! downloads). -//! -//! Per shape × scoring mode the bench reports: -//! -//! * Wall-clock per query and aggregate latency percentiles. -//! * Top-K Jaccard vs a single-merged-index ground truth (the same -//! FineWeb rows loaded into a single Lance dataset with one FTS -//! index, queried via `scanner.full_text_search`). -//! * Pearson correlation of `_score` between LSM mode and ground -//! truth on the intersection. -//! -//! Example: -//! -//! ```bash -//! cargo bench -p lance --bench lsm_fts_modes -- \ -//! --shape memwal_skewed --k 100 --num-queries 100 \ -//! --rescore-factor 10 \ -//! --cache-dir /tmp/fineweb-cache --output result.json -//! ``` - -#![recursion_limit = "256"] -#![allow(clippy::print_stdout, clippy::print_stderr)] - -use std::collections::{HashMap, HashSet}; -use std::path::PathBuf; -use std::sync::Arc; -use std::time::{Duration, Instant}; - -use arrow_array::{Array, Int64Array, RecordBatch, RecordBatchIterator, StringArray}; -use arrow_schema::{DataType, Field, Schema as ArrowSchema}; -use futures::TryStreamExt; -use lance::dataset::mem_wal::scanner::{ - FtsScoringMode, InMemoryMemTableRef, InMemoryMemTables, LsmScanner, -}; -use lance::dataset::mem_wal::write::{BatchStore, IndexStore}; -use lance::dataset::{Dataset, WriteParams}; -use lance::index::DatasetIndexExt; -use lance_core::Result; -use lance_index::IndexType; -use lance_index::scalar::FullTextSearchQuery; -use lance_index::scalar::inverted::tokenizer::InvertedIndexParams; -use lance_tokenizer::TokenStream; -use parquet::arrow::async_reader::ParquetRecordBatchStreamBuilder; -use serde_json::json; -use uuid::Uuid; - -const TEXT_COL: &str = "text"; -const FTS_INDEX_NAME: &str = "text_fts"; -const HF_API_LISTING: &str = - "https://huggingface.co/api/datasets/HuggingFaceFW/fineweb/tree/main/sample/10BT"; -const HF_FILE_BASE: &str = "https://huggingface.co/datasets/HuggingFaceFW/fineweb/resolve/main/"; - -// ---------------------------------------------------------------------- -// Shape -// ---------------------------------------------------------------------- - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum Shape { - /// 4 equal-sized flushed gens + 1 equal-sized active. Cross-source - /// stats are similar so Local should already be close to Rescore. - Balanced, - /// 1 huge base + 4 tiny flushed gens + 1 tiny active. The case - /// where local-stats BM25 is most distorted vs a merged index. - MemwalSkewed, - /// Heterogeneous flushed sizes (1k+5k+25k+100k) + 25k active. - GrowingLsm, -} - -impl Shape { - fn parse(value: &str) -> std::result::Result { - match value { - "balanced" => Ok(Self::Balanced), - "memwal_skewed" => Ok(Self::MemwalSkewed), - "growing_lsm" => Ok(Self::GrowingLsm), - other => Err(format!( - "unknown shape '{other}', expected balanced|memwal_skewed|growing_lsm" - )), - } - } - - fn as_str(self) -> &'static str { - match self { - Self::Balanced => "balanced", - Self::MemwalSkewed => "memwal_skewed", - Self::GrowingLsm => "growing_lsm", - } - } - - /// (base_rows, vec_of_flushed_gen_rows, active_rows). `base_rows = None` - /// means no base table (fresh-tier-only). Designed so total - /// `corpus_rows` is roughly comparable across shapes for fair Jaccard - /// vs a single-merged baseline. - fn slicing(self) -> (Option, Vec, usize) { - match self { - Self::Balanced => (Some(100_000), vec![25_000; 4], 25_000), - Self::MemwalSkewed => (Some(1_000_000), vec![5_000; 4], 5_000), - Self::GrowingLsm => (Some(100_000), vec![1_000, 5_000, 25_000, 100_000], 25_000), - } - } - - fn total_rows(self) -> usize { - let (base, gens, active) = self.slicing(); - base.unwrap_or(0) + gens.iter().sum::() + active - } -} - -// ---------------------------------------------------------------------- -// Args -// ---------------------------------------------------------------------- - -#[derive(Debug, Clone)] -struct Args { - shape: Shape, - k: usize, - num_queries: usize, - rescore_factor: u32, - cache_dir: PathBuf, - work_dir: Option, - output: Option, - skip_baseline: bool, - tokio_threads: usize, - /// If `Some`, cap rows used per shape — useful for smoke testing - /// without downloading hundreds of MB. - max_corpus_rows: Option, -} - -impl Default for Args { - fn default() -> Self { - let threads = std::thread::available_parallelism().map_or(1, usize::from); - Self { - shape: Shape::Balanced, - k: 100, - num_queries: 100, - rescore_factor: 10, - cache_dir: std::env::temp_dir().join("mem_wal_fineweb_fts_cache"), - work_dir: None, - output: None, - skip_baseline: false, - tokio_threads: threads, - max_corpus_rows: None, - } - } -} - -fn parse(flag: &str, value: &str) -> Result -where - T: std::str::FromStr, - T::Err: std::fmt::Display, -{ - value - .parse::() - .map_err(|e| lance_core::Error::io(format!("flag {flag}: {e}"))) -} - -fn parse_args() -> Result { - let mut args = Args::default(); - let raw: Vec = std::env::args().skip(1).collect(); - let mut iter = raw.iter(); - while let Some(flag) = iter.next() { - match flag.as_str() { - "--shape" => { - args.shape = Shape::parse( - iter.next() - .ok_or_else(|| lance_core::Error::io("--shape needs value"))?, - ) - .map_err(lance_core::Error::io)? - } - "--k" => { - args.k = parse( - "--k", - iter.next() - .ok_or_else(|| lance_core::Error::io("--k needs value"))?, - )? - } - "--num-queries" => { - args.num_queries = parse( - "--num-queries", - iter.next() - .ok_or_else(|| lance_core::Error::io("--num-queries needs value"))?, - )? - } - "--rescore-factor" => { - args.rescore_factor = parse( - "--rescore-factor", - iter.next() - .ok_or_else(|| lance_core::Error::io("--rescore-factor needs value"))?, - )? - } - "--cache-dir" => { - args.cache_dir = PathBuf::from( - iter.next() - .ok_or_else(|| lance_core::Error::io("--cache-dir needs value"))?, - ) - } - "--work-dir" => { - args.work_dir = - Some(PathBuf::from(iter.next().ok_or_else(|| { - lance_core::Error::io("--work-dir needs value") - })?)) - } - "--output" => { - args.output = - Some(PathBuf::from(iter.next().ok_or_else(|| { - lance_core::Error::io("--output needs value") - })?)) - } - "--skip-baseline" => args.skip_baseline = true, - "--tokio-threads" => { - args.tokio_threads = parse( - "--tokio-threads", - iter.next() - .ok_or_else(|| lance_core::Error::io("--tokio-threads needs value"))?, - )? - } - "--max-corpus-rows" => { - args.max_corpus_rows = Some(parse( - "--max-corpus-rows", - iter.next() - .ok_or_else(|| lance_core::Error::io("--max-corpus-rows needs value"))?, - )?) - } - // criterion-style noise we want to ignore so `cargo bench` - // can hand us nothing extra without erroring. - "--bench" | "--test" => {} - other => { - eprintln!("unknown flag: {other}"); - return Err(lance_core::Error::io(format!("unknown flag {other}"))); - } - } - } - Ok(args) -} - -// ---------------------------------------------------------------------- -// FineWeb loading (mirrors mem_wal_fineweb_fts.rs) -// ---------------------------------------------------------------------- - -#[derive(serde::Deserialize)] -struct HfTreeEntry { - #[serde(rename = "type")] - kind: String, - path: String, -} - -async fn list_shard_paths() -> Result> { - let entries: Vec = reqwest::get(HF_API_LISTING) - .await - .map_err(|e| lance_core::Error::io(format!("listing HTTP: {e}")))? - .json() - .await - .map_err(|e| lance_core::Error::io(format!("listing JSON: {e}")))?; - let mut shards: Vec = entries - .into_iter() - .filter(|e| e.kind == "file" && e.path.ends_with(".parquet")) - .map(|e| e.path) - .collect(); - shards.sort(); - Ok(shards) -} - -async fn download_shard(rel_path: &str, dest: &std::path::Path) -> Result<()> { - if dest.exists() { - return Ok(()); - } - let url = format!("{HF_FILE_BASE}{rel_path}"); - let tmp = dest.with_extension("part"); - for attempt in 1..=5u32 { - println!("downloading {rel_path} (attempt {attempt}/5) ..."); - let result: Result = async { - let resp = reqwest::get(&url) - .await - .map_err(|e| lance_core::Error::io(format!("download HTTP: {e}")))?; - if !resp.status().is_success() { - return Err(lance_core::Error::io(format!( - "download {url} -> status {}", - resp.status() - ))); - } - resp.bytes() - .await - .map_err(|e| lance_core::Error::io(format!("read body: {e}"))) - } - .await; - match result { - Ok(bytes) => { - std::fs::write(&tmp, &bytes) - .map_err(|e| lance_core::Error::io(format!("write: {e}")))?; - std::fs::rename(&tmp, dest) - .map_err(|e| lance_core::Error::io(format!("rename: {e}")))?; - println!( - " wrote {:.1} MB to {}", - bytes.len() as f64 / 1024.0 / 1024.0, - dest.display() - ); - return Ok(()); - } - Err(e) if attempt < 5 => { - eprintln!(" attempt {attempt} failed: {e}; retrying"); - tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await; - } - Err(e) => return Err(e), - } - } - unreachable!() -} - -async fn read_shard_text( - path: &std::path::Path, - out: &mut Vec, - max_rows: usize, -) -> Result { - let file = tokio::fs::File::open(path) - .await - .map_err(|e| lance_core::Error::io(format!("open parquet: {e}")))?; - let builder = ParquetRecordBatchStreamBuilder::new(file) - .await - .map_err(|e| lance_core::Error::io(format!("parquet builder: {e}")))?; - let mut stream = builder - .build() - .map_err(|e| lance_core::Error::io(format!("parquet stream: {e}")))?; - let mut taken = 0usize; - while taken < max_rows { - let Some(rb) = stream - .try_next() - .await - .map_err(|e| lance_core::Error::io(format!("parquet read: {e}")))? - else { - break; - }; - let col = rb - .column_by_name(TEXT_COL) - .ok_or_else(|| lance_core::Error::io("text column missing".to_string()))?; - let strs = col - .as_any() - .downcast_ref::() - .ok_or_else(|| lance_core::Error::io("text not StringArray".to_string()))?; - for i in 0..strs.len() { - if taken >= max_rows { - break; - } - if strs.is_null(i) { - continue; - } - out.push(strs.value(i).to_string()); - taken += 1; - } - } - Ok(taken) -} - -async fn load_corpus(needed_rows: usize, cache_dir: &std::path::Path) -> Result> { - std::fs::create_dir_all(cache_dir) - .map_err(|e| lance_core::Error::io(format!("mkdir cache: {e}")))?; - let shards = list_shard_paths().await?; - println!("fineweb sample/10BT: {} shards", shards.len()); - let mut buf: Vec = Vec::with_capacity(needed_rows); - for rel in &shards { - if buf.len() >= needed_rows { - break; - } - let name = rel.rsplit('/').next().unwrap_or(rel); - let local = cache_dir.join(name); - download_shard(rel, &local).await?; - let want = needed_rows - buf.len(); - let got = read_shard_text(&local, &mut buf, want).await?; - println!(" shard {name} -> {got} rows (cumulative {})", buf.len()); - } - if buf.len() < needed_rows { - return Err(lance_core::Error::io(format!( - "fineweb yielded only {} rows, need {needed_rows}", - buf.len() - ))); - } - Ok(buf) -} - -// ---------------------------------------------------------------------- -// Dataset shaping -// ---------------------------------------------------------------------- - -fn make_schema() -> Arc { - let mut id_meta = HashMap::new(); - id_meta.insert( - "lance-schema:unenforced-primary-key".to_string(), - "true".to_string(), - ); - let id_field = Field::new("id", DataType::Int64, false).with_metadata(id_meta); - Arc::new(ArrowSchema::new(vec![ - id_field, - Field::new(TEXT_COL, DataType::Utf8, true), - ])) -} - -fn slice_to_batch(schema: Arc, start_id: i64, texts: &[String]) -> RecordBatch { - let ids = Int64Array::from_iter_values(start_id..start_id + texts.len() as i64); - let text = StringArray::from_iter_values(texts.iter().map(String::as_str)); - RecordBatch::try_new(schema, vec![Arc::new(ids), Arc::new(text)]).unwrap() -} - -async fn write_lance(uri: &str, batches: Vec) -> Result { - let schema = batches[0].schema(); - let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); - Dataset::write(reader, uri, Some(WriteParams::default())).await -} - -async fn create_fts_index(ds: &mut Dataset) -> Result<()> { - ds.create_index( - &[TEXT_COL], - IndexType::Inverted, - Some(FTS_INDEX_NAME.to_string()), - &InvertedIndexParams::default(), - false, - ) - .await?; - Ok(()) -} - -/// Build the LSM shape: writes one Lance dataset per flushed gen (with -/// FTS index), and constructs an active in-memory memtable from the -/// active slice. Returns the collector inputs ready for `LsmScanner`. -async fn build_lsm_shape( - shape: Shape, - corpus: &[String], - work_dir: &std::path::Path, -) -> Result<( - Option>, - Vec, - Uuid, - InMemoryMemTables, -)> { - let schema = make_schema(); - let (base_rows, gen_rows, active_rows) = shape.slicing(); - let total = base_rows.unwrap_or(0) + gen_rows.iter().sum::() + active_rows; - assert!( - corpus.len() >= total, - "shape needs {total} rows, corpus has {}", - corpus.len() - ); - - let mut cursor: usize = 0; - let mut id_cursor: i64 = 0; - let shard_id = Uuid::new_v4(); - - // Base. - let base = if let Some(n) = base_rows { - let uri = format!("{}/base", work_dir.display()); - let mut ds = write_lance( - &uri, - vec![slice_to_batch( - schema.clone(), - id_cursor, - &corpus[cursor..cursor + n], - )], - ) - .await?; - create_fts_index(&mut ds).await?; - let ds = Arc::new(Dataset::open(&uri).await?); - cursor += n; - id_cursor += n as i64; - Some(ds) - } else { - None - }; - - // Flushed generations. - let mut shard_snapshot = lance::dataset::mem_wal::scanner::ShardSnapshot::new(shard_id) - .with_current_generation((gen_rows.len() as u64).max(1)); - let base_uri = base - .as_ref() - .map(|d| d.uri().to_string()) - .unwrap_or_else(|| format!("{}/base", work_dir.display())); - for (i, &n) in gen_rows.iter().enumerate() { - let gen_num = (i + 1) as u64; - let rel = format!("gen_{gen_num}"); - let uri = format!("{base_uri}/_mem_wal/{shard_id}/{rel}"); - let mut ds = write_lance( - &uri, - vec![slice_to_batch( - schema.clone(), - id_cursor, - &corpus[cursor..cursor + n], - )], - ) - .await?; - create_fts_index(&mut ds).await?; - cursor += n; - id_cursor += n as i64; - shard_snapshot = shard_snapshot.with_flushed_generation(gen_num, rel); - } - - // Active memtable. - let batch_store = Arc::new(BatchStore::with_capacity(active_rows.max(16))); - let mut indexes = IndexStore::new(); - indexes.add_fts(FTS_INDEX_NAME.to_string(), 1, TEXT_COL.to_string()); - let active_batch = slice_to_batch( - schema.clone(), - id_cursor, - &corpus[cursor..cursor + active_rows], - ); - batch_store.append(active_batch.clone()).unwrap(); - indexes - .insert_with_batch_position(&active_batch, 0, Some(0)) - .unwrap(); - let indexes = Arc::new(indexes); - - let in_memory = InMemoryMemTables { - active: InMemoryMemTableRef { - batch_store, - index_store: indexes, - schema, - generation: (gen_rows.len() as u64) + 1, - }, - frozen: vec![], - }; - - Ok((base, vec![shard_snapshot], shard_id, in_memory)) -} - -/// Build a single Lance dataset containing the full corpus + one FTS -/// index. Used as the ground-truth reference for Jaccard / score Pearson. -async fn build_baseline(corpus: &[String], work_dir: &std::path::Path) -> Result> { - let schema = make_schema(); - let uri = format!("{}/baseline_merged", work_dir.display()); - let batch = slice_to_batch(schema, 0, corpus); - let mut ds = write_lance(&uri, vec![batch]).await?; - create_fts_index(&mut ds).await?; - Ok(Arc::new(Dataset::open(&uri).await?)) -} - -// ---------------------------------------------------------------------- -// Query selection -// ---------------------------------------------------------------------- - -/// Pick `n` representative single-term queries from the corpus. -/// -/// Tokenizes a sample of the corpus with the default English tokenizer, -/// counts term frequencies, and returns terms in the "long tail" — not -/// the absolute most frequent (those match nearly every doc and don't -/// produce interesting BM25 rankings) and not the rarest (those match -/// nothing and produce empty top-K). Window roughly between the 80th -/// and 99th percentile of df. -fn pick_queries(corpus: &[String], n: usize) -> Vec { - let sample_n = corpus.len().min(2_000); - let mut tokenizer = InvertedIndexParams::default().build().expect("tokenizer"); - let mut df: HashMap = HashMap::new(); - for text in corpus.iter().take(sample_n) { - let mut stream = tokenizer.token_stream_for_doc(text); - let mut seen: HashSet = HashSet::new(); - while let Some(tok) = stream.next() { - if seen.insert(tok.text.clone()) { - *df.entry(tok.text.clone()).or_insert(0) += 1; - } - } - } - let mut all: Vec<(String, usize)> = df.into_iter().collect(); - all.sort_by_key(|(_, c)| *c); - // Pull from the 80th–99th percentile window. - let lo = (all.len() as f64 * 0.80) as usize; - let hi = (all.len() as f64 * 0.99) as usize; - let window: &[(String, usize)] = &all[lo.min(all.len())..hi.min(all.len())]; - if window.is_empty() { - return Vec::new(); - } - let stride = (window.len() / n.max(1)).max(1); - let mut out: Vec = Vec::with_capacity(n); - for (i, (term, _)) in window.iter().enumerate() { - if out.len() >= n { - break; - } - if i % stride == 0 { - out.push(term.clone()); - } - } - out -} - -// ---------------------------------------------------------------------- -// Mode runner -// ---------------------------------------------------------------------- - -#[derive(Debug)] -struct ModeRun { - /// Per-query top-k row id sets. - top_ids: Vec>, - /// Per-query (id → score) maps, for Pearson on the intersection. - scored: Vec>, - latencies_us: Vec, -} - -async fn run_mode( - scanner: &LsmScanner, - mode: FtsScoringMode, - queries: &[String], - k: usize, -) -> Result { - let mut top_ids = Vec::with_capacity(queries.len()); - let mut scored = Vec::with_capacity(queries.len()); - let mut latencies_us = Vec::with_capacity(queries.len()); - for q in queries { - let t = Instant::now(); - let plan = scanner - .full_text_search(TEXT_COL, FullTextSearchQuery::new(q.clone()), k, mode) - .await?; - let ctx = datafusion::prelude::SessionContext::new(); - let stream = plan - .execute(0, ctx.task_ctx()) - .map_err(|e| lance_core::Error::io(format!("plan execute for query '{q}': {e}")))?; - let batches: Vec = stream - .try_collect() - .await - .map_err(|e| lance_core::Error::io(format!("collect for query '{q}': {e}")))?; - latencies_us.push(t.elapsed().as_micros() as u64); - - let mut ids: HashSet = HashSet::new(); - let mut score_map: HashMap = HashMap::new(); - for b in &batches { - let id_col = b - .column_by_name("id") - .expect("id col") - .as_any() - .downcast_ref::() - .expect("id Int64"); - let score_col = b - .column_by_name("_score") - .expect("_score col") - .as_any() - .downcast_ref::() - .expect("_score Float32"); - for i in 0..b.num_rows() { - let id = id_col.value(i); - ids.insert(id); - score_map.insert(id, score_col.value(i)); - } - } - top_ids.push(ids); - scored.push(score_map); - } - Ok(ModeRun { - top_ids, - scored, - latencies_us, - }) -} - -/// Run the merged-index baseline (single Lance dataset). We reuse the -/// same `scanner.full_text_search` API on the merged dataset so the -/// score scale is comparable to the LSM result (`scanner` already uses -/// `build_global_bm25_scorer` for its internal multi-partition case). -async fn run_baseline(baseline: &Dataset, queries: &[String], k: usize) -> Result { - let mut top_ids = Vec::with_capacity(queries.len()); - let mut scored = Vec::with_capacity(queries.len()); - let mut latencies_us = Vec::with_capacity(queries.len()); - for q in queries { - let t = Instant::now(); - let mut scanner = baseline.scan(); - scanner.project(&["id", TEXT_COL])?; - scanner.full_text_search( - FullTextSearchQuery::new(q.clone()) - .with_column(TEXT_COL.to_string())? - .limit(Some(k as i64)), - )?; - let batches: Vec = scanner.try_into_stream().await?.try_collect().await?; - latencies_us.push(t.elapsed().as_micros() as u64); - - let mut ids: HashSet = HashSet::new(); - let mut score_map: HashMap = HashMap::new(); - for b in &batches { - let id_col = b - .column_by_name("id") - .expect("id col") - .as_any() - .downcast_ref::() - .expect("id Int64"); - let score_col = b - .column_by_name("_score") - .expect("_score col") - .as_any() - .downcast_ref::() - .expect("_score Float32"); - for i in 0..b.num_rows() { - let id = id_col.value(i); - ids.insert(id); - score_map.insert(id, score_col.value(i)); - } - } - top_ids.push(ids); - scored.push(score_map); - } - Ok(ModeRun { - top_ids, - scored, - latencies_us, - }) -} - -// ---------------------------------------------------------------------- -// Metrics -// ---------------------------------------------------------------------- - -fn percentile(values: &[f64], pct: f64) -> f64 { - if values.is_empty() { - return 0.0; - } - let mut sorted = values.to_vec(); - sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); - let rank = (pct / 100.0) * (sorted.len() - 1) as f64; - let lo = rank.floor() as usize; - let hi = rank.ceil() as usize; - if lo == hi { - sorted[lo] - } else { - let frac = rank - lo as f64; - sorted[lo] * (1.0 - frac) + sorted[hi] * frac - } -} - -fn mean_jaccard(a: &[HashSet], b: &[HashSet]) -> f64 { - let pairs: Vec = a - .iter() - .zip(b.iter()) - .filter_map(|(x, y)| { - if x.is_empty() && y.is_empty() { - None - } else { - let inter = x.intersection(y).count() as f64; - let union = x.union(y).count() as f64; - Some(inter / union) - } - }) - .collect(); - if pairs.is_empty() { - 0.0 - } else { - pairs.iter().sum::() / pairs.len() as f64 - } -} - -/// Pearson correlation of scores on the intersection. Averaged over -/// queries that have at least 2 overlapping ids (Pearson is undefined -/// for fewer). -fn mean_pearson(a: &[HashMap], b: &[HashMap]) -> f64 { - let pairs: Vec = a - .iter() - .zip(b.iter()) - .filter_map(|(x, y)| { - let common: Vec = x.keys().filter(|k| y.contains_key(k)).copied().collect(); - if common.len() < 2 { - return None; - } - let xs: Vec = common.iter().map(|i| x[i] as f64).collect(); - let ys: Vec = common.iter().map(|i| y[i] as f64).collect(); - let mx = xs.iter().sum::() / xs.len() as f64; - let my = ys.iter().sum::() / ys.len() as f64; - let num: f64 = xs - .iter() - .zip(ys.iter()) - .map(|(a, b)| (a - mx) * (b - my)) - .sum(); - let dx: f64 = xs.iter().map(|a| (a - mx).powi(2)).sum::().sqrt(); - let dy: f64 = ys.iter().map(|b| (b - my).powi(2)).sum::().sqrt(); - if dx == 0.0 || dy == 0.0 { - None - } else { - Some(num / (dx * dy)) - } - }) - .collect(); - if pairs.is_empty() { - 0.0 - } else { - pairs.iter().sum::() / pairs.len() as f64 - } -} - -// ---------------------------------------------------------------------- -// Run -// ---------------------------------------------------------------------- - -async fn run(args: Args) -> Result<()> { - let needed = args - .max_corpus_rows - .unwrap_or_else(|| args.shape.total_rows()); - println!( - "shape={} needed_rows={} k={} num_queries={} rescore_factor={}", - args.shape.as_str(), - needed, - args.k, - args.num_queries, - args.rescore_factor - ); - - let corpus = load_corpus(needed, &args.cache_dir).await?; - let queries = pick_queries(&corpus, args.num_queries); - println!("picked {} query terms", queries.len()); - - let work_dir = if let Some(d) = &args.work_dir { - std::fs::create_dir_all(d).map_err(|e| lance_core::Error::io(format!("mkdir: {e}")))?; - d.clone() - } else { - tempfile::tempdir() - .map_err(|e| lance_core::Error::io(format!("tempdir: {e}")))? - .keep() - }; - - let (base, shard_snapshots, shard_id, in_memory) = - build_lsm_shape(args.shape, &corpus, &work_dir).await?; - - let pk_columns = vec!["id".to_string()]; - let scanner = if let Some(b) = base.clone() { - LsmScanner::new(b, shard_snapshots.clone(), pk_columns.clone()) - } else { - let schema = make_schema(); - LsmScanner::without_base_table( - schema, - format!("{}/base", work_dir.display()), - shard_snapshots.clone(), - pk_columns.clone(), - ) - } - .with_in_memory_memtables(shard_id, in_memory); - - println!("running Local mode ..."); - let local = run_mode(&scanner, FtsScoringMode::Local, &queries, args.k).await?; - println!("running LocalWithGlobalRescore mode ..."); - let rescore = run_mode( - &scanner, - FtsScoringMode::LocalWithGlobalRescore { - rescore_factor: args.rescore_factor, - }, - &queries, - args.k, - ) - .await?; - - let baseline_run = if args.skip_baseline { - None - } else { - println!("building merged-index baseline ..."); - let baseline = build_baseline(&corpus, &work_dir).await?; - Some(run_baseline(&baseline, &queries, args.k).await?) - }; - - // Aggregate metrics. - let lat = |run: &ModeRun| -> (f64, f64, f64, f64) { - let mut v: Vec = run - .latencies_us - .iter() - .map(|x| *x as f64 / 1000.0) - .collect(); - v.sort_by(|a, b| a.partial_cmp(b).unwrap()); - let mean = v.iter().sum::() / v.len() as f64; - ( - mean, - percentile(&v, 50.0), - percentile(&v, 95.0), - percentile(&v, 99.0), - ) - }; - let (local_mean, local_p50, local_p95, local_p99) = lat(&local); - let (rescore_mean, rescore_p50, rescore_p95, rescore_p99) = lat(&rescore); - - let jaccard_local_rescore = mean_jaccard(&local.top_ids, &rescore.top_ids); - let pearson_local_rescore = mean_pearson(&local.scored, &rescore.scored); - let (jaccard_local_baseline, pearson_local_baseline) = if let Some(b) = &baseline_run { - ( - mean_jaccard(&local.top_ids, &b.top_ids), - mean_pearson(&local.scored, &b.scored), - ) - } else { - (f64::NAN, f64::NAN) - }; - let (jaccard_rescore_baseline, pearson_rescore_baseline) = if let Some(b) = &baseline_run { - ( - mean_jaccard(&rescore.top_ids, &b.top_ids), - mean_pearson(&rescore.scored, &b.scored), - ) - } else { - (f64::NAN, f64::NAN) - }; - - let summary = json!({ - "shape": args.shape.as_str(), - "k": args.k, - "num_queries": queries.len(), - "rescore_factor": args.rescore_factor, - "local": { - "mean_ms": local_mean, - "p50_ms": local_p50, - "p95_ms": local_p95, - "p99_ms": local_p99, - }, - "rescore": { - "mean_ms": rescore_mean, - "p50_ms": rescore_p50, - "p95_ms": rescore_p95, - "p99_ms": rescore_p99, - }, - "jaccard": { - "local_vs_rescore": jaccard_local_rescore, - "local_vs_baseline": jaccard_local_baseline, - "rescore_vs_baseline": jaccard_rescore_baseline, - }, - "pearson_score": { - "local_vs_rescore": pearson_local_rescore, - "local_vs_baseline": pearson_local_baseline, - "rescore_vs_baseline": pearson_rescore_baseline, - }, - }); - - println!( - "\n=== Result ===\n{}", - serde_json::to_string_pretty(&summary).unwrap() - ); - if let Some(path) = &args.output { - std::fs::write(path, serde_json::to_string_pretty(&summary).unwrap()) - .map_err(|e| lance_core::Error::io(format!("write output: {e}")))?; - println!("\nwrote {}", path.display()); - } - Ok(()) -} - -fn main() -> Result<()> { - let args = parse_args()?; - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .worker_threads(args.tokio_threads) - .build() - .map_err(|e| lance_core::Error::io(format!("tokio: {e}")))?; - rt.block_on(run(args)) -}