Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 66 additions & 7 deletions rust/examples/src/ivf_hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@
//!
//! run with `cargo run --release --example hnsw`
#![allow(clippy::print_stdout)]
use std::sync::Arc;

use arrow::array::AsArray;
use arrow::array::types::Float32Type;
use arrow::array::{FixedSizeListBuilder, Float32Builder, UInt64Array};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::{RecordBatch, RecordBatchIterator};
use clap::Parser;
use futures::TryStreamExt;
use lance::Dataset;
use lance::dataset::ProjectionRequest;
use lance::dataset::{ProjectionRequest, WriteMode, WriteParams};
use lance::index::DatasetIndexExt;
use lance::index::vector::VectorIndexParams;
use lance_core::utils::tempfile::TempStrDir;
use lance_index::IndexType;
use lance_index::vector::hnsw::builder::HnswBuildParams;
use lance_index::vector::ivf::IvfBuildParams;
Expand All @@ -22,8 +28,8 @@
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
/// Dataset URI
uri: String,
/// Dataset URI. If omitted, a local temporary vector dataset is generated.
uri: Option<String>,

/// Vector column name
#[arg(short, long, value_name = "NAME", default_value = "vector")]
Expand Down Expand Up @@ -52,6 +58,44 @@
metric_type: String,
}

async fn create_test_vector_dataset(
data_path: &str,
column: &str,
) -> Result<(), Box<dyn std::error::Error>> {
const NUM_ROWS: usize = 4096;
const DIM: i32 = 64;

let item_field = Arc::new(Field::new("item", DataType::Float32, true));
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt64, false),
Field::new(column, DataType::FixedSizeList(item_field, DIM), false),
]));

let ids = UInt64Array::from((0..NUM_ROWS as u64).collect::<Vec<_>>());
let values = Float32Builder::new();
let mut vector_builder = FixedSizeListBuilder::new(values, DIM);
for row_id in 0..NUM_ROWS {
for dim in 0..DIM as usize {
let value = ((row_id * 31 + dim * 17) % 997) as f32 / 997.0;
vector_builder.values().append_value(value);
}
vector_builder.append(true);
}

let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(ids), Arc::new(vector_builder.finish())],
)?;
let batches = RecordBatchIterator::new([Ok(batch)], schema);
let write_params = WriteParams {
mode: WriteMode::Overwrite,
..Default::default()
};

Dataset::write(batches, data_path, Some(write_params)).await?;
Ok(())
}

#[cfg(test)]
fn ground_truth(mat: &MatrixView<Float32Type>, query: &[f32], k: usize) -> HashSet<u32> {
let mut dists = vec![];
Expand All @@ -68,20 +112,35 @@
async fn main() {
env_logger::init();
let args = Args::parse();
let tempdir;
let column = args.column.as_deref().unwrap_or("vector");

let mut dataset = Dataset::open(&args.uri)
let uri = match args.uri.as_deref() {
Some(uri) => uri,
None => {
tempdir = TempStrDir::default();
let data_path = tempdir.as_ref();
create_test_vector_dataset(data_path, column)
.await
.expect("Failed to create test vector dataset");
println!("Generated test vector dataset at {}", data_path);
data_path
}

Check warning on line 128 in rust/examples/src/ivf_hnsw.rs

View workflow job for this annotation

GitHub Actions / format

Diff in /home/runner/work/lance/lance/rust/examples/src/ivf_hnsw.rs
};

let mut dataset = Dataset::open(uri)
.await
.expect("Failed to open dataset");
println!("Dataset schema: {:#?}", dataset.schema());

let column = args.column.as_deref().unwrap_or("vector");
let metric_type = MetricType::try_from(args.metric_type.as_str()).unwrap();

let mut ivf_params = IvfBuildParams::new(128);
ivf_params.sample_rate = 20480;
let hnsw_params = HnswBuildParams::default()
.ef_construction(100)
.num_edges(15);
.ef_construction(args.ef)
.max_level(args.max_level)
.num_edges(args.max_edges);
let pq_params = SQBuildParams::default();
let params =
VectorIndexParams::with_ivf_hnsw_sq_params(metric_type, ivf_params, hnsw_params, pq_params);
Expand Down
Loading