Skip to content

Commit 059929d

Browse files
authored
feat: Improve InListExpr types, flatten dict haystacks and validate in try_new_from_array (#21402)
## Which issue does this PR close? - Closes #20969. ## Rationale for this change To make needle/haystack more clear, add comments and with flattened we can use optimizations. ## What changes are included in this PR? - flatten dictionary haystacks - add detailed comments for needle/haystack and dictionary handling - type validation to `try_new_from_array` ## Are these changes tested? Yes. I've added new tests and existing pass test. ## Are there any user-facing changes? Yes. `try_new_from_array` now expects data types to be logically equal and it requires an additional parameter with this. Call sites in downstream will be affected because of this
1 parent dcd364a commit 059929d

4 files changed

Lines changed: 274 additions & 36 deletions

File tree

datafusion/physical-expr/src/expressions/in_list.rs

Lines changed: 252 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,23 @@ fn try_evaluate_constant_list(
123123
}
124124
}
125125

126+
/// Asserts that the InList expression's data type matches a list element's
127+
/// data type. `DataType::Null` list elements are accepted unconditionally so
128+
/// that null literals and `NullArray` haystacks remain compatible with any
129+
/// expression type.
130+
fn assert_inlist_data_types_match(
131+
expr_data_type: &DataType,
132+
list_data_type: &DataType,
133+
) -> Result<()> {
134+
if *list_data_type != DataType::Null {
135+
assert_or_internal_err!(
136+
DFSchema::datatype_is_logically_equal(expr_data_type, list_data_type),
137+
"The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_data_type}"
138+
);
139+
}
140+
Ok(())
141+
}
142+
126143
impl InListExpr {
127144
/// Create a new InList expression
128145
fn new(
@@ -164,20 +181,28 @@ impl InListExpr {
164181

165182
/// Create a new InList expression directly from an array, bypassing expression evaluation.
166183
///
167-
/// This is more efficient than `in_list()` when you already have the list as an array,
168-
/// as it avoids the conversion: `ArrayRef -> Vec<PhysicalExpr> -> ArrayRef -> StaticFilter`.
169-
/// Instead it goes directly: `ArrayRef -> StaticFilter`.
184+
/// This is more efficient than [`InListExpr::try_new`] when you already have the list
185+
/// as an array, as it builds the static filter directly from the array instead of
186+
/// reconstructing an intermediate array from literal expressions.
187+
///
188+
/// The `list` field is populated with literal expressions extracted from
189+
/// the array, and the array is used to build a static filter for
190+
/// efficient set membership evaluation.
170191
///
171-
/// The `list` field will be empty when using this constructor, as the array is stored
172-
/// directly in the static filter.
192+
/// The `array` may be dictionary-encoded — it will be flattened to its
193+
/// value type such that specialized filters are used.
173194
///
174-
/// This does not make the expression any more performant at runtime, but it does make it slightly
175-
/// cheaper to build.
195+
/// Returns an error if the expression's data type and the array's data type
196+
/// are not logically equal. Null arrays are always accepted.
176197
pub fn try_new_from_array(
177198
expr: Arc<dyn PhysicalExpr>,
178199
array: ArrayRef,
179200
negated: bool,
201+
schema: &Schema,
180202
) -> Result<Self> {
203+
let expr_data_type = expr.data_type(schema)?;
204+
assert_inlist_data_types_match(&expr_data_type, array.data_type())?;
205+
181206
let list = (0..array.len())
182207
.map(|i| {
183208
let scalar = ScalarValue::try_from_array(array.as_ref(), i)?;
@@ -210,13 +235,7 @@ impl InListExpr {
210235
let expr_data_type = expr.data_type(schema)?;
211236
for list_expr in list.iter() {
212237
let list_expr_data_type = list_expr.data_type(schema)?;
213-
assert_or_internal_err!(
214-
DFSchema::datatype_is_logically_equal(
215-
&expr_data_type,
216-
&list_expr_data_type
217-
),
218-
"The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}"
219-
);
238+
assert_inlist_data_types_match(&expr_data_type, &list_expr_data_type)?;
220239
}
221240

222241
// Try to create a static filter if all list expressions are constants
@@ -1835,6 +1854,7 @@ mod tests {
18351854
Arc::clone(&col_a),
18361855
array,
18371856
false,
1857+
&schema,
18381858
)?) as Arc<dyn PhysicalExpr>;
18391859

18401860
// Create test data: [1, 2, 3, 4, null]
@@ -1964,6 +1984,7 @@ mod tests {
19641984
Arc::clone(&col_a),
19651985
null_array,
19661986
false,
1987+
&schema,
19671988
)?) as Arc<dyn PhysicalExpr>;
19681989
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
19691990
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
@@ -1992,6 +2013,7 @@ mod tests {
19922013
Arc::clone(&col_a),
19932014
null_array,
19942015
false,
2016+
&schema,
19952017
)?) as Arc<dyn PhysicalExpr>;
19962018

19972019
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
@@ -3428,8 +3450,9 @@ mod tests {
34283450
let schema =
34293451
Schema::new(vec![Field::new("a", needle.data_type().clone(), false)]);
34303452
let col_a = col("a", &schema)?;
3431-
let expr = Arc::new(InListExpr::try_new_from_array(col_a, in_array, false)?)
3432-
as Arc<dyn PhysicalExpr>;
3453+
let expr = Arc::new(InListExpr::try_new_from_array(
3454+
col_a, in_array, false, &schema,
3455+
)?) as Arc<dyn PhysicalExpr>;
34333456
let batch = RecordBatch::try_new(Arc::new(schema), vec![needle])?;
34343457
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
34353458
Ok(as_boolean_array(&result).clone())
@@ -3562,41 +3585,237 @@ mod tests {
35623585
Ok(())
35633586
}
35643587

3588+
fn make_int32_dict_array(values: Vec<Option<i32>>) -> ArrayRef {
3589+
let mut builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
3590+
for v in values {
3591+
match v {
3592+
Some(val) => builder.append_value(val),
3593+
None => builder.append_null(),
3594+
}
3595+
}
3596+
Arc::new(builder.finish())
3597+
}
3598+
3599+
fn make_f64_dict_array(values: Vec<Option<f64>>) -> ArrayRef {
3600+
let mut builder = PrimitiveDictionaryBuilder::<Int8Type, Float64Type>::new();
3601+
for v in values {
3602+
match v {
3603+
Some(val) => builder.append_value(val),
3604+
None => builder.append_null(),
3605+
}
3606+
}
3607+
Arc::new(builder.finish())
3608+
}
3609+
3610+
#[test]
3611+
fn test_try_new_from_array_dict_haystack_int32() -> Result<()> {
3612+
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
3613+
let needle = Int32Array::from(vec![1, 2, 3, 4]);
3614+
let batch =
3615+
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(needle)])?;
3616+
3617+
let haystack = make_int32_dict_array(vec![Some(1), None, Some(3)]);
3618+
3619+
let col_a = col("a", &schema)?;
3620+
let expr = InListExpr::try_new_from_array(col_a, haystack, false, &schema)?;
3621+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3622+
let result = as_boolean_array(&result);
3623+
assert_eq!(
3624+
result,
3625+
&BooleanArray::from(vec![Some(true), None, Some(true), None])
3626+
);
3627+
3628+
Ok(())
3629+
}
3630+
35653631
#[test]
35663632
fn test_in_list_from_array_type_mismatch_errors() -> Result<()> {
3567-
// Utf8 needle, Dict(Utf8) in_array
3568-
let err = eval_in_list_from_array(
3569-
Arc::new(StringArray::from(vec!["a", "d", "b"])),
3570-
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))),
3571-
)
3572-
.unwrap_err()
3573-
.to_string();
3574-
assert!(
3575-
err.contains("Can't compare arrays of different types"),
3576-
"{err}"
3633+
// Utf8 needle, Dict(Utf8) in_array: now works with dict haystack support
3634+
assert_eq!(
3635+
BooleanArray::from(vec![Some(true), Some(false), Some(true)]),
3636+
eval_in_list_from_array(
3637+
Arc::new(StringArray::from(vec!["a", "d", "b"])),
3638+
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))),
3639+
)?
35773640
);
35783641

