Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
352 changes: 352 additions & 0 deletions rust/lance-index/src/scalar/inverted/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,46 @@ impl PartitionCandidates {
}
}

/// A single hit from [`InvertedIndex::bm25_candidate_search`] decorated with
/// the BM25 sufficient statistics needed to recompute the score against a
/// different (e.g. globally-aggregated) corpus.
///
/// `term_freqs[i]` is the frequency of the caller-supplied query token at
/// input position `i`. A `0` means the term did not appear in the doc.
#[derive(Debug, Clone)]
pub struct InvertedIndexCandidate {
pub row_id: u64,
pub doc_length: u32,
pub term_freqs: Vec<u32>,
/// Score this segment assigned under the scorer that drove pruning.
pub local_score: f32,
}

/// Heap entry used internally by [`InvertedIndex::bm25_candidate_search`] to
/// keep a bounded top-K' across partitions. Ordering is by score so a
/// min-heap (`Reverse<ScoredCandidate>`) prunes the lowest score on overflow.
struct ScoredCandidate {
score: crate::vector::graph::OrderedFloat,
inner: InvertedIndexCandidate,
}

impl PartialEq for ScoredCandidate {
fn eq(&self, other: &Self) -> bool {
self.score == other.score
}
}
impl Eq for ScoredCandidate {}
impl PartialOrd for ScoredCandidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ScoredCandidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.score.cmp(&other.score)
}
}

#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Default)]
pub enum TokenSetFormat {
Arrow,
Expand Down Expand Up @@ -677,6 +717,171 @@ impl InvertedIndex {
.unzip())
}

/// Per-candidate variant of [`bm25_search`] that returns the raw BM25
/// sufficient statistics each hit was scored from, instead of just
/// `(row_id, score)`.
///
/// Callers building a multi-segment query plan (`Local` mode + `Global`
/// rescore) need to combine per-segment top-K' results into a single
/// globally-comparable ranking. With only `(row_id, score)` they cannot
/// recompute scores against globally-aggregated stats; with
/// `(row_id, doc_length, term_freqs)` they can.
///
/// Pruning and ranking inside this segment still use `base_scorer` (or
/// the local IndexBM25Scorer when `None`) — same selection policy as
/// [`bm25_search`], so the returned set is the same as that method would
/// return for the same `params.limit`, just decorated with sufficient
/// stats. `term_freqs` is sized to `tokens.len()` and indexed by the
/// caller's input token order; a `0` means the term did not appear in
/// the doc (e.g., because the query is OR and only some terms matched).
#[instrument(level = "debug", skip_all)]
pub async fn bm25_candidate_search(
&self,
tokens: Arc<Tokens>,
params: Arc<FtsSearchParams>,
operator: Operator,
prefilter: Arc<dyn PreFilter>,
metrics: Arc<dyn MetricsCollector>,
base_scorer: Option<&MemBM25Scorer>,
) -> Result<Vec<InvertedIndexCandidate>> {
let local_scorer;
let scorer: &dyn Scorer = if let Some(base_scorer) = base_scorer {
base_scorer
} else {
local_scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref()));
&local_scorer
};

let limit = params.limit.unwrap_or(usize::MAX);
if limit == 0 {
return Ok(Vec::new());
}
let mask = prefilter.mask();
let num_query_terms = tokens.len();

// Min-heap of size `limit` keyed by local_score so we can prune the
// smallest score in O(log limit) when a new candidate arrives.
let mut heap: BinaryHeap<Reverse<ScoredCandidate>> = BinaryHeap::new();

