Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ pub(super) struct BufferedBatch {
pub null_joined: Vec<usize>,
/// Size estimation used for reserving / releasing memory
pub size_estimation: usize,
/// Actual amount tracked in the memory reservation for this batch.
///
/// - `InMemory`: equals `size_estimation` (full batch + join_arrays + metadata)
/// - `Spilled`: equals join_arrays memory if `try_grow` succeeded after spill, else 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch fixing the accounting path here 👍

One small thing: the reserved_amount field doc still says spilled batches only track join-array memory if try_grow succeeds, otherwise 0. Since the implementation now uses unconditional grow(join_arrays_mem), I think the doc comment should be updated to match the new behavior.

///
/// Invariant: `free_reservation()` shrinks by exactly this amount, so we never
/// shrink by more than we grew.
pub reserved_amount: usize,
/// Tracks filter outcomes for buffered rows in full outer joins.
/// Indexed by absolute row position within the batch. See [`FilterState`].
pub join_filter_status: Vec<FilterState>,
Expand Down Expand Up @@ -274,10 +282,20 @@ impl BufferedBatch {
join_arrays,
null_joined: vec![],
size_estimation,
reserved_amount: 0,
join_filter_status: vec![FilterState::Unvisited; num_rows],
num_rows,
}
}

/// Memory footprint of join key arrays that remain in memory even after
/// the main batch is spilled to disk
fn join_arrays_mem(&self) -> usize {
self.join_arrays
.iter()
.map(|arr| arr.get_array_memory_size())
.sum()
}
}

// TODO: Spill join arrays (https://github.com/apache/datafusion/pull/17429)
Expand Down Expand Up @@ -948,17 +966,17 @@ impl MaterializingSortMergeJoinStream {
}

