Skip to content

Commit e514a01

Browse files
lyne7-scblaginin
andauthored
perf: optimize retract_batch for median and percentile_cont (#21894)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #. ## Rationale for this change For sliding window aggregation, `retract_batch` removes outgoing rows from the aggregate state on every window slide. `median` and `percentile_cont` store primitive numeric values internally, but their retract paths converted values through `ScalarValue` before matching them. This PR keeps retract matching on native Arrow values, reducing conversion and hashing overhead in that hot path. <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? - Optimize `median` and `percentile_cont` `retract_batch` using `Hashable<T::Native>` keys. - Add sliding-window benchmarks for `median` and `percentile_cont` with window sizes `256`, `4096`, and `16384`. ### Benchmarks ``` group main optimized ----- ---- --------- median sliding_window f64 no_nulls window_size=16384 2.38 3.3±0.06ms ? ?/sec 1.00 1396.6±36.31µs ? ?/sec median sliding_window f64 no_nulls window_size=256 2.73 781.3±20.80µs ? ?/sec 1.00 286.0±10.52µs ? ?/sec median sliding_window f64 no_nulls window_size=4096 2.11 1052.2±27.13µs ? ?/sec 1.00 499.3±19.44µs ? ?/sec median sliding_window f64 with_nulls window_size=16384 2.52 3.0±0.06ms ? ?/sec 1.00 1173.1±36.86µs ? ?/sec median sliding_window f64 with_nulls window_size=256 2.67 728.6±20.07µs ? ?/sec 1.00 272.8±12.90µs ? ?/sec median sliding_window f64 with_nulls window_size=4096 2.11 954.8±27.37µs ? ?/sec 1.00 452.6±13.08µs ? ?/sec percentile_cont sliding_window f64 no_nulls window_size=16384 3.86 10.7±0.24ms ? ?/sec 1.00 2.8±0.05ms ? ?/sec percentile_cont sliding_window f64 no_nulls window_size=256 2.49 797.8±25.51µs ? ?/sec 1.00 320.1±58.86µs ? ?/sec percentile_cont sliding_window f64 no_nulls window_size=4096 3.44 3.2±0.12ms ? ?/sec 1.00 928.2±42.15µs ? ?/sec percentile_cont sliding_window f64 with_nulls window_size=16384 3.72 6.7±0.90ms ? ?/sec 1.00 1790.9±22.20µs ? ?/sec percentile_cont sliding_window f64 with_nulls window_size=256 2.51 721.0±25.52µs ? ?/sec 1.00 286.7±30.34µs ? ?/sec percentile_cont sliding_window f64 with_nulls window_size=4096 3.34 2.2±0.14ms ? ?/sec 1.00 667.1±20.87µs ? ?/sec ``` <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? Yes. existed slt passed. <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? No. <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Dmitrii Blaginin <dmitrii@blaginin.me>
1 parent f3cebc5 commit e514a01

5 files changed

Lines changed: 271 additions & 22 deletions

File tree

datafusion/functions-aggregate/Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,11 @@ harness = false
8787
[[bench]]
8888
name = "count_distinct"
8989
harness = false
90+
91+
[[bench]]
92+
name = "median"
93+
harness = false
94+
95+
[[bench]]
96+
name = "percentile_cont"
97+
harness = false
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::hint::black_box;
19+
use std::sync::Arc;
20+
21+
use arrow::array::{ArrayRef, Float64Array};
22+
use arrow::datatypes::{DataType, Field, Schema};
23+
use criterion::{BatchSize, Criterion, criterion_group, criterion_main};
24+
use datafusion_expr::function::AccumulatorArgs;
25+
use datafusion_expr::{Accumulator, AggregateUDFImpl};
26+
use datafusion_functions_aggregate::median::Median;
27+
use datafusion_physical_expr::expressions::col;
28+
29+
const STEP_SIZE: usize = 128;
30+
const SLIDES_PER_ITER: usize = 32;
31+
const WINDOW_SIZES: [usize; 3] = [256, 4096, 16384];
32+
33+
fn prepare_accumulator() -> Box<dyn Accumulator> {
34+
let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Float64, true)]));
35+
let expr = col("f", &schema).unwrap();
36+
let accumulator_args = AccumulatorArgs {
37+
return_field: Field::new("f", DataType::Float64, true).into(),
38+
schema: &schema,
39+
expr_fields: &[expr.return_field(&schema).unwrap()],
40+
ignore_nulls: false,
41+
order_bys: &[],
42+
is_reversed: false,
43+
name: "median(f)",
44+
is_distinct: false,
45+
exprs: &[expr],
46+
};
47+
Median::new().accumulator(accumulator_args).unwrap()
48+
}
49+
50+
fn stream_array(len: usize, null_stride: Option<usize>) -> ArrayRef {
51+
let values = (0..len)
52+
.map(|idx| {
53+
if null_stride.is_some_and(|stride| idx % stride == 0) {
54+
None
55+
} else {
56+
Some(idx as f64)
57+
}
58+
})
59+
.collect::<Vec<_>>();
60+
Arc::new(Float64Array::from(values)) as ArrayRef
61+
}
62+
63+
/// Benchmark the sliding window cycle: retract + update + evaluate
64+
fn sliding_window_bench(
65+
c: &mut Criterion,
66+
name: &str,
67+
window_size: usize,
68+
stream: &ArrayRef,
69+
) {
70+
c.bench_function(name, |b| {
71+
b.iter_batched(
72+
|| {
73+
let mut accumulator = prepare_accumulator();
74+
let initial = stream.slice(0, window_size);
75+
accumulator
76+
.update_batch(std::slice::from_ref(&initial))
77+
.unwrap();
78+
accumulator
79+
},
80+
|mut accumulator| {
81+
for slide in 0..SLIDES_PER_ITER {
82+
let offset = slide * STEP_SIZE;
83+
let retract = stream.slice(offset, STEP_SIZE);
84+
let update = stream.slice(offset + window_size, STEP_SIZE);
85+
accumulator
86+
.retract_batch(std::slice::from_ref(&retract))
87+
.unwrap();
88+
accumulator
89+
.update_batch(std::slice::from_ref(&update))
90+
.unwrap();
91+
black_box(accumulator.evaluate().unwrap());
92+
}
93+
},
94+
BatchSize::SmallInput,
95+
)
96+
});
97+
}
98+
99+
fn median_benchmark(c: &mut Criterion) {
100+
for window_size in WINDOW_SIZES {
101+
let stream_len = window_size + STEP_SIZE * SLIDES_PER_ITER;
102+
let stream_no_nulls = stream_array(stream_len, None);
103+
let stream_with_nulls = stream_array(stream_len, Some(10));
104+
105+
sliding_window_bench(
106+
c,
107+
&format!("median sliding_window f64 no_nulls window_size={window_size}"),
108+
window_size,
109+
&stream_no_nulls,
110+
);
111+
112+
sliding_window_bench(
113+
c,
114+
&format!("median sliding_window f64 with_nulls window_size={window_size}"),
115+
window_size,
116+
&stream_with_nulls,
117+
);
118+
}
119+
}
120+
121+
criterion_group!(benches, median_benchmark);
122+
criterion_main!(benches);
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::hint::black_box;
19+
use std::sync::Arc;
20+
21+
use arrow::array::{ArrayRef, Float64Array};
22+
use arrow::datatypes::{DataType, Field, Schema};
23+
use criterion::{BatchSize, Criterion, criterion_group, criterion_main};
24+
use datafusion_expr::function::AccumulatorArgs;
25+
use datafusion_expr::{Accumulator, AggregateUDFImpl};
26+
use datafusion_functions_aggregate::percentile_cont::PercentileCont;
27+
use datafusion_physical_expr::expressions::{col, lit};
28+
29+
const STEP_SIZE: usize = 128;
30+
const SLIDES_PER_ITER: usize = 32;
31+
const WINDOW_SIZES: [usize; 3] = [256, 4096, 16384];
32+
33+
fn prepare_accumulator() -> Box<dyn Accumulator> {
34+
let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Float64, true)]));
35+
let value_expr = col("f", &schema).unwrap();
36+
let percentile_expr = lit(0.5_f64);
37+
let value_field = value_expr.return_field(&schema).unwrap();
38+
let percentile_field = percentile_expr.return_field(&schema).unwrap();
39+
let accumulator_args = AccumulatorArgs {
40+
return_field: Field::new("f", DataType::Float64, true).into(),
41+
schema: &schema,
42+
expr_fields: &[value_field, percentile_field],
43+
ignore_nulls: false,
44+
order_bys: &[],
45+
is_reversed: false,
46+
name: "percentile_cont(f, 0.5)",
47+
is_distinct: false,
48+
exprs: &[value_expr, percentile_expr],
49+
};
50+
PercentileCont::new().accumulator(accumulator_args).unwrap()
51+
}
52+
53+
fn stream_array(len: usize, null_stride: Option<usize>) -> ArrayRef {
54+
let values = (0..len)
55+
.map(|idx| {
56+
if null_stride.is_some_and(|stride| idx % stride == 0) {
57+
None
58+
} else {
59+
Some(idx as f64)
60+
}
61+
})
62+
.collect::<Vec<_>>();
63+
Arc::new(Float64Array::from(values)) as ArrayRef
64+
}
65+
66+
/// Benchmark the sliding window cycle: retract + update + evaluate
67+
fn sliding_window_bench(
68+
c: &mut Criterion,
69+
name: &str,
70+
window_size: usize,
71+
stream: &ArrayRef,
72+
) {
73+
c.bench_function(name, |b| {
74+
b.iter_batched(
75+
|| {
76+
let mut accumulator = prepare_accumulator();
77+
let initial = stream.slice(0, window_size);
78+
accumulator
79+
.update_batch(std::slice::from_ref(&initial))
80+
.unwrap();
81+
accumulator
82+
},
83+
|mut accumulator| {
84+
for slide in 0..SLIDES_PER_ITER {
85+
let offset = slide * STEP_SIZE;
86+
let retract = stream.slice(offset, STEP_SIZE);
87+
let update = stream.slice(offset + window_size, STEP_SIZE);
88+
accumulator
89+
.retract_batch(std::slice::from_ref(&retract))
90+
.unwrap();
91+
accumulator
92+
.update_batch(std::slice::from_ref(&update))
93+
.unwrap();
94+
black_box(accumulator.evaluate().unwrap());
95+
}
96+
},
97+
BatchSize::SmallInput,
98+
)
99+
});
100+
}
101+
102+
fn percentile_cont_benchmark(c: &mut Criterion) {
103+
for window_size in WINDOW_SIZES {
104+
let stream_len = window_size + STEP_SIZE * SLIDES_PER_ITER;
105+
let stream_no_nulls = stream_array(stream_len, None);
106+
let stream_with_nulls = stream_array(stream_len, Some(10));
107+
108+
sliding_window_bench(
109+
c,
110+
&format!(
111+
"percentile_cont sliding_window f64 no_nulls window_size={window_size}"
112+
),
113+
window_size,
114+
&stream_no_nulls,
115+
);
116+
117+
sliding_window_bench(
118+
c,
119+
&format!(
120+
"percentile_cont sliding_window f64 with_nulls window_size={window_size}"
121+
),
122+
window_size,
123+
&stream_with_nulls,
124+
);
125+
}
126+
}
127+
128+
criterion_group!(benches, percentile_cont_benchmark);
129+
criterion_main!(benches);