3579-
// Dict(Utf8) needle, Int64 in_array: specialized Int64StaticFilter
3580-
// rejects the Utf8 dictionary values at construction time
3642+
// Dict(Utf8) needle, Int64 in_array: type validation rejects at construction
35813643
let err = eval_in_list_from_array(
35823644
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "d", "b"]))),
35833645
Arc::new(Int64Array::from(vec![1, 2, 3])),
35843646
)
35853647
.unwrap_err()
35863648
.to_string();
3587-
assert!(err.contains("Failed to downcast"), "{err}");
3649+
assert!(err.contains("The data type inlist should be same"), "{err}");
35883650

35893651
// Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different
3590-
// value types, make_comparator rejects the comparison
3652+
// value types, type validation rejects at construction
35913653
let err = eval_in_list_from_array(
35923654
wrap_in_dict(Arc::new(Int64Array::from(vec![1, 4, 2]))),
35933655
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))),
35943656
)
35953657
.unwrap_err()
35963658
.to_string();
3597-
assert!(
3598-
err.contains("Can't compare arrays of different types"),
3599-
"{err}"
3659+
assert!(err.contains("The data type inlist should be same"), "{err}");
3660+
3661+
Ok(())
3662+
}
3663+
3664+
#[test]
3665+
fn test_try_new_from_array_dict_haystack_negated() -> Result<()> {
3666+
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
3667+
let needle = Int32Array::from(vec![1, 2, 3, 4]);
3668+
let batch =
3669+
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(needle)])?;
3670+
3671+
let haystack = make_int32_dict_array(vec![Some(1), None, Some(3)]);
3672+
3673+
let col_a = col("a", &schema)?;
3674+
let expr = InListExpr::try_new_from_array(col_a, haystack, true, &schema)?;
3675+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3676+
let result = as_boolean_array(&result);
3677+
assert_eq!(
3678+
result,
3679+
&BooleanArray::from(vec![Some(false), None, Some(false), None])
3680+
);
3681+
3682+
Ok(())
3683+
}
3684+
3685+
#[test]
3686+
fn test_try_new_from_array_dict_haystack_utf8() -> Result<()> {
3687+
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
3688+
let needle = StringArray::from(vec!["a", "b", "c"]);
3689+
let batch =
3690+
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(needle)])?;
3691+
3692+
let dict_builder = StringDictionaryBuilder::<Int8Type>::new();
3693+
let mut builder = dict_builder;
3694+
builder.append_value("a");
3695+
builder.append_value("c");
3696+
let haystack: ArrayRef = Arc::new(builder.finish());
3697+
3698+
let col_a = col("a", &schema)?;
3699+
let expr = InListExpr::try_new_from_array(col_a, haystack, false, &schema)?;
3700+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3701+
let result = as_boolean_array(&result);
3702+
assert_eq!(
3703+
result,
3704+
&BooleanArray::from(vec![Some(true), Some(false), Some(true)])
3705+
);
3706+
3707+
Ok(())
3708+
}
3709+
3710+
#[test]
3711+
fn test_try_new_from_array_dict_needle_and_plain_haystack() -> Result<()> {
3712+
let schema = Schema::new(vec![Field::new(
3713+
"a",
3714+
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
3715+
false,
3716+
)]);
3717+
3718+
let needle = make_int32_dict_array(vec![Some(1), Some(2), Some(3), Some(4)]);
3719+
let batch =
3720+
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::clone(&needle)])?;
3721+
3722+
let haystack: ArrayRef = Arc::new(Int32Array::from(vec![1, 3]));
3723+
let col_a = col("a", &schema)?;
3724+
let expr = InListExpr::try_new_from_array(col_a, haystack, false, &schema)?;
3725+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3726+
let result = as_boolean_array(&result);
3727+
assert_eq!(
3728+
result,
3729+
&BooleanArray::from(vec![Some(true), Some(false), Some(true), Some(false)])
3730+
);
3731+
3732+
Ok(())
3733+
}
3734+
3735+
#[test]
3736+
fn test_try_new_from_array_dict_haystack_float64() -> Result<()> {
3737+
let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
3738+
let needle = Float64Array::from(vec![1.0, 2.0, 3.0]);
3739+
let batch =
3740+
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(needle)])?;
3741+
3742+
let haystack = make_f64_dict_array(vec![Some(1.0), Some(3.0)]);
3743+
3744+
let col_a = col("a", &schema)?;
3745+
let expr = InListExpr::try_new_from_array(col_a, haystack, false, &schema)?;
3746+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3747+
let result = as_boolean_array(&result);
3748+
assert_eq!(
3749+
result,
3750+
&BooleanArray::from(vec![Some(true), Some(false), Some(true)])
3751+
);
3752+
3753+
Ok(())
3754+
}
3755+
3756+
#[test]
3757+
fn test_try_new_from_array_type_mismatch_rejects() -> Result<()> {
3758+
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
3759+
let col_a = col("a", &schema)?;
3760+
let haystack: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0]));
3761+
3762+
let result = InListExpr::try_new_from_array(col_a, haystack, false, &schema);
3763+
assert!(result.is_err());
3764+
Ok(())
3765+
}
3766+
3767+
#[test]
3768+
fn test_try_new_from_array_struct_haystack() -> Result<()> {
3769+
let struct_fields = Fields::from(vec![
3770+
Field::new("x", DataType::Int32, false),
3771+
Field::new("y", DataType::Utf8, false),
3772+
]);
3773+
let struct_dt = DataType::Struct(struct_fields.clone());
3774+
let schema = Schema::new(vec![Field::new("a", struct_dt, true)]);
3775+
3776+
// Needle: [{1,"a"}, {2,"b"}, NULL, {4,"d"}]
3777+
let needle = Arc::new(StructArray::new(
3778+
struct_fields.clone(),
3779+
vec![
3780+
Arc::new(Int32Array::from(vec![1, 2, 3, 4])),
3781+
Arc::new(StringArray::from(vec!["a", "b", "c", "d"])),
3782+
],
3783+
Some(vec![true, true, false, true].into()),
3784+
));
3785+
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![needle])?;
3786+
3787+
// Haystack: [{1,"a"}, {4,"d"}]
3788+
let haystack: ArrayRef = Arc::new(StructArray::new(
3789+
struct_fields,
3790+
vec![
3791+
Arc::new(Int32Array::from(vec![1, 4])),
3792+
Arc::new(StringArray::from(vec!["a", "d"])),
3793+
],
3794+
None,
3795+
));
3796+
3797+
let col_a = col("a", &schema)?;
3798+
let expr = InListExpr::try_new_from_array(
3799+
Arc::clone(&col_a),
3800+
Arc::clone(&haystack),
3801+
false,
3802+
&schema,
3803+
)?;
3804+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3805+
let result = as_boolean_array(&result);
3806+
// {1,"a"} -> true, {2,"b"} -> false, NULL -> NULL, {4,"d"} -> true
3807+
assert_eq!(
3808+
result,
3809+
&BooleanArray::from(vec![Some(true), Some(false), None, Some(true)])
3810+
);
3811+
3812+
// Negated path
3813+
let expr = InListExpr::try_new_from_array(col_a, haystack, true, &schema)?;
3814+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3815+
let result = as_boolean_array(&result);
3816+
assert_eq!(
3817+
result,
3818+
&BooleanArray::from(vec![Some(false), Some(true), None, Some(false)])
36003819
);
36013820

36023821
Ok(())

datafusion/physical-expr/src/expressions/in_list/static_filter.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,20 @@
1818
use arrow::array::{Array, BooleanArray};
1919
use datafusion_common::Result;
2020

21-
/// Trait for InList static filters
21+
/// Trait for InList static filters.
22+
///
23+
/// Static filters store a pre-computed set of values (the haystack) and check
24+
/// whether needle values are contained in that set. The haystack is always
25+
/// represented in its non-dictionary (value) type. Dictionary haystacks are
26+
/// flattened via `cast()` before construction.
27+
///
28+
/// Dictionary-encoded needles are unwrapped inside `contains()` and
29+
/// evaluated against the dictionary's values.
2230
pub(super) trait StaticFilter {
2331
fn null_count(&self) -> usize;
2432

25-
/// Checks if values in `v` are contained in the filter
33+
/// Checks if values in `v` (needle) are contained in this filter's
34+
/// haystack. `v` may be dictionary-encoded, in which case the
35+
/// implementation unwraps the dictionary and operates on its values.
2636
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray>;
2737
}

0 commit comments

Comments
 (0)