let parts = self
.partitions
.iter()
.map(|part| {
let part = part.clone();
let tokens = tokens.clone();
let params = params.clone();
let mask = mask.clone();
let metrics = metrics.clone();
async move {
let postings = part
.load_posting_lists(tokens.as_ref(), params.as_ref(), metrics.as_ref())
.await?;
if postings.is_empty() {
return Result::Ok(PartitionCandidates::empty());
}
let max_position = postings
.iter()
.map(|posting| posting.term_index() as usize)
.max()
.unwrap_or_default();
let mut tokens_by_position = vec![String::new(); max_position + 1];
for posting in &postings {
let idx = posting.term_index() as usize;
tokens_by_position[idx] = posting.token().to_owned();
}
let params = params.clone();
let mask = mask.clone();
let metrics = metrics.clone();
spawn_cpu(move || {
let candidates = part.bm25_search(
params.as_ref(),
operator,
mask,
postings,
metrics.as_ref(),
)?;
Ok(PartitionCandidates {
tokens_by_position,
candidates,
})
})
.await
}
})
.collect::<Vec<_>>();
let mut parts = stream::iter(parts).buffer_unordered(get_num_compute_intensive_cpus());
let mut idf_cache: HashMap<String, f32> = HashMap::new();
while let Some(res) = parts.try_next().await? {
if res.candidates.is_empty() {
continue;
}
// Map this partition's `term_index` → the caller's query-token
// position. `tokens_by_position` is partition-local; reorder so
// returned `term_freqs[input_idx]` matches the caller's input.
let term_index_to_input_idx: Vec<Option<usize>> = res
.tokens_by_position
.iter()
.map(|tok| tokens.token_index(tok))
.collect();
let mut idf_by_position = Vec::with_capacity(res.tokens_by_position.len());
for token in &res.tokens_by_position {
let idf_weight = match idf_cache.get(token) {
Some(weight) => *weight,
None => {
let weight = scorer.query_weight(token);
idf_cache.insert(token.clone(), weight);
weight
}
};
idf_by_position.push(idf_weight);
}
for DocCandidate {
row_id,
freqs,
doc_length,
} in res.candidates
{
let mut term_freqs = vec![0u32; num_query_terms];
let mut score = 0.0;
for (term_index, freq) in freqs.into_iter() {
let pos = term_index as usize;
debug_assert!(pos < idf_by_position.len());
score += idf_by_position[pos] * scorer.doc_weight(freq, doc_length);
if let Some(input_idx) = term_index_to_input_idx[pos] {
term_freqs[input_idx] = freq;
}
}
let candidate = InvertedIndexCandidate {
row_id,
doc_length,
term_freqs,
local_score: score,
};
if heap.len() < limit {
heap.push(Reverse(ScoredCandidate {
score: crate::vector::graph::OrderedFloat(score),
inner: candidate,
}));
} else if heap.peek().unwrap().0.score.0 < score {
heap.pop();
heap.push(Reverse(ScoredCandidate {
score: crate::vector::graph::OrderedFloat(score),
inner: candidate,
}));
}
}
}
// `BinaryHeap<Reverse<_>>` is a min-heap on score, so
// `into_sorted_vec` returns ascending by `Reverse(score)`, which is
// descending by `score` — exactly the "best first" order callers
// expect, matching `bm25_search`'s convention.
Ok(heap
.into_sorted_vec()
.into_iter()
.map(|Reverse(s)| s.inner)
.collect())
}

async fn load_legacy_index(
store: Arc<dyn IndexStore>,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
Expand Down Expand Up @@ -5455,6 +5660,153 @@ mod tests {
}
}

