From dfe744fc764a423f8bacd06ead05c44dfaca0096 Mon Sep 17 00:00:00 2001 From: Xianming Lei Date: Thu, 17 Jul 2025 21:21:02 +0800 Subject: [PATCH 1/4] [CELEBORN-2065] Worker should support wait partition sort asynchronously --- .../worker/storage/FileResolvedCallback.java | 26 ++ .../worker/storage/PartitionFilesSorter.java | 176 +++++++++---- .../service/deploy/worker/FetchHandler.scala | 243 +++++++++++------- 3 files changed, 298 insertions(+), 147 deletions(-) create mode 100644 worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileResolvedCallback.java diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileResolvedCallback.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileResolvedCallback.java new file mode 100644 index 00000000000..857af8e886f --- /dev/null +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileResolvedCallback.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.worker.storage; + +import org.apache.celeborn.common.meta.FileInfo; + +public interface FileResolvedCallback { + void onSuccess(FileInfo fileInfo); + + void onFailure(Throwable e); +} diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java index cdfb95c1e59..dfece2ad317 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java @@ -31,6 +31,7 @@ import java.util.Set; import java.util.TreeMap; import java.util.concurrent.*; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.locks.Condition; @@ -78,6 +79,8 @@ public class PartitionFilesSorter extends ShuffleRecoverHelper { JavaUtils.newConcurrentHashMap(); private final ConcurrentHashMap> sortingShuffleFiles = JavaUtils.newConcurrentHashMap(); + private final ConcurrentHashMap> pendingSortCallbacks = + JavaUtils.newConcurrentHashMap(); private final Cache>> indexCache; private final Map> indexCacheNames = JavaUtils.newConcurrentHashMap(); @@ -205,7 +208,12 @@ public long getSortedSize() { // 3. If the sorted file is generated, it returns the sorted FileInfo. // This method will generate temporary file info for this shuffle read public FileInfo getSortedFileInfo( - String shuffleKey, String fileName, FileInfo fileInfo, int startMapIndex, int endMapIndex) + String shuffleKey, + String fileName, + FileInfo fileInfo, + int startMapIndex, + int endMapIndex, + FileResolvedCallback fileResolvedCallback) throws IOException { if (fileInfo instanceof MemoryFileInfo) { MemoryFileInfo memoryFileInfo = ((MemoryFileInfo) fileInfo); @@ -227,11 +235,13 @@ public FileInfo getSortedFileInfo( memoryFileInfo.getSortedBuffer(), targetBuffer, shuffleChunkSize); - return new MemoryFileInfo( - memoryFileInfo.getUserIdentifier(), - memoryFileInfo.isPartitionSplitEnabled(), - reduceFileMeta, - targetBuffer); + FileInfo sortedFileInfo = + new MemoryFileInfo( + memoryFileInfo.getUserIdentifier(), + memoryFileInfo.isPartitionSplitEnabled(), + reduceFileMeta, + targetBuffer); + fileResolvedCallback.onSuccess(sortedFileInfo); } else { DiskFileInfo diskFileInfo = ((DiskFileInfo) fileInfo); String fileId = shuffleKey + "-" + fileName; @@ -243,13 +253,76 @@ public FileInfo getSortedFileInfo( String sortedFilePath = Utils.getSortedFilePath(diskFileInfo.getFilePath()); String indexFilePath = Utils.getIndexFilePath(diskFileInfo.getFilePath()); - boolean fileSorting = true; synchronized (sorting) { if (sorted.contains(fileId)) { - fileSorting = false; - } else if (!sorting.contains(fileId)) { + return resolve( + shuffleKey, + fileId, + userIdentifier, + sortedFilePath, + indexFilePath, + startMapIndex, + endMapIndex); + } else if (sorting.contains(fileId)) { + FileResolvedCallback pendingCallback = + new FileResolvedCallback() { + @Override + public void onSuccess(FileInfo ignored) { + try { + FileInfo sortedFileInfo = + resolve( + shuffleKey, + fileId, + userIdentifier, + sortedFilePath, + indexFilePath, + startMapIndex, + endMapIndex); + fileResolvedCallback.onSuccess(sortedFileInfo); + } catch (Throwable e) { + fileResolvedCallback.onFailure(e); + } + } + + @Override + public void onFailure(Throwable e) { + fileResolvedCallback.onFailure(e); + } + }; + pendingSortCallbacks + .computeIfAbsent(fileId, k -> new CopyOnWriteArrayList<>()) + .add(pendingCallback); + } else { + FileSortedCallback fileSortedCallback = + new FileSortedCallback() { + @Override + public void onSuccess() { + try { + FileInfo sortedFileInfo = + resolve( + shuffleKey, + fileId, + userIdentifier, + sortedFilePath, + indexFilePath, + startMapIndex, + endMapIndex); + fileResolvedCallback.onSuccess(sortedFileInfo); + } catch (Throwable e) { + fileResolvedCallback.onFailure(e); + } + notifyPendingSortCallbacks(fileId, null); + } + + @Override + public void onFailure(Throwable e) { + fileResolvedCallback.onFailure(e); + notifyPendingSortCallbacks(fileId, e); + } + }; try { - FileSorter fileSorter = new FileSorter(diskFileInfo, fileId, shuffleKey); + FileSorter fileSorter = + new FileSorter(diskFileInfo, fileId, shuffleKey, fileSortedCallback); sorting.add(fileId); logger.debug( "Adding sorter to sort queue shuffle key {}, file name {}", shuffleKey, fileName); @@ -257,59 +330,33 @@ public FileInfo getSortedFileInfo( } catch (InterruptedException e) { logger.error( "Sorter scheduler thread is interrupted means worker is shutting down.", e); - throw new IOException( - "Sort scheduler thread is interrupted means worker is shutting down.", e); + fileResolvedCallback.onFailure( + new IOException( + "Sort scheduler thread is interrupted means worker is shutting down.", e)); } catch (IOException e) { logger.error("File sorter access DFS failed.", e); - throw new IOException("File sorter access DFS failed.", e); + fileResolvedCallback.onFailure(new IOException("File sorter access DFS failed.", e)); } } } + } + return null; + } - if (fileSorting) { - long sortStartTime = System.currentTimeMillis(); - while (!sorted.contains(fileId)) { - if (sorting.contains(fileId)) { - try { - Thread.sleep(50); - if (System.currentTimeMillis() - sortStartTime > sortTimeout) { - String msg = - String.format( - "Sorting file %s path %s length %s timeout after %dms", - fileId, - diskFileInfo.getFilePath(), - diskFileInfo.getFileLength(), - sortTimeout); - logger.error(msg); - throw new IOException(msg); - } - } catch (InterruptedException e) { - logger.error( - "Sorter scheduler thread is interrupted means worker is shutting down.", e); - throw new IOException( - "Sorter scheduler thread is interrupted means worker is shutting down.", e); - } + private void notifyPendingSortCallbacks(String fileId, Throwable error) { + List callbacks = pendingSortCallbacks.remove(fileId); + if (callbacks != null) { + for (FileResolvedCallback cb : callbacks) { + try { + if (error != null) { + cb.onFailure(error); } else { - logger.debug( - "Sorting shuffle file for {} {} failed.", shuffleKey, diskFileInfo.getFilePath()); - throw new IOException( - "Sorting shuffle file for " - + shuffleKey - + " " - + diskFileInfo.getFilePath() - + " failed."); + cb.onSuccess(null); } + } catch (Exception e) { + logger.error("Error notifying pending sort callback for {}", fileId, e); } } - - return resolve( - shuffleKey, - fileId, - userIdentifier, - sortedFilePath, - indexFilePath, - startMapIndex, - endMapIndex); } } @@ -686,8 +733,14 @@ class FileSorter { private FileChannel originFileChannel = null; private FileChannel sortedFileChannel = null; private FileSystem hadoopFs; - - FileSorter(DiskFileInfo fileInfo, String fileId, String shuffleKey) throws IOException { + private FileSortedCallback fileSortedCallback; + + FileSorter( + DiskFileInfo fileInfo, + String fileId, + String shuffleKey, + FileSortedCallback fileSortedCallback) + throws IOException { this.originFileInfo = fileInfo; this.originFilePath = fileInfo.getFilePath(); this.sortedFilePath = Utils.getSortedFilePath(originFilePath); @@ -700,6 +753,7 @@ class FileSorter { this.fileId = fileId; this.shuffleKey = shuffleKey; this.indexFilePath = Utils.getIndexFilePath(originFilePath); + this.fileSortedCallback = fileSortedCallback; if (!isDfs) { File sortedFile = new File(this.sortedFilePath); if (sortedFile.exists()) { @@ -728,6 +782,7 @@ class FileSorter { public void sort() { source.startTimer(WorkerSource.SORT_TIME(), fileId); + boolean success = true; long sortStartTime = -1; if (sortTimeLogThreshold > 0) { sortStartTime = System.nanoTime(); @@ -804,6 +859,8 @@ public void sort() { } catch (Exception e) { logger.error( "Sorting shuffle file for " + fileId + " " + originFilePath + " failed, detail: ", e); + success = false; + fileSortedCallback.onFailure(e); } finally { closeFiles(); Set sorting = sortingShuffleFiles.get(shuffleKey); @@ -811,6 +868,9 @@ public void sort() { sorting.remove(fileId); } } + if (success) { + fileSortedCallback.onSuccess(); + } if (sortTimeLogThreshold > 0) { long sortDuration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - sortStartTime); if (sortDuration > sortTimeLogThreshold) { @@ -1000,3 +1060,9 @@ public void close() { cleaner.shutdownNow(); } } + +interface FileSortedCallback { + void onSuccess(); + + void onFailure(Throwable e); +} diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala index ca9138e8fbc..0014590aeae 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala @@ -20,9 +20,11 @@ package org.apache.celeborn.service.deploy.worker import java.io.{FileNotFoundException, IOException} import java.nio.charset.StandardCharsets import java.util -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} import java.util.function.Consumer +import scala.collection.mutable.ArrayBuffer + import com.google.common.base.Throwables import com.google.protobuf.GeneratedMessageV3 import io.netty.util.concurrent.{Future, GenericFutureListener} @@ -39,8 +41,8 @@ import org.apache.celeborn.common.network.server.BaseMessageHandler import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf} import org.apache.celeborn.common.protocol.{MessageType, PbBufferStreamEnd, PbChunkFetchRequest, PbNotifyRequiredSegment, PbOpenStream, PbOpenStreamList, PbOpenStreamListResponse, PbReadAddCredit, PbStreamHandler, PbStreamHandlerOpt, StreamType} import org.apache.celeborn.common.protocol.message.StatusCode -import org.apache.celeborn.common.util.{ExceptionUtils, Utils} -import org.apache.celeborn.service.deploy.worker.storage.{ChunkStreamManager, CreditStreamManager, PartitionFilesSorter, StorageManager} +import org.apache.celeborn.common.util.{ExceptionUtils, JavaUtils, Utils} +import org.apache.celeborn.service.deploy.worker.storage.{ChunkStreamManager, CreditStreamManager, FileResolvedCallback, PartitionFilesSorter, StorageManager} class FetchHandler( val conf: CelebornConf, @@ -143,36 +145,75 @@ class FetchHandler( val startIndices = openStreamList.getStartIndexList val endIndices = openStreamList.getEndIndexList val readLocalFlags = openStreamList.getReadLocalShuffleList - val pbOpenStreamListResponse = PbOpenStreamListResponse.newBuilder() checkAuth(client, Utils.splitShuffleKey(shuffleKey)._1) val openStreamRequestId = Utils.makeOpenStreamRequestId( shuffleKey, client.getChannel.id().toString, rpcRequest.requestId) workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME, openStreamRequestId) - try { - 0 until files.size() foreach { idx => - val pbStreamHandlerOpt = handleReduceOpenStreamInternal( - client, - shuffleKey, - files.get(idx), - startIndices.get(idx), - endIndices.get(idx), - readLocalFlags.get(idx)) - if (pbStreamHandlerOpt.getStatus != StatusCode.SUCCESS.getValue) { - workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT) - } - pbOpenStreamListResponse.addStreamHandlerOpt(pbStreamHandlerOpt) + val totalFiles = files.size() + val results = new Array[PbStreamHandlerOpt](totalFiles) + val completedCount = new AtomicInteger(0) + val replied = new AtomicBoolean(false) + + def trySendBatchResponse(): Unit = { + if (completedCount.get() >= totalFiles && replied.compareAndSet(false, true)) { + val pbOpenStreamListResponse = PbOpenStreamListResponse.newBuilder() + results.foreach(pbOpenStreamListResponse.addStreamHandlerOpt) + workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, openStreamRequestId) + client.getChannel.writeAndFlush(new RpcResponse( + rpcRequest.requestId, + new NioManagedBuffer(new TransportMessage( + MessageType.BATCH_OPEN_STREAM_RESPONSE, + pbOpenStreamListResponse.build().toByteArray).toByteBuffer))) } - } finally { - workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, openStreamRequestId) } - client.getChannel.writeAndFlush(new RpcResponse( - rpcRequest.requestId, - new NioManagedBuffer(new TransportMessage( - MessageType.BATCH_OPEN_STREAM_RESPONSE, - pbOpenStreamListResponse.build().toByteArray).toByteBuffer))) + 0 until totalFiles foreach { idx => + val fileName = files.get(idx) + val startMapIndex = startIndices.get(idx) + val endMapIndex = endIndices.get(idx) + val readLocalFlag = readLocalFlags.get(idx) + val streamId = chunkStreamManager.nextStreamId() + try { + val fileInfo = getRawFileInfo(shuffleKey, fileName) + openReduceStreamAsync( + shuffleKey, fileName, fileInfo, startMapIndex, endMapIndex, streamId, + new FileResolvedCallback { + override def onSuccess(sortedFileInfo: FileInfo): Unit = { + results(idx) = registerAndHandleStream( + client, shuffleKey, fileName, + startMapIndex, endMapIndex, readLocalFlag, sortedFileInfo, streamId) + if (results(idx).getStatus != StatusCode.SUCCESS.getValue) { + workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT) + } + completedCount.incrementAndGet() + trySendBatchResponse() + } + + override def onFailure(e: Throwable): Unit = { + val msg = s"Read file: $fileName with shuffleKey: $shuffleKey error, " + + s"Exception: ${e.getMessage}" + results(idx) = PbStreamHandlerOpt.newBuilder() + .setStatus(StatusCode.OPEN_STREAM_FAILED.getValue) + .setErrorMsg(msg).build() + workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT) + completedCount.incrementAndGet() + trySendBatchResponse() + } + }) + } catch { + case e: IOException => + val msg = s"Read file: $fileName with shuffleKey: $shuffleKey error, " + + s"Exception: ${e.getMessage}" + results(idx) = PbStreamHandlerOpt.newBuilder() + .setStatus(StatusCode.OPEN_STREAM_FAILED.getValue) + .setErrorMsg(msg).build() + workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT) + completedCount.incrementAndGet() + trySendBatchResponse() + } + } case bufferStreamEnd: PbBufferStreamEnd => handleEndStreamFromClient( client, @@ -245,33 +286,16 @@ class FetchHandler( } - private def handleReduceOpenStreamInternal( + private def registerAndHandleStream( client: TransportClient, shuffleKey: String, fileName: String, startIndex: Int, endIndex: Int, - readLocalShuffle: Boolean = false): PbStreamHandlerOpt = { + readLocalShuffle: Boolean, + fileInfo: FileInfo, + streamId: Long): PbStreamHandlerOpt = { try { - logDebug(s"Received open stream request $shuffleKey $fileName $startIndex " + - s"$endIndex get file name $fileName from client channel " + - s"${NettyUtils.getRemoteAddress(client.getChannel)}") - - var fileInfo = getRawFileInfo(shuffleKey, fileName) - val streamId = chunkStreamManager.nextStreamId() - // we must get sorted fileInfo for the following cases. - // 1. when the current request is a non-range openStream, but the original unsorted file - // has been deleted by another range's openStream request. - // 2. when the current request is a range openStream request. - if ((endIndex != Int.MaxValue && endIndex != -1 && endIndex >= startIndex) || (endIndex == Int.MaxValue && !fileInfo.addStream( - streamId))) { - fileInfo = partitionsSorter.getSortedFileInfo( - shuffleKey, - fileName, - fileInfo, - startIndex, - endIndex) - } val meta = fileInfo.getReduceFileMeta val streamHandler = if (readLocalShuffle && !fileInfo.isInstanceOf[MemoryFileInfo]) { @@ -286,43 +310,25 @@ class FetchHandler( fileInfo.asInstanceOf[DiskFileInfo].getFilePath) } else fileInfo match { case info: DiskFileInfo if info.isHdfs => - chunkStreamManager.registerStream( - streamId, - shuffleKey, - fileName) + chunkStreamManager.registerStream(streamId, shuffleKey, fileName) makeStreamHandler(streamId, numChunks = 0) case info: DiskFileInfo if info.isS3 => - chunkStreamManager.registerStream( - streamId, - shuffleKey, - fileName) + chunkStreamManager.registerStream(streamId, shuffleKey, fileName) makeStreamHandler(streamId, numChunks = 0) case info: DiskFileInfo if info.isOSS => - chunkStreamManager.registerStream( - streamId, - shuffleKey, - fileName) + chunkStreamManager.registerStream(streamId, shuffleKey, fileName) makeStreamHandler(streamId, numChunks = 0) case _ => val managedBuffer = fileInfo match { - case df: DiskFileInfo => - new FileChunkBuffers(df, transportConf) - case mf: MemoryFileInfo => - new MemoryChunkBuffers(mf) + case df: DiskFileInfo => new FileChunkBuffers(df, transportConf) + case mf: MemoryFileInfo => new MemoryChunkBuffers(mf) + } + val fetchTimeMetric = fileInfo match { + case info: DiskFileInfo => storageManager.getFetchTimeMetric(info.getFile) + case _ => null } - val fetchTimeMetric = - fileInfo match { - case info: DiskFileInfo => - storageManager.getFetchTimeMetric(info.getFile) - case _ => - null - } chunkStreamManager.registerStream( - streamId, - shuffleKey, - managedBuffer, - fileName, - fetchTimeMetric) + streamId, shuffleKey, managedBuffer, fileName, fetchTimeMetric) if (meta.getNumChunks == 0) logDebug(s"StreamId $streamId, fileName $fileName, mapRange " + s"[$startIndex-$endIndex] is empty. Received from client channel " + @@ -331,9 +337,7 @@ class FetchHandler( s"StreamId $streamId, fileName $fileName, numChunks ${meta.getNumChunks}, " + s"mapRange [$startIndex-$endIndex]. Received from client channel " + s"${NettyUtils.getRemoteAddress(client.getChannel)}") - makeStreamHandler( - streamId, - meta.getNumChunks) + makeStreamHandler(streamId, meta.getNumChunks) } workerSource.incCounter(WorkerSource.OPEN_STREAM_SUCCESS_COUNT) PbStreamHandlerOpt.newBuilder().setStreamHandler(streamHandler) @@ -349,6 +353,30 @@ class FetchHandler( } } + /** + * Decides whether sorting is needed and either triggers async sort or invokes callback directly. + * Sorting is required when: + * 1. A range openStream request (startIndex..endIndex) is specified. + * 2. A non-range openStream where the raw file has already been deleted by another + * range's sort (addStream returns false). + */ + private def openReduceStreamAsync( + shuffleKey: String, + fileName: String, + fileInfo: FileInfo, + startIndex: Int, + endIndex: Int, + streamId: Long, + callback: FileResolvedCallback): Unit = { + if ((endIndex != Int.MaxValue && endIndex != -1 && endIndex >= startIndex) || + (endIndex == Int.MaxValue && !fileInfo.addStream(streamId))) { + partitionsSorter.getSortedFileInfo( + shuffleKey, fileName, fileInfo, startIndex, endIndex, callback) + } else { + callback.onSuccess(fileInfo) + } + } + private def handleOpenStreamInternal( client: TransportClient, shuffleKey: String, @@ -367,23 +395,53 @@ class FetchHandler( client.getChannel.id().toString, rpcRequestId) workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME, requestId) + var asyncReplied = false try { val fileInfo = getRawFileInfo(shuffleKey, fileName) fileInfo.getFileMeta match { case _: ReduceFileMeta => - val pbStreamHandlerOpt = - handleReduceOpenStreamInternal( - client, - shuffleKey, - fileName, - startIndex, - endIndex, - readLocalShuffle) - - if (pbStreamHandlerOpt.getStatus != StatusCode.SUCCESS.getValue) { - throw new CelebornIOException(pbStreamHandlerOpt.getErrorMsg) - } - replyStreamHandler(client, rpcRequestId, pbStreamHandlerOpt.getStreamHandler, isLegacy) + val streamId = chunkStreamManager.nextStreamId() + openReduceStreamAsync( + shuffleKey, fileName, fileInfo, startIndex, endIndex, streamId, + new FileResolvedCallback { + private var timerStopped = false + + override def onSuccess(sortedFileInfo: FileInfo): Unit = { + try { + val pbStreamHandlerOpt = + registerAndHandleStream( + client, shuffleKey, fileName, + startIndex, endIndex, readLocalShuffle, + sortedFileInfo, streamId) + if (pbStreamHandlerOpt.getStatus != StatusCode.SUCCESS.getValue) { + throw new CelebornIOException(pbStreamHandlerOpt.getErrorMsg) + } + replyStreamHandler( + client, rpcRequestId, pbStreamHandlerOpt.getStreamHandler, isLegacy) + } catch { + case t: Throwable => + onFailure(t) + } finally { + stopTimerIfNeeded() + } + } + + override def onFailure(e: Throwable): Unit = { + workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT) + handleRpcIOException( + client, rpcRequestId, shuffleKey, fileName, + ExceptionUtils.wrapThrowableToIOException(e), callback) + stopTimerIfNeeded() + } + + private def stopTimerIfNeeded(): Unit = { + if (!timerStopped) { + workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, requestId) + timerStopped = true + } + } + }) + asyncReplied = true case _: MapFileMeta => val creditStreamHandler = new Consumer[java.lang.Long] { @@ -395,7 +453,6 @@ class FetchHandler( replyStreamHandler(client, rpcRequestId, pbStreamHandler, isLegacy) } } - creditStreamManager.registerStream( creditStreamHandler, client.getChannel, @@ -404,14 +461,16 @@ class FetchHandler( startIndex, endIndex, fileInfo.asInstanceOf[DiskFileInfo]) + workerSource.incCounter(WorkerSource.OPEN_STREAM_SUCCESS_COUNT) } - workerSource.incCounter(WorkerSource.OPEN_STREAM_SUCCESS_COUNT) } catch { case e: IOException => workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT) handleRpcIOException(client, rpcRequestId, shuffleKey, fileName, e, callback) } finally { - workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, requestId) + if (!asyncReplied) { + workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, requestId) + } } } From 606f09fc4ad1618117ad470ef2a2dbef5ff6e336 Mon Sep 17 00:00:00 2001 From: Xianming Lei Date: Tue, 7 Apr 2026 19:16:50 +0800 Subject: [PATCH 2/4] fix --- .../worker/storage/PartitionFilesSorter.java | 23 ++- .../deploy/worker/FetchHandlerSuiteJ.java | 145 ++++++++++++++++-- .../local/DiskPartitionFilesSorterSuiteJ.java | 116 +++++++++++++- .../DiskReducePartitionDataWriterSuiteJ.java | 10 +- .../MemoryPartitionFilesSorterSuiteJ.java | 43 ++++-- ...MemoryReducePartitionDataWriterSuiteJ.java | 10 +- 6 files changed, 312 insertions(+), 35 deletions(-) diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java index dfece2ad317..a0d3684eebc 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java @@ -255,14 +255,21 @@ public FileInfo getSortedFileInfo( String indexFilePath = Utils.getIndexFilePath(diskFileInfo.getFilePath()); synchronized (sorting) { if (sorted.contains(fileId)) { - return resolve( - shuffleKey, - fileId, - userIdentifier, - sortedFilePath, - indexFilePath, - startMapIndex, - endMapIndex); + try { + FileInfo sortedFileInfo = + resolve( + shuffleKey, + fileId, + userIdentifier, + sortedFilePath, + indexFilePath, + startMapIndex, + endMapIndex); + fileResolvedCallback.onSuccess(sortedFileInfo); + } catch (Throwable e) { + fileResolvedCallback.onFailure(e); + } + return null; } else if (sorting.contains(fileId)) { FileResolvedCallback pendingCallback = new FileResolvedCallback() { diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java index b54a0b99a56..9b9e21b76c2 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java @@ -63,10 +63,14 @@ import org.apache.celeborn.common.protocol.PbBufferStreamEnd; import org.apache.celeborn.common.protocol.PbChunkFetchRequest; import org.apache.celeborn.common.protocol.PbOpenStream; +import org.apache.celeborn.common.protocol.PbOpenStreamList; +import org.apache.celeborn.common.protocol.PbOpenStreamListResponse; import org.apache.celeborn.common.protocol.PbStreamChunkSlice; import org.apache.celeborn.common.protocol.PbStreamHandler; +import org.apache.celeborn.common.protocol.PbStreamHandlerOpt; import org.apache.celeborn.common.protocol.StreamType; import org.apache.celeborn.common.protocol.TransportModuleConstants; +import org.apache.celeborn.common.protocol.message.StatusCode; import org.apache.celeborn.common.unsafe.Platform; import org.apache.celeborn.common.util.JavaUtils; import org.apache.celeborn.common.util.Utils; @@ -161,7 +165,7 @@ public static void afterAll() { } @Test - public void testFetchOriginFile() throws IOException { + public void testFetchOriginFile() throws Exception { FileInfo fileInfo = null; try { // total write: 32 * 50 * 256k = 400m @@ -180,7 +184,7 @@ public void testFetchOriginFile() throws IOException { } @Test - public void testFetchSortFile() throws IOException { + public void testFetchSortFile() throws Exception { FileInfo fileInfo = null; try { // total write size: 32 * 50 * 256k = 400m @@ -198,7 +202,7 @@ public void testFetchSortFile() throws IOException { } @Test - public void testLegacyOpenStream() throws IOException { + public void testLegacyOpenStream() throws Exception { FileInfo fileInfo = null; try { // total write size: 32 * 50 * 256k = 400m @@ -215,7 +219,7 @@ public void testLegacyOpenStream() throws IOException { } @Test - public void testWorkerReadSortFileOnceOriginalFileBeDeleted() throws IOException { + public void testWorkerReadSortFileOnceOriginalFileBeDeleted() throws Exception { FileInfo fileInfo = null; try { // total write size: 32 * 50 * 256k = 400m @@ -237,7 +241,7 @@ public void testWorkerReadSortFileOnceOriginalFileBeDeleted() throws IOException } @Test - public void testLocalReadSortFileOnceOriginalFileBeDeleted() throws IOException { + public void testLocalReadSortFileOnceOriginalFileBeDeleted() throws Exception { FileInfo fileInfo = null; try { // total write size: 32 * 50 * 256k = 400m @@ -255,7 +259,7 @@ public void testLocalReadSortFileOnceOriginalFileBeDeleted() throws IOException } @Test - public void testDoNotDeleteOriginalFileWhenNonRangeWorkerReadWorkInProgress() throws IOException { + public void testDoNotDeleteOriginalFileWhenNonRangeWorkerReadWorkInProgress() throws Exception { FileInfo fileInfo = null; try { // total write size: 32 * 50 * 256k = 400m @@ -280,7 +284,7 @@ public void testDoNotDeleteOriginalFileWhenNonRangeWorkerReadWorkInProgress() th } @Test - public void testDoNotDeleteOriginalFileWhenNonRangeLocalReadWorkInProgress() throws IOException { + public void testDoNotDeleteOriginalFileWhenNonRangeLocalReadWorkInProgress() throws Exception { FileInfo fileInfo = null; try { // total write size: 32 * 50 * 256k = 400m @@ -302,6 +306,112 @@ public void testDoNotDeleteOriginalFileWhenNonRangeLocalReadWorkInProgress() thr } } + @Test + public void testBatchOpenStream() throws Exception { + FileInfo fileInfo = null; + try { + fileInfo = prepare(32); + EmbeddedChannel channel = new EmbeddedChannel(); + TransportClient client = new TransportClient(channel, mock(TransportResponseHandler.class)); + FetchHandler fetchHandler = mockFetchHandler(fileInfo); + + PbOpenStreamList.Builder builder = PbOpenStreamList.newBuilder().setShuffleKey(shuffleKey); + int batchSize = 3; + for (int i = 0; i < batchSize; i++) { + builder.addFileName(fileName); + builder.addStartIndex(5); + builder.addEndIndex(10); + builder.addReadLocalShuffle(false); + } + ByteBuffer batchOpenStreamBuffer = + new TransportMessage( + MessageType.BATCH_OPEN_STREAM, builder.build().toByteArray()) + .toByteBuffer(); + fetchHandler.receive( + client, + new RpcRequest(dummyRequestId, new NioManagedBuffer(batchOpenStreamBuffer)), + createRpcResponseCallback(channel)); + + RpcResponse result = readOutboundWithTimeout(channel, 30_000); + PbOpenStreamListResponse response = + TransportMessage.fromByteBuffer(result.body().nioByteBuffer()).getParsedPayload(); + assertEquals(batchSize, response.getStreamHandlerOptCount()); + for (int i = 0; i < batchSize; i++) { + PbStreamHandlerOpt opt = response.getStreamHandlerOpt(i); + assertEquals(StatusCode.SUCCESS.getValue(), opt.getStatus()); + assertEquals(10 - 5, opt.getStreamHandler().getNumChunks()); + } + } finally { + cleanup(fileInfo); + } + } + + @Test + public void testBatchOpenStreamPartialFailure() throws Exception { + FileInfo fileInfo = null; + try { + fileInfo = prepare(32); + EmbeddedChannel channel = new EmbeddedChannel(); + TransportClient client = new TransportClient(channel, mock(TransportResponseHandler.class)); + + WorkerSource workerSource = mock(WorkerSource.class); + TransportConf transportConf = + Utils.fromCelebornConf(conf, TransportModuleConstants.FETCH_MODULE, 4); + FetchHandler fetchHandler0 = new FetchHandler(conf, transportConf, workerSource); + Worker worker = mock(Worker.class); + PartitionFilesSorter partitionFilesSorter = + new PartitionFilesSorter(MemoryManager.instance(), conf, workerSource); + StorageManager storageManager = mock(StorageManager.class); + Mockito.doReturn(storageManager).when(worker).storageManager(); + Mockito.doReturn(workerSource).when(worker).workerSource(); + Mockito.doReturn(partitionFilesSorter).when(worker).partitionsSorter(); + fetchHandler0.init(worker); + FetchHandler fetchHandler = spy(fetchHandler0); + + String existingFile = "existingFile"; + String missingFile = "missingFile"; + Mockito.doReturn(fileInfo).when(fetchHandler).getRawFileInfo(shuffleKey, existingFile); + Mockito.doAnswer( + invocation -> { + throw new java.io.FileNotFoundException("Not found"); + }) + .when(fetchHandler) + .getRawFileInfo(shuffleKey, missingFile); + + PbOpenStreamList.Builder builder = PbOpenStreamList.newBuilder().setShuffleKey(shuffleKey); + builder.addFileName(existingFile).addStartIndex(5).addEndIndex(10).addReadLocalShuffle(false); + builder.addFileName(missingFile).addStartIndex(0).addEndIndex(10).addReadLocalShuffle(false); + builder + .addFileName(existingFile) + .addStartIndex(5) + .addEndIndex(10) + .addReadLocalShuffle(false); + + ByteBuffer batchBuffer = + new TransportMessage( + MessageType.BATCH_OPEN_STREAM, builder.build().toByteArray()) + .toByteBuffer(); + fetchHandler.receive( + client, + new RpcRequest(dummyRequestId, new NioManagedBuffer(batchBuffer)), + createRpcResponseCallback(channel)); + + RpcResponse result = readOutboundWithTimeout(channel, 30_000); + PbOpenStreamListResponse response = + TransportMessage.fromByteBuffer(result.body().nioByteBuffer()).getParsedPayload(); + assertEquals(3, response.getStreamHandlerOptCount()); + + assertEquals( + StatusCode.SUCCESS.getValue(), response.getStreamHandlerOpt(0).getStatus()); + assertEquals( + StatusCode.OPEN_STREAM_FAILED.getValue(), response.getStreamHandlerOpt(1).getStatus()); + assertEquals( + StatusCode.SUCCESS.getValue(), response.getStreamHandlerOpt(2).getStatus()); + } finally { + cleanup(fileInfo); + } + } + private FetchHandler mockFetchHandler(FileInfo fileInfo) { WorkerSource workerSource = mock(WorkerSource.class); TransportConf transportConf = @@ -325,6 +435,21 @@ private FetchHandler mockFetchHandler(FileInfo fileInfo) { private final String fileName = "dummyFileName"; private final long dummyRequestId = 0; + @SuppressWarnings("unchecked") + private T readOutboundWithTimeout(EmbeddedChannel channel, long timeoutMs) + throws InterruptedException { + long deadline = System.currentTimeMillis() + timeoutMs; + T result; + while ((result = (T) channel.readOutbound()) == null) { + if (System.currentTimeMillis() > deadline) { + fail("Timed out waiting for outbound message"); + } + channel.runPendingTasks(); + Thread.sleep(50); + } + return result; + } + @Deprecated private void legacyOpenStreamAndCheck( TransportClient client, @@ -354,7 +479,7 @@ private PbStreamHandler openStreamAndCheck( FetchHandler fetchHandler, int startIndex, int endIndex) - throws IOException { + throws Exception { return openStreamAndCheck(client, channel, fetchHandler, startIndex, endIndex, false); } @@ -365,7 +490,7 @@ private PbStreamHandler openStreamAndCheck( int startIndex, int endIndex, Boolean readLocalShuffle) - throws IOException { + throws Exception { ByteBuffer openStreamByteBuffer = new TransportMessage( MessageType.OPEN_STREAM, @@ -382,7 +507,7 @@ private PbStreamHandler openStreamAndCheck( client, new RpcRequest(dummyRequestId, new NioManagedBuffer(openStreamByteBuffer)), createRpcResponseCallback(channel)); - RpcResponse result = channel.readOutbound(); + RpcResponse result = readOutboundWithTimeout(channel, 30_000); PbStreamHandler streamHandler = TransportMessage.fromByteBuffer(result.body().nioByteBuffer()).getParsedPayload(); if (endIndex == Integer.MAX_VALUE) { diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskPartitionFilesSorterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskPartitionFilesSorterSuiteJ.java index ce092633db0..1bb9599dbf2 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskPartitionFilesSorterSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskPartitionFilesSorterSuiteJ.java @@ -27,6 +27,9 @@ import java.util.HashMap; import java.util.Map; import java.util.Random; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Assert; import org.junit.Test; @@ -45,6 +48,7 @@ import org.apache.celeborn.common.util.Utils; import org.apache.celeborn.service.deploy.worker.WorkerSource; import org.apache.celeborn.service.deploy.worker.memory.MemoryManager; +import org.apache.celeborn.service.deploy.worker.storage.FileResolvedCallback; import org.apache.celeborn.service.deploy.worker.storage.PartitionDataWriter; import org.apache.celeborn.service.deploy.worker.storage.PartitionFilesSorter; @@ -136,7 +140,44 @@ public void clean() throws IOException { JavaUtils.deleteRecursively(new File(shuffleFile.getPath() + ".index")); } - private void check(int mapCount, int startMapIndex, int endMapIndex) throws IOException { + private FileInfo getSortedFileInfoAsync( + PartitionFilesSorter sorter, + String shuffleKey, + String fileName, + FileInfo fileInfo, + int startMapIndex, + int endMapIndex) + throws Exception { + CountDownLatch latch = new CountDownLatch(1); + AtomicReference resultRef = new AtomicReference<>(); + AtomicReference errorRef = new AtomicReference<>(); + sorter.getSortedFileInfo( + shuffleKey, + fileName, + fileInfo, + startMapIndex, + endMapIndex, + new FileResolvedCallback() { + @Override + public void onSuccess(FileInfo sortedFileInfo) { + resultRef.set(sortedFileInfo); + latch.countDown(); + } + + @Override + public void onFailure(Throwable e) { + errorRef.set(e); + latch.countDown(); + } + }); + Assert.assertTrue("Sort timed out", latch.await(60, TimeUnit.SECONDS)); + if (errorRef.get() != null) { + throw new IOException("Sort failed", errorRef.get()); + } + return resultRef.get(); + } + + private void check(int mapCount, int startMapIndex, int endMapIndex) throws Exception { try { long[] partitionSize = prepare(mapCount); CelebornConf conf = new CelebornConf(); @@ -144,7 +185,8 @@ private void check(int mapCount, int startMapIndex, int endMapIndex) throws IOEx PartitionFilesSorter partitionFilesSorter = new PartitionFilesSorter(MemoryManager.instance(), conf, new WorkerSource(conf)); FileInfo info = - partitionFilesSorter.getSortedFileInfo( + getSortedFileInfoAsync( + partitionFilesSorter, "application-1", originFileName, partitionDataWriter.getDiskFileInfo(), @@ -168,19 +210,85 @@ private void check(int mapCount, int startMapIndex, int endMapIndex) throws IOEx } @Test - public void testSmallFile() throws IOException { + public void testSmallFile() throws Exception { int startMapIndex = random.nextInt(5); int endMapIndex = startMapIndex + random.nextInt(5) + 5; check(1000, startMapIndex, endMapIndex); } @Test - public void testLargeFile() throws IOException { + public void testLargeFile() throws Exception { int startMapIndex = random.nextInt(5); int endMapIndex = startMapIndex + random.nextInt(5) + 5; check(15000, startMapIndex, endMapIndex); } + @Test + public void testConcurrentSortReaders() throws Exception { + try { + long[] partitionSize = prepare(1000); + CelebornConf conf = new CelebornConf(); + conf.set(CelebornConf.SHUFFLE_CHUNK_SIZE().key(), "8m"); + PartitionFilesSorter sorter = + new PartitionFilesSorter(MemoryManager.instance(), conf, new WorkerSource(conf)); + + int startMapIndex = 3; + int endMapIndex = 8; + int numReaders = 5; + CountDownLatch allDone = new CountDownLatch(numReaders); + @SuppressWarnings("unchecked") + AtomicReference[] results = new AtomicReference[numReaders]; + @SuppressWarnings("unchecked") + AtomicReference[] errors = new AtomicReference[numReaders]; + + for (int i = 0; i < numReaders; i++) { + results[i] = new AtomicReference<>(); + errors[i] = new AtomicReference<>(); + final int idx = i; + sorter.getSortedFileInfo( + "application-1", + originFileName, + partitionDataWriter.getDiskFileInfo(), + startMapIndex, + endMapIndex, + new FileResolvedCallback() { + @Override + public void onSuccess(FileInfo sortedFileInfo) { + results[idx].set(sortedFileInfo); + allDone.countDown(); + } + + @Override + public void onFailure(Throwable e) { + errors[idx].set(e); + allDone.countDown(); + } + }); + } + + Assert.assertTrue( + "Concurrent readers timed out", allDone.await(60, TimeUnit.SECONDS)); + + long totalSizeToFetch = 0; + for (int i = startMapIndex; i < endMapIndex; i++) { + totalSizeToFetch += partitionSize[i]; + } + + for (int i = 0; i < numReaders; i++) { + Assert.assertNull( + "Reader " + i + " failed: " + errors[i].get(), errors[i].get()); + FileInfo info = results[i].get(); + Assert.assertNotNull("Reader " + i + " got null result", info); + long actualTotalChunkSize = + ((ReduceFileMeta) info.getFileMeta()).getLastChunkOffset() + - ((ReduceFileMeta) info.getFileMeta()).getChunkOffsets().get(0); + Assert.assertEquals(totalSizeToFetch, actualTotalChunkSize); + } + } finally { + clean(); + } + } + @Test public void testLevelDB() { if (Utils.isMacOnAppleSilicon()) { diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskReducePartitionDataWriterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskReducePartitionDataWriterSuiteJ.java index 536b4aab364..4212e53cea8 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskReducePartitionDataWriterSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskReducePartitionDataWriterSuiteJ.java @@ -165,9 +165,15 @@ public boolean checkRegistered() { } }; PartitionFilesSorter sorter = mock(PartitionFilesSorter.class); - Mockito.doReturn(info) + Mockito.doAnswer( + invocation -> { + FileResolvedCallback callback = invocation.getArgument(5); + callback.onSuccess(info); + return null; + }) .when(sorter) - .getSortedFileInfo(anyString(), anyString(), eq(info), anyInt(), anyInt()); + .getSortedFileInfo( + anyString(), anyString(), eq(info), anyInt(), anyInt(), any()); handler.setPartitionsSorter(sorter); transportContext = new TransportContext(transConf, handler); server = transportContext.createServer(); diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryPartitionFilesSorterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryPartitionFilesSorterSuiteJ.java index aa8a4593ed7..8fe1a7487b2 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryPartitionFilesSorterSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryPartitionFilesSorterSuiteJ.java @@ -24,6 +24,9 @@ import java.util.HashMap; import java.util.Map; import java.util.Random; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; @@ -41,6 +44,7 @@ import org.apache.celeborn.common.unsafe.Platform; import org.apache.celeborn.service.deploy.worker.WorkerSource; import org.apache.celeborn.service.deploy.worker.memory.MemoryManager; +import org.apache.celeborn.service.deploy.worker.storage.FileResolvedCallback; import org.apache.celeborn.service.deploy.worker.storage.PartitionDataWriter; import org.apache.celeborn.service.deploy.worker.storage.PartitionFilesSorter; import org.apache.celeborn.service.deploy.worker.storage.StorageManager; @@ -123,19 +127,40 @@ public long[] prepare(int mapCount) { return partitionSize; } - private void check(int mapCount, int startMapIndex, int endMapIndex) throws IOException { + private void check(int mapCount, int startMapIndex, int endMapIndex) throws Exception { long[] partitionSize = prepare(mapCount); CelebornConf conf = new CelebornConf(); conf.set(CelebornConf.SHUFFLE_CHUNK_SIZE().key(), "8m"); PartitionFilesSorter partitionFilesSorter = new PartitionFilesSorter(MemoryManager.instance(), conf, new WorkerSource(conf)); - FileInfo info = - partitionFilesSorter.getSortedFileInfo( - "application-1", - "", - partitionDataWriter.getMemoryFileInfo(), - startMapIndex, - endMapIndex); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference resultRef = new AtomicReference<>(); + AtomicReference errorRef = new AtomicReference<>(); + partitionFilesSorter.getSortedFileInfo( + "application-1", + "", + partitionDataWriter.getMemoryFileInfo(), + startMapIndex, + endMapIndex, + new FileResolvedCallback() { + @Override + public void onSuccess(FileInfo sortedFileInfo) { + resultRef.set(sortedFileInfo); + latch.countDown(); + } + + @Override + public void onFailure(Throwable e) { + errorRef.set(e); + latch.countDown(); + } + }); + + Assert.assertTrue("Sort timed out", latch.await(60, TimeUnit.SECONDS)); + Assert.assertNull("Sort failed: " + errorRef.get(), errorRef.get()); + FileInfo info = resultRef.get(); + long totalSizeToFetch = 0; for (int i = startMapIndex; i < endMapIndex; i++) { totalSizeToFetch += partitionSize[i]; @@ -152,7 +177,7 @@ private void check(int mapCount, int startMapIndex, int endMapIndex) throws IOEx } @Test - public void testSortMemoryShuffleFile() throws IOException { + public void testSortMemoryShuffleFile() throws Exception { int startMapIndex = random.nextInt(5); int endMapIndex = startMapIndex + random.nextInt(5) + 5; check(1000, startMapIndex, endMapIndex); diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryReducePartitionDataWriterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryReducePartitionDataWriterSuiteJ.java index 92d0fd5d416..0440bed5192 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryReducePartitionDataWriterSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryReducePartitionDataWriterSuiteJ.java @@ -192,9 +192,15 @@ public boolean checkRegistered() { } }; PartitionFilesSorter sorter = mock(PartitionFilesSorter.class); - Mockito.doReturn(info) + Mockito.doAnswer( + invocation -> { + FileResolvedCallback callback = invocation.getArgument(5); + callback.onSuccess(info); + return null; + }) .when(sorter) - .getSortedFileInfo(anyString(), anyString(), eq(info), anyInt(), anyInt()); + .getSortedFileInfo( + anyString(), anyString(), eq(info), anyInt(), anyInt(), any()); handler.setPartitionsSorter(sorter); TransportContext context = new TransportContext(transConf, handler); server = context.createServer(); From 8781a16a2db34b0554d49ce5efa0ebfb43953749 Mon Sep 17 00:00:00 2001 From: Xianming Lei Date: Wed, 8 Apr 2026 15:32:34 +0800 Subject: [PATCH 3/4] Fix spotless --- .../service/deploy/worker/FetchHandler.scala | 61 +++++++++++++++---- .../deploy/worker/FetchHandlerSuiteJ.java | 18 ++---- .../local/DiskPartitionFilesSorterSuiteJ.java | 6 +- .../DiskReducePartitionDataWriterSuiteJ.java | 3 +- .../MemoryPartitionFilesSorterSuiteJ.java | 1 - ...MemoryReducePartitionDataWriterSuiteJ.java | 3 +- 6 files changed, 58 insertions(+), 34 deletions(-) diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala index 0014590aeae..dfc732696a5 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala @@ -178,12 +178,23 @@ class FetchHandler( try { val fileInfo = getRawFileInfo(shuffleKey, fileName) openReduceStreamAsync( - shuffleKey, fileName, fileInfo, startMapIndex, endMapIndex, streamId, + shuffleKey, + fileName, + fileInfo, + startMapIndex, + endMapIndex, + streamId, new FileResolvedCallback { override def onSuccess(sortedFileInfo: FileInfo): Unit = { results(idx) = registerAndHandleStream( - client, shuffleKey, fileName, - startMapIndex, endMapIndex, readLocalFlag, sortedFileInfo, streamId) + client, + shuffleKey, + fileName, + startMapIndex, + endMapIndex, + readLocalFlag, + sortedFileInfo, + streamId) if (results(idx).getStatus != StatusCode.SUCCESS.getValue) { workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT) } @@ -328,7 +339,11 @@ class FetchHandler( case _ => null } chunkStreamManager.registerStream( - streamId, shuffleKey, managedBuffer, fileName, fetchTimeMetric) + streamId, + shuffleKey, + managedBuffer, + fileName, + fetchTimeMetric) if (meta.getNumChunks == 0) logDebug(s"StreamId $streamId, fileName $fileName, mapRange " + s"[$startIndex-$endIndex] is empty. Received from client channel " + @@ -371,7 +386,12 @@ class FetchHandler( if ((endIndex != Int.MaxValue && endIndex != -1 && endIndex >= startIndex) || (endIndex == Int.MaxValue && !fileInfo.addStream(streamId))) { partitionsSorter.getSortedFileInfo( - shuffleKey, fileName, fileInfo, startIndex, endIndex, callback) + shuffleKey, + fileName, + fileInfo, + startIndex, + endIndex, + callback) } else { callback.onSuccess(fileInfo) } @@ -402,7 +422,12 @@ class FetchHandler( case _: ReduceFileMeta => val streamId = chunkStreamManager.nextStreamId() openReduceStreamAsync( - shuffleKey, fileName, fileInfo, startIndex, endIndex, streamId, + shuffleKey, + fileName, + fileInfo, + startIndex, + endIndex, + streamId, new FileResolvedCallback { private var timerStopped = false @@ -410,14 +435,22 @@ class FetchHandler( try { val pbStreamHandlerOpt = registerAndHandleStream( - client, shuffleKey, fileName, - startIndex, endIndex, readLocalShuffle, - sortedFileInfo, streamId) + client, + shuffleKey, + fileName, + startIndex, + endIndex, + readLocalShuffle, + sortedFileInfo, + streamId) if (pbStreamHandlerOpt.getStatus != StatusCode.SUCCESS.getValue) { throw new CelebornIOException(pbStreamHandlerOpt.getErrorMsg) } replyStreamHandler( - client, rpcRequestId, pbStreamHandlerOpt.getStreamHandler, isLegacy) + client, + rpcRequestId, + pbStreamHandlerOpt.getStreamHandler, + isLegacy) } catch { case t: Throwable => onFailure(t) @@ -429,8 +462,12 @@ class FetchHandler( override def onFailure(e: Throwable): Unit = { workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT) handleRpcIOException( - client, rpcRequestId, shuffleKey, fileName, - ExceptionUtils.wrapThrowableToIOException(e), callback) + client, + rpcRequestId, + shuffleKey, + fileName, + ExceptionUtils.wrapThrowableToIOException(e), + callback) stopTimerIfNeeded() } diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java index 9b9e21b76c2..09efed0b2f2 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java @@ -324,8 +324,7 @@ public void testBatchOpenStream() throws Exception { builder.addReadLocalShuffle(false); } ByteBuffer batchOpenStreamBuffer = - new TransportMessage( - MessageType.BATCH_OPEN_STREAM, builder.build().toByteArray()) + new TransportMessage(MessageType.BATCH_OPEN_STREAM, builder.build().toByteArray()) .toByteBuffer(); fetchHandler.receive( client, @@ -381,15 +380,10 @@ public void testBatchOpenStreamPartialFailure() throws Exception { PbOpenStreamList.Builder builder = PbOpenStreamList.newBuilder().setShuffleKey(shuffleKey); builder.addFileName(existingFile).addStartIndex(5).addEndIndex(10).addReadLocalShuffle(false); builder.addFileName(missingFile).addStartIndex(0).addEndIndex(10).addReadLocalShuffle(false); - builder - .addFileName(existingFile) - .addStartIndex(5) - .addEndIndex(10) - .addReadLocalShuffle(false); + builder.addFileName(existingFile).addStartIndex(5).addEndIndex(10).addReadLocalShuffle(false); ByteBuffer batchBuffer = - new TransportMessage( - MessageType.BATCH_OPEN_STREAM, builder.build().toByteArray()) + new TransportMessage(MessageType.BATCH_OPEN_STREAM, builder.build().toByteArray()) .toByteBuffer(); fetchHandler.receive( client, @@ -401,12 +395,10 @@ public void testBatchOpenStreamPartialFailure() throws Exception { TransportMessage.fromByteBuffer(result.body().nioByteBuffer()).getParsedPayload(); assertEquals(3, response.getStreamHandlerOptCount()); - assertEquals( - StatusCode.SUCCESS.getValue(), response.getStreamHandlerOpt(0).getStatus()); + assertEquals(StatusCode.SUCCESS.getValue(), response.getStreamHandlerOpt(0).getStatus()); assertEquals( StatusCode.OPEN_STREAM_FAILED.getValue(), response.getStreamHandlerOpt(1).getStatus()); - assertEquals( - StatusCode.SUCCESS.getValue(), response.getStreamHandlerOpt(2).getStatus()); + assertEquals(StatusCode.SUCCESS.getValue(), response.getStreamHandlerOpt(2).getStatus()); } finally { cleanup(fileInfo); } diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskPartitionFilesSorterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskPartitionFilesSorterSuiteJ.java index 1bb9599dbf2..ad249db3b43 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskPartitionFilesSorterSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskPartitionFilesSorterSuiteJ.java @@ -266,8 +266,7 @@ public void onFailure(Throwable e) { }); } - Assert.assertTrue( - "Concurrent readers timed out", allDone.await(60, TimeUnit.SECONDS)); + Assert.assertTrue("Concurrent readers timed out", allDone.await(60, TimeUnit.SECONDS)); long totalSizeToFetch = 0; for (int i = startMapIndex; i < endMapIndex; i++) { @@ -275,8 +274,7 @@ public void onFailure(Throwable e) { } for (int i = 0; i < numReaders; i++) { - Assert.assertNull( - "Reader " + i + " failed: " + errors[i].get(), errors[i].get()); + Assert.assertNull("Reader " + i + " failed: " + errors[i].get(), errors[i].get()); FileInfo info = results[i].get(); Assert.assertNotNull("Reader " + i + " got null result", info); long actualTotalChunkSize = diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskReducePartitionDataWriterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskReducePartitionDataWriterSuiteJ.java index 4212e53cea8..257727cf1fa 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskReducePartitionDataWriterSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskReducePartitionDataWriterSuiteJ.java @@ -172,8 +172,7 @@ public boolean checkRegistered() { return null; }) .when(sorter) - .getSortedFileInfo( - anyString(), anyString(), eq(info), anyInt(), anyInt(), any()); + .getSortedFileInfo(anyString(), anyString(), eq(info), anyInt(), anyInt(), any()); handler.setPartitionsSorter(sorter); transportContext = new TransportContext(transConf, handler); server = transportContext.createServer(); diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryPartitionFilesSorterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryPartitionFilesSorterSuiteJ.java index 8fe1a7487b2..ddce90d17cf 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryPartitionFilesSorterSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryPartitionFilesSorterSuiteJ.java @@ -19,7 +19,6 @@ import static org.mockito.Mockito.when; -import java.io.IOException; import java.nio.ByteBuffer; import java.util.HashMap; import java.util.Map; diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryReducePartitionDataWriterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryReducePartitionDataWriterSuiteJ.java index 0440bed5192..cb30446e17e 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryReducePartitionDataWriterSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryReducePartitionDataWriterSuiteJ.java @@ -199,8 +199,7 @@ public boolean checkRegistered() { return null; }) .when(sorter) - .getSortedFileInfo( - anyString(), anyString(), eq(info), anyInt(), anyInt(), any()); + .getSortedFileInfo(anyString(), anyString(), eq(info), anyInt(), anyInt(), any()); handler.setPartitionsSorter(sorter); TransportContext context = new TransportContext(transConf, handler); server = context.createServer(); From 93b0879887446ba844ae60d8b1110284f6be2cbf Mon Sep 17 00:00:00 2001 From: Xianming Lei Date: Thu, 9 Apr 2026 15:04:33 +0800 Subject: [PATCH 4/4] fix --- .../worker/storage/PartitionFilesSorter.java | 10 +++++-- .../service/deploy/worker/FetchHandler.scala | 28 +++++++++++-------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java index a0d3684eebc..d05e734d698 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java @@ -317,14 +317,18 @@ public void onSuccess() { fileResolvedCallback.onSuccess(sortedFileInfo); } catch (Throwable e) { fileResolvedCallback.onFailure(e); + } finally { + notifyPendingSortCallbacks(fileId, null); } - notifyPendingSortCallbacks(fileId, null); } @Override public void onFailure(Throwable e) { - fileResolvedCallback.onFailure(e); - notifyPendingSortCallbacks(fileId, e); + try { + fileResolvedCallback.onFailure(e); + } finally { + notifyPendingSortCallbacks(fileId, e); + } } }; try { diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala index dfc732696a5..3a8815f2017 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala @@ -186,17 +186,23 @@ class FetchHandler( streamId, new FileResolvedCallback { override def onSuccess(sortedFileInfo: FileInfo): Unit = { - results(idx) = registerAndHandleStream( - client, - shuffleKey, - fileName, - startMapIndex, - endMapIndex, - readLocalFlag, - sortedFileInfo, - streamId) - if (results(idx).getStatus != StatusCode.SUCCESS.getValue) { - workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT) + try { + results(idx) = registerAndHandleStream( + client, + shuffleKey, + fileName, + startMapIndex, + endMapIndex, + readLocalFlag, + sortedFileInfo, + streamId) + if (results(idx).getStatus != StatusCode.SUCCESS.getValue) { + workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT) + } + } catch { + case t: Throwable => + onFailure(t) + return } completedCount.incrementAndGet() trySendBatchResponse()