Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@
import org.apache.celeborn.common.unsafe.Platform;
import org.apache.celeborn.common.util.Utils;

/**
* Sort-based pusher for MapReduce shuffle data to Celeborn.
*
* This implementation uses primitive int arrays to store record metadata
* (offsets, key lengths, value lengths) during data collection to minimize
* object allocation. During flush, temporary Record objects are created
* for sorting and immediately garbage collected in Young Gen.
*
* To prevent memory pressure during sorting, the implementation triggers
* early spill when record count exceeds a threshold (5M records by default),
* ensuring temporary objects fit within Young Gen capacity.
*/
public class CelebornSortBasedPusher<K, V> extends OutputStream {
private final Logger logger = LoggerFactory.getLogger(CelebornSortBasedPusher.class);
private final int mapId;
Expand All @@ -48,7 +60,7 @@ public class CelebornSortBasedPusher<K, V> extends OutputStream {
private final AtomicReference<Exception> exception = new AtomicReference<>();
private final Counters.Counter mapOutputByteCounter;
private final Counters.Counter mapOutputRecordCounter;
private final Map<Integer, List<SerializedKV>> partitionedKVs;
private final Map<Integer, KVBufferInfo> partitionedKVBuffers;
private int writePos;
private byte[] serializedKV;
private final int maxPushDataSize;
Expand Down Expand Up @@ -79,7 +91,7 @@ public CelebornSortBasedPusher(
this.mapOutputRecordCounter = mapOutputRecordCounter;
this.comparator = comparator;
this.shuffleClient = shuffleClient;
partitionedKVs = new HashMap<>();
partitionedKVBuffers = new HashMap<>();
serializedKV = new byte[maxIOBufferSize];
maxPushDataSize = (int) celebornConf.clientMrMaxPushData();
logger.info(
Expand All @@ -102,6 +114,7 @@ public CelebornSortBasedPusher(

public void insert(K key, V value, int partition) {
try {
// Check if we should spill based on buffer size
if (writePos >= spillIOBufferSize) {
// needs to sort and flush data
if (logger.isDebugEnabled()) {
Expand All @@ -114,6 +127,22 @@ public void insert(K key, V value, int partition) {
sortKVs();
sendKVAndUpdateWritePos();
}

// Additional check: limit total record count to avoid memory pressure during sort
// If total records exceed safe threshold, force an early spill
int totalRecords = getTotalRecordCount();
final int MAX_RECORDS_BEFORE_SPILL = 5_000_000; // 5M records = ~120MB temporary objects
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling getTotalRecordCount() on every insert() makes each record insertion O(#partitions) due to iterating partitionedKVBuffers.values(). This can add significant overhead at tens of millions of records. Consider maintaining a totalRecordCount field that increments in insertRecordInternal(...) (after bufferInfo.add(...)) and resets when spilling/clearing (e.g., in sendKVAndUpdateWritePos() and close()).

Copilot uses AI. Check for mistakes.

if (totalRecords >= MAX_RECORDS_BEFORE_SPILL && writePos > 0) {
if (logger.isDebugEnabled()) {
logger.debug(
"Record count {} exceeds safe threshold {}, forcing early spill",
totalRecords, MAX_RECORDS_BEFORE_SPILL);
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 5M threshold is duplicated (MAX_RECORDS_BEFORE_SPILL and MAX_SORT_RECORDS) and could drift over time. Consider defining a single class-level constant (or deriving it from configuration) and using it consistently for both early-spill and sorting limits. This also makes it easier to document/adjust the policy in one place.

Suggested change
final int MAX_RECORDS_BEFORE_SPILL = 5_000_000; // 5M records = ~120MB temporary objects
if (totalRecords >= MAX_RECORDS_BEFORE_SPILL && writePos > 0) {
if (logger.isDebugEnabled()) {
logger.debug(
"Record count {} exceeds safe threshold {}, forcing early spill",
totalRecords, MAX_RECORDS_BEFORE_SPILL);
if (totalRecords >= MAX_SORT_RECORDS && writePos > 0) {
if (logger.isDebugEnabled()) {
logger.debug(
"Record count {} exceeds safe threshold {}, forcing early spill",
totalRecords, MAX_SORT_RECORDS);

Copilot uses AI. Check for mistakes.
}
sortKVs();
sendKVAndUpdateWritePos();
}

int dataLen = insertRecordInternal(key, value, partition);
if (logger.isDebugEnabled()) {
logger.debug(
Expand All @@ -126,46 +155,64 @@ public void insert(K key, V value, int partition) {
}
}

/**
* Get total record count across all partitions.
*/
private int getTotalRecordCount() {
int total = 0;
for (KVBufferInfo bufferInfo : partitionedKVBuffers.values()) {
total += bufferInfo.count;
}
return total;
}
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling getTotalRecordCount() on every insert() makes each record insertion O(#partitions) due to iterating partitionedKVBuffers.values(). This can add significant overhead at tens of millions of records. Consider maintaining a totalRecordCount field that increments in insertRecordInternal(...) (after bufferInfo.add(...)) and resets when spilling/clearing (e.g., in sendKVAndUpdateWritePos() and close()).

Copilot uses AI. Check for mistakes.

private void sendKVAndUpdateWritePos() throws IOException {
Iterator<Map.Entry<Integer, List<SerializedKV>>> entryIter =
partitionedKVs.entrySet().iterator();
Iterator<Map.Entry<Integer, KVBufferInfo>> entryIter =
partitionedKVBuffers.entrySet().iterator();
while (entryIter.hasNext()) {
Map.Entry<Integer, List<SerializedKV>> entry = entryIter.next();
Map.Entry<Integer, KVBufferInfo> entry = entryIter.next();
entryIter.remove();
int partition = entry.getKey();
List<SerializedKV> kvs = entry.getValue();
List<SerializedKV> localKVs = new ArrayList<>();
KVBufferInfo bufferInfo = entry.getValue();
int partitionKVTotalLen = 0;
// process buffers for specific partition
for (SerializedKV kv : kvs) {
partitionKVTotalLen += kv.kLen + kv.vLen;
localKVs.add(kv);
int batchStartIdx = 0;
// process buffers for specific partition (arrays are already sorted in-place)
for (int i = 0; i < bufferInfo.count; i++) {
partitionKVTotalLen += bufferInfo.keyLens[i] + bufferInfo.valueLens[i];
if (partitionKVTotalLen > maxPushDataSize) {
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This batching logic sends a batch after partitionKVTotalLen has already exceeded maxPushDataSize, which can produce push payloads larger than the configured limit (contradicting the comment about limiting max size). To actually enforce the limit, structure the loop so that if adding the next record would exceed the threshold, you send the previous batch first, then start a new batch with the current record. Also handle the edge case where a single record is larger than maxPushDataSize (so you avoid an empty-batch send or an infinite loop).

Copilot uses AI. Check for mistakes.
// limit max size of pushdata to avoid possible memory issue in Celeborn worker
// data layout
// pushdata header (16) + pushDataLen(4) +
// [varKeyLen+varValLen+serializedRecord(x)][...]
sendSortedBuffersPartition(partition, localKVs, partitionKVTotalLen);
localKVs.clear();
sendSortedBuffersPartition(partition, bufferInfo, batchStartIdx, i - batchStartIdx + 1, partitionKVTotalLen);
// move batch start
partitionKVTotalLen = 0;
batchStartIdx = i + 1;
}
}
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This batching logic sends a batch after partitionKVTotalLen has already exceeded maxPushDataSize, which can produce push payloads larger than the configured limit (contradicting the comment about limiting max size). To actually enforce the limit, structure the loop so that if adding the next record would exceed the threshold, you send the previous batch first, then start a new batch with the current record. Also handle the edge case where a single record is larger than maxPushDataSize (so you avoid an empty-batch send or an infinite loop).

Copilot uses AI. Check for mistakes.
if (!localKVs.isEmpty()) {
sendSortedBuffersPartition(partition, localKVs, partitionKVTotalLen);
// send remaining records
if (batchStartIdx < bufferInfo.count) {
// recalculate total length for remaining records
partitionKVTotalLen = 0;
for (int i = batchStartIdx; i < bufferInfo.count; i++) {
partitionKVTotalLen += bufferInfo.keyLens[i] + bufferInfo.valueLens[i];
}
sendSortedBuffersPartition(partition, bufferInfo, batchStartIdx, bufferInfo.count - batchStartIdx, partitionKVTotalLen);
}
kvs.clear();
// Clear buffer info for reuse
bufferInfo.clear();
}
// all data sent
partitionedKVs.clear();
partitionedKVBuffers.clear();
writePos = 0;
}

private void sendSortedBuffersPartition(
int partition, List<SerializedKV> localKVs, int partitionKVTotalLen) throws IOException {
int partition, KVBufferInfo bufferInfo, int startIdx, int count, int partitionKVTotalLen) throws IOException {
int extraSize = 0;
for (SerializedKV localKV : localKVs) {
extraSize += WritableUtils.getVIntSize(localKV.kLen);
extraSize += WritableUtils.getVIntSize(localKV.vLen);
for (int i = startIdx; i < startIdx + count; i++) {
extraSize += WritableUtils.getVIntSize(bufferInfo.keyLens[i]);
extraSize += WritableUtils.getVIntSize(bufferInfo.valueLens[i]);
}
// copied from hadoop logic
extraSize += WritableUtils.getVIntSize(-1);
Expand All @@ -174,14 +221,16 @@ private void sendSortedBuffersPartition(
byte[] pkvs = new byte[4 + extraSize + partitionKVTotalLen];
int pkvsPos = 4;
Platform.putInt(pkvs, Platform.BYTE_ARRAY_OFFSET, partitionKVTotalLen + extraSize);
for (SerializedKV kv : localKVs) {
int recordLen = kv.kLen + kv.vLen;
for (int i = startIdx; i < startIdx + count; i++) {
int kLen = bufferInfo.keyLens[i];
int vLen = bufferInfo.valueLens[i];
int recordLen = kLen + vLen;
// write key len
pkvsPos = writeVLong(pkvs, pkvsPos, kv.kLen);
pkvsPos = writeVLong(pkvs, pkvsPos, kLen);
// write value len
pkvsPos = writeVLong(pkvs, pkvsPos, kv.vLen);
pkvsPos = writeVLong(pkvs, pkvsPos, vLen);
// write serialized record
System.arraycopy(serializedKV, kv.offset, pkvs, pkvsPos, recordLen);
System.arraycopy(serializedKV, bufferInfo.offsets[i], pkvs, pkvsPos, recordLen);
pkvsPos += recordLen;
}
// finally write -1 two times
Expand Down Expand Up @@ -245,13 +294,119 @@ private int writeVLong(byte[] data, int offset, long dataInt) {
}

private void sortKVs() {
for (Map.Entry<Integer, List<SerializedKV>> partitionKVEntry : partitionedKVs.entrySet()) {
partitionKVEntry
.getValue()
.sort(
(o1, o2) ->
comparator.compare(
serializedKV, o1.offset, o1.kLen, serializedKV, o2.offset, o2.kLen));
// Maximum number of temporary Record objects to create at once.
// This limits Young Gen pressure and prevents Full GC.
// Record size ~24 bytes, so 5M records = 120MB at peak
final int MAX_SORT_RECORDS = 5_000_000;
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 5M threshold is duplicated (MAX_RECORDS_BEFORE_SPILL and MAX_SORT_RECORDS) and could drift over time. Consider defining a single class-level constant (or deriving it from configuration) and using it consistently for both early-spill and sorting limits. This also makes it easier to document/adjust the policy in one place.

Copilot uses AI. Check for mistakes.

for (Map.Entry<Integer, KVBufferInfo> partitionKVEntry : partitionedKVBuffers.entrySet()) {
KVBufferInfo bufferInfo = partitionKVEntry.getValue();
if (bufferInfo.count <= 1) {
continue;
}

// If too many records, split into batches
if (bufferInfo.count > MAX_SORT_RECORDS) {
// Sort and flush each batch immediately
int remaining = bufferInfo.count;
int start = 0;

while (remaining > 0) {
int batchSize = Math.min(remaining, MAX_SORT_RECORDS);
int end = start + batchSize;

// Sort this batch
sortBatch(bufferInfo, start, end);

// Send this batch immediately to free memory
sendPartialBuffer(bufferInfo, start, batchSize);

start = end;
remaining -= batchSize;
}

// After all batches sent, clear the buffer
bufferInfo.clear();
} else {
// Full sort (single batch)
sortBatch(bufferInfo, 0, bufferInfo.count);
}
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sortKVs() calls sendPartialBuffer(...) when bufferInfo.count > MAX_SORT_RECORDS, but sendPartialBuffer currently always throws UnsupportedOperationException. That makes large buffers fail at runtime rather than spilling safely. If the intent is to rely on early-spill to guarantee count <= MAX_SORT_RECORDS, consider removing this branch and enforcing the invariant (e.g., via a clear exception that should never happen). Otherwise, implement partial sending by integrating batch boundaries into sendKVAndUpdateWritePos() (so the batching logic that exists there is reused rather than throwing).

Suggested change
// If too many records, split into batches
if (bufferInfo.count > MAX_SORT_RECORDS) {
// Sort and flush each batch immediately
int remaining = bufferInfo.count;
int start = 0;
while (remaining > 0) {
int batchSize = Math.min(remaining, MAX_SORT_RECORDS);
int end = start + batchSize;
// Sort this batch
sortBatch(bufferInfo, start, end);
// Send this batch immediately to free memory
sendPartialBuffer(bufferInfo, start, batchSize);
start = end;
remaining -= batchSize;
}
// After all batches sent, clear the buffer
bufferInfo.clear();
} else {
// Full sort (single batch)
sortBatch(bufferInfo, 0, bufferInfo.count);
}
// Enforce invariant that early spilling keeps record count within bounds.
if (bufferInfo.count > MAX_SORT_RECORDS) {
throw new IllegalStateException(
"KVBufferInfo.count (" + bufferInfo.count
+ ") exceeds MAX_SORT_RECORDS (" + MAX_SORT_RECORDS
+ "). Early spill should prevent this; partial buffer sending "
+ "is not supported in sortKVs().");
}
// Full sort (single batch)
sortBatch(bufferInfo, 0, bufferInfo.count);

Copilot uses AI. Check for mistakes.
}
}

/**
* Sort a batch of records from start (inclusive) to end (exclusive).
*/
private void sortBatch(KVBufferInfo bufferInfo, int start, int end) {
int size = end - start;

// Create temporary Record objects
Record[] records = new Record[size];
for (int i = 0; i < size; i++) {
records[i] = new Record(
serializedKV,
comparator,
bufferInfo.offsets[start + i],
bufferInfo.keyLens[start + i],
bufferInfo.valueLens[start + i]);
}

// Sort using Arrays.sort
Arrays.sort(records);

// Write back sorted results
for (int i = 0; i < size; i++) {
bufferInfo.offsets[start + i] = records[i].offset;
bufferInfo.keyLens[start + i] = records[i].kLen;
bufferInfo.valueLens[start + i] = records[i].vLen;
}
}

/**
* Send a portion of the buffer (after it's been sorted).
* This method requires careful re-design of sendKVAndUpdateWritePos
* to work with partial sends.
*
* For simplicity, we revert to the original approach of sorting the entire buffer
* when it's safe (within MAX_SORT_RECORDS). But if we exceed the limit,
* we need to handle partial sends differently.
*
* For the first version, let's just throw if we exceed MAX_SORT_RECORDS,
* and let the user adjust heap size or spill.percent instead.
*/
private void sendPartialBuffer(KVBufferInfo bufferInfo, int start, int count) {
// TODO: This needs careful redesign of sendKVAndUpdateWritePos
// For now, throw to force user to adjust configuration
throw new UnsupportedOperationException(
"Buffer too large for single batch sorting. " +
"Please reduce mapreduce.task.io.sort.mb or increase mapreduce.map.java.opts heap.");
}
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description states the fix is fully backward compatible and requires no configuration changes, but this code path explicitly instructs users to change mapreduce.task.io.sort.mb or heap settings and fails the task. Either remove/avoid this user-facing failure mode (preferred, consistent with the PR description) or update the PR description to reflect the new behavior and conditions under which it can occur.

Copilot uses AI. Check for mistakes.

/**
* Temporary record for sorting.
* These objects are created only during sort, then garbage collected in Young Gen.
* Static class to avoid holding reference to outer class instance.
*/
private static class Record implements Comparable<Record> {
private final byte[] serializedKV;
private final RawComparator comparator;
final int offset;
final int kLen;
final int vLen;

Record(byte[] serializedKV, RawComparator comparator, int offset, int kLen, int vLen) {
this.serializedKV = serializedKV;
this.comparator = comparator;
this.offset = offset;
this.kLen = kLen;
this.vLen = vLen;
}

@Override
public int compareTo(Record other) {
return comparator.compare(
serializedKV, offset, kLen,
other.serializedKV, other.offset, other.kLen);
}
}

Expand All @@ -263,17 +418,18 @@ private int insertRecordInternal(K key, V value, int partition) throws IOExcepti
keyLen = writePos - offset;
vSer.serialize(value);
valLen = writePos - keyLen - offset;
List<SerializedKV> serializedKVs =
partitionedKVs.computeIfAbsent(partition, v -> new ArrayList<>());
serializedKVs.add(new SerializedKV(offset, keyLen, valLen));
KVBufferInfo bufferInfo = partitionedKVBuffers.computeIfAbsent(partition,
v -> new KVBufferInfo(1024)); // Initial capacity: 1024 records
// Store metadata directly in primitive arrays, no object allocation
bufferInfo.add(offset, keyLen, valLen);
if (logger.isDebugEnabled()) {
logger.debug(
"Pusher insert into buffer partition:{} offset:{} keyLen:{} valueLen:{} size:{}",
partition,
offset,
keyLen,
valLen,
partitionedKVs.size());
partitionedKVBuffers.size());
}
return keyLen + valLen;
}
Expand Down Expand Up @@ -320,19 +476,64 @@ public void close() {
} catch (IOException e) {
exception.compareAndSet(null, e);
}
partitionedKVs.clear();
partitionedKVBuffers.clear();
serializedKV = null;
}

static class SerializedKV {
final int offset;
final int kLen;
final int vLen;
/**
* Buffer info to manage serialized key-value records for each partition.
* Uses primitive int arrays to store metadata instead of object arrays,
* significantly reducing memory overhead.
*
* Memory comparison for 1 million records:
* - ArrayList<SerializedKV>: ~32MB (8MB references + 24MB objects)
* - This approach: ~12MB (4MB×3 int arrays)
*
* Saves 62.5% memory!
*/
static class KVBufferInfo {
int[] offsets; // Store key offset in serializedKV buffer
int[] keyLens; // Store key length
int[] valueLens; // Store value length
int count;
int capacity;

public SerializedKV(int offset, int kLen, int vLen) {
this.offset = offset;
this.kLen = kLen;
this.vLen = vLen;
KVBufferInfo(int initialCapacity) {
this.offsets = new int[initialCapacity];
this.keyLens = new int[initialCapacity];
this.valueLens = new int[initialCapacity];
this.capacity = initialCapacity;
this.count = 0;
}

void add(int offset, int kLen, int vLen) {
if (count >= capacity) {
// Expand arrays with 2x growth strategy
int newCapacity = capacity * 2;

int[] newOffsets = new int[newCapacity];
int[] newKeyLens = new int[newCapacity];
int[] newValueLens = new int[newCapacity];

System.arraycopy(offsets, 0, newOffsets, 0, count);
System.arraycopy(keyLens, 0, newKeyLens, 0, count);
System.arraycopy(valueLens, 0, newValueLens, 0, count);

offsets = newOffsets;
keyLens = newKeyLens;
valueLens = newValueLens;
capacity = newCapacity;
}
offsets[count] = offset;
keyLens[count] = kLen;
valueLens[count] = vLen;
count++;
}

void clear() {
count = 0;
// Note: We don't clear the arrays themselves to avoid overhead
// They will be overwritten as new data is added
}
}
}
Loading