fn free_reservation(&mut self, buffered_batch: &BufferedBatch) -> Result<()> {
// Shrink memory usage for in-memory batches only
if let BufferedBatchState::InMemory(_) = buffered_batch.batch {
if buffered_batch.reserved_amount > 0 {
self.reservation
.try_shrink(buffered_batch.size_estimation)?;
.try_shrink(buffered_batch.reserved_amount)?;
}
Ok(())
}

fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> {
match self.reservation.try_grow(buffered_batch.size_estimation) {
Ok(_) => {
buffered_batch.reserved_amount = buffered_batch.size_estimation;
self.join_metrics
.peak_mem_used()
.set_max(self.reservation.size());
Expand All @@ -978,6 +996,21 @@ impl MaterializingSortMergeJoinStream {
.unwrap(); // Operation only return None if no batches are spilled, here we ensure that at least one batch is spilled

buffered_batch.batch = BufferedBatchState::Spilled(spill_file);

// Join key arrays remain in memory after the batch is
// spilled — the comparator needs them for key boundary
// detection. Force-grow the reservation so the pool
// reflects actual memory usage. This is unconditional
// because the memory is physically consumed regardless
// and not tracking it would let other operators
// over-allocate against a stale pool view.
let join_arrays_mem = buffered_batch.join_arrays_mem();
self.reservation.grow(join_arrays_mem);
buffered_batch.reserved_amount = join_arrays_mem;
self.join_metrics
.peak_mem_used()
.set_max(self.reservation.size());

Ok(())
}
_ => internal_err!("Buffered batch has empty body"),
Expand Down
217 changes: 217 additions & 0 deletions datafusion/physical-plan/src/joins/sort_merge_join/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2487,6 +2487,223 @@ async fn overallocation_multi_batch_spill() -> Result<()> {
Ok(())
}

/// Verifies that `peak_mem_used` reflects join_arrays memory on the spill path.
///
/// Uses a memory limit smaller than a single batch's `size_estimation` so that
/// every batch spills — the `Ok` arm of `allocate_reservation` is never hit.
/// Before the fix, `peak_mem_used` would stay 0 because `set_max` was only
/// called in the `Ok` arm. After the fix, the spill path calls
/// `grow(join_arrays_mem)` + `set_max`, so `peak_mem_used > 0`.
#[tokio::test]
async fn spill_join_arrays_memory_accounting() -> Result<()> {
use arrow::array::Array;

let left_batch = build_table_i32(
("a1", &vec![0, 1]),
("b1", &vec![1, 1]),
("c1", &vec![4, 5]),
);
let size_estimation = left_batch.get_array_memory_size()
+ Int32Array::from(vec![1, 1]).get_array_memory_size()
+ 2usize.next_power_of_two() * size_of::<usize>()
+ size_of::<std::ops::Range<usize>>()
+ size_of::<usize>();
let join_arrays_mem = Int32Array::from(vec![1, 1]).get_array_memory_size();

// Memory limit: too small for a full batch, large enough for join_arrays.
// Every batch hits the Err arm → spills → grow(join_arrays_mem).
let memory_limit = (size_estimation + join_arrays_mem) / 2;
assert!(
memory_limit < size_estimation && memory_limit > join_arrays_mem,
"limit {memory_limit} must be between join_arrays_mem {join_arrays_mem} \
and size_estimation {size_estimation}"
);

let left_batches: Vec<RecordBatch> = (0..4)
.map(|i| {
build_table_i32(
("a1", &vec![i * 2, i * 2 + 1]),
("b1", &vec![1, 1]),
("c1", &vec![100 + i, 101 + i]),
)
})
.collect();
let left = build_table_from_batches(left_batches);

let right_batches: Vec<RecordBatch> = (0..2)
.map(|i| {
build_table_i32(
("a2", &vec![i * 2, i * 2 + 1]),
("b2", &vec![1, 1]),
("c2", &vec![200 + i, 201 + i]),
)
})
.collect();
let right = build_table_from_batches(right_batches);

let on = vec![(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let sort_options = vec![SortOptions::default(); on.len()];

let runtime = RuntimeEnvBuilder::new()
.with_memory_limit(memory_limit, 1.0)
.with_disk_manager_builder(
DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory),
)
.build_arc()?;

let session_config = SessionConfig::default().with_batch_size(50);
let task_ctx = Arc::new(
TaskContext::default()
.with_session_config(session_config)
.with_runtime(Arc::clone(&runtime)),
);

let join = join_with_options(
Arc::clone(&left),
Arc::clone(&right),
on.clone(),
Inner,
sort_options,
NullEquality::NullEqualsNothing,
)?;

let stream = join.execute(0, task_ctx)?;
let result = common::collect(stream).await.unwrap();

assert!(!result.is_empty(), "Expected non-empty join result");

let metrics = join.metrics().unwrap();
assert!(
metrics.spill_count().unwrap() > 0,
"Expected spilling to occur"
);

// Before the fix, peak_mem_used was 0 here because set_max was only
// called in the Ok arm of allocate_reservation, which is never reached
// when every batch spills. After the fix, the spill path calls
// grow(join_arrays_mem) + set_max unconditionally.
let peak_mem = metrics
.sum_by_name("peak_mem_used")
.map(|m| m.as_usize())
.unwrap_or(0);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new tests look good and peak_mem > 0 definitely catches the old missing-accounting path.

As a possible follow-up improvement, it might be worth tightening this to something like peak_mem >= join_arrays_mem if that value is available in scope. That would also help catch partial accounting regressions. Non-blocking though.

assert!(
peak_mem > 0,
"peak_mem_used should reflect join_arrays tracked on spill path"
);

// All memory must be released (grow/shrink balanced, no underflow)
assert_eq!(
runtime.memory_pool.reserved(),
0,
"All memory should be released after join completes"
);

Ok(())
}

/// Test the no-headroom scenario: pool is so tight that even
/// join_arrays_mem exceeds the pool limit. With force-grow, the
/// reservation still tracks the join_arrays unconditionally so the
/// pool reflects actual memory usage.
#[tokio::test]
async fn spill_join_arrays_no_headroom() -> Result<()> {
use arrow::array::Array;

let join_arrays_mem = Int32Array::from(vec![1, 1]).get_array_memory_size();

// Pool smaller than join_arrays_mem: try_grow(size_estimation) fails → spill.
// Force-grow(join_arrays_mem) succeeds unconditionally → reserved_amount > 0.
let memory_limit = join_arrays_mem / 2;
assert!(
memory_limit < join_arrays_mem,
"limit {memory_limit} must be smaller than join_arrays_mem {join_arrays_mem}"
);

let left_batches: Vec<RecordBatch> = (0..4)
.map(|i| {
build_table_i32(
("a1", &vec![i * 2, i * 2 + 1]),
("b1", &vec![1, 1]),
("c1", &vec![100 + i, 101 + i]),
)
})
.collect();
let left = build_table_from_batches(left_batches);

let right_batches: Vec<RecordBatch> = (0..2)
.map(|i| {
build_table_i32(
("a2", &vec![i * 2, i * 2 + 1]),
("b2", &vec![1, 1]),
("c2", &vec![200 + i, 201 + i]),
)
})
.collect();
let right = build_table_from_batches(right_batches);

let on = vec![(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let sort_options = vec![SortOptions::default(); on.len()];

let runtime = RuntimeEnvBuilder::new()
.with_memory_limit(memory_limit, 1.0)
.with_disk_manager_builder(
DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory),
)
.build_arc()?;

let session_config = SessionConfig::default().with_batch_size(50);
let task_ctx = Arc::new(
TaskContext::default()
.with_session_config(session_config)
.with_runtime(Arc::clone(&runtime)),
);

let join = join_with_options(
Arc::clone(&left),
Arc::clone(&right),
on.clone(),
Inner,
sort_options,
NullEquality::NullEqualsNothing,
)?;

let stream = join.execute(0, task_ctx)?;
let result = common::collect(stream).await.unwrap();

assert!(!result.is_empty(), "Expected non-empty join result");

let metrics = join.metrics().unwrap();
assert!(
metrics.spill_count().unwrap() > 0,
"Expected spilling to occur"
);

// Force-grow means peak_mem_used is always tracked, even when pool is tight.
let peak_mem = metrics
.sum_by_name("peak_mem_used")
.map(|m| m.as_usize())
.unwrap_or(0);
assert!(
peak_mem > 0,
"peak_mem_used should reflect force-grown join_arrays"
);

// Pool should be fully released (grow/shrink balanced)
assert_eq!(
runtime.memory_pool.reserved(),
0,
"All memory should be released after join completes"
);

Ok(())
}

/// Build a c1 < c2 filter on the third column of each side.
fn build_c1_lt_c2_filter(left_schema: &Schema, right_schema: &Schema) -> JoinFilter {
JoinFilter::new(
Expand Down
Loading