Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4424,6 +4424,14 @@ impl Scanner {
} else {
input
};
let retain_vector = if self.is_batch_nearest {
let vector_field_id = self.dataset.schema().field_id(q.column.as_str())?;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work with nested vector field? can we add a test for this?

self.projection_plan
.physical_projection
.contains_field_id(vector_field_id)
} else {
false
};
let flat_dist = Arc::new(KNNVectorDistanceExec::try_new_batch(
input,
&q.column,
Expand All @@ -4435,6 +4443,7 @@ impl Scanner {
lower_bound: q.lower_bound,
upper_bound: q.upper_bound,
distance_type: metric_type,
retain_vector,
},
)?);

Expand Down Expand Up @@ -5871,6 +5880,60 @@ mod test {
(queries, query_values)
}

async fn nested_vector_test_dataset(dim: u32) -> (TempStrDir, Dataset) {
let path = TempStrDir::default();
let vec_field = ArrowField::new(
"vec",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
dim as i32,
),
true,
);
let payload_field = ArrowField::new(
"payload",
DataType::Struct(vec![vec_field.clone()].into()),
true,
);
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int32, true),
payload_field.clone(),
]));

let batches: Vec<RecordBatch> = (0..5)
.map(|batch_idx| {
let vector_values: Float32Array = (0..dim * 80).map(|v| v as f32).collect();
let vectors =
FixedSizeListArray::try_new_from_values(vector_values, dim as i32).unwrap();
let payload = StructArray::from(vec![(
Arc::new(vec_field.clone()),
Arc::new(vectors) as ArrayRef,
)]);
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(
batch_idx * 80..(batch_idx + 1) * 80,
)),
Arc::new(payload),
],
)
.unwrap()
})
.collect();

let params = WriteParams {
max_rows_per_group: 10,
max_rows_per_file: 200,
data_storage_version: Some(LanceFileVersion::Stable),
enable_stable_row_ids: true,
..Default::default()
};
let reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let dataset = Dataset::write(reader, &path, Some(params)).await.unwrap();
(path, dataset)
}