datafusion/functions-aggregate/src/median.rs

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ use datafusion_expr::{
5151
use datafusion_expr::{EmitTo, GroupsAccumulator};
5252
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
5353
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
54-
use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
54+
use datafusion_functions_aggregate_common::utils::{GenericDistinctBuffer, Hashable};
5555
use datafusion_macros::user_doc;
5656
use std::collections::HashMap;
5757

@@ -285,24 +285,17 @@ impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
285285
size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
286286
}
287287

288-
#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key
289288
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
290-
let mut to_remove: HashMap<ScalarValue, usize> = HashMap::new();
289+
let mut to_remove: HashMap<Hashable<T::Native>, usize> = HashMap::new();
291290

292-
let arr = &values[0];
293-
for i in 0..arr.len() {
294-
let v = ScalarValue::try_from_array(arr, i)?;
295-
if !v.is_null() {
296-
*to_remove.entry(v).or_default() += 1;
297-
}
291+
let arr = values[0].as_primitive::<T>();
292+
for value in arr.iter().flatten() {
293+
*to_remove.entry(Hashable(value)).or_default() += 1;
298294
}
299295

300296
let mut i = 0;
301297
while i < self.all_values.len() {
302-
let k = ScalarValue::new_primitive::<T>(
303-
Some(self.all_values[i]),
304-
&self.data_type,
305-
)?;
298+
let k = Hashable(self.all_values[i]);
306299
if let Some(count) = to_remove.get_mut(&k)
307300
&& *count > 0
308301
{

datafusion/functions-aggregate/src/percentile_cont.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -440,20 +440,17 @@ where
440440
size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
441441
}
442442

443-
#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key
444443
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
445-
let mut to_remove: HashMap<ScalarValue, usize> = HashMap::new();
446-
for i in 0..values[0].len() {
447-
let v = ScalarValue::try_from_array(&values[0], i)?;
448-
if !v.is_null() {
449-
*to_remove.entry(v).or_default() += 1;
450-
}
444+
let mut to_remove: HashMap<Hashable<T::Native>, usize> = HashMap::new();
445+
446+
let arr = values[0].as_primitive::<T>();
447+
for value in arr.iter().flatten() {
448+
*to_remove.entry(Hashable(value)).or_default() += 1;
451449
}
452450

453451
let mut i = 0;
454452
while i < self.all_values.len() {
455-
let k =
456-
ScalarValue::new_primitive::<T>(Some(self.all_values[i]), &T::DATA_TYPE)?;
453+
let k = Hashable(self.all_values[i]);
457454
if let Some(count) = to_remove.get_mut(&k)
458455
&& *count > 0
459456
{

0 commit comments

Comments
 (0)