diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 5737ec013b5..ee84467f373 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -3280,6 +3280,9 @@ def _create_index_impl( index_uuid: Optional[str] = None, *, target_partition_size: Optional[int] = None, + streaming_sample_rate: Optional[int] = None, + streaming_coreset_rate: Optional[int] = None, + streaming_refine_passes: Optional[int] = None, skip_transpose: bool = False, require_commit: bool = True, **kwargs, @@ -3491,6 +3494,12 @@ def _create_index_impl( kwargs["num_partitions"] = num_partitions if target_partition_size is not None: kwargs["target_partition_size"] = target_partition_size + if streaming_sample_rate is not None: + kwargs["streaming_sample_rate"] = streaming_sample_rate + if streaming_coreset_rate is not None: + kwargs["streaming_coreset_rate"] = streaming_coreset_rate + if streaming_refine_passes is not None: + kwargs["streaming_refine_passes"] = streaming_refine_passes if (precomputed_partition_dataset is not None) and (ivf_centroids is None): raise ValueError( @@ -3652,6 +3661,9 @@ def create_index( index_uuid: Optional[str] = None, *, target_partition_size: Optional[int] = None, + streaming_sample_rate: Optional[int] = None, + streaming_coreset_rate: Optional[int] = None, + streaming_refine_passes: Optional[int] = None, skip_transpose: bool = False, progress_callback: Optional[Callable[[IndexProgress], None]] = None, **kwargs, @@ -3740,6 +3752,19 @@ def create_index( The target partition size. If set, the number of partitions will be computed based on the target partition size. Otherwise, the target partition size will be set by index type. + streaming_sample_rate : int, optional + If set below ``sample_rate``, IVF kmeans trains incrementally and samples + at most ``num_partitions * streaming_sample_rate`` vectors per step. For + ``num_partitions > 256``, chunks are compressed into a weighted coreset + and final centroids are trained with weighted hierarchical kmeans. + streaming_coreset_rate : int, optional + If set, controls the final weighted coreset budget independently from + ``streaming_sample_rate``. The budget is + ``num_partitions * streaming_coreset_rate``. + streaming_refine_passes : int, optional + Number of extra streaming Lloyd refinement passes to run after streaming + coreset training. Each pass loads at most + ``num_partitions * streaming_sample_rate`` raw vectors at a time. kwargs : Parameters passed to the index building process. @@ -3861,6 +3886,9 @@ def create_index( fragment_ids=fragment_ids, index_uuid=index_uuid, target_partition_size=target_partition_size, + streaming_sample_rate=streaming_sample_rate, + streaming_coreset_rate=streaming_coreset_rate, + streaming_refine_passes=streaming_refine_passes, skip_transpose=skip_transpose, require_commit=True, **kwargs, @@ -3895,6 +3923,9 @@ def create_index_uncommitted( index_uuid: Optional[str] = None, *, target_partition_size: Optional[int] = None, + streaming_sample_rate: Optional[int] = None, + streaming_coreset_rate: Optional[int] = None, + streaming_refine_passes: Optional[int] = None, skip_transpose: bool = False, **kwargs, ) -> Index: @@ -3950,6 +3981,9 @@ def create_index_uncommitted( fragment_ids=fragment_ids, index_uuid=index_uuid, target_partition_size=target_partition_size, + streaming_sample_rate=streaming_sample_rate, + streaming_coreset_rate=streaming_coreset_rate, + streaming_refine_passes=streaming_refine_passes, skip_transpose=skip_transpose, require_commit=False, **kwargs, diff --git a/python/src/dataset.rs b/python/src/dataset.rs index c868504e87c..b8f4e5d6aab 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -4293,6 +4293,16 @@ fn prepare_vector_index_params( pq_params.max_iters = max_iters; } + if let Some(streaming_sample_rate) = kwargs.get_item("streaming_sample_rate")? { + ivf_params.streaming_sample_rate = Some(streaming_sample_rate.extract()?); + } + if let Some(streaming_coreset_rate) = kwargs.get_item("streaming_coreset_rate")? { + ivf_params.streaming_coreset_rate = Some(streaming_coreset_rate.extract()?); + } + if let Some(streaming_refine_passes) = kwargs.get_item("streaming_refine_passes")? { + ivf_params.streaming_refine_passes = streaming_refine_passes.extract()?; + } + // Parse IVF params if let Some(n) = kwargs.get_item("num_partitions")? { ivf_params.num_partitions = Some(n.extract()?) diff --git a/rust/lance-index/src/vector/ivf/builder.rs b/rust/lance-index/src/vector/ivf/builder.rs index 72e05555441..b8b6e0ec4cc 100644 --- a/rust/lance-index/src/vector/ivf/builder.rs +++ b/rust/lance-index/src/vector/ivf/builder.rs @@ -39,6 +39,33 @@ pub struct IvfBuildParams { pub sample_rate: usize, + /// Optional per-step sample rate for streaming IVF kmeans training. + /// + /// When set, IVF training loads at most `num_partitions * streaming_sample_rate` + /// vectors at a time. For `num_partitions > 256`, each chunk is compressed into + /// a weighted coreset and final centroids are trained with weighted hierarchical + /// kmeans over the coreset. The coreset budget is also bounded by this rate by + /// default so large partition counts can control peak memory by lowering + /// `streaming_sample_rate`. The total number of sampled vectors remains bounded + /// by `num_partitions * sample_rate`. + pub streaming_sample_rate: Option, + + /// Optional coreset rate for streaming IVF kmeans training. + /// + /// When set, the final weighted coreset budget is + /// `num_partitions * streaming_coreset_rate`, independent of + /// `streaming_sample_rate`. The streaming chunk size is still controlled by + /// `streaming_sample_rate`. + pub streaming_coreset_rate: Option, + + /// Number of extra streaming Lloyd refinement passes to run after streaming + /// coreset training. + /// + /// Each pass reuses the same sampled vectors and only loads + /// `num_partitions * streaming_sample_rate` raw vectors at a time. This is + /// experimental and defaults to 0 to preserve existing behavior. + pub streaming_refine_passes: usize, + /// Precomputed partitions file (row_id -> partition_id) /// mutually exclusive with `precomputed_shuffle_buffers` pub precomputed_partitions_file: Option, @@ -67,6 +94,9 @@ impl Default for IvfBuildParams { centroids: None, retrain: false, sample_rate: 256, // See faiss + streaming_sample_rate: None, + streaming_coreset_rate: None, + streaming_refine_passes: 0, precomputed_partitions_file: None, precomputed_shuffle_buffers: None, shuffle_partition_batches: 1024 * 10, diff --git a/rust/lance-index/src/vector/kmeans.rs b/rust/lance-index/src/vector/kmeans.rs index 4a610b41cf6..d0e86028b89 100644 --- a/rust/lance-index/src/vector/kmeans.rs +++ b/rust/lance-index/src/vector/kmeans.rs @@ -561,6 +561,24 @@ impl KMeansAlgo for KModeAlgo { } } +/// Cluster id assignment for each vector in a batch. +pub type KMeansMembership = Vec>; + +/// Distance from each vector to its assigned centroid. +pub type KMeansDistances = Vec>; + +/// Maximum assignment distance per centroid. +pub type KMeansClusterRadii = Vec; + +/// Sum of assignment distances per centroid. +pub type KMeansClusterLosses = Vec; + +/// Batch assignment results with per-centroid radii and losses. +pub type KMeansMembershipAndLoss = (KMeansMembership, KMeansClusterRadii, KMeansClusterLosses); + +/// Batch assignment results with per-vector distances. +pub type KMeansMembershipAndDistances = (KMeansMembership, KMeansDistances); + /// KMeans implementation for Apache Arrow Arrays. #[derive(Debug, Clone)] pub struct KMeans { @@ -636,6 +654,116 @@ impl KMeans { Self::new_with_params(data, k, ¶ms) } + /// Assign a batch of vectors to these centroids and return membership, radius, and loss. + pub fn compute_membership_and_loss( + &self, + data: &FixedSizeListArray, + ) -> arrow::error::Result { + let (membership, distances) = self.compute_membership_and_distances(data)?; + let k = self.centroids.len() / self.dimension; + let mut cluster_radius: Vec = vec![0.0_f32; k]; + let mut losses = vec![0.0; k]; + for (cluster_id, dist) in membership.iter().zip(distances.iter()) { + if let (Some(cluster_id), Some(dist)) = (cluster_id, dist) { + let cluster_id = *cluster_id as usize; + cluster_radius[cluster_id] = cluster_radius[cluster_id].max(*dist); + losses[cluster_id] += *dist as f64; + } + } + Ok((membership, cluster_radius, losses)) + } + + /// Assign a batch of vectors to these centroids and return per-vector distances. + pub fn compute_membership_and_distances( + &self, + data: &FixedSizeListArray, + ) -> arrow::error::Result { + if data.value_length() as usize != self.dimension { + return Err(ArrowError::InvalidArgumentError(format!( + "KMeans: data dimension {} does not match centroid dimension {}", + data.value_length(), + self.dimension + ))); + } + + let index = SimpleIndex::may_train_index( + self.centroids.clone(), + self.dimension, + self.distance_type, + ) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + match ( + data.value_type(), + self.centroids.data_type(), + self.distance_type, + ) { + (DataType::Float16, DataType::Float16, _) => { + let data_values = data.values().as_primitive::().values(); + let centroids = self.centroids.as_primitive::().values(); + Ok(KMeansAlgoFloat::::compute_membership_and_dist( + centroids, + data_values, + self.dimension, + self.distance_type, + 0.0, + None, + index.as_ref(), + )) + } + (DataType::Float32, DataType::Float32, _) => { + let data_values = data.values().as_primitive::().values(); + let centroids = self.centroids.as_primitive::().values(); + Ok(KMeansAlgoFloat::::compute_membership_and_dist( + centroids, + data_values, + self.dimension, + self.distance_type, + 0.0, + None, + index.as_ref(), + )) + } + (DataType::Float64, DataType::Float64, _) => { + let data_values = data.values().as_primitive::().values(); + let centroids = self.centroids.as_primitive::().values(); + Ok(KMeansAlgoFloat::::compute_membership_and_dist( + centroids, + data_values, + self.dimension, + self.distance_type, + 0.0, + None, + index.as_ref(), + )) + } + (DataType::UInt8, DataType::UInt8, DistanceType::Hamming) => { + let data_values = data.values().as_primitive::().values(); + let centroids = self.centroids.as_primitive::().values(); + Ok(KModeAlgo::compute_membership_and_dist( + centroids, + data_values, + self.dimension, + self.distance_type, + 0.0, + None, + index.as_ref(), + )) + } + _ => Err(ArrowError::InvalidArgumentError(format!( + "KMeans: can not compute membership for data type {} with centroid type {} and distance type {}", + data.value_type(), + self.centroids.data_type(), + self.distance_type + ))), + } + } + + /// Compute the kmeans loss for a batch of vectors against these centroids. + pub fn compute_loss(&self, data: &FixedSizeListArray) -> arrow::error::Result { + let (_, _, losses) = self.compute_membership_and_loss(data)?; + Ok(losses.iter().sum()) + } + fn train_kmeans>( data: &FixedSizeListArray, k: usize, diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index 87b32344ec6..ff7d2383c67 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -1841,6 +1841,9 @@ fn derive_ivf_params(ivf_model: &IvfModel) -> IvfBuildParams { #[allow(deprecated)] retrain: false, // Don't retrain since we have centroids sample_rate: 256, // Default + streaming_sample_rate: None, + streaming_coreset_rate: None, + streaming_refine_passes: 0, precomputed_partitions_file: None, precomputed_shuffle_buffers: None, shuffle_partition_batches: 1024 * 10, // Default diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 179ec96df4c..2d4086b3ad9 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -26,14 +26,16 @@ use crate::{ }, }; use crate::{dataset::builder::DatasetBuilder, index::vector::IndexFileVersion}; +use arrow::array::ArrayData; use arrow::datatypes::UInt8Type; use arrow_arith::numeric::sub; use arrow_array::Float32Array; use arrow_array::{ - Array, FixedSizeListArray, PrimitiveArray, RecordBatch, UInt32Array, + Array, ArrayRef, FixedSizeListArray, PrimitiveArray, RecordBatch, UInt32Array, cast::AsArray, types::{ArrowPrimitiveType, Float16Type, Float32Type, Float64Type}, }; +use arrow_buffer::MutableBuffer; use arrow_schema::{DataType, Schema}; use async_trait::async_trait; use datafusion::execution::SendableRecordBatchStream; @@ -69,7 +71,7 @@ use lance_index::vector::hnsw::HnswMetadata; use lance_index::vector::hnsw::builder::HNSW_METADATA_KEY; use lance_index::vector::ivf::storage::IVF_METADATA_KEY; use lance_index::vector::ivf::storage::IvfModel; -use lance_index::vector::kmeans::KMeansParams; +use lance_index::vector::kmeans::{KMeans, KMeansParams}; use lance_index::vector::pq::storage::transpose; use lance_index::vector::quantizer::QuantizationType; use lance_index::vector::v3::shuffler::create_ivf_shuffler; @@ -105,12 +107,15 @@ use lance_table::format::{IndexMetadata as TableIndexMetadata, list_index_files_ use log::{info, warn}; use object_store::path::Path; use prost::Message; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; use roaring::RoaringBitmap; use serde::Serialize; use serde_json::json; use std::{ any::Any, collections::{HashMap, HashSet}, + ops::Range, sync::Arc, }; use tokio::sync::mpsc; @@ -1426,6 +1431,50 @@ pub async fn build_ivf_model( } let sample_size_hint = num_partitions * params.sample_rate; + if let Some(streaming_sample_rate) = params.streaming_sample_rate { + if streaming_sample_rate == 0 { + return Err(Error::invalid_input( + "streaming_sample_rate must be greater than 0".to_string(), + )); + } + if let Some(streaming_coreset_rate) = params.streaming_coreset_rate { + if streaming_coreset_rate == 0 { + return Err(Error::invalid_input( + "streaming_coreset_rate must be greater than 0".to_string(), + )); + } + if streaming_coreset_rate > params.sample_rate { + return Err(Error::invalid_input(format!( + "streaming_coreset_rate ({streaming_coreset_rate}) must be less than or equal to sample_rate ({})", + params.sample_rate + ))); + } + } + if streaming_sample_rate < params.sample_rate { + info!( + "Start streaming IVF training. Total sample size: {}, per-step sample size: {}", + sample_size_hint, + num_partitions * streaming_sample_rate + ); + let start = std::time::Instant::now(); + let ivf = train_streaming_ivf_model( + dataset, + column, + dim, + metric_type, + params, + fragment_ids, + progress, + ) + .await?; + info!( + "Trained streaming IVF model in {:02} seconds", + start.elapsed().as_secs_f32() + ); + return Ok(ivf); + } + } + let start = std::time::Instant::now(); info!( "Loading training data for IVF. Sample size: {}", @@ -2603,225 +2652,1818 @@ where warn!("Progress worker join error during train_ivf: {e}"); } let kmeans = kmeans?; + let training_data = FixedSizeListArray::try_new_from_values( + Arc::new(data.clone()) as ArrayRef, + dimension as i32, + )?; + let loss = kmeans.compute_loss(&training_data)?; Ok(IvfModel::new( FixedSizeListArray::try_new_from_values(kmeans.centroids, dimension as i32)?, - Some(kmeans.loss), + Some(loss), )) } -/// Train IVF partitions using kmeans. -async fn train_ivf_model( - centroids: Option>, - data: &FixedSizeListArray, - distance_type: DistanceType, - params: &IvfBuildParams, - progress: std::sync::Arc, -) -> Result { - assert!( - distance_type != DistanceType::Cosine, - "Cosine metric should be done by normalized L2 distance", - ); - let values = data.values(); - let dim = data.value_length() as usize; - match (values.data_type(), distance_type) { - (DataType::Float16, _) => { - do_train_ivf_model::( - centroids, - values.as_primitive::(), - dim, - distance_type, - params, - progress.clone(), - ) - .await - } - (DataType::Float32, _) => { - do_train_ivf_model::( - centroids, - values.as_primitive::(), - dim, - distance_type, - params, - progress.clone(), - ) - .await - } - (DataType::Float64, _) => { - do_train_ivf_model::( - centroids, - values.as_primitive::(), - dim, - distance_type, - params, - progress.clone(), - ) - .await - } - (DataType::Int8, DistanceType::L2) - | (DataType::Int8, DistanceType::Dot) - | (DataType::Int8, DistanceType::Cosine) => { - do_train_ivf_model::( - centroids, - data.convert_to_floating_point()? - .values() - .as_primitive::(), - dim, - distance_type, - params, - progress.clone(), - ) - .await +async fn sample_ivf_training_chunk( + dataset: &Dataset, + column: &str, + sample_size_hint: usize, + metric_type: MetricType, + fragment_ids: Option<&[u32]>, +) -> Result<(FixedSizeListArray, MetricType)> { + let training_data = + maybe_sample_training_data(dataset, column, sample_size_hint, fragment_ids).await?; + let (training_data, mt) = if metric_type == MetricType::Cosine { + let training_data = normalize_fsl_owned(training_data)?; + (training_data, MetricType::L2) + } else { + (training_data, metric_type) + }; + Ok((filter_finite_training_data(training_data)?, mt)) +} + +#[derive(Debug, Clone)] +struct FixedIvfTrainingRanges { + ranges: Vec>, + num_rows: usize, +} + +impl FixedIvfTrainingRanges { + fn new(ranges: Vec>) -> Self { + let num_rows = ranges.iter().map(range_len).sum(); + Self { ranges, num_rows } + } + + fn num_rows(&self) -> usize { + self.num_rows + } + + fn chunk(&self, row_offset: usize, row_count: usize) -> Vec> { + if row_count == 0 || row_offset >= self.num_rows { + return Vec::new(); } - (DataType::UInt8, DistanceType::Hamming) => { - do_train_ivf_model::( - centroids, - values.as_primitive::(), - dim, - distance_type, - params, - progress.clone(), - ) - .await + + let mut remaining_skip = row_offset; + let mut remaining_take = row_count.min(self.num_rows - row_offset); + let mut chunk = Vec::new(); + for range in &self.ranges { + let range_len = range_len(range); + if remaining_skip >= range_len { + remaining_skip -= range_len; + continue; + } + + let start = range.start + remaining_skip as u64; + let available = range_len - remaining_skip; + let take = available.min(remaining_take); + chunk.push(start..start + take as u64); + remaining_take -= take; + remaining_skip = 0; + if remaining_take == 0 { + break; + } } - _ => Err(Error::index(format!( - "Unsupported data type {} with distance type {}", - values.data_type(), - distance_type - ))), + chunk } } -#[cfg(test)] -mod tests { - use super::*; +fn range_len(range: &Range) -> usize { + (range.end - range.start) as usize +} - use std::collections::HashSet; - use std::iter::repeat_n; - use std::ops::Range; +const DEFAULT_STREAMING_IVF_TAKE_RANGE_ROWS: usize = 8192; +const DEFAULT_STREAMING_IVF_PREFETCH_DEPTH: usize = 1; +const DEFAULT_STREAMING_IVF_PROGRESS_INTERVAL: u64 = 128; +const STREAMING_IVF_PREFETCH_DEPTH_ENV: &str = "LANCE_STREAMING_IVF_PREFETCH_DEPTH"; +const STREAMING_IVF_TAKE_RANGE_ROWS_ENV: &str = "LANCE_STREAMING_IVF_TAKE_RANGE_ROWS"; +const STREAMING_IVF_PROGRESS_INTERVAL_ENV: &str = "LANCE_STREAMING_IVF_PROGRESS_INTERVAL"; + +fn streaming_ivf_prefetch_depth() -> usize { + std::env::var(STREAMING_IVF_PREFETCH_DEPTH_ENV) + .ok() + .and_then(|value| value.parse::().ok()) + .filter(|depth| *depth > 0) + .unwrap_or(DEFAULT_STREAMING_IVF_PREFETCH_DEPTH) +} - use arrow_array::types::UInt64Type; - use arrow_array::{ - FixedSizeListArray, Float16Array, Float32Array, RecordBatch, RecordBatchIterator, - RecordBatchReader, UInt64Array, make_array, - }; - use arrow_buffer::{BooleanBuffer, NullBuffer}; - use arrow_schema::{DataType, Field, Schema}; - use half::f16; - use itertools::Itertools; - use lance_core::ROW_ID; - use lance_core::utils::address::RowAddress; - use lance_core::utils::tempfile::TempStrDir; - use lance_datagen::{ArrayGeneratorExt, Dimension, RowCount, array, gen_batch}; - use lance_index::VECTOR_INDEX_VERSION; - use lance_index::metrics::NoOpMetricsCollector; - use lance_index::vector::sq::builder::SQBuildParams; - use lance_linalg::distance::l2_distance_batch; - use lance_testing::datagen::{ - generate_random_array, generate_random_array_with_range, generate_random_array_with_seed, - generate_scaled_random_array, sample_without_replacement, - }; - use rand::{rng, seq::SliceRandom}; - use rstest::rstest; +fn streaming_ivf_take_range_rows() -> usize { + std::env::var(STREAMING_IVF_TAKE_RANGE_ROWS_ENV) + .ok() + .and_then(|value| value.parse::().ok()) + .filter(|rows| *rows > 0) + .unwrap_or(DEFAULT_STREAMING_IVF_TAKE_RANGE_ROWS) +} - use crate::dataset::{InsertBuilder, WriteMode, WriteParams}; - use crate::index::prefilter::DatasetPreFilter; - use crate::index::vector::IndexFileVersion; - use crate::index::vector_index_details_default; - use crate::index::{DatasetIndexExt, DatasetIndexInternalExt, vector::VectorIndexParams}; - use crate::utils::test::copy_test_data_to_tmp; +fn streaming_ivf_progress_interval() -> u64 { + std::env::var(STREAMING_IVF_PROGRESS_INTERVAL_ENV) + .ok() + .and_then(|value| value.parse::().ok()) + .filter(|interval| *interval > 0) + .unwrap_or(DEFAULT_STREAMING_IVF_PROGRESS_INTERVAL) +} - const DIM: usize = 32; +fn should_report_streaming_ivf_progress(total: u64, interval: u64) -> bool { + total == 1 || total % interval.max(1) == 0 +} - // Verifies LANCE_INCLUDE_VECTOR_CENTROIDS env var is honored by - // maybe_centroids_for_stats. The env var is process-global, so this test - // is serialized against any other test that touches the same key. - #[test] - #[serial_test::serial(LANCE_INCLUDE_VECTOR_CENTROIDS)] - fn test_maybe_centroids_for_stats_env_var() { - let centroids = Float32Array::from(vec![1.0_f32, 2.0, 3.0, 4.0]); - let centroids = FixedSizeListArray::try_new_from_values(centroids, 2).unwrap(); - let expected = Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); +fn split_ranges_by_row_count(ranges: &[Range], max_rows: usize) -> Vec> { + let max_rows = max_rows.max(1) as u64; + let mut split = Vec::new(); + for range in ranges { + let mut start = range.start; + while start < range.end { + let end = (start + max_rows).min(range.end); + split.push(start..end); + start = end; + } + } + split +} - // Save the original value so we can restore it afterwards. - let original = std::env::var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV).ok(); +fn generate_fixed_training_ranges( + num_rows: usize, + sample_size: usize, + block_size: usize, + byte_width: usize, +) -> FixedIvfTrainingRanges { + let sample_size = num_rows.min(sample_size); + if sample_size == 0 { + return FixedIvfTrainingRanges::new(Vec::new()); + } + if sample_size >= num_rows { + return FixedIvfTrainingRanges::new(vec![0..num_rows as u64]); + } - // Unset → centroids included (with one-time warning). - unsafe { - std::env::remove_var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV); - } - assert_eq!(maybe_centroids_for_stats(¢roids).unwrap(), expected); + let rows_per_range = 1.max(block_size / byte_width); + let num_bins = num_rows.div_ceil(rows_per_range); + let mut rng = SmallRng::seed_from_u64(0x1a6c_e5eed); - // Truthy values → centroids included. - for truthy in ["1", "true", "TRUE", "on", "yes", "y"] { - unsafe { - std::env::set_var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV, truthy); + let bins = if sample_size * 5 >= num_rows { + let mut bins = (0..num_bins).collect::>(); + for i in 0..num_bins { + let j = rng.random_range(i..num_bins); + bins.swap(i, j); + } + bins + } else { + let mut bins = Vec::with_capacity(sample_size.div_ceil(rows_per_range).saturating_add(1)); + let mut seen = HashSet::with_capacity(bins.capacity()); + while bins.len() * rows_per_range < sample_size { + let bin = rng.random_range(0..num_bins); + if seen.insert(bin) { + bins.push(bin); } - assert_eq!( - maybe_centroids_for_stats(¢roids).unwrap(), - expected, - "expected centroids to be included for {truthy:?}", - ); } + bins + }; - // Non-truthy values → centroids omitted. - for falsy in ["0", "false", "FALSE", "no", "off"] { - unsafe { - std::env::set_var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV, falsy); - } - assert_eq!( - maybe_centroids_for_stats(¢roids).unwrap(), - None, - "expected centroids to be omitted for {falsy:?}", - ); + let mut remaining = sample_size; + let mut ranges = Vec::new(); + for bin in bins { + if remaining == 0 { + break; + } + let bin_start = bin * rows_per_range; + let bin_end = ((bin + 1) * rows_per_range).min(num_rows); + let bin_len = bin_end - bin_start; + if bin_len == 0 { + continue; } - // Restore original value. - unsafe { - match original { - Some(v) => std::env::set_var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV, v), - None => std::env::remove_var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV), - } + let take = bin_len.min(remaining); + let offset = if take < bin_len { + rng.random_range(0..=bin_len - take) + } else { + 0 + }; + let start = bin_start + offset; + ranges.push(start as u64..(start + take) as u64); + remaining -= take; + } + + ranges.sort_unstable_by_key(|range| range.start); + let mut merged: Vec> = Vec::with_capacity(ranges.len()); + for range in ranges { + if range.is_empty() { + continue; + } + if let Some(last) = merged.last_mut() + && last.end >= range.start + { + last.end = last.end.max(range.end); + continue; } + merged.push(range); } + FixedIvfTrainingRanges::new(merged) +} - // Verifies that when centroids are omitted via the env var, the - // serialized stats JSON does not contain the `centroids` field at all - // (instead of an explicit null), since downstream code distinguishes - // missing from null. - #[test] - #[serial_test::serial(LANCE_INCLUDE_VECTOR_CENTROIDS)] - fn test_stats_centroids_omitted_when_disabled() { - let original = std::env::var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV).ok(); +fn default_streaming_coreset_rate(total_sample_rate: usize, streaming_sample_rate: usize) -> usize { + total_sample_rate.min(streaming_sample_rate).min(64) +} - unsafe { - std::env::set_var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV, "false"); - } - let stats = IvfIndexStatistics { - index_type: "IVF_PQ".to_string(), - uuid: "uuid".to_string(), - uri: "uri".to_string(), - metric_type: "l2".to_string(), - num_partitions: 0, - sub_index: serde_json::Value::Null, - partitions: vec![], - centroids: None, - loss: None, - index_file_version: IndexFileVersion::V3, - }; - let json = serde_json::to_value(&stats).unwrap(); - assert!(json.get("centroids").is_none()); +fn streaming_coreset_rate( + total_sample_rate: usize, + streaming_sample_rate: usize, + configured_coreset_rate: Option, +) -> usize { + configured_coreset_rate + .unwrap_or_else(|| default_streaming_coreset_rate(total_sample_rate, streaming_sample_rate)) + .min(total_sample_rate) + .max(1) +} - unsafe { - match original { - Some(v) => std::env::set_var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV, v), - None => std::env::remove_var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV), - } - } +fn streaming_local_coreset_k( + num_partitions: usize, + step_sample_size: usize, + coreset_rate: usize, + total_steps: usize, + decoupled_coreset_budget: bool, +) -> usize { + if !decoupled_coreset_budget { + return num_partitions.min(step_sample_size); } + num_partitions + .saturating_mul(coreset_rate) + .div_ceil(total_steps.max(1)) + .max(num_partitions) + .min(step_sample_size) +} - /// This goal of this function is to generate data that behaves in a very deterministic way so that - /// we can evaluate the correctness of an IVF_PQ implementation. Currently it is restricted to the - /// L2 distance metric. +fn get_top_level_vector_column(batch: &RecordBatch, column: &str) -> Result { + batch.column_by_name(column).cloned().ok_or_else(|| { + Error::index(format!( + "Fixed streaming IVF sampling only supports top-level vector column '{}'", + column + )) + }) +} + +fn append_fsl_values( + values_buf: &mut MutableBuffer, + total_rows: &mut usize, + array: &ArrayRef, + byte_width: usize, +) -> Result<()> { + let fsl = array.as_fixed_size_list(); + let values = fsl.values(); + let values_data = values.to_data(); + let elem_size = byte_width / fsl.value_length() as usize; + let offset_bytes = values_data.offset() * elem_size; + let total_bytes = fsl.len() * byte_width; + let buf = &values_data.buffers()[0].as_slice()[offset_bytes..offset_bytes + total_bytes]; + values_buf.extend_from_slice(buf); + *total_rows += fsl.len(); + Ok(()) +} + +fn fsl_values_to_fixed_array( + vector_type: &DataType, + values_buf: MutableBuffer, + rows: usize, +) -> Result { + let DataType::FixedSizeList(field, dimension) = vector_type else { + return Err(Error::invalid_input(format!( + "expected FixedSizeList vector type, got {}", + vector_type + ))); + }; + let value_len = rows * *dimension as usize; + let values = arrow_array::make_array( + ArrayData::builder(field.data_type().clone()) + .len(value_len) + .add_buffer(values_buf.into()) + .build()?, + ); + Ok(FixedSizeListArray::try_new( + field.clone(), + *dimension, + values, + None, + )?) +} + +struct FixedIvfTrainingSampler<'a> { + dataset: &'a Dataset, + column: &'a str, + vector_type: DataType, + projection: Arc, + byte_width: usize, +} + +impl<'a> FixedIvfTrainingSampler<'a> { + fn try_new(dataset: &'a Dataset, column: &'a str) -> Result> { + let vector_field = dataset.schema().field(column).ok_or(Error::index(format!( + "Sample training data: column {} does not exist in schema", + column + )))?; + if vector_field.nullable + || !matches!(vector_field.data_type(), DataType::FixedSizeList(_, _)) + { + return Ok(None); + } + Ok(Some(Self { + dataset, + column, + vector_type: vector_field.data_type(), + projection: Arc::new(dataset.schema().project(&[column])?), + byte_width: vector_field + .data_type() + .byte_width_opt() + .unwrap_or(4 * 1024), + })) + } + + async fn sample_ranges( + &self, + ranges: &[Range], + metric_type: MetricType, + ) -> Result<(FixedSizeListArray, MetricType)> { + let rows = ranges.iter().map(range_len).sum::(); + let mut values_buf = MutableBuffer::with_capacity(rows * self.byte_width); + let mut total_rows = 0; + + let read_ranges = split_ranges_by_row_count(ranges, streaming_ivf_take_range_rows()); + let range_stream = stream::iter(read_ranges.into_iter().map(Ok)); + let batch_readahead = streaming_ivf_prefetch_depth(); + let mut batch_stream = self.dataset.take_scan( + Box::pin(range_stream), + self.projection.clone(), + batch_readahead, + ); + while let Some(batch) = batch_stream.try_next().await? { + let array = get_top_level_vector_column(&batch, self.column)?; + append_fsl_values(&mut values_buf, &mut total_rows, &array, self.byte_width)?; + } + + let training_data = fsl_values_to_fixed_array(&self.vector_type, values_buf, total_rows)?; + let (training_data, mt) = if metric_type == MetricType::Cosine { + let training_data = normalize_fsl_owned(training_data)?; + (training_data, MetricType::L2) + } else { + (training_data, metric_type) + }; + Ok((filter_finite_training_data(training_data)?, mt)) + } +} + +type KMeansProgressCallback = Arc; + +struct KMeansStepOptions { + dimension: usize, + metric_type: MetricType, + num_partitions: usize, + sample_rate: usize, + max_iters: usize, + on_progress: KMeansProgressCallback, +} + +fn train_ivf_kmeans_step( + centroids: Option>, + data: &PrimitiveArray, + options: &KMeansStepOptions, +) -> Result +where + ::Native: Dot + L2 + Normalize, + PrimitiveArray: From>, +{ + let has_centroids = centroids.is_some(); + let mut kmeans_params = + KMeansParams::new(centroids, options.max_iters as u32, 1, options.metric_type) + .with_balance_factor(1.0) + .with_on_progress(options.on_progress.clone()); + if has_centroids { + // Incremental refinement already has the full centroid set. The + // hierarchical trainer bootstraps a smaller tree and is only suitable + // for the initial training pass. + kmeans_params = kmeans_params.with_hierarchical_k(1); + } + lance_index::vector::kmeans::train_kmeans::( + data, + kmeans_params, + options.dimension, + options.num_partitions, + options.sample_rate, + ) +} + +fn train_ivf_kmeans_step_arrow_array_no_loss( + centroids: Option>, + data: &FixedSizeListArray, + metric_type: MetricType, + num_partitions: usize, + sample_rate: usize, + max_iters: usize, + on_progress: Arc, +) -> Result { + let dimension = data.value_length() as usize; + let values = data.values(); + let step_options = KMeansStepOptions { + dimension, + metric_type, + num_partitions, + sample_rate, + max_iters, + on_progress, + }; + let kmeans = match (values.data_type(), metric_type) { + (DataType::Float16, _) => train_ivf_kmeans_step::( + centroids, + values.as_primitive::(), + &step_options, + )?, + (DataType::Float32, _) => train_ivf_kmeans_step::( + centroids, + values.as_primitive::(), + &step_options, + )?, + (DataType::Float64, _) => train_ivf_kmeans_step::( + centroids, + values.as_primitive::(), + &step_options, + )?, + (DataType::Int8, DistanceType::L2) + | (DataType::Int8, DistanceType::Dot) + | (DataType::Int8, DistanceType::Cosine) => { + let data = data.convert_to_floating_point()?; + train_ivf_kmeans_step::( + centroids, + data.values().as_primitive::(), + &step_options, + )? + } + (DataType::UInt8, DistanceType::Hamming) => train_ivf_kmeans_step::( + centroids, + values.as_primitive::(), + &step_options, + )?, + _ => Err(Error::index(format!( + "KMeans: can not train data type {} with distance type: {}", + values.data_type(), + metric_type + )))?, + }; + Ok(kmeans) +} + +fn accumulate_refine_assignments( + data: &FixedSizeListArray, + centroids: &FixedSizeListArray, + cluster_sums: &mut [f32], + cluster_weights: &mut [f64], +) -> Result { + let dimension = data.value_length() as usize; + let kmeans = KMeans::with_centroids( + centroids.values().clone(), + dimension, + DistanceType::L2, + f64::MAX, + ); + let (membership, distances) = kmeans.compute_membership_and_distances(data)?; + let data_values = data.values().as_primitive::().values(); + let mut loss = 0.0; + + for row_idx in 0..data.len() { + let (Some(cluster_id), Some(distance)) = (membership[row_idx], distances[row_idx]) else { + continue; + }; + let cluster_id = cluster_id as usize; + cluster_weights[cluster_id] += 1.0; + loss += distance as f64; + let vector = &data_values[row_idx * dimension..(row_idx + 1) * dimension]; + let sum = &mut cluster_sums[cluster_id * dimension..(cluster_id + 1) * dimension]; + for (sum, value) in sum.iter_mut().zip(vector) { + *sum += *value; + } + } + + Ok(loss) +} + +fn update_refined_centroids( + centroids: &FixedSizeListArray, + cluster_sums: &[f32], + cluster_weights: &[f64], +) -> Result { + let dimension = centroids.value_length() as usize; + let mut next = centroids + .values() + .as_primitive::() + .values() + .to_vec(); + for cluster_id in 0..centroids.len() { + let weight = cluster_weights[cluster_id]; + if weight <= 0.0 { + continue; + } + let centroid = &mut next[cluster_id * dimension..(cluster_id + 1) * dimension]; + let sum = &cluster_sums[cluster_id * dimension..(cluster_id + 1) * dimension]; + for (value, sum) in centroid.iter_mut().zip(sum) { + *value = *sum / weight as f32; + } + } + f32_fsl_from_values(next, dimension) +} + +async fn refine_streaming_f32_kmeans_with_sampler( + sampler: &FixedIvfTrainingSampler<'_>, + metric_type: MetricType, + streaming_sample_size: usize, + sample_ranges: &FixedIvfTrainingRanges, + initial_centroids: &FixedSizeListArray, + passes: usize, + on_progress: Arc, +) -> Result { + let dimension = initial_centroids.value_length() as usize; + let mut centroids = initial_centroids.clone(); + for pass in 1..=passes { + let mut cluster_sums = vec![0.0_f32; centroids.len() * dimension]; + let mut cluster_weights = vec![0.0_f64; centroids.len()]; + let mut loss = 0.0; + let mut row_offset = 0; + while row_offset < sample_ranges.num_rows() { + let ranges = sample_ranges.chunk(row_offset, streaming_sample_size.max(1)); + row_offset += ranges.iter().map(range_len).sum::(); + let (training_data, mt) = sampler.sample_ranges(&ranges, metric_type).await?; + let training_data = if training_data.value_type() == DataType::Float32 { + training_data + } else { + training_data.convert_to_floating_point()? + }; + if mt != DistanceType::L2 { + return Err(Error::invalid_input(format!( + "streaming IVF refinement currently supports L2/Cosine training, got {}", + metric_type + ))); + } + loss += accumulate_refine_assignments( + &training_data, + ¢roids, + &mut cluster_sums, + &mut cluster_weights, + )?; + } + centroids = update_refined_centroids(¢roids, &cluster_sums, &cluster_weights)?; + on_progress(pass as u32, passes as u32); + info!( + "Streaming IVF raw-vector refinement pass {} / {} assigned {} vectors; pre-update loss={}", + pass, + passes, + cluster_weights.iter().sum::() as usize, + loss + ); + } + Ok(centroids) +} + +#[allow(clippy::too_many_arguments)] +async fn refine_streaming_f32_kmeans_with_resampling( + dataset: &Dataset, + column: &str, + metric_type: MetricType, + total_sample_rate: usize, + streaming_sample_rate: usize, + num_partitions: usize, + initial_centroids: &FixedSizeListArray, + fragment_ids: Option<&[u32]>, + passes: usize, + on_progress: Arc, +) -> Result { + let dimension = initial_centroids.value_length() as usize; + let mut centroids = initial_centroids.clone(); + for pass in 1..=passes { + let mut cluster_sums = vec![0.0_f32; centroids.len() * dimension]; + let mut cluster_weights = vec![0.0_f64; centroids.len()]; + let mut remaining_sample_rate = total_sample_rate; + let mut loss = 0.0; + while remaining_sample_rate > 0 { + let step_sample_rate = remaining_sample_rate.min(streaming_sample_rate); + let step_sample_size = num_partitions * step_sample_rate; + let (training_data, mt) = sample_ivf_training_chunk( + dataset, + column, + step_sample_size, + metric_type, + fragment_ids, + ) + .await?; + let training_data = if training_data.value_type() == DataType::Float32 { + training_data + } else { + training_data.convert_to_floating_point()? + }; + if mt != DistanceType::L2 { + return Err(Error::invalid_input(format!( + "streaming IVF refinement currently supports L2/Cosine training, got {}", + metric_type + ))); + } + loss += accumulate_refine_assignments( + &training_data, + ¢roids, + &mut cluster_sums, + &mut cluster_weights, + )?; + remaining_sample_rate -= step_sample_rate; + } + centroids = update_refined_centroids(¢roids, &cluster_sums, &cluster_weights)?; + on_progress(pass as u32, passes as u32); + info!( + "Streaming IVF resampled raw-vector refinement pass {} / {} assigned {} vectors; pre-update loss={}", + pass, + passes, + cluster_weights.iter().sum::() as usize, + loss + ); + } + Ok(centroids) +} + +fn f32_fsl_from_values(values: Vec, dimension: usize) -> Result { + Ok(FixedSizeListArray::try_new_from_values( + Float32Array::from(values), + dimension as i32, + )?) +} + +struct WeightedCoreset { + values: Vec, + weights: Vec, + losses: Vec, +} + +impl WeightedCoreset { + fn new(dimension: usize, capacity: usize) -> Self { + Self { + values: Vec::with_capacity(capacity * dimension), + weights: Vec::with_capacity(capacity), + losses: Vec::with_capacity(capacity), + } + } + + fn len(&self) -> usize { + self.weights.len() + } + + fn push(&mut self, centroid: &[f32], weight: f64, loss: f64) { + if weight <= 0.0 { + return; + } + self.values.extend_from_slice(centroid); + self.weights.push(weight); + self.losses.push(loss); + } + + fn append(&mut self, other: Self) { + self.values.extend(other.values); + self.weights.extend(other.weights); + self.losses.extend(other.losses); + } + + fn into_fsl_parts(self, dimension: usize) -> Result<(FixedSizeListArray, Vec, Vec)> { + Ok(( + f32_fsl_from_values(self.values, dimension)?, + self.weights, + self.losses, + )) + } + + fn reduce_to_budget(&mut self, dimension: usize, budget: usize) { + if self.len() <= budget { + return; + } + let total_weight = self.weights.iter().sum::(); + if total_weight <= 0.0 { + *self = Self::new(dimension, budget); + return; + } + + let mut weighted_sums = vec![0.0_f64; dimension]; + let mut weighted_square_sums = vec![0.0_f64; dimension]; + for (row_idx, vector) in self.values.chunks_exact(dimension).enumerate() { + let weight = self.weights[row_idx]; + for dim in 0..dimension { + let value = vector[dim] as f64; + weighted_sums[dim] += weight * value; + weighted_square_sums[dim] += weight * value * value; + } + } + let split_dim = (0..dimension) + .max_by(|left, right| { + let left_mean = weighted_sums[*left] / total_weight; + let right_mean = weighted_sums[*right] / total_weight; + let left_var = weighted_square_sums[*left] / total_weight - left_mean * left_mean; + let right_var = + weighted_square_sums[*right] / total_weight - right_mean * right_mean; + left_var + .partial_cmp(&right_var) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .unwrap_or(0); + + let mut indices = (0..self.len()).collect::>(); + indices.sort_unstable_by(|left, right| { + self.values[left * dimension + split_dim] + .partial_cmp(&self.values[right * dimension + split_dim]) + .unwrap_or(std::cmp::Ordering::Equal) + .then_with(|| left.cmp(right)) + }); + + let mut reduced = Self::new(dimension, budget); + for group_idx in 0..budget { + let group_start = group_idx * indices.len() / budget; + let group_end = (group_idx + 1) * indices.len() / budget; + if group_start == group_end { + continue; + } + let mut weight_sum = 0.0; + let centroid_start = reduced.values.len(); + reduced.values.resize(centroid_start + dimension, 0.0); + { + let centroid = &mut reduced.values[centroid_start..centroid_start + dimension]; + for &idx in &indices[group_start..group_end] { + let weight = self.weights[idx]; + weight_sum += weight; + let vector = &self.values[idx * dimension..(idx + 1) * dimension]; + for (sum, value) in centroid.iter_mut().zip(vector) { + *sum += *value * weight as f32; + } + } + } + if weight_sum <= 0.0 { + reduced.values.truncate(centroid_start); + continue; + } + { + let centroid = &mut reduced.values[centroid_start..centroid_start + dimension]; + for value in centroid { + *value /= weight_sum as f32; + } + } + + let mut loss = 0.0; + let centroid = &reduced.values[centroid_start..centroid_start + dimension]; + for &idx in &indices[group_start..group_end] { + let vector = &self.values[idx * dimension..(idx + 1) * dimension]; + let dist = vector + .iter() + .zip(centroid) + .map(|(left, right)| { + let diff = left - right; + diff * diff + }) + .sum::() as f64; + loss += self.losses[idx] + self.weights[idx] * dist; + } + reduced.weights.push(weight_sum); + reduced.losses.push(loss); + } + *self = reduced; + } +} + +struct WeightedKMeansResult { + centroids: Vec, + membership: Vec>, + cluster_weights: Vec, + cluster_losses: Vec, + loss: f64, +} + +fn initialize_weighted_centroids( + data_values: &[f32], + dimension: usize, + k: usize, + n: usize, + weights: &[f64], +) -> Vec { + let mut rng = SmallRng::seed_from_u64(0x1f17_5eed); + let mut centroids = Vec::with_capacity(k * dimension); + let mut selected = vec![false; n]; + let total_weight = weights.iter().copied().sum::(); + let first = if total_weight > 0.0 { + let mut threshold = rng.random::() * total_weight; + let mut row_idx = 0; + for (idx, weight) in weights.iter().enumerate() { + threshold -= *weight; + if threshold <= 0.0 { + row_idx = idx; + break; + } + } + row_idx + } else { + 0 + }; + selected[first] = true; + centroids.extend_from_slice(&data_values[first * dimension..(first + 1) * dimension]); + + let mut min_distances = vec![f64::MAX; n]; + while centroids.len() / dimension < k { + let last_centroid = ¢roids[centroids.len() - dimension..centroids.len()]; + for row_idx in 0..n { + if selected[row_idx] { + min_distances[row_idx] = 0.0; + continue; + } + let vector = &data_values[row_idx * dimension..(row_idx + 1) * dimension]; + let distance = vector + .iter() + .zip(last_centroid) + .map(|(left, right)| { + let diff = left - right; + diff * diff + }) + .sum::() as f64; + min_distances[row_idx] = min_distances[row_idx].min(distance); + } + + let weighted_distance_sum = min_distances + .iter() + .zip(weights) + .map(|(distance, weight)| distance * weight) + .sum::(); + let next = if weighted_distance_sum > 0.0 { + let mut threshold = rng.random::() * weighted_distance_sum; + let mut row_idx = None; + for idx in 0..n { + if selected[idx] { + continue; + } + threshold -= min_distances[idx] * weights[idx]; + if threshold <= 0.0 { + row_idx = Some(idx); + break; + } + } + row_idx + } else { + None + } + .or_else(|| (0..n).find(|idx| !selected[*idx])); + + let Some(next) = next else { + break; + }; + selected[next] = true; + centroids.extend_from_slice(&data_values[next * dimension..(next + 1) * dimension]); + } + + while centroids.len() / dimension < k { + let row_idx = (centroids.len() / dimension) * n / k; + centroids.extend_from_slice(&data_values[row_idx * dimension..(row_idx + 1) * dimension]); + } + centroids +} + +fn assign_weighted_f32_points( + data: &FixedSizeListArray, + weights: &[f64], + base_losses: &[f64], + centroid_values: &[f32], + metric_type: MetricType, +) -> Result { + let dimension = data.value_length() as usize; + let k = centroid_values.len() / dimension; + let centroids = Arc::new(Float32Array::from(centroid_values.to_vec())) as ArrayRef; + let kmeans = KMeans::with_centroids(centroids, dimension, metric_type, f64::MAX); + let (membership, distances) = kmeans.compute_membership_and_distances(data)?; + let data_values = data.values().as_primitive::().values(); + let mut centroid_sums = vec![0.0_f32; k * dimension]; + let mut cluster_weights = vec![0.0; k]; + let mut cluster_losses = vec![0.0; k]; + + for row_idx in 0..data.len() { + let Some(cluster_id) = membership[row_idx] else { + continue; + }; + let Some(distance) = distances[row_idx] else { + continue; + }; + let cluster_id = cluster_id as usize; + let weight = weights[row_idx]; + cluster_weights[cluster_id] += weight; + cluster_losses[cluster_id] += base_losses[row_idx] + weight * distance as f64; + let vector = &data_values[row_idx * dimension..(row_idx + 1) * dimension]; + let centroid_sum = &mut centroid_sums[cluster_id * dimension..(cluster_id + 1) * dimension]; + for (sum, value) in centroid_sum.iter_mut().zip(vector) { + *sum += *value * weight as f32; + } + } + + let mut next_centroids = vec![0.0_f32; k * dimension]; + for cluster_id in 0..k { + let next_centroid = + &mut next_centroids[cluster_id * dimension..(cluster_id + 1) * dimension]; + if cluster_weights[cluster_id] > 0.0 { + let centroid_sum = ¢roid_sums[cluster_id * dimension..(cluster_id + 1) * dimension]; + for (value, sum) in next_centroid.iter_mut().zip(centroid_sum) { + *value = *sum / cluster_weights[cluster_id] as f32; + } + } else { + next_centroid.copy_from_slice( + ¢roid_values[cluster_id * dimension..(cluster_id + 1) * dimension], + ); + } + } + + let loss = cluster_losses.iter().sum(); + Ok(WeightedKMeansResult { + centroids: next_centroids, + membership, + cluster_weights, + cluster_losses, + loss, + }) +} + +fn train_weighted_f32_kmeans( + data: &FixedSizeListArray, + weights: &[f64], + base_losses: &[f64], + k: usize, + metric_type: MetricType, + max_iters: usize, + on_progress: Arc, +) -> Result { + if data.len() < k { + return Err(Error::invalid_input(format!( + "weighted kmeans requires at least {k} coreset rows, got {}", + data.len() + ))); + } + if weights.len() != data.len() || base_losses.len() != data.len() { + return Err(Error::invalid_input(format!( + "weighted kmeans input lengths do not match: data={}, weights={}, losses={}", + data.len(), + weights.len(), + base_losses.len() + ))); + } + + let dimension = data.value_length() as usize; + let data_values = data.values().as_primitive::().values(); + let mut centroids = + initialize_weighted_centroids(data_values, dimension, k, data.len(), weights); + let mut previous_loss = f64::MAX; + let max_iters = max_iters.max(1); + for iter in 1..=max_iters { + on_progress(iter as u32, max_iters as u32); + let mut result = + assign_weighted_f32_points(data, weights, base_losses, ¢roids, metric_type)?; + let converged = (previous_loss - result.loss).abs() < 1e-4 * result.loss.max(1.0); + previous_loss = result.loss; + if converged || iter == max_iters { + return Ok(result); + } + centroids = std::mem::take(&mut result.centroids); + } + unreachable!("weighted kmeans runs at least one iteration") +} + +fn refine_weighted_f32_kmeans( + data: &FixedSizeListArray, + weights: &[f64], + base_losses: &[f64], + initial_centroids: &FixedSizeListArray, + metric_type: MetricType, + max_iters: usize, + on_progress: Arc, +) -> Result { + let mut centroids = initial_centroids + .values() + .as_primitive::() + .values() + .to_vec(); + let mut previous_loss = f64::MAX; + let max_iters = max_iters.max(1); + for iter in 1..=max_iters { + on_progress(iter as u32, max_iters as u32); + let mut result = + assign_weighted_f32_points(data, weights, base_losses, ¢roids, metric_type)?; + let converged = (previous_loss - result.loss).abs() < 1e-4 * result.loss.max(1.0); + previous_loss = result.loss; + if converged || iter == max_iters { + return Ok(result); + } + centroids = std::mem::take(&mut result.centroids); + } + unreachable!("weighted kmeans refinement runs at least one iteration") +} + +fn append_local_coreset( + coreset: &mut WeightedCoreset, + data: &FixedSizeListArray, + metric_type: MetricType, + local_k: usize, + max_iters: usize, + on_progress: Arc, +) -> Result<()> { + let dimension = data.value_length() as usize; + let sample_rate = data.len().div_ceil(local_k).max(1); + let kmeans = train_ivf_kmeans_step_arrow_array_no_loss( + None, + data, + metric_type, + local_k, + sample_rate, + max_iters, + on_progress, + )?; + let centroids = FixedSizeListArray::try_new_from_values(kmeans.centroids, dimension as i32)?; + let kmeans = + KMeans::with_centroids(centroids.values().clone(), dimension, metric_type, f64::MAX); + let (membership, distances) = kmeans.compute_membership_and_distances(data)?; + let mut weights = vec![0.0; centroids.len()]; + let mut losses = vec![0.0; centroids.len()]; + for (member, distance) in membership.into_iter().zip(distances) { + let (Some(member), Some(distance)) = (member, distance) else { + continue; + }; + weights[member as usize] += 1.0; + losses[member as usize] += distance as f64; + } + + let centroid_values = centroids.values().as_primitive::().values(); + for centroid_idx in 0..centroids.len() { + coreset.push( + ¢roid_values[centroid_idx * dimension..(centroid_idx + 1) * dimension], + weights[centroid_idx], + losses[centroid_idx], + ); + } + Ok(()) +} + +#[derive(Clone, Debug)] +struct WeightedCluster { + id: usize, + indices: Vec, + centroid: Vec, + weight: f64, + loss: f64, + finalized: bool, +} + +impl Eq for WeightedCluster {} + +impl PartialEq for WeightedCluster { + fn eq(&self, other: &Self) -> bool { + self.loss == other.loss && self.weight == other.weight + } +} + +impl Ord for WeightedCluster { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + match (self.finalized, other.finalized) { + (false, true) => std::cmp::Ordering::Greater, + (true, false) => std::cmp::Ordering::Less, + _ => self + .loss + .partial_cmp(&other.loss) + .unwrap_or(std::cmp::Ordering::Equal) + .then_with(|| { + self.weight + .partial_cmp(&other.weight) + .unwrap_or(std::cmp::Ordering::Equal) + }), + } + } +} + +impl PartialOrd for WeightedCluster { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +fn weighted_subset( + data_values: &[f32], + weights: &[f64], + losses: &[f64], + indices: &[usize], + dimension: usize, +) -> Result<(FixedSizeListArray, Vec, Vec)> { + let mut values = Vec::with_capacity(indices.len() * dimension); + let mut subset_weights = Vec::with_capacity(indices.len()); + let mut subset_losses = Vec::with_capacity(indices.len()); + for &idx in indices { + values.extend_from_slice(&data_values[idx * dimension..(idx + 1) * dimension]); + subset_weights.push(weights[idx]); + subset_losses.push(losses[idx]); + } + Ok(( + f32_fsl_from_values(values, dimension)?, + subset_weights, + subset_losses, + )) +} + +fn train_weighted_hierarchical_f32_kmeans( + data: &FixedSizeListArray, + weights: &[f64], + losses: &[f64], + dimension: usize, + target_k: usize, + metric_type: MetricType, + max_iters: usize, + on_progress: Arc, +) -> Result { + if data.len() == 0 { + return Err(Error::index("empty weighted coreset")); + } + if weights.len() != data.len() || losses.len() != data.len() { + return Err(Error::invalid_input(format!( + "weighted hierarchical kmeans input lengths do not match: data={}, weights={}, losses={}", + data.len(), + weights.len(), + losses.len() + ))); + } + + let initial_k = 16_usize.min(target_k).min(data.len()).max(1); + let initial = train_weighted_f32_kmeans( + data, + weights, + losses, + initial_k, + metric_type, + max_iters, + on_progress.clone(), + )?; + + let centroids = initial.centroids; + let mut heap = std::collections::BinaryHeap::new(); + let mut next_cluster_id = 0; + for cluster_id in 0..initial_k { + let mut indices = Vec::new(); + for (row_idx, member) in initial.membership.iter().enumerate() { + if member.is_some_and(|member| member as usize == cluster_id) { + indices.push(row_idx); + } + } + if !indices.is_empty() { + heap.push(WeightedCluster { + id: next_cluster_id, + indices, + centroid: centroids[cluster_id * dimension..(cluster_id + 1) * dimension].to_vec(), + weight: initial.cluster_weights[cluster_id], + loss: initial.cluster_losses[cluster_id], + finalized: false, + }); + next_cluster_id += 1; + } + } + + let data_values = data.values().as_primitive::().values(); + while heap.len() < target_k { + let mut cluster = heap + .pop() + .ok_or_else(|| Error::index("No weighted cluster can be further split"))?; + if cluster.finalized || cluster.indices.len() <= 1 { + cluster.finalized = true; + heap.push(cluster); + break; + } + + let remaining_k = target_k - heap.len(); + let cluster_k = if cluster.indices.len() <= 16 { + 2.min(remaining_k).min(cluster.indices.len()) + } else { + (cluster.indices.len() / 16).min(remaining_k).clamp(2, 16) + }; + let (sub_data, sub_weights, sub_losses) = + weighted_subset(data_values, weights, losses, &cluster.indices, dimension)?; + let split = train_weighted_f32_kmeans( + &sub_data, + &sub_weights, + &sub_losses, + cluster_k, + metric_type, + max_iters.min(20), + on_progress.clone(), + )?; + + let mut assignments = vec![Vec::new(); cluster_k]; + let mut first_member = None; + let mut all_same = true; + for (local_idx, member) in split.membership.iter().enumerate() { + let Some(member) = member else { + continue; + }; + if first_member.is_some_and(|first| first != *member) { + all_same = false; + } else if first_member.is_none() { + first_member = Some(*member); + } + assignments[*member as usize].push(cluster.indices[local_idx]); + } + if all_same { + cluster.finalized = true; + heap.push(cluster); + continue; + } + + for (child_id, child_indices) in assignments.into_iter().enumerate() { + if child_indices.is_empty() { + continue; + } + heap.push(WeightedCluster { + id: next_cluster_id, + indices: child_indices, + centroid: split.centroids[child_id * dimension..(child_id + 1) * dimension] + .to_vec(), + weight: split.cluster_weights[child_id], + loss: split.cluster_losses[child_id], + finalized: false, + }); + next_cluster_id += 1; + } + } + + let mut clusters = heap.into_vec(); + clusters.sort_by_key(|cluster| cluster.id); + while clusters.len() < target_k { + let duplicate = clusters + .iter() + .max_by(|left, right| { + left.weight + .partial_cmp(&right.weight) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .cloned() + .ok_or_else(|| Error::index("No weighted clusters were trained"))?; + clusters.push(WeightedCluster { + id: next_cluster_id, + ..duplicate + }); + next_cluster_id += 1; + } + clusters.truncate(target_k); + + let mut values = Vec::with_capacity(target_k * dimension); + for cluster in clusters { + values.extend_from_slice(&cluster.centroid); + } + f32_fsl_from_values(values, dimension) +} + +async fn train_streaming_coreset_ivf_model( + dataset: &Dataset, + column: &str, + dimension: usize, + metric_type: MetricType, + params: &IvfBuildParams, + fragment_ids: Option<&[u32]>, + progress: std::sync::Arc, +) -> Result { + let num_partitions = params.num_partitions.unwrap_or(32); + let streaming_sample_rate = params.streaming_sample_rate.unwrap(); + let total_sample_rate = params.sample_rate; + let mut remaining_sample_rate = total_sample_rate; + let mut max_training_vectors = 0; + let mut total_training_vectors = 0; + let fixed_sampler = if fragment_ids.is_none() { + FixedIvfTrainingSampler::try_new(dataset, column)? + } else { + None + }; + let fixed_sample_ranges = if let Some(sampler) = &fixed_sampler { + let num_rows = dataset.count_rows(None).await?; + let sample_size = num_rows.min(num_partitions * total_sample_rate); + Some(generate_fixed_training_ranges( + num_rows, + sample_size, + dataset.object_store.as_ref().block_size(), + sampler.byte_width, + )) + } else { + None + }; + let mut sample_offset = 0; + + let (progress_tx, mut progress_rx) = mpsc::unbounded_channel::(); + let progress_worker = { + let progress = progress.clone(); + tokio::spawn(async move { + while let Some(iter) = progress_rx.recv().await { + if let Err(e) = progress.stage_progress("train_ivf", iter).await { + warn!("Progress callback error during train_ivf: {e}"); + } + } + }) + }; + + let on_progress: Arc = { + let progress_tx = progress_tx.clone(); + let cumulative_iters = std::sync::atomic::AtomicU64::new(0); + let progress_interval = streaming_ivf_progress_interval(); + Arc::new(move |_iter: u32, _max_iters: u32| { + let total = cumulative_iters.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1; + if should_report_streaming_ivf_progress(total, progress_interval) { + let _ = progress_tx.send(total); + } + }) + }; + + let coreset_rate = streaming_coreset_rate( + total_sample_rate, + streaming_sample_rate, + params.streaming_coreset_rate, + ); + let coreset_budget = num_partitions + .saturating_mul(coreset_rate) + .max(num_partitions); + let total_steps = total_sample_rate.div_ceil(streaming_sample_rate); + let decoupled_coreset_budget = params.streaming_coreset_rate.is_some(); + let mut coreset = WeightedCoreset::new(dimension, coreset_budget.min(num_partitions * 16)); + let mut step = 0; + while remaining_sample_rate > 0 { + let step_sample_rate = remaining_sample_rate.min(streaming_sample_rate); + let step_sample_size = num_partitions * step_sample_rate; + step += 1; + info!( + "Streaming coreset IVF training: step {}, sample_rate={}, sample_size={}", + step, step_sample_rate, step_sample_size + ); + + let (training_data, mt) = if let (Some(sample_ranges), Some(sampler)) = + (&fixed_sample_ranges, &fixed_sampler) + { + let ranges = sample_ranges.chunk(sample_offset, step_sample_size); + sample_offset += ranges.iter().map(range_len).sum::(); + sampler.sample_ranges(&ranges, metric_type).await? + } else { + sample_ivf_training_chunk(dataset, column, step_sample_size, metric_type, fragment_ids) + .await? + }; + let training_data = if training_data.value_type() == DataType::Float32 { + training_data + } else { + training_data.convert_to_floating_point()? + }; + if mt != DistanceType::L2 { + return Err(Error::invalid_input(format!( + "streaming coreset IVF currently supports L2/Cosine training, got {}", + metric_type + ))); + } + if training_data.len() < num_partitions { + return Err(Error::index(format!( + "Not enough training vectors for streaming coreset IVF. Requires at least {} rows but sampled {} rows", + num_partitions, + training_data.len() + ))); + } + + max_training_vectors = max_training_vectors.max(training_data.len()); + total_training_vectors += training_data.len(); + let local_k = streaming_local_coreset_k( + num_partitions, + training_data.len(), + coreset_rate, + total_steps, + decoupled_coreset_budget, + ); + let mut chunk_coreset = WeightedCoreset::new(dimension, local_k); + append_local_coreset( + &mut chunk_coreset, + &training_data, + mt, + local_k, + params.max_iters, + on_progress.clone(), + )?; + coreset.append(chunk_coreset); + coreset.reduce_to_budget(dimension, coreset_budget); + info!( + "Streaming coreset IVF step {} compressed {} vectors into {} weighted centroids", + step, + total_training_vectors, + coreset.len() + ); + remaining_sample_rate -= step_sample_rate; + } + + let coreset_len = coreset.len(); + let (coreset_data, coreset_weights, coreset_losses) = coreset.into_fsl_parts(dimension)?; + let mut centroids = train_weighted_hierarchical_f32_kmeans( + &coreset_data, + &coreset_weights, + &coreset_losses, + dimension, + num_partitions, + DistanceType::L2, + params.max_iters, + on_progress.clone(), + )?; + let refine_iters = 3; + if refine_iters > 0 { + let refined = refine_weighted_f32_kmeans( + &coreset_data, + &coreset_weights, + &coreset_losses, + ¢roids, + DistanceType::L2, + refine_iters, + on_progress.clone(), + )?; + centroids = f32_fsl_from_values(refined.centroids, dimension)?; + } + if params.streaming_refine_passes > 0 { + info!( + "Running {} streaming raw-vector refinement pass(es)", + params.streaming_refine_passes + ); + centroids = + if let (Some(sample_ranges), Some(sampler)) = (&fixed_sample_ranges, &fixed_sampler) { + refine_streaming_f32_kmeans_with_sampler( + sampler, + metric_type, + num_partitions * streaming_sample_rate, + sample_ranges, + ¢roids, + params.streaming_refine_passes, + on_progress.clone(), + ) + .await? + } else { + refine_streaming_f32_kmeans_with_resampling( + dataset, + column, + metric_type, + total_sample_rate, + streaming_sample_rate, + num_partitions, + ¢roids, + fragment_ids, + params.streaming_refine_passes, + on_progress.clone(), + ) + .await? + }; + } + + drop(progress_tx); + drop(on_progress); + if let Err(e) = progress_worker.await { + warn!("Progress worker join error during train_ivf: {e}"); + } + + info!( + "Streaming coreset IVF sampled {} vectors total; max in-memory training vectors per step: {}; coreset vectors: {}", + total_training_vectors, max_training_vectors, coreset_len + ); + + Ok(IvfModel::new(centroids, None)) +} + +async fn train_streaming_ivf_model( + dataset: &Dataset, + column: &str, + dimension: usize, + metric_type: MetricType, + params: &IvfBuildParams, + fragment_ids: Option<&[u32]>, + progress: std::sync::Arc, +) -> Result { + let num_partitions = params.num_partitions.unwrap_or(32); + if num_partitions > 256 { + return train_streaming_coreset_ivf_model( + dataset, + column, + dimension, + metric_type, + params, + fragment_ids, + progress, + ) + .await; + } + let streaming_sample_rate = params.streaming_sample_rate.unwrap(); + let total_sample_rate = params.sample_rate; + let mut remaining_sample_rate = total_sample_rate; + let mut centroids = params.centroids.clone(); + let mut max_training_vectors = 0; + let mut total_training_vectors = 0; + + let (progress_tx, mut progress_rx) = mpsc::unbounded_channel::(); + let progress_worker = { + let progress = progress.clone(); + tokio::spawn(async move { + while let Some(iter) = progress_rx.recv().await { + if let Err(e) = progress.stage_progress("train_ivf", iter).await { + warn!("Progress callback error during train_ivf: {e}"); + } + } + }) + }; + + let on_progress: Arc = { + let progress_tx = progress_tx.clone(); + let cumulative_iters = std::sync::atomic::AtomicU64::new(0); + Arc::new(move |_iter: u32, _max_iters: u32| { + let total = cumulative_iters.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1; + let _ = progress_tx.send(total); + }) + }; + + let mut step = 0; + while remaining_sample_rate > 0 { + let step_sample_rate = remaining_sample_rate.min(streaming_sample_rate); + let step_sample_size = num_partitions * step_sample_rate; + step += 1; + info!( + "Streaming IVF training: step {}, sample_rate={}, sample_size={}", + step, step_sample_rate, step_sample_size + ); + + let (training_data, mt) = + sample_ivf_training_chunk(dataset, column, step_sample_size, metric_type, fragment_ids) + .await?; + if training_data.len() < num_partitions { + return Err(Error::index(format!( + "Not enough training vectors for streaming IVF. Requires at least {} rows but sampled {} rows", + num_partitions, + training_data.len() + ))); + } + + max_training_vectors = max_training_vectors.max(training_data.len()); + total_training_vectors += training_data.len(); + if params.sample_rate >= 1024 && training_data.value_type() == DataType::Float16 { + warn!( + "Large sample_rate ({} >= 1024) for float16 vectors is possible to result in all zeros cluster centroid", + params.sample_rate + ); + } + + let kmeans = train_ivf_kmeans_step_arrow_array_no_loss( + centroids.clone(), + &training_data, + mt, + num_partitions, + step_sample_rate, + params.max_iters, + on_progress.clone(), + )?; + let trained_centroids = Arc::new(FixedSizeListArray::try_new_from_values( + kmeans.centroids, + dimension as i32, + )?); + centroids = Some(trained_centroids); + + remaining_sample_rate -= step_sample_rate; + } + + drop(progress_tx); + drop(on_progress); + if let Err(e) = progress_worker.await { + warn!("Progress worker join error during train_ivf: {e}"); + } + + info!( + "Streaming IVF training sampled {} vectors total; max in-memory training vectors per step: {}", + total_training_vectors, max_training_vectors + ); + + let centroids = centroids.ok_or_else(|| Error::index("No IVF centroids trained"))?; + Ok(IvfModel::new((*centroids).clone(), None)) +} + +/// Train IVF partitions using kmeans. +async fn train_ivf_model( + centroids: Option>, + data: &FixedSizeListArray, + distance_type: DistanceType, + params: &IvfBuildParams, + progress: std::sync::Arc, +) -> Result { + assert!( + distance_type != DistanceType::Cosine, + "Cosine metric should be done by normalized L2 distance", + ); + let values = data.values(); + let dim = data.value_length() as usize; + match (values.data_type(), distance_type) { + (DataType::Float16, _) => { + do_train_ivf_model::( + centroids, + values.as_primitive::(), + dim, + distance_type, + params, + progress.clone(), + ) + .await + } + (DataType::Float32, _) => { + do_train_ivf_model::( + centroids, + values.as_primitive::(), + dim, + distance_type, + params, + progress.clone(), + ) + .await + } + (DataType::Float64, _) => { + do_train_ivf_model::( + centroids, + values.as_primitive::(), + dim, + distance_type, + params, + progress.clone(), + ) + .await + } + (DataType::Int8, DistanceType::L2) + | (DataType::Int8, DistanceType::Dot) + | (DataType::Int8, DistanceType::Cosine) => { + do_train_ivf_model::( + centroids, + data.convert_to_floating_point()? + .values() + .as_primitive::(), + dim, + distance_type, + params, + progress.clone(), + ) + .await + } + (DataType::UInt8, DistanceType::Hamming) => { + do_train_ivf_model::( + centroids, + values.as_primitive::(), + dim, + distance_type, + params, + progress.clone(), + ) + .await + } + _ => Err(Error::index(format!( + "Unsupported data type {} with distance type {}", + values.data_type(), + distance_type + ))), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::collections::HashSet; + use std::iter::repeat_n; + use std::ops::Range; + + use arrow_array::types::UInt64Type; + use arrow_array::{ + FixedSizeListArray, Float16Array, Float32Array, RecordBatch, RecordBatchIterator, + RecordBatchReader, UInt64Array, make_array, + }; + use arrow_buffer::{BooleanBuffer, NullBuffer}; + use arrow_schema::{DataType, Field, Schema}; + use half::f16; + use itertools::Itertools; + use lance_core::ROW_ID; + use lance_core::utils::address::RowAddress; + use lance_core::utils::tempfile::TempStrDir; + use lance_datagen::{ArrayGeneratorExt, BatchCount, Dimension, RowCount, array, gen_batch}; + use lance_index::VECTOR_INDEX_VERSION; + use lance_index::metrics::NoOpMetricsCollector; + use lance_index::vector::sq::builder::SQBuildParams; + use lance_linalg::distance::l2_distance_batch; + use lance_testing::datagen::{ + generate_random_array, generate_random_array_with_range, generate_random_array_with_seed, + generate_scaled_random_array, sample_without_replacement, + }; + use rand::{rng, seq::SliceRandom}; + use rstest::rstest; + + use crate::dataset::{InsertBuilder, WriteMode, WriteParams}; + use crate::index::prefilter::DatasetPreFilter; + use crate::index::vector::IndexFileVersion; + use crate::index::vector_index_details_default; + use crate::index::{DatasetIndexExt, DatasetIndexInternalExt, vector::VectorIndexParams}; + use crate::utils::test::copy_test_data_to_tmp; + + const DIM: usize = 32; + + async fn compute_test_ivf_loss(dataset: &Dataset, column: &str, ivf: &IvfModel) -> f64 { + let centroids = ivf + .centroids_array() + .expect("test IVF model should include centroids"); + let mut scanner = dataset.scan(); + scanner.project(&[column]).unwrap(); + let batch = scanner.try_into_batch().await.unwrap(); + let data = batch + .column_by_name(column) + .expect("test vector column should exist") + .as_fixed_size_list() + .clone(); + let kmeans = KMeans::with_centroids( + centroids.values().clone(), + centroids.value_length() as usize, + DistanceType::L2, + f64::MAX, + ); + kmeans.compute_loss(&data).unwrap() + } + + // Verifies LANCE_INCLUDE_VECTOR_CENTROIDS env var is honored by + // maybe_centroids_for_stats. The env var is process-global, so this test + // is serialized against any other test that touches the same key. + #[test] + #[serial_test::serial(LANCE_INCLUDE_VECTOR_CENTROIDS)] + fn test_maybe_centroids_for_stats_env_var() { + let centroids = Float32Array::from(vec![1.0_f32, 2.0, 3.0, 4.0]); + let centroids = FixedSizeListArray::try_new_from_values(centroids, 2).unwrap(); + let expected = Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); + + // Save the original value so we can restore it afterwards. + let original = std::env::var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV).ok(); + + // Unset → centroids included (with one-time warning). + unsafe { + std::env::remove_var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV); + } + assert_eq!(maybe_centroids_for_stats(¢roids).unwrap(), expected); + + // Truthy values → centroids included. + for truthy in ["1", "true", "TRUE", "on", "yes", "y"] { + unsafe { + std::env::set_var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV, truthy); + } + assert_eq!( + maybe_centroids_for_stats(¢roids).unwrap(), + expected, + "expected centroids to be included for {truthy:?}", + ); + } + + // Non-truthy values → centroids omitted. + for falsy in ["0", "false", "FALSE", "no", "off"] { + unsafe { + std::env::set_var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV, falsy); + } + assert_eq!( + maybe_centroids_for_stats(¢roids).unwrap(), + None, + "expected centroids to be omitted for {falsy:?}", + ); + } + + // Restore original value. + unsafe { + match original { + Some(v) => std::env::set_var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV, v), + None => std::env::remove_var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV), + } + } + } + + // Verifies that when centroids are omitted via the env var, the + // serialized stats JSON does not contain the `centroids` field at all + // (instead of an explicit null), since downstream code distinguishes + // missing from null. + #[test] + #[serial_test::serial(LANCE_INCLUDE_VECTOR_CENTROIDS)] + fn test_stats_centroids_omitted_when_disabled() { + let original = std::env::var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV).ok(); + + unsafe { + std::env::set_var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV, "false"); + } + let stats = IvfIndexStatistics { + index_type: "IVF_PQ".to_string(), + uuid: "uuid".to_string(), + uri: "uri".to_string(), + metric_type: "l2".to_string(), + num_partitions: 0, + sub_index: serde_json::Value::Null, + partitions: vec![], + centroids: None, + loss: None, + index_file_version: IndexFileVersion::V3, + }; + let json = serde_json::to_value(&stats).unwrap(); + assert!(json.get("centroids").is_none()); + + unsafe { + match original { + Some(v) => std::env::set_var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV, v), + None => std::env::remove_var(LANCE_INCLUDE_VECTOR_CENTROIDS_ENV), + } + } + } + + /// This goal of this function is to generate data that behaves in a very deterministic way so that + /// we can evaluate the correctness of an IVF_PQ implementation. Currently it is restricted to the + /// L2 distance metric. /// /// First, we generate a set of centroids. These are generated randomly but we ensure that is /// sufficient distance between each of the centroids. @@ -3746,6 +5388,180 @@ mod tests { } } + #[tokio::test] + async fn test_build_ivf_model_streaming_training() { + let test_dir = TempStrDir::default(); + let uri = format!("{}/ds", test_dir.as_str()); + let reader = gen_batch() + .col("id", array::step::()) + .col("vector", array::rand_vec::(32.into())) + .into_reader_rows(RowCount::from(512), BatchCount::from(2)); + let dataset = Dataset::write(reader, &uri, None).await.unwrap(); + + let mut params = IvfBuildParams::new(8); + params.sample_rate = 16; + params.streaming_sample_rate = Some(4); + params.streaming_refine_passes = 1; + params.max_iters = 2; + + let ivf_model = build_ivf_model( + &dataset, + "vector", + 32, + MetricType::L2, + ¶ms, + None, + lance_index::progress::noop_progress(), + ) + .await + .unwrap(); + + assert_eq!(ivf_model.num_partitions(), 8); + assert_eq!(ivf_model.dimension(), 32); + assert!(ivf_model.loss().is_none()); + assert!( + compute_test_ivf_loss(&dataset, "vector", &ivf_model) + .await + .is_finite() + ); + } + + #[tokio::test] + async fn test_build_ivf_model_streaming_training_large_partitions() { + let test_dir = TempStrDir::default(); + let uri = format!("{}/ds", test_dir.as_str()); + let reader = gen_batch() + .col("id", array::step::()) + .col("vector", array::rand_vec::(8.into())) + .into_reader_rows(RowCount::from(640), BatchCount::from(2)); + let dataset = Dataset::write(reader, &uri, None).await.unwrap(); + + let mut params = IvfBuildParams::new(320); + params.sample_rate = 2; + params.streaming_sample_rate = Some(1); + params.max_iters = 1; + + let ivf_model = build_ivf_model( + &dataset, + "vector", + 8, + MetricType::L2, + ¶ms, + None, + lance_index::progress::noop_progress(), + ) + .await + .unwrap(); + + assert_eq!(ivf_model.num_partitions(), 320); + assert_eq!(ivf_model.dimension(), 8); + assert!(ivf_model.loss().is_none()); + assert!( + compute_test_ivf_loss(&dataset, "vector", &ivf_model) + .await + .is_finite() + ); + } + + #[test] + fn test_fixed_training_ranges_are_sorted_and_bounded() { + let ranges = generate_fixed_training_ranges(10_000, 1_234, 1_024, 16); + assert_eq!(ranges.num_rows(), 1_234); + assert!(ranges.ranges.iter().all(|range| { + range.start < range.end && range.end <= 10_000 && range_len(range) <= 1_234 + })); + assert!( + ranges + .ranges + .windows(2) + .all(|pair| pair[0].end < pair[1].start) + ); + + let all_rows = generate_fixed_training_ranges(128, 256, 1_024, 16); + assert_eq!(all_rows.ranges, vec![0..128]); + assert_eq!(all_rows.num_rows(), 128); + } + + #[test] + fn test_fixed_training_ranges_chunk_splits_ranges() { + let ranges = FixedIvfTrainingRanges::new(vec![10..20, 30..45]); + assert_eq!(ranges.num_rows(), 25); + assert_eq!(ranges.chunk(0, 5), vec![10..15]); + assert_eq!(ranges.chunk(5, 12), vec![15..20, 30..37]); + assert_eq!(ranges.chunk(20, 10), vec![40..45]); + assert!(ranges.chunk(25, 10).is_empty()); + } + + #[test] + fn test_split_ranges_by_row_count() { + assert_eq!( + split_ranges_by_row_count(&[10..25, 30..33], 8), + vec![10..18, 18..25, 30..33] + ); + assert_eq!( + split_ranges_by_row_count(&[5..8], 0), + vec![5..6, 6..7, 7..8] + ); + assert!(split_ranges_by_row_count(&[], 8).is_empty()); + } + + #[test] + fn test_streaming_ivf_progress_throttle() { + assert!(should_report_streaming_ivf_progress(1, 64)); + assert!(!should_report_streaming_ivf_progress(63, 64)); + assert!(should_report_streaming_ivf_progress(64, 64)); + assert!(should_report_streaming_ivf_progress(128, 64)); + assert!(should_report_streaming_ivf_progress(2, 0)); + } + + #[test] + fn test_streaming_coreset_default_rate_is_bounded_by_stream_rate() { + assert_eq!(default_streaming_coreset_rate(256, 1), 1); + assert_eq!(default_streaming_coreset_rate(256, 16), 16); + assert_eq!(default_streaming_coreset_rate(256, 128), 64); + assert_eq!(default_streaming_coreset_rate(8, 128), 8); + assert_eq!(streaming_coreset_rate(256, 128, Some(16)), 16); + assert_eq!( + streaming_local_coreset_k(1024, 1024 * 128, 16, 2, true), + 1024 * 8 + ); + assert_eq!( + streaming_local_coreset_k(1024, 1024 * 128, 16, 2, false), + 1024 + ); + } + + #[test] + fn test_weighted_coreset_reduction_groups_nearby_centroids() { + let mut coreset = WeightedCoreset::new(1, 4); + coreset.push(&[0.0], 1.0, 0.0); + coreset.push(&[100.0], 1.0, 0.0); + coreset.push(&[1.0], 1.0, 0.0); + coreset.push(&[101.0], 1.0, 0.0); + + coreset.reduce_to_budget(1, 2); + + assert_eq!(coreset.len(), 2); + assert!((coreset.values[0] - 0.5).abs() < 1e-6); + assert!((coreset.values[1] - 100.5).abs() < 1e-6); + assert_eq!(coreset.weights, vec![2.0, 2.0]); + assert!((coreset.losses.iter().sum::() - 1.0).abs() < 1e-6); + } + + #[test] + fn test_weighted_kmeanspp_initialization_selects_distant_centroids() { + let values = vec![0.0, 0.1, 100.0, 101.0]; + let weights = vec![1.0; 4]; + let centroids = initialize_weighted_centroids(&values, 1, 2, 4, &weights); + + assert_eq!(centroids.len(), 2); + assert!( + (centroids[0] - centroids[1]).abs() > 10.0, + "weighted kmeans++ should seed distant coreset regions, got {:?}", + centroids + ); + } + #[tokio::test] async fn test_create_ivf_pq_f16() { let test_dir = TempStrDir::default();