fn assert_query_index_field(batch: &RecordBatch) {
let schema = batch.schema();
let field = schema.field(0);
Expand All @@ -5879,6 +5942,14 @@ mod test {
assert!(!field.is_nullable());
}

fn assert_batch_knn_output_has_no_vector(batch: &RecordBatch, vector_column: &str) {
assert!(
batch.schema().column_with_name(vector_column).is_none(),
"batch flat KNN output must not include vector column '{vector_column}' when it is not projected; columns: {:?}",
batch.schema().field_names()
);
}

async fn assert_batch_matches_single_queries(
dataset: &Dataset,
batch: &RecordBatch,
Expand Down Expand Up @@ -5953,6 +6024,7 @@ mod test {

let batch = scan.try_into_batch().await.unwrap();
assert_query_index_field(&batch);
assert_batch_knn_output_has_no_vector(&batch, "vec");
assert_eq!(
batch.num_rows(),
2 * k,
Expand All @@ -5975,6 +6047,25 @@ mod test {
}
assert_batch_matches_single_queries(dataset, &batch, &query_values, k, false, None).await;

let mut scan_with_vec = dataset.scan();
scan_with_vec.nearest("vec", &queries, k).unwrap();
scan_with_vec.use_index(false);
scan_with_vec.project(&["i", "vec"]).unwrap();
let batch_with_vec = scan_with_vec.try_into_batch().await.unwrap();
assert!(
batch_with_vec.schema().column_with_name("vec").is_some(),
"batch flat KNN should return vector column when projected"
);
assert_batch_matches_single_queries(
dataset,
&batch_with_vec,
&query_values,
k,
false,
None,
)
.await;

let query_values_one = (32..64).map(|v| v as f32).collect::<Vec<_>>();
let queries_one = FixedSizeListArray::try_new_from_values(
Float32Array::from(query_values_one.clone()),
Expand All @@ -6000,12 +6091,129 @@ mod test {

let batch = scan.try_into_batch().await.unwrap();
assert_query_index_field(&batch);
assert_batch_knn_output_has_no_vector(&batch, "vec");
assert_eq!(
batch[QUERY_INDEX_COL].as_primitive::<Int32Type>().values(),
&[0, 0]
);
}

#[tokio::test]
async fn test_batch_knn_flat_omits_vector_without_projection() {
let test_ds = TestVectorDataset::new(LanceFileVersion::Stable, true)
.await
.unwrap();
let dataset = &test_ds.dataset;
let k = 2;
let (queries, query_values) = batch_knn_two_queries();

let mut scan = dataset.scan();
scan.nearest("vec", &queries, k).unwrap();
scan.use_index(false);
scan.project(&["i"]).unwrap();
let batch = scan.try_into_batch().await.unwrap();
assert_batch_knn_output_has_no_vector(&batch, "vec");
assert_query_index_field(&batch);
assert!(batch.schema().column_with_name("i").is_some());
assert!(batch.schema().column_with_name(DIST_COL).is_some());
assert_batch_matches_single_queries(dataset, &batch, &query_values, k, false, None).await;

let mut scan_rowid_only = dataset.scan();
scan_rowid_only.nearest("vec", &queries, k).unwrap();
scan_rowid_only.use_index(false);
scan_rowid_only.project(&[ROW_ID]).unwrap();
let batch_rowid_only = scan_rowid_only.try_into_batch().await.unwrap();
assert_batch_knn_output_has_no_vector(&batch_rowid_only, "vec");
assert!(batch_rowid_only.schema().column_with_name(ROW_ID).is_some());
assert!(batch_rowid_only.schema().column_with_name("i").is_none());

let mut scan_with_vec = dataset.scan();
scan_with_vec.nearest("vec", &queries, k).unwrap();
scan_with_vec.use_index(false);
scan_with_vec.project(&["vec"]).unwrap();
let batch_with_vec = scan_with_vec.try_into_batch().await.unwrap();
assert!(
batch_with_vec.schema().column_with_name("vec").is_some(),
"batch flat KNN must include vector column when vec is projected"
);
}

#[tokio::test]
async fn test_batch_knn_flat_filter_keeps_non_vector_columns() {
let test_ds = TestVectorDataset::new(LanceFileVersion::Stable, true)
.await
.unwrap();
let dataset = &test_ds.dataset;
let k = 2;
let (queries, query_values) = batch_knn_two_queries();

let mut scan = dataset.scan();
scan.nearest("vec", &queries, k).unwrap();
scan.use_index(false);
scan.filter("i >= 0").unwrap();
scan.project(&["i"]).unwrap();
let batch = scan.try_into_batch().await.unwrap();

assert_query_index_field(&batch);
assert_batch_knn_output_has_no_vector(&batch, "vec");
assert!(batch.schema().column_with_name("i").is_some());

let query_indices = batch[QUERY_INDEX_COL].as_primitive::<Int32Type>();
for query_index in 0..2 {
let query =
Float32Array::from(query_values[query_index * 32..(query_index + 1) * 32].to_vec());
let mut single = dataset.scan();
single.nearest("vec", &query, k).unwrap();
single.use_index(false);
single.filter("i >= 0").unwrap();
single.project(&["i"]).unwrap();
let single_batch = single.try_into_batch().await.unwrap();

let mask = BooleanArray::from_iter(
query_indices
.iter()
.map(|value| value.map(|value| value == query_index as i32)),
);
let batch_slice = arrow::compute::filter_record_batch(&batch, &mask).unwrap();
assert_eq!(
batch_slice["i"].as_primitive::<Int32Type>().values(),
single_batch["i"].as_primitive::<Int32Type>().values()
);
}
}

#[tokio::test]
async fn test_batch_knn_flat_nested_vector_projection() {
const VECTOR_COLUMN: &str = "payload.vec";
let (_tmp, dataset) = nested_vector_test_dataset(32).await;
let k = 2;
let (queries, _query_values) = batch_knn_two_queries();

let mut scan = dataset.scan();
scan.nearest(VECTOR_COLUMN, &queries, k).unwrap();
scan.use_index(false);
scan.project(&["i"]).unwrap();
let batch = scan.try_into_batch().await.unwrap();
assert_query_index_field(&batch);
assert_batch_knn_output_has_no_vector(&batch, VECTOR_COLUMN);
assert_eq!(batch.num_rows(), 2 * k);
assert!(batch.schema().column_with_name("i").is_some());

let mut scan_with_vec = dataset.scan();
scan_with_vec.nearest(VECTOR_COLUMN, &queries, k).unwrap();
scan_with_vec.use_index(false);
scan_with_vec.project(&[VECTOR_COLUMN]).unwrap();
let batch_with_vec = scan_with_vec.try_into_batch().await.unwrap();
assert!(
batch_with_vec
.schema()
.column_with_name(VECTOR_COLUMN)
.is_some(),
"batch flat KNN must include nested vector column when projected; columns: {:?}",
batch_with_vec.schema().field_names()
);
}

#[tokio::test]
async fn test_primitive_query_length_multiple_of_dim_is_rejected() {
let test_ds = TestVectorDataset::new(LanceFileVersion::Stable, true)
Expand Down
Loading
Loading