Skip to content

Commit 9a29e33

Browse files
crm26Jefffreyalamb
authored
feat: add inner_product scalar function (#21861)
## Which issue does this PR close? Part of #21536 — split of #21371 into one-function-per-PR. ## Rationale for this change Adds `inner_product(array1, array2)` — the dot product of two equal-length numeric arrays, returning `Float64`. Computed as `sum(array1[i] * array2[i])`. ## What changes are included in this PR? Mirrors the structural pattern of merged #21542 (`cosine_distance`): - Same `coerce_types` for `List`/`LargeList`/`FixedSizeList` of any numeric inner type, with widening to `LargeList` when any input is `LargeList` (per the #21704 pattern) - Same NULL semantics: bare `NULL` → `NULL`, NULL row → NULL, NULL element in list → NULL - Same Arrow-idiomatic implementation: single `as_float64_array(list_array.values())` downcast, slice by `value_offsets()`, iterate via `ScalarBuffer<f64>` - No alias, no shared module — standalone, inline math The arithmetic is the only semantic divergence from `cosine_distance`: - `dot += a*b` (no magnitude or normalization) - Empty arrays return `0.0` (sum of empty set), not `NULL` - No zero-magnitude special case (`inner_product([0,0], [1,2])` returns `0`, which is well-defined for inner product) ## Are these changes tested? Yes. SLT covers: - Orthogonal, identical, opposite, general non-trivial vectors - Single zero vector, both zero vectors - Bare `NULL` in either or both positions - NULL element inside a list (returns NULL for that row) - Mismatched lengths (error) - `LargeList` inputs - Mixed `(List, LargeList)` in both orders - `(FixedSizeList, FixedSizeList)` and `(FixedSizeList, LargeList)` - `Float32` and `Int64` inner type coercion - Multi-row query with NULL row propagation - Empty arrays (returns `0`) - No-args error - Return-type assertion (`Float64`) ## Are there any user-facing changes? New scalar function `inner_product`, documented in `docs/source/user-guide/sql/scalar_functions.md`. --------- Co-authored-by: Jeffrey Vo <jeffrey.vo.australia@gmail.com> Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 059929d commit 9a29e33

4 files changed

Lines changed: 440 additions & 0 deletions

File tree

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`ScalarUDFImpl`] definitions for inner_product function.
19+
20+
use crate::utils::make_scalar_function;
21+
use arrow::array::{Array, ArrayRef, Float64Array, OffsetSizeTrait};
22+
use arrow::datatypes::{
23+
DataType,
24+
DataType::{FixedSizeList, LargeList, List, Null},
25+
Field,
26+
};
27+
use datafusion_common::cast::{as_float64_array, as_generic_list_array};
28+
use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
29+
use datafusion_common::{
30+
Result, exec_err, internal_err, plan_err, utils::take_function_args,
31+
};
32+
use datafusion_expr::{
33+
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34+
Volatility,
35+
};
36+
use datafusion_macros::user_doc;
37+
use std::sync::Arc;
38+
39+
make_udf_expr_and_func!(
40+
InnerProduct,
41+
inner_product,
42+
array1 array2,
43+
"returns the inner product (dot product) of two numeric arrays.",
44+
inner_product_udf
45+
);
46+
47+
#[user_doc(
48+
doc_section(label = "Array Functions"),
49+
description = "Returns the inner product (dot product) of two input arrays of equal length, computed as `sum(array1[i] * array2[i])`. Returns NULL if either array is NULL or contains NULL elements. Returns 0.0 for two empty arrays.",
50+
syntax_example = "inner_product(array1, array2)",
51+
sql_example = r#"```sql
52+
> select inner_product([1.0, 2.0, 3.0], [4.0, 5.0, 6.0]);
53+
+-------------------------------------------------------+
54+
| inner_product(List([1.0,2.0,3.0]),List([4.0,5.0,6.0])) |
55+
+-------------------------------------------------------+
56+
| 32.0 |
57+
+-------------------------------------------------------+
58+
```"#,
59+
argument(
60+
name = "array1",
61+
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62+
),
63+
argument(
64+
name = "array2",
65+
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
66+
)
67+
)]
68+
#[derive(Debug, PartialEq, Eq, Hash)]
69+
pub struct InnerProduct {
70+
signature: Signature,
71+
aliases: Vec<String>,
72+
}
73+
74+
impl Default for InnerProduct {
75+
fn default() -> Self {
76+
Self::new()
77+
}
78+
}
79+
80+
impl InnerProduct {
81+
pub fn new() -> Self {
82+
Self {
83+
signature: Signature::user_defined(Volatility::Immutable),
84+
aliases: vec!["dot_product".to_string()],
85+
}
86+
}
87+
}
88+
89+
impl ScalarUDFImpl for InnerProduct {
90+
fn name(&self) -> &str {
91+
"inner_product"
92+
}
93+
94+
fn signature(&self) -> &Signature {
95+
&self.signature
96+
}
97+
98+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
99+
Ok(DataType::Float64)
100+
}
101+
102+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
103+
let [_, _] = take_function_args(self.name(), arg_types)?;
104+
let coercion = Some(&ListCoercion::FixedSizedListToList);
105+
106+
for arg_type in arg_types {
107+
if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
108+
return plan_err!("{} does not support type {arg_type}", self.name());
109+
}
110+
}
111+
112+
// If any input is `LargeList`, both sides must be widened to `LargeList`
113+
// so the runtime dispatch in `inner_product_inner` sees a homogeneous
114+
// pair. Follows the pattern in `ArrayConcat::coerce_types`.
115+
let any_large_list = arg_types.iter().any(|t| matches!(t, LargeList(_)));
116+
117+
let coerced = arg_types
118+
.iter()
119+
.map(|arg_type| {
120+
if matches!(arg_type, Null) {
121+
let field = Arc::new(Field::new_list_field(DataType::Float64, true));
122+
return if any_large_list {
123+
LargeList(field)
124+
} else {
125+
List(field)
126+
};
127+
}
128+
let coerced = coerced_type_with_base_type_only(
129+
arg_type,
130+
&DataType::Float64,
131+
coercion,
132+
);
133+
match coerced {
134+
List(field) if any_large_list => LargeList(field),
135+
other => other,
136+
}
137+
})
138+
.collect();
139+
140+
Ok(coerced)
141+
}
142+
143+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
144+
make_scalar_function(inner_product_inner)(&args.args)
145+
}
146+
147+
fn aliases(&self) -> &[String] {
148+
&self.aliases
149+
}
150+
151+
fn documentation(&self) -> Option<&Documentation> {
152+
self.doc()
153+
}
154+
}
155+
156+
fn inner_product_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
157+
let [array1, array2] = take_function_args("inner_product", args)?;
158+
match (array1.data_type(), array2.data_type()) {
159+
(List(_), List(_)) => general_inner_product::<i32>(args),
160+
(LargeList(_), LargeList(_)) => general_inner_product::<i64>(args),
161+
(arg_type1, arg_type2) => internal_err!(
162+
"inner_product received unexpected types after coercion: {arg_type1} and {arg_type2}"
163+
),
164+
}
165+
}
166+
167+
fn general_inner_product<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
168+
let list_array1 = as_generic_list_array::<O>(&arrays[0])?;
169+
let list_array2 = as_generic_list_array::<O>(&arrays[1])?;
170+
171+
let values1 = as_float64_array(list_array1.values())?;
172+
let values2 = as_float64_array(list_array2.values())?;
173+
let offsets1 = list_array1.value_offsets();
174+
let offsets2 = list_array2.value_offsets();
175+
176+
let mut builder = Float64Array::builder(list_array1.len());
177+
for row in 0..list_array1.len() {
178+
if list_array1.is_null(row) || list_array2.is_null(row) {
179+
builder.append_null();
180+
continue;
181+
}
182+
183+
let start1 = offsets1[row].as_usize();
184+
let end1 = offsets1[row + 1].as_usize();
185+
let start2 = offsets2[row].as_usize();
186+
let end2 = offsets2[row + 1].as_usize();
187+
let len1 = end1 - start1;
188+
let len2 = end2 - start2;
189+
190+
if len1 != len2 {
191+
return exec_err!(
192+
"inner_product requires both list inputs to have the same length, got {len1} and {len2}"
193+
);
194+
}
195+
196+
let slice1 = values1.slice(start1, len1);
197+
let slice2 = values2.slice(start2, len2);
198+
if slice1.null_count() != 0 || slice2.null_count() != 0 {
199+
builder.append_null();
200+
continue;
201+
}
202+
203+
let vals1 = slice1.values();
204+
let vals2 = slice2.values();
205+
206+
let mut dot = 0.0;
207+
for i in 0..len1 {
208+
dot += vals1[i] * vals2[i];
209+
}
210+
builder.append_value(dot);
211+
}
212+
213+
Ok(Arc::new(builder.finish()) as ArrayRef)
214+
}

datafusion/functions-nested/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ pub mod except;
5555
pub mod expr_ext;
5656
pub mod extract;
5757
pub mod flatten;
58+
pub mod inner_product;
5859
pub mod length;
5960
pub mod make_array;
6061
pub mod map;
@@ -107,6 +108,7 @@ pub mod expr_fn {
107108
pub use super::extract::array_pop_front;
108109
pub use super::extract::array_slice;
109110
pub use super::flatten::flatten;
111+
pub use super::inner_product::inner_product;
110112
pub use super::length::array_length;
111113
pub use super::make_array::make_array;
112114
pub use super::map_entries::map_entries;
@@ -163,6 +165,7 @@ pub fn all_default_nested_functions() -> Vec<Arc<ScalarUDF>> {
163165
empty::array_empty_udf(),
164166
length::array_length_udf(),
165167
cosine_distance::cosine_distance_udf(),
168+
inner_product::inner_product_udf(),
166169
distance::array_distance_udf(),
167170
flatten::flatten_udf(),
168171
min_max::array_max_udf(),

0 commit comments

Comments
 (0)