#[tokio::test]
async fn test_bm25_candidate_search_matches_bm25_search() {
// Same 2-partition fixture as test_bm25_search_uses_global_idf, but
// we also call bm25_candidate_search and verify:
// - the returned row_id set matches bm25_search
// - doc_length and term_freqs are correct against the inputs
// - local_score equals bm25_search's score for the same row
let tmpdir = TempObjDir::default();
let store = Arc::new(LanceIndexStore::new(
ObjectStore::local().into(),
tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));

let mut builder0 = InnerBuilder::new(0, false, TokenSetFormat::default());
builder0.tokens.add("alpha".to_owned());
builder0.tokens.add("beta".to_owned());
builder0.posting_lists.push(PostingListBuilder::new(false));
builder0.posting_lists.push(PostingListBuilder::new(false));
// doc 0 ("alpha"): freq 2, dl 5 → only doc with non-1 freq, used to
// prove term_freqs really comes from the posting.
builder0.posting_lists[0].add(0, PositionRecorder::Count(2));
builder0.posting_lists[1].add(1, PositionRecorder::Count(1));
builder0.posting_lists[1].add(2, PositionRecorder::Count(1));
builder0.docs.append(100, 5);
builder0.docs.append(101, 1);
builder0.docs.append(102, 1);
builder0.write(store.as_ref()).await.unwrap();

let mut builder1 = InnerBuilder::new(1, false, TokenSetFormat::default());
builder1.tokens.add("alpha".to_owned());
builder1.posting_lists.push(PostingListBuilder::new(false));
builder1.posting_lists[0].add(0, PositionRecorder::Count(1));
builder1.docs.append(200, 1);
builder1.write(store.as_ref()).await.unwrap();

let metadata = std::collections::HashMap::from_iter(vec![
(
"partitions".to_owned(),
serde_json::to_string(&vec![0u64, 1u64]).unwrap(),
),
(
"params".to_owned(),
serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
),
(
TOKEN_SET_FORMAT_KEY.to_owned(),
TokenSetFormat::default().to_string(),
),
]);
let mut writer = store
.new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
.await
.unwrap();
writer.finish_with_metadata(metadata).await.unwrap();

let cache = Arc::new(LanceCache::with_capacity(4096));
let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
.await
.unwrap();

// Query 2 terms; "missing" must not crash and must yield freq=0.
let tokens = Arc::new(Tokens::new(
vec!["alpha".to_string(), "missing".to_string()],
DocType::Text,
));
let params = Arc::new(FtsSearchParams::new().with_limit(Some(10)));
let prefilter = Arc::new(NoFilter);
let metrics = Arc::new(NoOpMetricsCollector);

let (row_ids, scores) = index
.bm25_search(
tokens.clone(),
params.clone(),
Operator::Or,
prefilter.clone(),
metrics.clone(),
None,
)
.await
.unwrap();

let candidates = index
.bm25_candidate_search(
tokens.clone(),
params,
Operator::Or,
prefilter,
metrics,
None,
)
.await
.unwrap();

// Same hit set.
let bm25_rows: std::collections::HashSet<u64> = row_ids.iter().copied().collect();
let cand_rows: std::collections::HashSet<u64> =
candidates.iter().map(|c| c.row_id).collect();
assert_eq!(bm25_rows, cand_rows);

// Candidates are sorted by local_score DESC.
for w in candidates.windows(2) {
assert!(
w[0].local_score >= w[1].local_score,
"candidates not sorted: {:?} then {:?}",
w[0].local_score,
w[1].local_score
);
}

// term_freqs is in input-token order: ["alpha", "missing"].
for c in &candidates {
assert_eq!(c.term_freqs.len(), 2);
// No row has "missing", so column 1 must be 0.
assert_eq!(c.term_freqs[1], 0);
}
let by_row: std::collections::HashMap<u64, &InvertedIndexCandidate> =
candidates.iter().map(|c| (c.row_id, c)).collect();

// Doc 100 had freq=2 ("alpha" added twice) and dl=5.
let c100 = by_row.get(&100).expect("row 100");
assert_eq!(c100.term_freqs[0], 2);
assert_eq!(c100.doc_length, 5);

// Doc 200 had freq=1 and dl=1.
let c200 = by_row.get(&200).expect("row 200");
assert_eq!(c200.term_freqs[0], 1);
assert_eq!(c200.doc_length, 1);

// local_score must match bm25_search's score for the same row.
let bm25_by_row: std::collections::HashMap<u64, f32> = row_ids
.iter()
.copied()
.zip(scores.iter().copied())
.collect();
for c in &candidates {
let want = bm25_by_row[&c.row_id];
assert!(
(c.local_score - want).abs() < 1e-6,
"row {} local_score={} vs bm25_search {}",
c.row_id,
c.local_score,
want,
);
}
}

#[tokio::test]
async fn test_phrase_query_reads_legacy_per_doc_positions() {
let tmpdir = TempObjDir::default();
Expand Down
5 changes: 5 additions & 0 deletions rust/lance/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ name = "mem_wal_fts_bench"
path = "benches/mem_wal/fts/mem_wal_fts_bench.rs"
harness = false

[[bench]]
name = "mem_wal_fts_read_bench"
path = "benches/mem_wal/fts/mem_wal_fts_read_bench.rs"
harness = false

[[bench]]
name = "manifest_commit"
harness = false
Expand Down
Loading
Loading