diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs index 4840b56f55fff..069e94d0a9fd6 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs @@ -235,6 +235,18 @@ pub(super) struct BufferedBatch { pub null_joined: Vec, /// Size estimation used for reserving / releasing memory pub size_estimation: usize, + /// Memory footprint of `join_arrays` cached at construction time. + /// Used during spill to track the residual memory that remains after + /// the main batch is written to disk. + pub join_arrays_mem: usize, + /// Actual amount tracked in the memory reservation for this batch. + /// + /// - `InMemory`: equals `size_estimation` (full batch + join_arrays + metadata) + /// - `Spilled`: equals `join_arrays_mem` (join key arrays stay in memory) + /// + /// Invariant: `free_reservation()` shrinks by exactly this amount, so we never + /// shrink by more than we grew. + pub reserved_amount: usize, /// Tracks filter outcomes for buffered rows in full outer joins. /// Indexed by absolute row position within the batch. See [`FilterState`]. pub join_filter_status: Vec, @@ -258,11 +270,13 @@ impl BufferedBatch { // + worst case null_joined (as vector capacity * element size) // + Range size // + size of this estimation + let join_arrays_mem: usize = join_arrays + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum(); + let size_estimation = batch.get_array_memory_size() - + join_arrays - .iter() - .map(|arr| arr.get_array_memory_size()) - .sum::() + + join_arrays_mem + batch.num_rows().next_power_of_two() * size_of::() + size_of::>() + size_of::(); @@ -274,6 +288,8 @@ impl BufferedBatch { join_arrays, null_joined: vec![], size_estimation, + join_arrays_mem, + reserved_amount: 0, join_filter_status: vec![FilterState::Unvisited; num_rows], num_rows, } @@ -947,18 +963,16 @@ impl MaterializingSortMergeJoinStream { } } - fn free_reservation(&mut self, buffered_batch: &BufferedBatch) -> Result<()> { - // Shrink memory usage for in-memory batches only - if let BufferedBatchState::InMemory(_) = buffered_batch.batch { - self.reservation - .try_shrink(buffered_batch.size_estimation)?; + fn free_reservation(&mut self, buffered_batch: &BufferedBatch) { + if buffered_batch.reserved_amount > 0 { + self.reservation.shrink(buffered_batch.reserved_amount); } - Ok(()) } fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> { match self.reservation.try_grow(buffered_batch.size_estimation) { Ok(_) => { + buffered_batch.reserved_amount = buffered_batch.size_estimation; self.join_metrics .peak_mem_used() .set_max(self.reservation.size()); @@ -978,6 +992,22 @@ impl MaterializingSortMergeJoinStream { .unwrap(); // Operation only return None if no batches are spilled, here we ensure that at least one batch is spilled buffered_batch.batch = BufferedBatchState::Spilled(spill_file); + + // Join key arrays remain in memory after the batch is + // spilled — the comparator needs them for key boundary + // detection. Force-grow the reservation so the pool + // reflects actual memory usage even if this pushes + // pool.reserved() above the configured limit. This is + // safe because the memory is physically consumed and + // not tracking it would let other operators over-allocate + // against a stale pool view. + let join_arrays_mem = buffered_batch.join_arrays_mem; + self.reservation.grow(join_arrays_mem); + buffered_batch.reserved_amount = join_arrays_mem; + self.join_metrics + .peak_mem_used() + .set_max(self.reservation.size()); + Ok(()) } _ => internal_err!("Buffered batch has empty body"), @@ -1006,7 +1036,7 @@ impl MaterializingSortMergeJoinStream { self.buffered_data.batches.pop_front() { self.produce_buffered_not_matched(&mut buffered_batch)?; - self.free_reservation(&buffered_batch)?; + self.free_reservation(&buffered_batch); head_changed = true; } } else { diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index 5d70530528728..bc34c351c5e21 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -2487,6 +2487,223 @@ async fn overallocation_multi_batch_spill() -> Result<()> { Ok(()) } +/// Verifies that `peak_mem_used` reflects join_arrays memory on the spill path. +/// +/// Uses a memory limit smaller than a single batch's `size_estimation` so that +/// every batch spills — the `Ok` arm of `allocate_reservation` is never hit. +/// Before the fix, `peak_mem_used` would stay 0 because `set_max` was only +/// called in the `Ok` arm. After the fix, the spill path calls +/// `grow(join_arrays_mem)` + `set_max`, so `peak_mem_used > 0`. +#[tokio::test] +async fn spill_join_arrays_memory_accounting() -> Result<()> { + use arrow::array::Array; + + let left_batch = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let size_estimation = left_batch.get_array_memory_size() + + Int32Array::from(vec![1, 1]).get_array_memory_size() + + 2usize.next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); + let join_arrays_mem = Int32Array::from(vec![1, 1]).get_array_memory_size(); + + // Memory limit: too small for a full batch, large enough for join_arrays. + // Every batch hits the Err arm → spills → grow(join_arrays_mem). + let memory_limit = (size_estimation + join_arrays_mem) / 2; + assert!( + memory_limit < size_estimation && memory_limit > join_arrays_mem, + "limit {memory_limit} must be between join_arrays_mem {join_arrays_mem} \ + and size_estimation {size_estimation}" + ); + + let left_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a1", &vec![i * 2, i * 2 + 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![100 + i, 101 + i]), + ) + }) + .collect(); + let left = build_table_from_batches(left_batches); + + let right_batches: Vec = (0..2) + .map(|i| { + build_table_i32( + ("a2", &vec![i * 2, i * 2 + 1]), + ("b2", &vec![1, 1]), + ("c2", &vec![200 + i, 201 + i]), + ) + }) + .collect(); + let right = build_table_from_batches(right_batches); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + let session_config = SessionConfig::default().with_batch_size(50); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(Arc::clone(&runtime)), + ); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Inner, + sort_options, + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let result = common::collect(stream).await.unwrap(); + + assert!(!result.is_empty(), "Expected non-empty join result"); + + let metrics = join.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to occur" + ); + + // Before the fix, peak_mem_used was 0 here because set_max was only + // called in the Ok arm of allocate_reservation, which is never reached + // when every batch spills. After the fix, the spill path calls + // grow(join_arrays_mem) + set_max unconditionally. + let peak_mem = metrics + .sum_by_name("peak_mem_used") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert!( + peak_mem >= join_arrays_mem, + "peak_mem_used ({peak_mem}) should be >= join_arrays_mem ({join_arrays_mem})" + ); + + // All memory must be released (grow/shrink balanced, no underflow) + assert_eq!( + runtime.memory_pool.reserved(), + 0, + "All memory should be released after join completes" + ); + + Ok(()) +} + +/// Test the no-headroom scenario: pool is so tight that even +/// join_arrays_mem exceeds the pool limit. With force-grow, the +/// reservation still tracks the join_arrays unconditionally so the +/// pool reflects actual memory usage. +#[tokio::test] +async fn spill_join_arrays_no_headroom() -> Result<()> { + use arrow::array::Array; + + let join_arrays_mem = Int32Array::from(vec![1, 1]).get_array_memory_size(); + + // Pool smaller than join_arrays_mem: try_grow(size_estimation) fails → spill. + // Force-grow(join_arrays_mem) succeeds unconditionally → reserved_amount > 0. + let memory_limit = join_arrays_mem / 2; + assert!( + memory_limit < join_arrays_mem, + "limit {memory_limit} must be smaller than join_arrays_mem {join_arrays_mem}" + ); + + let left_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a1", &vec![i * 2, i * 2 + 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![100 + i, 101 + i]), + ) + }) + .collect(); + let left = build_table_from_batches(left_batches); + + let right_batches: Vec = (0..2) + .map(|i| { + build_table_i32( + ("a2", &vec![i * 2, i * 2 + 1]), + ("b2", &vec![1, 1]), + ("c2", &vec![200 + i, 201 + i]), + ) + }) + .collect(); + let right = build_table_from_batches(right_batches); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + let session_config = SessionConfig::default().with_batch_size(50); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(Arc::clone(&runtime)), + ); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Inner, + sort_options, + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let result = common::collect(stream).await.unwrap(); + + assert!(!result.is_empty(), "Expected non-empty join result"); + + let metrics = join.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to occur" + ); + + // Force-grow means peak_mem_used is always tracked, even when pool is tight. + let peak_mem = metrics + .sum_by_name("peak_mem_used") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert!( + peak_mem >= join_arrays_mem, + "peak_mem_used ({peak_mem}) should be >= join_arrays_mem ({join_arrays_mem})" + ); + + // Pool should be fully released (grow/shrink balanced) + assert_eq!( + runtime.memory_pool.reserved(), + 0, + "All memory should be released after join completes" + ); + + Ok(()) +} + /// Build a c1 < c2 filter on the third column of each side. fn build_c1_lt_c2_filter(left_schema: &Schema, right_schema: &Schema) -> JoinFilter { JoinFilter::new(