Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion datafusion/optimizer/src/propagate_empty_relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use datafusion_common::JoinType;
use datafusion_common::tree_node::Transformed;
use datafusion_common::{Column, DFSchemaRef, Result, ScalarValue, plan_err};
use datafusion_expr::logical_plan::LogicalPlan;
use datafusion_expr::{EmptyRelation, Expr, Projection, Union, cast, lit};
use datafusion_expr::{EmptyRelation, Expr, GroupingSet, Projection, Union, cast, lit};

use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
Expand Down Expand Up @@ -174,7 +174,13 @@ impl OptimizerRule for PropagateEmptyRelation {
}
}
LogicalPlan::Aggregate(ref agg) => {
// An aggregate over an empty input can be eliminated only when
// there is no empty grouping set. An empty grouping set `()`
// (from `GROUPING SETS(())`, `ROLLUP(...)`, or `CUBE(...)`)
// always produces exactly one row even on empty input, so it
// must not be replaced by an empty relation.
if !agg.group_expr.is_empty()
&& !has_empty_grouping_set(&agg.group_expr)
&& let Some(empty_plan) = empty_child(&plan)?
{
return Ok(Transformed::yes(empty_plan));
Expand Down Expand Up @@ -315,6 +321,30 @@ fn build_null_padded_projection(
)?))
}

/// Returns `true` if any grouping set in the list of GROUP BY expressions is
/// the empty set `()`.
///
/// An empty grouping set acts as a "grand total" group: the aggregate must
/// always produce **exactly one row** for it, even when the input is empty.
/// This means an aggregate with an empty grouping set cannot be replaced by
/// an empty relation.
///
/// The three forms that can contain an empty grouping set:
/// - `GROUPING SETS (…, (), …)` — explicitly listed.
/// - `ROLLUP(exprs)` — always expands to include `()`.
/// - `CUBE(exprs)` — always expands to include `()`.
fn has_empty_grouping_set(group_expr: &[Expr]) -> bool {
match group_expr.first() {
Some(Expr::GroupingSet(GroupingSet::GroupingSets(groups))) => {
groups.iter().any(|g| g.is_empty())
}
// Both ROLLUP and CUBE always include the empty grouping set ().
Some(Expr::GroupingSet(GroupingSet::Rollup(_)))
| Some(Expr::GroupingSet(GroupingSet::Cube(_))) => true,
_ => false,
}
}

