diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java index 013785ecc6c..10725340e3b 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java @@ -221,6 +221,8 @@ public long pushData(boolean growThreshold) throws IOException { if (currentPartition == -1) { currentPartition = partition; } else { + shuffleClient.computeBatchCRC( + shuffleId, mapId, attemptNumber, currentPartition, dataBuf, 0, offSet); int bytesWritten = shuffleClient.mergeData( shuffleId, @@ -246,6 +248,8 @@ public long pushData(boolean growThreshold) throws IOException { if (offSet + recordSize > dataBuf.length) { try { + shuffleClient.computeBatchCRC( + shuffleId, mapId, attemptNumber, partition, dataBuf, 0, offSet); dataPusher.addTask(partition, dataBuf, offSet); memoryThresholdManager.updateStats(offSet, true); } catch (InterruptedException e) { @@ -261,6 +265,8 @@ public long pushData(boolean growThreshold) throws IOException { } if (offSet > 0) { try { + shuffleClient.computeBatchCRC( + shuffleId, mapId, attemptNumber, currentPartition, dataBuf, 0, offSet); dataPusher.addTask(currentPartition, dataBuf, offSet); memoryThresholdManager.updateStats(offSet, offSet == pushBufferMaxSize); } catch (InterruptedException e) { diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java index de7d00b4e47..3d299621d6a 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java @@ -277,6 +277,8 @@ private byte[] getOrCreateBuffer(int partitionId) { private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throws IOException { logger.debug("Push giant record for partition {}, size {}.", partitionId, numBytes); + shuffleClient.computeBatchCRC( + shuffleId, mapId, encodedAttemptId, partitionId, buffer, 0, numBytes); int bytesWritten = shuffleClient.pushData( shuffleId, @@ -318,6 +320,7 @@ private void flushSendBuffer(int partitionId, byte[] buffer, int size) throws IOException, InterruptedException { long start = System.nanoTime(); logger.debug("Flush buffer for partition {}, size {}.", partitionId, size); + shuffleClient.computeBatchCRC(shuffleId, mapId, encodedAttemptId, partitionId, buffer, 0, size); dataPusher.addTask(partitionId, buffer, size); writeMetrics.incWriteTime(System.nanoTime() - start); } @@ -338,6 +341,8 @@ private void close() throws IOException, InterruptedException { for (int i = 0; i < sendBuffers.length; i++) { final int size = sendOffsets[i]; if (size > 0) { + shuffleClient.computeBatchCRC( + shuffleId, mapId, encodedAttemptId, i, sendBuffers[i], 0, size); int bytesWritten = shuffleClient.mergeData( shuffleId, diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java index 9ba908ade1a..495e2770562 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java @@ -283,6 +283,8 @@ private void write0(scala.collection.Iterator iterator) throws IOException { private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throws IOException { logger.debug("Push giant record, size {}.", Utils.bytesToString(numBytes)); + shuffleClient.computeBatchCRC( + shuffleId, mapId, encodedAttemptId, partitionId, buffer, 0, numBytes); int bytesWritten = shuffleClient.pushData( shuffleId, diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java index 49c6d6954f4..cce54ff0396 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java @@ -278,6 +278,8 @@ private byte[] getOrCreateBuffer(int partitionId) { protected void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throws IOException { logger.debug("Push giant record, size {}.", numBytes); long start = System.nanoTime(); + shuffleClient.computeBatchCRC( + shuffleId, mapId, encodedAttemptId, partitionId, buffer, 0, numBytes); int bytesWritten = shuffleClient.pushData( shuffleId, @@ -321,6 +323,7 @@ private void flushSendBuffer(int partitionId, byte[] buffer, int size) throws IOException, InterruptedException { long start = System.nanoTime(); if (logger.isDebugEnabled()) logger.debug("Flush buffer, size {}.", Utils.bytesToString(size)); + shuffleClient.computeBatchCRC(shuffleId, mapId, encodedAttemptId, partitionId, buffer, 0, size); dataPusher.addTask(partitionId, buffer, size); writeMetrics.incWriteTime(System.nanoTime() - start); } @@ -332,6 +335,8 @@ protected void closeWrite() throws IOException { for (int i = 0; i < numPartitions; i++) { final int size = sendOffsets[i]; if (size > 0) { + shuffleClient.computeBatchCRC( + shuffleId, mapId, encodedAttemptId, i, sendBuffers[i], 0, size); mergeData(i, sendBuffers[i], 0, size); } } diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java index e413ce42fdb..f79c082d473 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java @@ -348,6 +348,8 @@ private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throw if (logger.isDebugEnabled()) logger.debug("Push giant record, size {}.", Utils.bytesToString(numBytes)); long start = System.nanoTime(); + shuffleClient.computeBatchCRC( + shuffleId, mapId, encodedAttemptId, partitionId, buffer, 0, numBytes); int bytesWritten = shuffleClient.pushData( shuffleId, diff --git a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java index 6ca3406be7f..9044f965cfb 100644 --- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -92,6 +92,16 @@ public int pushData( return length; } + @Override + public void computeBatchCRC( + int shuffleId, + int mapId, + int attemptId, + int partitionId, + byte[] data, + int offset, + int length) {} + @Override public int mergeData( int shuffleId, diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 7a89b051d4c..d1870779942 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -188,6 +188,20 @@ public abstract int pushData( int numPartitions) throws IOException; + /** + * Pre-compute CRC for a batch immediately after assembly in the writer, before the data enters + * the async push pipeline. This is the sole CRC accumulation path when shuffle integrity check is + * enabled; {@link #pushOrMergeData} does not perform CRC computation. + */ + public abstract void computeBatchCRC( + int shuffleId, + int mapId, + int attemptId, + int partitionId, + byte[] data, + int offset, + int length); + public abstract int mergeData( int shuffleId, int mapId, diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index be2bdf87d11..8c7c24d113a 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -1044,12 +1044,6 @@ public int pushOrMergeData( // increment batchId final int nextBatchId = pushState.nextBatchId(); - // Track commit metadata if shuffle compression and integrity check are enabled and this request - // is not for pushing metadata itself. - if (shuffleIntegrityCheckEnabled) { - pushState.addDataWithOffsetAndLength(partitionId, data, offset, length); - } - if (shuffleCompressionEnabled && !skipCompress) { // compress data final Compressor compressor = compressorThreadLocal.get(); @@ -1404,6 +1398,23 @@ public int mergeData( false); } + @Override + public void computeBatchCRC( + int shuffleId, + int mapId, + int attemptId, + int partitionId, + byte[] data, + int offset, + int length) { + if (!shuffleIntegrityCheckEnabled) { + return; + } + final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); + PushState pushState = getPushState(mapKey); + pushState.addDataWithOffsetAndLength(partitionId, data, offset, length); + } + @Override public void pushMergedData(int shuffleId, int mapId, int attemptId) throws IOException { final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); diff --git a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java index e6d450d87f2..f583cf06a25 100644 --- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java +++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java @@ -47,6 +47,7 @@ import org.apache.celeborn.client.compress.Compressor; import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.CommitMetadata; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.identity.UserIdentifier; import org.apache.celeborn.common.network.client.TransportClient; @@ -61,6 +62,8 @@ import org.apache.celeborn.common.protocol.message.StatusCode; import org.apache.celeborn.common.rpc.RpcEndpointRef; import org.apache.celeborn.common.rpc.RpcTimeoutException; +import org.apache.celeborn.common.util.Utils; +import org.apache.celeborn.common.write.PushState; public class ShuffleClientSuiteJ { @@ -709,4 +712,43 @@ public void testCorrectParametersPassedInRequest() throws IOException { assertEquals(crc32, capturedRequest.getCrc32()); assertEquals(bytesWritten, capturedRequest.getBytesWritten()); } + + @Test + public void testComputeBatchCRCAccumulatesCorrectly() { + CelebornConf conf = new CelebornConf(); + conf.set("celeborn.client.shuffle.integrityCheck.enabled", "true"); + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + byte[] batch0 = "hello world".getBytes(StandardCharsets.UTF_8); + byte[] batch1a = "foo".getBytes(StandardCharsets.UTF_8); + byte[] batch1b = "bar".getBytes(StandardCharsets.UTF_8); + + shuffleClient.computeBatchCRC( + TEST_SHUFFLE_ID, TEST_MAP_ID, TEST_ATTEMPT_ID, 0, batch0, 0, batch0.length); + shuffleClient.computeBatchCRC( + TEST_SHUFFLE_ID, TEST_MAP_ID, TEST_ATTEMPT_ID, 1, batch1a, 0, batch1a.length); + shuffleClient.computeBatchCRC( + TEST_SHUFFLE_ID, TEST_MAP_ID, TEST_ATTEMPT_ID, 1, batch1b, 0, batch1b.length); + + PushState pushState = + shuffleClient.getPushState(Utils.makeMapKey(TEST_SHUFFLE_ID, TEST_MAP_ID, TEST_ATTEMPT_ID)); + + int numPartitions = 2; + int[] crcPerPartition = pushState.getCRC32PerPartition(true, numPartitions); + long[] bytesPerPartition = pushState.getBytesWrittenPerPartition(true, numPartitions); + + // compute expected values via CommitMetadata — same code path as production + CommitMetadata expected0 = new CommitMetadata(); + expected0.addDataWithOffsetAndLength(batch0, 0, batch0.length); + assertEquals(expected0.getChecksum(), crcPerPartition[0]); + assertEquals(expected0.getBytes(), bytesPerPartition[0]); + + CommitMetadata expected1 = new CommitMetadata(); + expected1.addDataWithOffsetAndLength(batch1a, 0, batch1a.length); + expected1.addDataWithOffsetAndLength(batch1b, 0, batch1b.length); + assertEquals(expected1.getChecksum(), crcPerPartition[1]); + assertEquals(expected1.getBytes(), bytesPerPartition[1]); + } }