From 1aee9d68488e0eb61cc6b1f2cacdbe245218ada6 Mon Sep 17 00:00:00 2001 From: Minoru OSUKA Date: Thu, 4 Jun 2026 13:43:42 +0900 Subject: [PATCH] fix(search): use total_cmp for NaN-safe float ordering (#667) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Search, scoring, ranking, and segment-maintenance code ordered floats with partial_cmp(...).unwrap_or(Ordering::Equal) (and a few .unwrap()). That makes the comparator non-total: a NaN compares Equal to everything, which sort_unstable_by / BinaryHeap forbid (silent reorder, or a panic on recent std), and the .unwrap() sites panic on NaN outright (e.g. deletion_ratio is 0/0 = NaN when total_docs == 0, in a background merge). Replace all NaN-unsafe float comparators with f32/f64::total_cmp (a real IEEE-754 total order) across 70 sites in 22 files, preserving each sort's direction and tie-break structure. For non-NaN values total_cmp matches partial_cmp exactly, so behaviour is unchanged except NaN is now ordered deterministically instead of reordering or panicking. The HNSW Candidate (min-heap) / ResultCandidate (max-heap) Ord impls — the "float heaps" the issue names — are covered; new tests assert the heaps handle a NaN distance without panic or loss. Closes #667 --- laurus/src/engine.rs | 2 +- .../inverted/maintenance/optimization.rs | 2 +- laurus/src/lexical/index/inverted/searcher.rs | 25 ++----- .../lexical/index/inverted/segment/manager.rs | 5 +- .../index/inverted/segment/merge_engine.rs | 5 +- .../index/inverted/segment/merge_policy.rs | 4 +- laurus/src/lexical/query/advanced_query.rs | 2 +- laurus/src/lexical/query/collector.rs | 39 ++++------ laurus/src/lexical/query/geo.rs | 20 ++---- laurus/src/lexical/query/geo3d.rs | 28 +++----- laurus/src/lexical/search/features/facet.rs | 6 +- .../src/lexical/search/features/highlight.rs | 14 ++-- laurus/src/spelling/suggest.rs | 5 +- laurus/src/vector/index/config.rs | 14 ++-- laurus/src/vector/index/flat/searcher.rs | 8 +-- laurus/src/vector/index/hnsw/searcher.rs | 71 +++++++++++++++---- laurus/src/vector/index/hnsw/writer.rs | 30 +++----- laurus/src/vector/index/ivf/searcher.rs | 7 +- laurus/src/vector/index/segmented_field.rs | 5 +- laurus/src/vector/search/searcher.rs | 22 ++---- laurus/src/vector/store.rs | 21 ++---- laurus/src/vector/store/memory.rs | 3 +- 22 files changed, 137 insertions(+), 201 deletions(-) diff --git a/laurus/src/engine.rs b/laurus/src/engine.rs index 49d577d6..a949b34b 100644 --- a/laurus/src/engine.rs +++ b/laurus/src/engine.rs @@ -1530,7 +1530,7 @@ impl Engine { .collect(); // Sort by fused score descending - intermediate.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + intermediate.sort_by(|a, b| b.1.total_cmp(&a.1)); // Limit results if intermediate.len() > limit { diff --git a/laurus/src/lexical/index/inverted/maintenance/optimization.rs b/laurus/src/lexical/index/inverted/maintenance/optimization.rs index 5d6124be..b7531d19 100644 --- a/laurus/src/lexical/index/inverted/maintenance/optimization.rs +++ b/laurus/src/lexical/index/inverted/maintenance/optimization.rs @@ -227,7 +227,7 @@ impl IndexOptimizer { // Get all segments sorted by deletion ratio (highest first) let mut segments = segment_manager.get_segments(); - segments.sort_by(|a, b| b.deletion_ratio().partial_cmp(&a.deletion_ratio()).unwrap()); + segments.sort_by(|a, b| b.deletion_ratio().total_cmp(&a.deletion_ratio())); // Merge segments with any deletions first let segments_with_deletions: Vec<_> = segments diff --git a/laurus/src/lexical/index/inverted/searcher.rs b/laurus/src/lexical/index/inverted/searcher.rs index 352334d5..a0481903 100644 --- a/laurus/src/lexical/index/inverted/searcher.rs +++ b/laurus/src/lexical/index/inverted/searcher.rs @@ -814,9 +814,7 @@ impl InvertedIndexSearcher { SortField::Score => { // Default behavior: already sorted by score from collector // Re-sort to ensure descending order - hits.sort_unstable_by(|a, b| { - b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal) - }); + hits.sort_unstable_by(|a, b| b.score.total_cmp(&a.score)); } SortField::Field { name, order } => { // Sort by field value @@ -853,16 +851,13 @@ impl InvertedIndexSearcher { // Same type comparisons (Text(a_str), Text(b_str)) => a_str.cmp(b_str), (Integer(a_int), Integer(b_int)) => a_int.cmp(b_int), - (Float(a_float), Float(b_float)) => { - a_float.partial_cmp(b_float).unwrap_or(Ordering::Equal) - } + (Float(a_float), Float(b_float)) => a_float.total_cmp(b_float), (Boolean(a_bool), Boolean(b_bool)) => a_bool.cmp(b_bool), (DateTime(a_dt), DateTime(b_dt)) => a_dt.cmp(b_dt), (Geo(a), Geo(b)) => a .lat - .partial_cmp(&b.lat) - .unwrap_or(Ordering::Equal) - .then_with(|| a.lon.partial_cmp(&b.lon).unwrap_or(Ordering::Equal)), + .total_cmp(&b.lat) + .then_with(|| a.lon.total_cmp(&b.lon)), (Bytes(_, a_bytes), Bytes(_, b_bytes)) => a_bytes.cmp(b_bytes), (Null, Null) => Ordering::Equal, @@ -1259,21 +1254,13 @@ mod tests { .search(LexicalSearchRequest::new(make_query()).limit(usize::MAX)) .unwrap(); let mut hits: Vec<_> = big.hits.into_iter().map(|h| (h.doc_id, h.score)).collect(); - hits.sort_by(|x, y| { - y.1.partial_cmp(&x.1) - .unwrap_or(Ordering::Equal) - .then(x.0.cmp(&y.0)) - }); + hits.sort_by(|x, y| y.1.total_cmp(&x.1).then(x.0.cmp(&y.0))); hits.truncate(10); hits }; let mut bmw_hits: Vec<_> = bmw.hits.iter().map(|h| (h.doc_id, h.score)).collect(); - bmw_hits.sort_by(|x, y| { - y.1.partial_cmp(&x.1) - .unwrap_or(Ordering::Equal) - .then(x.0.cmp(&y.0)) - }); + bmw_hits.sort_by(|x, y| y.1.total_cmp(&x.1).then(x.0.cmp(&y.0))); assert_eq!(bmw_hits.len(), reference.len(), "result count differs"); for (idx, (x, y)) in bmw_hits.iter().zip(reference.iter()).enumerate() { assert_eq!(x.0, y.0, "rank {idx}: doc_id mismatch"); diff --git a/laurus/src/lexical/index/inverted/segment/manager.rs b/laurus/src/lexical/index/inverted/segment/manager.rs index 437b5cec..1ffe1d0e 100644 --- a/laurus/src/lexical/index/inverted/segment/manager.rs +++ b/laurus/src/lexical/index/inverted/segment/manager.rs @@ -696,8 +696,7 @@ impl SegmentManager { .filter(|s| s.deletion_ratio() > self.config.max_deletion_ratio / 2.0) .collect(); - high_deletion_segments - .sort_by(|a, b| b.deletion_ratio().partial_cmp(&a.deletion_ratio()).unwrap()); + high_deletion_segments.sort_by(|a, b| b.deletion_ratio().total_cmp(&a.deletion_ratio())); // Group high-deletion segments for chunk in high_deletion_segments.chunks(self.config.segments_per_tier) { @@ -765,7 +764,7 @@ impl SegmentManager { all_candidates.extend(self.generate_time_based_candidates(segments)); // Sort by priority and remove duplicates - all_candidates.sort_by(|a, b| b.priority.partial_cmp(&a.priority).unwrap()); + all_candidates.sort_by(|a, b| b.priority.total_cmp(&a.priority)); all_candidates.dedup_by(|a, b| a.segments == b.segments); // Take top candidates diff --git a/laurus/src/lexical/index/inverted/segment/merge_engine.rs b/laurus/src/lexical/index/inverted/segment/merge_engine.rs index 72abaf73..4af62e64 100644 --- a/laurus/src/lexical/index/inverted/segment/merge_engine.rs +++ b/laurus/src/lexical/index/inverted/segment/merge_engine.rs @@ -212,8 +212,7 @@ impl MergeEngine { ) -> Result { // Sort by deletion ratio (highest first for better compaction) let mut sorted_segments = segments.to_vec(); - sorted_segments - .sort_by(|a, b| b.deletion_ratio().partial_cmp(&a.deletion_ratio()).unwrap()); + sorted_segments.sort_by(|a, b| b.deletion_ratio().total_cmp(&a.deletion_ratio())); self.perform_merge(&sorted_segments, new_segment_id) } @@ -251,7 +250,7 @@ impl MergeEngine { .collect(); // Sort by composite score (highest first) - scored_segments.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + scored_segments.sort_by(|a, b| b.1.total_cmp(&a.1)); let sorted_segments: Vec<_> = scored_segments.into_iter().map(|(seg, _)| seg).collect(); diff --git a/laurus/src/lexical/index/inverted/segment/merge_policy.rs b/laurus/src/lexical/index/inverted/segment/merge_policy.rs index 572a5a18..f7db422b 100644 --- a/laurus/src/lexical/index/inverted/segment/merge_policy.rs +++ b/laurus/src/lexical/index/inverted/segment/merge_policy.rs @@ -213,7 +213,7 @@ impl MergePolicy for TieredMergePolicy { } // Sort by priority (highest first) - all_candidates.sort_by(|a, b| b.priority.partial_cmp(&a.priority).unwrap()); + all_candidates.sort_by(|a, b| b.priority.total_cmp(&a.priority)); all_candidates } @@ -336,7 +336,7 @@ impl MergePolicy for LogStructuredMergePolicy { } } - candidates.sort_by(|a, b| b.priority.partial_cmp(&a.priority).unwrap()); + candidates.sort_by(|a, b| b.priority.total_cmp(&a.priority)); candidates } diff --git a/laurus/src/lexical/query/advanced_query.rs b/laurus/src/lexical/query/advanced_query.rs index 6a4c7074..374bc99b 100644 --- a/laurus/src/lexical/query/advanced_query.rs +++ b/laurus/src/lexical/query/advanced_query.rs @@ -210,7 +210,7 @@ impl AdvancedQuery { } // Sort by score descending - results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap()); + results.sort_by(|a, b| b.score.total_cmp(&a.score)); Ok(results) } diff --git a/laurus/src/lexical/query/collector.rs b/laurus/src/lexical/query/collector.rs index 958c31e8..4dac7c87 100644 --- a/laurus/src/lexical/query/collector.rs +++ b/laurus/src/lexical/query/collector.rs @@ -238,17 +238,15 @@ impl<'a> Collector for TopFieldCollector<'a> { sorted_docs.sort_unstable_by(|a, b| match (&a.field_value, &b.field_value) { (FieldValue::Text(av), FieldValue::Text(bv)) => av.cmp(bv), (FieldValue::Int64(av), FieldValue::Int64(bv)) => av.cmp(bv), - (FieldValue::Float64(av), FieldValue::Float64(bv)) => { - av.partial_cmp(bv).unwrap_or(Ordering::Equal) - } + (FieldValue::Float64(av), FieldValue::Float64(bv)) => av.total_cmp(bv), (FieldValue::Bool(av), FieldValue::Bool(bv)) => av.cmp(bv), (FieldValue::DateTime(av), FieldValue::DateTime(bv)) => av.cmp(bv), (FieldValue::Geo(a), FieldValue::Geo(b)) => { - let lat_cmp = a.lat.partial_cmp(&b.lat).unwrap_or(Ordering::Equal); + let lat_cmp = a.lat.total_cmp(&b.lat); if lat_cmp != Ordering::Equal { lat_cmp } else { - a.lon.partial_cmp(&b.lon).unwrap_or(Ordering::Equal) + a.lon.total_cmp(&b.lon) } } (FieldValue::Bytes(av, _), FieldValue::Bytes(bv, _)) => av.cmp(bv), @@ -262,17 +260,15 @@ impl<'a> Collector for TopFieldCollector<'a> { sorted_docs.sort_unstable_by(|a, b| match (&a.field_value, &b.field_value) { (FieldValue::Text(av), FieldValue::Text(bv)) => bv.cmp(av), (FieldValue::Int64(av), FieldValue::Int64(bv)) => bv.cmp(av), - (FieldValue::Float64(av), FieldValue::Float64(bv)) => { - bv.partial_cmp(av).unwrap_or(Ordering::Equal) - } + (FieldValue::Float64(av), FieldValue::Float64(bv)) => bv.total_cmp(av), (FieldValue::Bool(av), FieldValue::Bool(bv)) => bv.cmp(av), (FieldValue::DateTime(av), FieldValue::DateTime(bv)) => bv.cmp(av), (FieldValue::Geo(a), FieldValue::Geo(b)) => { - let lat_cmp = b.lat.partial_cmp(&a.lat).unwrap_or(Ordering::Equal); + let lat_cmp = b.lat.total_cmp(&a.lat); if lat_cmp != Ordering::Equal { lat_cmp } else { - b.lon.partial_cmp(&a.lon).unwrap_or(Ordering::Equal) + b.lon.total_cmp(&a.lon) } } (FieldValue::Bytes(av, _), FieldValue::Bytes(bv, _)) => bv.cmp(av), @@ -342,17 +338,15 @@ impl Ord for FieldScoredDoc { match (&self.field_value, &other.field_value) { (FieldValue::Text(a), FieldValue::Text(b)) => b.cmp(a), (FieldValue::Int64(a), FieldValue::Int64(b)) => b.cmp(a), - (FieldValue::Float64(a), FieldValue::Float64(b)) => { - b.partial_cmp(a).unwrap_or(Ordering::Equal) - } + (FieldValue::Float64(a), FieldValue::Float64(b)) => b.total_cmp(a), (FieldValue::Bool(a), FieldValue::Bool(b)) => b.cmp(a), (FieldValue::DateTime(a), FieldValue::DateTime(b)) => b.cmp(a), (FieldValue::Geo(a), FieldValue::Geo(b)) => { - let lat_cmp = b.lat.partial_cmp(&a.lat).unwrap_or(Ordering::Equal); + let lat_cmp = b.lat.total_cmp(&a.lat); if lat_cmp != Ordering::Equal { lat_cmp } else { - b.lon.partial_cmp(&a.lon).unwrap_or(Ordering::Equal) + b.lon.total_cmp(&a.lon) } } (FieldValue::Bytes(a, _), FieldValue::Bytes(b, _)) => b.cmp(a), @@ -366,17 +360,15 @@ impl Ord for FieldScoredDoc { match (&self.field_value, &other.field_value) { (FieldValue::Text(a), FieldValue::Text(b)) => a.cmp(b), (FieldValue::Int64(a), FieldValue::Int64(b)) => a.cmp(b), - (FieldValue::Float64(a), FieldValue::Float64(b)) => { - a.partial_cmp(b).unwrap_or(Ordering::Equal) - } + (FieldValue::Float64(a), FieldValue::Float64(b)) => a.total_cmp(b), (FieldValue::Bool(a), FieldValue::Bool(b)) => a.cmp(b), (FieldValue::DateTime(a), FieldValue::DateTime(b)) => a.cmp(b), (FieldValue::Geo(a), FieldValue::Geo(b)) => { - let lat_cmp = a.lat.partial_cmp(&b.lat).unwrap_or(Ordering::Equal); + let lat_cmp = a.lat.total_cmp(&b.lat); if lat_cmp != Ordering::Equal { lat_cmp } else { - a.lon.partial_cmp(&b.lon).unwrap_or(Ordering::Equal) + a.lon.total_cmp(&b.lon) } } (FieldValue::Bytes(a, _), FieldValue::Bytes(b, _)) => a.cmp(b), @@ -411,8 +403,7 @@ impl Ord for ScoredDoc { // Min-heap: lower scores come first other .score - .partial_cmp(&self.score) - .unwrap_or(Ordering::Equal) + .total_cmp(&self.score) .then_with(|| other.doc_id.cmp(&self.doc_id)) } } @@ -496,7 +487,7 @@ impl Collector for TopDocsCollector { .collect(); // Sort by score descending - results.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)); + results.sort_unstable_by(|a, b| b.score.total_cmp(&a.score)); results } @@ -685,7 +676,7 @@ impl Collector for AllDocsCollector { return cached.clone(); } let mut results = self.hits.clone(); - results.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)); + results.sort_unstable_by(|a, b| b.score.total_cmp(&a.score)); *cache = Some(results.clone()); results } diff --git a/laurus/src/lexical/query/geo.rs b/laurus/src/lexical/query/geo.rs index 00ef53a5..7d2af985 100644 --- a/laurus/src/lexical/query/geo.rs +++ b/laurus/src/lexical/query/geo.rs @@ -206,13 +206,8 @@ impl GeoDistanceQuery { // Sort by distance (closest first), then by relevance score matches.sort_by(|a, b| { a.distance_m - .partial_cmp(&b.distance_m) - .unwrap_or(std::cmp::Ordering::Equal) - .then_with(|| { - b.relevance_score - .partial_cmp(&a.relevance_score) - .unwrap_or(std::cmp::Ordering::Equal) - }) + .total_cmp(&b.distance_m) + .then_with(|| b.relevance_score.total_cmp(&a.relevance_score)) }); Ok(matches) @@ -575,9 +570,8 @@ impl GeoBoundingBoxQuery { // Sort by relevance score (highest first), then by distance to center matches.sort_by(|a, b| { b.relevance_score - .partial_cmp(&a.relevance_score) - .unwrap() - .then_with(|| a.distance_m.partial_cmp(&b.distance_m).unwrap()) + .total_cmp(&a.relevance_score) + .then_with(|| a.distance_m.total_cmp(&b.distance_m)) }); Ok(matches) @@ -831,11 +825,7 @@ impl GeoMatcher { /// Create a new geo matcher. pub fn new(mut matches: Vec) -> Self { // Sort matches by distance (closest first) - matches.sort_by(|a, b| { - a.distance_m - .partial_cmp(&b.distance_m) - .unwrap_or(std::cmp::Ordering::Equal) - }); + matches.sort_by(|a, b| a.distance_m.total_cmp(&b.distance_m)); GeoMatcher { matches, diff --git a/laurus/src/lexical/query/geo3d.rs b/laurus/src/lexical/query/geo3d.rs index 388ef65b..92d6f7d1 100644 --- a/laurus/src/lexical/query/geo3d.rs +++ b/laurus/src/lexical/query/geo3d.rs @@ -136,19 +136,13 @@ impl Geo3dDistanceQuery { // Multi-segment readers can produce duplicates; keep the closest. matches.sort_by(|a, b| { - a.doc_id.cmp(&b.doc_id).then_with(|| { - a.distance_m - .partial_cmp(&b.distance_m) - .unwrap_or(std::cmp::Ordering::Equal) - }) + a.doc_id + .cmp(&b.doc_id) + .then_with(|| a.distance_m.total_cmp(&b.distance_m)) }); matches.dedup_by_key(|m| m.doc_id); // Final order: distance ascending. - matches.sort_by(|a, b| { - a.distance_m - .partial_cmp(&b.distance_m) - .unwrap_or(std::cmp::Ordering::Equal) - }); + matches.sort_by(|a, b| a.distance_m.total_cmp(&b.distance_m)); Ok(matches) } @@ -718,11 +712,9 @@ impl Geo3dNearestQuery { // duplicates) to get an accurate "unique candidates" count. let mut deduped = current.hits.clone(); deduped.sort_by(|a, b| { - a.doc_id.cmp(&b.doc_id).then_with(|| { - a.distance_sq - .partial_cmp(&b.distance_sq) - .unwrap_or(std::cmp::Ordering::Equal) - }) + a.doc_id + .cmp(&b.doc_id) + .then_with(|| a.distance_sq.total_cmp(&b.distance_sq)) }); deduped.dedup_by_key(|h| h.doc_id); let unique_count = deduped.len(); @@ -761,11 +753,7 @@ impl Geo3dNearestQuery { // Final sort by distance ascending and truncation to top-k. let mut hits = visitor.hits; - hits.sort_by(|a, b| { - a.distance_sq - .partial_cmp(&b.distance_sq) - .unwrap_or(std::cmp::Ordering::Equal) - }); + hits.sort_by(|a, b| a.distance_sq.total_cmp(&b.distance_sq)); hits.truncate(self.k); // Normalize scores against the farthest distance in the returned diff --git a/laurus/src/lexical/search/features/facet.rs b/laurus/src/lexical/search/features/facet.rs index ec5fb6b7..c34ce9f8 100644 --- a/laurus/src/lexical/search/features/facet.rs +++ b/laurus/src/lexical/search/features/facet.rs @@ -1,6 +1,5 @@ //! Faceted search functionality for categorizing and filtering search results. -use std::cmp::Ordering; use std::collections::HashMap; use serde::{Deserialize, Serialize}; @@ -623,7 +622,7 @@ impl FacetedSearchEngine { } // Sort hits by score - hits.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)); + hits.sort_by(|a, b| b.score.total_cmp(&a.score)); // Finalize facet collection let facet_results = facet_collector.finalize()?; @@ -776,8 +775,7 @@ impl SearchGroup { /// Sort documents in this group by score. pub fn sort_by_score(&mut self) { - self.documents - .sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)); + self.documents.sort_by(|a, b| b.score.total_cmp(&a.score)); } /// Limit the number of documents in this group. diff --git a/laurus/src/lexical/search/features/highlight.rs b/laurus/src/lexical/search/features/highlight.rs index 0daa3660..a0b9ae27 100644 --- a/laurus/src/lexical/search/features/highlight.rs +++ b/laurus/src/lexical/search/features/highlight.rs @@ -145,11 +145,9 @@ impl FieldHighlight { /// Get the best fragment (highest score). pub fn best_fragment(&self) -> Option<&HighlightFragment> { - self.fragments.iter().max_by(|a, b| { - a.score - .partial_cmp(&b.score) - .unwrap_or(std::cmp::Ordering::Equal) - }) + self.fragments + .iter() + .max_by(|a, b| a.score.total_cmp(&b.score)) } /// Combine all fragments into a single string. @@ -441,11 +439,7 @@ impl Highlighter { } // Sort fragments by score (highest first) - fragments.sort_by(|a, b| { - b.score - .partial_cmp(&a.score) - .unwrap_or(std::cmp::Ordering::Equal) - }); + fragments.sort_by(|a, b| b.score.total_cmp(&a.score)); // Limit number of fragments fragments.truncate(self.config.max_fragments); diff --git a/laurus/src/spelling/suggest.rs b/laurus/src/spelling/suggest.rs index 121f4f3e..425042a5 100644 --- a/laurus/src/spelling/suggest.rs +++ b/laurus/src/spelling/suggest.rs @@ -52,10 +52,7 @@ impl Eq for Suggestion {} impl Ord for Suggestion { fn cmp(&self, other: &Self) -> Ordering { // Higher scores come first - other - .score - .partial_cmp(&self.score) - .unwrap_or(Ordering::Equal) + other.score.total_cmp(&self.score) } } diff --git a/laurus/src/vector/index/config.rs b/laurus/src/vector/index/config.rs index bcd05bb5..973d2cef 100644 --- a/laurus/src/vector/index/config.rs +++ b/laurus/src/vector/index/config.rs @@ -115,16 +115,10 @@ pub mod utils { } VectorNormalization::MinMax => { for vector in vectors.iter_mut() { - if let (Some(&min_val), Some(&max_val)) = - ( - vector.data.iter().min_by(|a, b| { - a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) - }), - vector.data.iter().max_by(|a, b| { - a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) - }), - ) - { + if let (Some(&min_val), Some(&max_val)) = ( + vector.data.iter().min_by(|a, b| a.total_cmp(b)), + vector.data.iter().max_by(|a, b| a.total_cmp(b)), + ) { let range = max_val - min_val; if range > 0.0 { for value in Arc::make_mut(&mut vector.data) { diff --git a/laurus/src/vector/index/flat/searcher.rs b/laurus/src/vector/index/flat/searcher.rs index 266553e5..26d2a73d 100644 --- a/laurus/src/vector/index/flat/searcher.rs +++ b/laurus/src/vector/index/flat/searcher.rs @@ -130,9 +130,7 @@ impl VectorIndexSearcher for FlatVectorSearcher { results.candidates_examined = candidates.len(); } - candidates.sort_unstable_by(|a, b| { - b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal) - }); + candidates.sort_unstable_by(|a, b| b.1.total_cmp(&a.1)); let top_k = request.params.top_k.min(candidates.len()); for (doc_id, similarity, distance, vector) in candidates.into_iter().take(top_k) { @@ -199,9 +197,7 @@ impl VectorIndexSearcher for FlatVectorSearcher { results.candidates_examined = candidates.len(); } - candidates.sort_unstable_by(|a, b| { - b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal) - }); + candidates.sort_unstable_by(|a, b| b.2.total_cmp(&a.2)); let top_k = request.params.top_k.min(candidates.len()); for (doc_id, field_name, similarity, distance, vector) in diff --git a/laurus/src/vector/index/hnsw/searcher.rs b/laurus/src/vector/index/hnsw/searcher.rs index 37c38bce..bca3dd43 100644 --- a/laurus/src/vector/index/hnsw/searcher.rs +++ b/laurus/src/vector/index/hnsw/searcher.rs @@ -293,7 +293,7 @@ impl VectorIndexSearcher for HnswSearcher { } } - candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal)); + candidates.sort_by(|a, b| b.1.total_cmp(&a.1)); let top_k = request.params.top_k.min(candidates.len()); for (doc_id, similarity, distance, vector) in candidates.into_iter().take(top_k) { @@ -333,7 +333,7 @@ impl VectorIndexSearcher for HnswSearcher { } } - candidates.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(Ordering::Equal)); + candidates.sort_by(|a, b| b.2.total_cmp(&a.2)); let top_k = request.params.top_k.min(candidates.len()); for (doc_id, field_name, similarity, distance, vector) in @@ -432,10 +432,7 @@ impl Ord for Candidate { // Then BinaryHeap is MaxHeap (largest at top). // This impl makes BinaryHeap a MIN-HEAP (smallest distance at top) - other - .distance - .partial_cmp(&self.distance) - .unwrap_or(Ordering::Equal) + other.distance.total_cmp(&self.distance) } } @@ -455,9 +452,7 @@ impl Eq for ResultCandidate {} impl Ord for ResultCandidate { fn cmp(&self, other: &Self) -> Ordering { // Max-heap: larger distance at top (to remove worst) - self.distance - .partial_cmp(&other.distance) - .unwrap_or(Ordering::Equal) + self.distance.total_cmp(&other.distance) } } impl PartialOrd for ResultCandidate { @@ -943,11 +938,7 @@ impl HnswSearcher { } // Sort results (similarity descending) - final_results.sort_by(|a, b| { - b.similarity - .partial_cmp(&a.similarity) - .unwrap_or(Ordering::Equal) - }); + final_results.sort_by(|a, b| b.similarity.total_cmp(&a.similarity)); // Top K let top_k = request.params.top_k.min(final_results.len()); @@ -1127,3 +1118,55 @@ mod ef_search_tests { assert_eq!(s.effective_ef(&req(10, None, Some(0))), 50); } } + +#[cfg(test)] +mod nan_ordering_tests { + //! Issue #667: the HNSW candidate / result heaps order by an `f32` + //! `distance`. The previous `partial_cmp(...).unwrap_or(Equal)` made a + //! NaN distance compare equal to everything — a non-total order, which + //! `BinaryHeap` / `sort_unstable` forbid (silent reorder, or a panic on + //! recent std). `total_cmp` restores a total order so a NaN is handled + //! deterministically without losing or misordering the finite entries. + + use super::{Candidate, ResultCandidate}; + use std::collections::BinaryHeap; + + #[test] + fn candidate_min_heap_handles_nan_without_panic() { + // `Candidate` is a min-heap by distance (nearest pops first). + let mut heap = BinaryHeap::new(); + for d in [3.0_f32, 1.0, f32::NAN, 2.0] { + heap.push(Candidate { id: 0, distance: d }); + } + let popped: Vec = std::iter::from_fn(|| heap.pop().map(|c| c.distance)).collect(); + assert_eq!(popped.len(), 4, "no candidate is lost"); + let finite: Vec = popped.iter().copied().filter(|d| !d.is_nan()).collect(); + assert_eq!( + finite, + vec![1.0, 2.0, 3.0], + "finite distances pop nearest-first regardless of the NaN" + ); + assert_eq!( + popped.iter().filter(|d| d.is_nan()).count(), + 1, + "the NaN is retained, not silently dropped" + ); + } + + #[test] + fn result_candidate_max_heap_handles_nan_without_panic() { + // `ResultCandidate` is a max-heap by distance (furthest pops first). + let mut heap = BinaryHeap::new(); + for d in [3.0_f32, 1.0, f32::NAN, 2.0] { + heap.push(ResultCandidate { id: 0, distance: d }); + } + let popped: Vec = std::iter::from_fn(|| heap.pop().map(|c| c.distance)).collect(); + assert_eq!(popped.len(), 4, "no candidate is lost"); + let finite: Vec = popped.iter().copied().filter(|d| !d.is_nan()).collect(); + assert_eq!( + finite, + vec![3.0, 2.0, 1.0], + "finite distances pop furthest-first regardless of the NaN" + ); + } +} diff --git a/laurus/src/vector/index/hnsw/writer.rs b/laurus/src/vector/index/hnsw/writer.rs index 9877ac75..eebb96f9 100644 --- a/laurus/src/vector/index/hnsw/writer.rs +++ b/laurus/src/vector/index/hnsw/writer.rs @@ -182,9 +182,7 @@ impl Ord for Candidate { // Wait, for BinaryHeap in Rust, it's a max-heap. // If we want smallest distance at top, we need reverse. // If we want largest distance at top (to remove worst candidate), we use standard. - self.distance - .partial_cmp(&other.distance) - .unwrap_or(Ordering::Equal) + self.distance.total_cmp(&other.distance) } } @@ -769,11 +767,10 @@ impl HnswIndexWriter { let candidates = writer_ref.search_layer(&graph, curr_obj, vector, ef_construction, lc)?; - if let Some(min_cand) = candidates.iter().min_by(|a, b| { - a.distance - .partial_cmp(&b.distance) - .unwrap_or(Ordering::Equal) - }) { + if let Some(min_cand) = candidates + .iter() + .min_by(|a, b| a.distance.total_cmp(&b.distance)) + { curr_obj = min_cand.id; } @@ -873,10 +870,7 @@ impl HnswIndexWriter { impl Ord for VisitorCandidate { fn cmp(&self, other: &Self) -> Ordering { // Min-heap: smaller distance > larger distance - other - .distance - .partial_cmp(&self.distance) - .unwrap_or(Ordering::Equal) + other.distance.total_cmp(&self.distance) } } impl PartialOrd for VisitorCandidate { @@ -954,11 +948,7 @@ impl HnswIndexWriter { // Simple heuristic: take M nearest. // Collect without cloning the heap, then sort by ascending distance. let mut sorted: Vec<_> = candidates.iter().cloned().collect(); - sorted.sort_unstable_by(|a, b| { - a.distance - .partial_cmp(&b.distance) - .unwrap_or(Ordering::Equal) - }); + sorted.sort_unstable_by(|a, b| a.distance.total_cmp(&b.distance)); sorted.truncate(m); sorted.into_iter().map(|c| c.id).collect() } @@ -993,11 +983,7 @@ impl HnswIndexWriter { } // We want to keep nearest. Move to min-heap or just sort. - candidates.sort_by(|a, b| { - a.distance - .partial_cmp(&b.distance) - .unwrap_or(Ordering::Equal) - }); + candidates.sort_by(|a, b| a.distance.total_cmp(&b.distance)); candidates.truncate(max_conn); Ok(candidates.into_iter().map(|c| c.id).collect()) diff --git a/laurus/src/vector/index/ivf/searcher.rs b/laurus/src/vector/index/ivf/searcher.rs index 98a18b12..f5ba41d2 100644 --- a/laurus/src/vector/index/ivf/searcher.rs +++ b/laurus/src/vector/index/ivf/searcher.rs @@ -145,9 +145,7 @@ impl IvfSearcher { return Ok(Vec::new()); } if n < centroid_distances.len() { - centroid_distances.select_nth_unstable_by(n - 1, |a, b| { - a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) - }); + centroid_distances.select_nth_unstable_by(n - 1, |a, b| a.1.total_cmp(&b.1)); } // Collect vector IDs from the `n` nearest clusters. @@ -277,8 +275,7 @@ impl VectorIndexSearcher for IvfSearcher { )?; // Sort by similarity (descending) - candidates - .sort_unstable_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal)); + candidates.sort_unstable_by(|a, b| b.2.total_cmp(&a.2)); // Take top_k results let candidates_len = candidates.len(); diff --git a/laurus/src/vector/index/segmented_field.rs b/laurus/src/vector/index/segmented_field.rs index 06129bdb..6acec861 100644 --- a/laurus/src/vector/index/segmented_field.rs +++ b/laurus/src/vector/index/segmented_field.rs @@ -26,7 +26,6 @@ use crate::vector::search::searcher::{ }; use crate::vector::store::config::VectorFieldConfig; use crate::vector::writer::VectorIndexWriterConfig; -use std::cmp::Ordering; /// A vector field implementation that partitions data into segments. /// @@ -361,7 +360,7 @@ impl SegmentedVectorField { } // Sort by similarity descending - candidates.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal)); + candidates.sort_unstable_by(|a, b| b.1.total_cmp(&a.1)); let hits = candidates .into_iter() @@ -508,7 +507,7 @@ impl VectorFieldReader for SegmentedVectorField { } let mut hits: Vec = merged.into_values().collect(); - hits.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)); + hits.sort_unstable_by(|a, b| b.score.total_cmp(&a.score)); if hits.len() > request.limit { hits.truncate(request.limit); } diff --git a/laurus/src/vector/search/searcher.rs b/laurus/src/vector/search/searcher.rs index 5f1dbe52..8960c0e7 100644 --- a/laurus/src/vector/search/searcher.rs +++ b/laurus/src/vector/search/searcher.rs @@ -312,20 +312,14 @@ impl VectorIndexQueryResults { /// Sort results by similarity (descending). pub fn sort_by_similarity(&mut self) { - self.results.sort_by(|a, b| { - b.similarity - .partial_cmp(&a.similarity) - .unwrap_or(std::cmp::Ordering::Equal) - }); + self.results + .sort_by(|a, b| b.similarity.total_cmp(&a.similarity)); } /// Sort results by distance (ascending). pub fn sort_by_distance(&mut self) { - self.results.sort_by(|a, b| { - a.distance - .partial_cmp(&b.distance) - .unwrap_or(std::cmp::Ordering::Equal) - }); + self.results + .sort_by(|a, b| a.distance.total_cmp(&b.distance)); } /// Take the top k results. @@ -343,11 +337,9 @@ impl VectorIndexQueryResults { /// Get the best (highest similarity) result. pub fn best_result(&self) -> Option<&VectorIndexQueryResult> { - self.results.iter().max_by(|a, b| { - a.similarity - .partial_cmp(&b.similarity) - .unwrap_or(std::cmp::Ordering::Equal) - }) + self.results + .iter() + .max_by(|a, b| a.similarity.total_cmp(&b.similarity)) } } diff --git a/laurus/src/vector/store.rs b/laurus/src/vector/store.rs index ba6ffd12..304670b5 100644 --- a/laurus/src/vector/store.rs +++ b/laurus/src/vector/store.rs @@ -659,23 +659,11 @@ impl VectorStore { // than the requested limit. let limit = request.params.limit.min(hits.len()); if limit > 0 && limit < hits.len() { - hits.select_nth_unstable_by(limit - 1, |a, b| { - b.score - .partial_cmp(&a.score) - .unwrap_or(std::cmp::Ordering::Equal) - }); + hits.select_nth_unstable_by(limit - 1, |a, b| b.score.total_cmp(&a.score)); hits.truncate(limit); - hits.sort_unstable_by(|a, b| { - b.score - .partial_cmp(&a.score) - .unwrap_or(std::cmp::Ordering::Equal) - }); + hits.sort_unstable_by(|a, b| b.score.total_cmp(&a.score)); } else if !hits.is_empty() { - hits.sort_unstable_by(|a, b| { - b.score - .partial_cmp(&a.score) - .unwrap_or(std::cmp::Ordering::Equal) - }); + hits.sort_unstable_by(|a, b| b.score.total_cmp(&a.score)); } return Ok(VectorSearchResults { hits }); @@ -738,8 +726,7 @@ impl VectorStore { hits.sort_by(|a, b| { b.score - .partial_cmp(&a.score) - .unwrap_or(std::cmp::Ordering::Equal) + .total_cmp(&a.score) .then_with(|| a.doc_id.cmp(&b.doc_id)) }); diff --git a/laurus/src/vector/store/memory.rs b/laurus/src/vector/store/memory.rs index c12b63f6..7ba786ee 100644 --- a/laurus/src/vector/store/memory.rs +++ b/laurus/src/vector/store/memory.rs @@ -2,7 +2,6 @@ //! //! このモジュールはインメモリでベクトルを管理するフィールド実装を提供する。 -use std::cmp::Ordering as CmpOrdering; use std::collections::HashMap; use std::collections::hash_map::Entry; use std::sync::Arc; @@ -428,7 +427,7 @@ impl VectorFieldReader for InMemoryFieldReader { } let mut hits: Vec = merged.into_values().collect(); - hits.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(CmpOrdering::Equal)); + hits.sort_by(|a, b| b.score.total_cmp(&a.score)); if hits.len() > limit { hits.truncate(limit); }