#[cfg(test)]
mod tests {

Expand Down
9 changes: 4 additions & 5 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2063,11 +2063,11 @@ fn evaluate_optional(
/// The integer type is chosen to be the smallest `UInt8 / UInt16 / UInt32 /
/// UInt64` that can represent both parts. It matches the type returned by
/// [`Aggregate::grouping_id_type`].
fn group_id_array(
pub(crate) fn group_id_array(
group: &[bool],
ordinal: usize,
max_ordinal: usize,
batch: &RecordBatch,
num_rows: usize,
) -> Result<ArrayRef> {
let n = group.len();
if n > 64 {
Expand All @@ -2087,7 +2087,6 @@ fn group_id_array(
(acc << 1) | if is_null { 1 } else { 0 }
});
let full_id = semantic_id | ((ordinal as u64) << n);
let num_rows = batch.num_rows();
if total_bits <= 8 {
Ok(Arc::new(UInt8Array::from(vec![full_id as u8; num_rows])))
} else if total_bits <= 16 {
Expand All @@ -2106,7 +2105,7 @@ fn group_id_array(
/// ordinal 0, the second gets 1, and so on. If the same `Vec<bool>` appears
/// three times the ordinals are 0, 1, 2 and this function returns 2.
/// Returns 0 when no grouping set is duplicated.
fn max_duplicate_ordinal(groups: &[Vec<bool>]) -> usize {
pub(crate) fn max_duplicate_ordinal(groups: &[Vec<bool>]) -> usize {
let mut counts: HashMap<&[bool], usize> = HashMap::new();
for group in groups {
*counts.entry(group).or_insert(0) += 1;
Expand Down Expand Up @@ -2160,7 +2159,7 @@ pub fn evaluate_group_by(
group,
current_ordinal,
max_ordinal,
batch,
batch.num_rows(),
)?);
}
Ok(group_values)
Expand Down
104 changes: 103 additions & 1 deletion datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ use crate::aggregates::group_values::{GroupByMetrics, GroupValues, new_group_val
use crate::aggregates::order::GroupOrderingFull;
use crate::aggregates::{
AggregateInputMode, AggregateMode, AggregateOutputMode, PhysicalGroupBy,
create_schema, evaluate_group_by, evaluate_many, evaluate_optional,
create_schema, evaluate_group_by, evaluate_many, evaluate_optional, group_id_array,
max_duplicate_ordinal,
};
use crate::metrics::{BaselineMetrics, MetricBuilder, MetricCategory, RecordOutput};
use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder};
Expand Down Expand Up @@ -360,6 +361,7 @@ pub(crate) struct GroupedHashAggregateStream {
// the execution.
// ========================================================================
schema: SchemaRef,
input_schema: SchemaRef,
input: SendableRecordBatchStream,
mode: AggregateMode,

Expand Down Expand Up @@ -661,6 +663,7 @@ impl GroupedHashAggregateStream {

Ok(GroupedHashAggregateStream {
schema: agg_schema,
input_schema: agg.input().schema(),
input,
mode: agg.mode,
accumulators,
Expand Down Expand Up @@ -1125,6 +1128,104 @@ impl GroupedHashAggregateStream {
Ok(Some(batch))
}

/// Registers groups for empty grouping sets when no input rows were seen.
///
/// `GROUP BY GROUPING SETS (())` must always produce one row even when there
/// are no input rows (standard SQL semantics for a "grand total" group).
/// Mixed grouping sets like `GROUPING SETS (a, ())` also produce one row for
/// the empty set `()` on empty input.
///
/// This method interns the group keys and primes the accumulators so they
/// produce their zero-row aggregate values (e.g. `NULL` for `SUM`,
/// `0` for `COUNT`).
fn init_empty_grouping_sets(&mut self) -> Result<()> {
if !self.group_by.has_grouping_set() || !self.group_values.is_empty() {
return Ok(());
}

let max_ordinal = max_duplicate_ordinal(self.group_by.groups());
let mut ordinals: std::collections::HashMap<&[bool], usize> =
std::collections::HashMap::new();
let group_schema = self.group_by.group_schema(&self.input_schema)?;
let n_expr = self.group_by.expr().len();
let mut any_interned = false;

for group in self.group_by.groups() {
let ordinal = {
let entry = ordinals.entry(group.as_slice()).or_insert(0);
let o = *entry;
*entry += 1;
o
};

if !group.iter().all(|&is_null| is_null) {
continue;
}

// Build the group key: one NULL per group-by expression, then the grouping_id.
let mut cols: Vec<ArrayRef> = group_schema
.fields()
.iter()
.take(n_expr)
.map(|f| new_null_array(f.data_type(), 1))
.collect();
cols.push(group_id_array(group, ordinal, max_ordinal, 1)?);

let starting_groups = self.group_values.len();
self.group_values
.intern(&cols, &mut self.current_group_indices)?;
let total_groups = self.group_values.len();
if total_groups > starting_groups {
self.group_ordering.new_groups(
&cols,
&self.current_group_indices,
total_groups,
)?;
}
any_interned = true;
}

if any_interned {
// Prime each accumulator for the registered group count with no data.
//
// We build 1-row null arrays for each aggregate argument and pass them
// with an all-false filter. The filter ensures no row is accumulated
// into any group, which keeps every group in its "zero" initial state
// (NULL for SUM/AVG/MIN/MAX, 0 for COUNT).
//
// Using a 1-row batch rather than 0 rows is required to avoid a fast
// path in `NullState::accumulate` that treats "0 nulls in a 0-row
// array" as "all groups have been seen", which would cause SUM to
// return 0 instead of NULL.
//
// Argument types are inferred directly from the expression metadata so
// we never need to construct a full `RecordBatch`.
let total_groups = self.group_values.len();
let null_args: Vec<Vec<ArrayRef>> = self
.aggregate_arguments
.iter()
.map(|args| {
args.iter()
.map(|expr| {
let dt = expr.data_type(&self.input_schema)?;
Ok(new_null_array(&dt, 1))
})
.collect::<Result<Vec<_>>>()
})
.collect::<Result<Vec<_>>>()?;
let false_filter = BooleanArray::from(vec![false]);
for (acc, args) in self.accumulators.iter_mut().zip(null_args.iter()) {
if self.mode.input_mode() == AggregateInputMode::Raw {
acc.update_batch(args, &[0], Some(&false_filter), total_groups)?;
} else {
acc.merge_batch(args, &[0], Some(&false_filter), total_groups)?;
}
}
}

Ok(())
}

/// Emit all intermediate aggregation states, sort them, and store them on disk.
/// This process helps in reducing memory pressure by allowing the data to be
/// read back with streaming merge.
Expand Down Expand Up @@ -1223,6 +1324,7 @@ impl GroupedHashAggregateStream {
let timer = elapsed_compute.timer();
self.exec_state = if self.spill_state.spills.is_empty() {
// Input has been entirely processed without spilling to disk.
self.init_empty_grouping_sets()?;

// Flush any remaining group values.
let batch = self.emit(EmitTo::All, false)?;
Expand Down
19 changes: 19 additions & 0 deletions datafusion/sqllogictest/test_files/grouping.slt
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,22 @@ query I
SELECT SUM(v1) FROM generate_series(10) AS t1(v1) GROUP BY GROUPING SETS(())
----
55

# grouping_sets_empty_input: GROUPING SETS (()) must produce one NULL row on empty input
# (standard SQL: the empty grouping set always defines exactly one group)
query I
SELECT SUM(v1) FROM generate_series(10) AS t1(v1) WHERE false GROUP BY GROUPING SETS(())
----
NULL

# grouping_sets_empty_input_count: COUNT returns 0 for the empty group, not a missing row
query I
SELECT COUNT(*) FROM generate_series(10) AS t1(v1) WHERE false GROUP BY GROUPING SETS(())
----
0

# grouping_sets_mixed_empty_and_non_empty: only the empty set (()) produces a row on empty input
query II
SELECT SUM(v1), COUNT(*) FROM generate_series(10) AS t1(v1) WHERE false GROUP BY GROUPING SETS((), (v1))
----
NULL 0
Loading