diff --git a/rust/examples/src/ivf_hnsw.rs b/rust/examples/src/ivf_hnsw.rs index c1898e10682..02b5b7490de 100644 --- a/rust/examples/src/ivf_hnsw.rs +++ b/rust/examples/src/ivf_hnsw.rs @@ -5,14 +5,20 @@ //! //! run with `cargo run --release --example hnsw` #![allow(clippy::print_stdout)] +use std::sync::Arc; + use arrow::array::AsArray; use arrow::array::types::Float32Type; +use arrow::array::{FixedSizeListBuilder, Float32Builder, UInt64Array}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::{RecordBatch, RecordBatchIterator}; use clap::Parser; use futures::TryStreamExt; use lance::Dataset; -use lance::dataset::ProjectionRequest; +use lance::dataset::{ProjectionRequest, WriteMode, WriteParams}; use lance::index::DatasetIndexExt; use lance::index::vector::VectorIndexParams; +use lance_core::utils::tempfile::TempStrDir; use lance_index::IndexType; use lance_index::vector::hnsw::builder::HnswBuildParams; use lance_index::vector::ivf::IvfBuildParams; @@ -22,8 +28,8 @@ use lance_linalg::distance::MetricType; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Args { - /// Dataset URI - uri: String, + /// Dataset URI. If omitted, a local temporary vector dataset is generated. + uri: Option, /// Vector column name #[arg(short, long, value_name = "NAME", default_value = "vector")] @@ -52,6 +58,44 @@ struct Args { metric_type: String, } +async fn create_test_vector_dataset( + data_path: &str, + column: &str, +) -> Result<(), Box> { + const NUM_ROWS: usize = 4096; + const DIM: i32 = 64; + + let item_field = Arc::new(Field::new("item", DataType::Float32, true)); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::UInt64, false), + Field::new(column, DataType::FixedSizeList(item_field, DIM), false), + ])); + + let ids = UInt64Array::from((0..NUM_ROWS as u64).collect::>()); + let values = Float32Builder::new(); + let mut vector_builder = FixedSizeListBuilder::new(values, DIM); + for row_id in 0..NUM_ROWS { + for dim in 0..DIM as usize { + let value = ((row_id * 31 + dim * 17) % 997) as f32 / 997.0; + vector_builder.values().append_value(value); + } + vector_builder.append(true); + } + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(ids), Arc::new(vector_builder.finish())], + )?; + let batches = RecordBatchIterator::new([Ok(batch)], schema); + let write_params = WriteParams { + mode: WriteMode::Overwrite, + ..Default::default() + }; + + Dataset::write(batches, data_path, Some(write_params)).await?; + Ok(()) +} + #[cfg(test)] fn ground_truth(mat: &MatrixView, query: &[f32], k: usize) -> HashSet { let mut dists = vec![]; @@ -68,20 +112,33 @@ fn ground_truth(mat: &MatrixView, query: &[f32], k: usize) -> HashS async fn main() { env_logger::init(); let args = Args::parse(); + let tempdir; + let column = args.column.as_deref().unwrap_or("vector"); - let mut dataset = Dataset::open(&args.uri) - .await - .expect("Failed to open dataset"); + let uri = match args.uri.as_deref() { + Some(uri) => uri, + None => { + tempdir = TempStrDir::default(); + let data_path = tempdir.as_ref(); + create_test_vector_dataset(data_path, column) + .await + .expect("Failed to create test vector dataset"); + println!("Generated test vector dataset at {}", data_path); + data_path + } + }; + + let mut dataset = Dataset::open(uri).await.expect("Failed to open dataset"); println!("Dataset schema: {:#?}", dataset.schema()); - let column = args.column.as_deref().unwrap_or("vector"); let metric_type = MetricType::try_from(args.metric_type.as_str()).unwrap(); let mut ivf_params = IvfBuildParams::new(128); ivf_params.sample_rate = 20480; let hnsw_params = HnswBuildParams::default() - .ef_construction(100) - .num_edges(15); + .ef_construction(args.ef) + .max_level(args.max_level) + .num_edges(args.max_edges); let pq_params = SQBuildParams::default(); let params = VectorIndexParams::with_ivf_hnsw_sq_params(metric_type, ivf_params, hnsw_params, pq_params);