diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileResolvedCallback.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileResolvedCallback.java new file mode 100644 index 00000000000..857af8e886f --- /dev/null +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileResolvedCallback.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.worker.storage; + +import org.apache.celeborn.common.meta.FileInfo; + +public interface FileResolvedCallback { + void onSuccess(FileInfo fileInfo); + + void onFailure(Throwable e); +} diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java index cdfb95c1e59..d05e734d698 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java @@ -31,6 +31,7 @@ import java.util.Set; import java.util.TreeMap; import java.util.concurrent.*; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.locks.Condition; @@ -78,6 +79,8 @@ public class PartitionFilesSorter extends ShuffleRecoverHelper { JavaUtils.newConcurrentHashMap(); private final ConcurrentHashMap> sortingShuffleFiles = JavaUtils.newConcurrentHashMap(); + private final ConcurrentHashMap> pendingSortCallbacks = + JavaUtils.newConcurrentHashMap(); private final Cache>> indexCache; private final Map> indexCacheNames = JavaUtils.newConcurrentHashMap(); @@ -205,7 +208,12 @@ public long getSortedSize() { // 3. If the sorted file is generated, it returns the sorted FileInfo. // This method will generate temporary file info for this shuffle read public FileInfo getSortedFileInfo( - String shuffleKey, String fileName, FileInfo fileInfo, int startMapIndex, int endMapIndex) + String shuffleKey, + String fileName, + FileInfo fileInfo, + int startMapIndex, + int endMapIndex, + FileResolvedCallback fileResolvedCallback) throws IOException { if (fileInfo instanceof MemoryFileInfo) { MemoryFileInfo memoryFileInfo = ((MemoryFileInfo) fileInfo); @@ -227,11 +235,13 @@ public FileInfo getSortedFileInfo( memoryFileInfo.getSortedBuffer(), targetBuffer, shuffleChunkSize); - return new MemoryFileInfo( - memoryFileInfo.getUserIdentifier(), - memoryFileInfo.isPartitionSplitEnabled(), - reduceFileMeta, - targetBuffer); + FileInfo sortedFileInfo = + new MemoryFileInfo( + memoryFileInfo.getUserIdentifier(), + memoryFileInfo.isPartitionSplitEnabled(), + reduceFileMeta, + targetBuffer); + fileResolvedCallback.onSuccess(sortedFileInfo); } else { DiskFileInfo diskFileInfo = ((DiskFileInfo) fileInfo); String fileId = shuffleKey + "-" + fileName; @@ -243,13 +253,87 @@ public FileInfo getSortedFileInfo( String sortedFilePath = Utils.getSortedFilePath(diskFileInfo.getFilePath()); String indexFilePath = Utils.getIndexFilePath(diskFileInfo.getFilePath()); - boolean fileSorting = true; synchronized (sorting) { if (sorted.contains(fileId)) { - fileSorting = false; - } else if (!sorting.contains(fileId)) { try { - FileSorter fileSorter = new FileSorter(diskFileInfo, fileId, shuffleKey); + FileInfo sortedFileInfo = + resolve( + shuffleKey, + fileId, + userIdentifier, + sortedFilePath, + indexFilePath, + startMapIndex, + endMapIndex); + fileResolvedCallback.onSuccess(sortedFileInfo); + } catch (Throwable e) { + fileResolvedCallback.onFailure(e); + } + return null; + } else if (sorting.contains(fileId)) { + FileResolvedCallback pendingCallback = + new FileResolvedCallback() { + @Override + public void onSuccess(FileInfo ignored) { + try { + FileInfo sortedFileInfo = + resolve( + shuffleKey, + fileId, + userIdentifier, + sortedFilePath, + indexFilePath, + startMapIndex, + endMapIndex); + fileResolvedCallback.onSuccess(sortedFileInfo); + } catch (Throwable e) { + fileResolvedCallback.onFailure(e); + } + } + + @Override + public void onFailure(Throwable e) { + fileResolvedCallback.onFailure(e); + } + }; + pendingSortCallbacks + .computeIfAbsent(fileId, k -> new CopyOnWriteArrayList<>()) + .add(pendingCallback); + } else { + FileSortedCallback fileSortedCallback = + new FileSortedCallback() { + @Override + public void onSuccess() { + try { + FileInfo sortedFileInfo = + resolve( + shuffleKey, + fileId, + userIdentifier, + sortedFilePath, + indexFilePath, + startMapIndex, + endMapIndex); + fileResolvedCallback.onSuccess(sortedFileInfo); + } catch (Throwable e) { + fileResolvedCallback.onFailure(e); + } finally { + notifyPendingSortCallbacks(fileId, null); + } + } + + @Override + public void onFailure(Throwable e) { + try { + fileResolvedCallback.onFailure(e); + } finally { + notifyPendingSortCallbacks(fileId, e); + } + } + }; + try { + FileSorter fileSorter = + new FileSorter(diskFileInfo, fileId, shuffleKey, fileSortedCallback); sorting.add(fileId); logger.debug( "Adding sorter to sort queue shuffle key {}, file name {}", shuffleKey, fileName); @@ -257,59 +341,33 @@ public FileInfo getSortedFileInfo( } catch (InterruptedException e) { logger.error( "Sorter scheduler thread is interrupted means worker is shutting down.", e); - throw new IOException( - "Sort scheduler thread is interrupted means worker is shutting down.", e); + fileResolvedCallback.onFailure( + new IOException( + "Sort scheduler thread is interrupted means worker is shutting down.", e)); } catch (IOException e) { logger.error("File sorter access DFS failed.", e); - throw new IOException("File sorter access DFS failed.", e); + fileResolvedCallback.onFailure(new IOException("File sorter access DFS failed.", e)); } } } + } + return null; + } - if (fileSorting) { - long sortStartTime = System.currentTimeMillis(); - while (!sorted.contains(fileId)) { - if (sorting.contains(fileId)) { - try { - Thread.sleep(50); - if (System.currentTimeMillis() - sortStartTime > sortTimeout) { - String msg = - String.format( - "Sorting file %s path %s length %s timeout after %dms", - fileId, - diskFileInfo.getFilePath(), - diskFileInfo.getFileLength(), - sortTimeout); - logger.error(msg); - throw new IOException(msg); - } - } catch (InterruptedException e) { - logger.error( - "Sorter scheduler thread is interrupted means worker is shutting down.", e); - throw new IOException( - "Sorter scheduler thread is interrupted means worker is shutting down.", e); - } + private void notifyPendingSortCallbacks(String fileId, Throwable error) { + List callbacks = pendingSortCallbacks.remove(fileId); + if (callbacks != null) { + for (FileResolvedCallback cb : callbacks) { + try { + if (error != null) { + cb.onFailure(error); } else { - logger.debug( - "Sorting shuffle file for {} {} failed.", shuffleKey, diskFileInfo.getFilePath()); - throw new IOException( - "Sorting shuffle file for " - + shuffleKey - + " " - + diskFileInfo.getFilePath() - + " failed."); + cb.onSuccess(null); } + } catch (Exception e) { + logger.error("Error notifying pending sort callback for {}", fileId, e); } } - - return resolve( - shuffleKey, - fileId, - userIdentifier, - sortedFilePath, - indexFilePath, - startMapIndex, - endMapIndex); } } @@ -686,8 +744,14 @@ class FileSorter { private FileChannel originFileChannel = null; private FileChannel sortedFileChannel = null; private FileSystem hadoopFs; - - FileSorter(DiskFileInfo fileInfo, String fileId, String shuffleKey) throws IOException { + private FileSortedCallback fileSortedCallback; + + FileSorter( + DiskFileInfo fileInfo, + String fileId, + String shuffleKey, + FileSortedCallback fileSortedCallback) + throws IOException { this.originFileInfo = fileInfo; this.originFilePath = fileInfo.getFilePath(); this.sortedFilePath = Utils.getSortedFilePath(originFilePath); @@ -700,6 +764,7 @@ class FileSorter { this.fileId = fileId; this.shuffleKey = shuffleKey; this.indexFilePath = Utils.getIndexFilePath(originFilePath); + this.fileSortedCallback = fileSortedCallback; if (!isDfs) { File sortedFile = new File(this.sortedFilePath); if (sortedFile.exists()) { @@ -728,6 +793,7 @@ class FileSorter { public void sort() { source.startTimer(WorkerSource.SORT_TIME(), fileId); + boolean success = true; long sortStartTime = -1; if (sortTimeLogThreshold > 0) { sortStartTime = System.nanoTime(); @@ -804,6 +870,8 @@ public void sort() { } catch (Exception e) { logger.error( "Sorting shuffle file for " + fileId + " " + originFilePath + " failed, detail: ", e); + success = false; + fileSortedCallback.onFailure(e); } finally { closeFiles(); Set sorting = sortingShuffleFiles.get(shuffleKey); @@ -811,6 +879,9 @@ public void sort() { sorting.remove(fileId); } } + if (success) { + fileSortedCallback.onSuccess(); + } if (sortTimeLogThreshold > 0) { long sortDuration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - sortStartTime); if (sortDuration > sortTimeLogThreshold) { @@ -1000,3 +1071,9 @@ public void close() { cleaner.shutdownNow(); } } + +interface FileSortedCallback { + void onSuccess(); + + void onFailure(Throwable e); +} diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala index ca9138e8fbc..3a8815f2017 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala @@ -20,9 +20,11 @@ package org.apache.celeborn.service.deploy.worker import java.io.{FileNotFoundException, IOException} import java.nio.charset.StandardCharsets import java.util -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} import java.util.function.Consumer +import scala.collection.mutable.ArrayBuffer + import com.google.common.base.Throwables import com.google.protobuf.GeneratedMessageV3 import io.netty.util.concurrent.{Future, GenericFutureListener} @@ -39,8 +41,8 @@ import org.apache.celeborn.common.network.server.BaseMessageHandler import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf} import org.apache.celeborn.common.protocol.{MessageType, PbBufferStreamEnd, PbChunkFetchRequest, PbNotifyRequiredSegment, PbOpenStream, PbOpenStreamList, PbOpenStreamListResponse, PbReadAddCredit, PbStreamHandler, PbStreamHandlerOpt, StreamType} import org.apache.celeborn.common.protocol.message.StatusCode -import org.apache.celeborn.common.util.{ExceptionUtils, Utils} -import org.apache.celeborn.service.deploy.worker.storage.{ChunkStreamManager, CreditStreamManager, PartitionFilesSorter, StorageManager} +import org.apache.celeborn.common.util.{ExceptionUtils, JavaUtils, Utils} +import org.apache.celeborn.service.deploy.worker.storage.{ChunkStreamManager, CreditStreamManager, FileResolvedCallback, PartitionFilesSorter, StorageManager} class FetchHandler( val conf: CelebornConf, @@ -143,36 +145,92 @@ class FetchHandler( val startIndices = openStreamList.getStartIndexList val endIndices = openStreamList.getEndIndexList val readLocalFlags = openStreamList.getReadLocalShuffleList - val pbOpenStreamListResponse = PbOpenStreamListResponse.newBuilder() checkAuth(client, Utils.splitShuffleKey(shuffleKey)._1) val openStreamRequestId = Utils.makeOpenStreamRequestId( shuffleKey, client.getChannel.id().toString, rpcRequest.requestId) workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME, openStreamRequestId) - try { - 0 until files.size() foreach { idx => - val pbStreamHandlerOpt = handleReduceOpenStreamInternal( - client, + val totalFiles = files.size() + val results = new Array[PbStreamHandlerOpt](totalFiles) + val completedCount = new AtomicInteger(0) + val replied = new AtomicBoolean(false) + + def trySendBatchResponse(): Unit = { + if (completedCount.get() >= totalFiles && replied.compareAndSet(false, true)) { + val pbOpenStreamListResponse = PbOpenStreamListResponse.newBuilder() + results.foreach(pbOpenStreamListResponse.addStreamHandlerOpt) + workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, openStreamRequestId) + client.getChannel.writeAndFlush(new RpcResponse( + rpcRequest.requestId, + new NioManagedBuffer(new TransportMessage( + MessageType.BATCH_OPEN_STREAM_RESPONSE, + pbOpenStreamListResponse.build().toByteArray).toByteBuffer))) + } + } + + 0 until totalFiles foreach { idx => + val fileName = files.get(idx) + val startMapIndex = startIndices.get(idx) + val endMapIndex = endIndices.get(idx) + val readLocalFlag = readLocalFlags.get(idx) + val streamId = chunkStreamManager.nextStreamId() + try { + val fileInfo = getRawFileInfo(shuffleKey, fileName) + openReduceStreamAsync( shuffleKey, - files.get(idx), - startIndices.get(idx), - endIndices.get(idx), - readLocalFlags.get(idx)) - if (pbStreamHandlerOpt.getStatus != StatusCode.SUCCESS.getValue) { + fileName, + fileInfo, + startMapIndex, + endMapIndex, + streamId, + new FileResolvedCallback { + override def onSuccess(sortedFileInfo: FileInfo): Unit = { + try { + results(idx) = registerAndHandleStream( + client, + shuffleKey, + fileName, + startMapIndex, + endMapIndex, + readLocalFlag, + sortedFileInfo, + streamId) + if (results(idx).getStatus != StatusCode.SUCCESS.getValue) { + workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT) + } + } catch { + case t: Throwable => + onFailure(t) + return + } + completedCount.incrementAndGet() + trySendBatchResponse() + } + + override def onFailure(e: Throwable): Unit = { + val msg = s"Read file: $fileName with shuffleKey: $shuffleKey error, " + + s"Exception: ${e.getMessage}" + results(idx) = PbStreamHandlerOpt.newBuilder() + .setStatus(StatusCode.OPEN_STREAM_FAILED.getValue) + .setErrorMsg(msg).build() + workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT) + completedCount.incrementAndGet() + trySendBatchResponse() + } + }) + } catch { + case e: IOException => + val msg = s"Read file: $fileName with shuffleKey: $shuffleKey error, " + + s"Exception: ${e.getMessage}" + results(idx) = PbStreamHandlerOpt.newBuilder() + .setStatus(StatusCode.OPEN_STREAM_FAILED.getValue) + .setErrorMsg(msg).build() workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT) - } - pbOpenStreamListResponse.addStreamHandlerOpt(pbStreamHandlerOpt) + completedCount.incrementAndGet() + trySendBatchResponse() } - } finally { - workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, openStreamRequestId) } - - client.getChannel.writeAndFlush(new RpcResponse( - rpcRequest.requestId, - new NioManagedBuffer(new TransportMessage( - MessageType.BATCH_OPEN_STREAM_RESPONSE, - pbOpenStreamListResponse.build().toByteArray).toByteBuffer))) case bufferStreamEnd: PbBufferStreamEnd => handleEndStreamFromClient( client, @@ -245,33 +303,16 @@ class FetchHandler( } - private def handleReduceOpenStreamInternal( + private def registerAndHandleStream( client: TransportClient, shuffleKey: String, fileName: String, startIndex: Int, endIndex: Int, - readLocalShuffle: Boolean = false): PbStreamHandlerOpt = { + readLocalShuffle: Boolean, + fileInfo: FileInfo, + streamId: Long): PbStreamHandlerOpt = { try { - logDebug(s"Received open stream request $shuffleKey $fileName $startIndex " + - s"$endIndex get file name $fileName from client channel " + - s"${NettyUtils.getRemoteAddress(client.getChannel)}") - - var fileInfo = getRawFileInfo(shuffleKey, fileName) - val streamId = chunkStreamManager.nextStreamId() - // we must get sorted fileInfo for the following cases. - // 1. when the current request is a non-range openStream, but the original unsorted file - // has been deleted by another range's openStream request. - // 2. when the current request is a range openStream request. - if ((endIndex != Int.MaxValue && endIndex != -1 && endIndex >= startIndex) || (endIndex == Int.MaxValue && !fileInfo.addStream( - streamId))) { - fileInfo = partitionsSorter.getSortedFileInfo( - shuffleKey, - fileName, - fileInfo, - startIndex, - endIndex) - } val meta = fileInfo.getReduceFileMeta val streamHandler = if (readLocalShuffle && !fileInfo.isInstanceOf[MemoryFileInfo]) { @@ -286,37 +327,23 @@ class FetchHandler( fileInfo.asInstanceOf[DiskFileInfo].getFilePath) } else fileInfo match { case info: DiskFileInfo if info.isHdfs => - chunkStreamManager.registerStream( - streamId, - shuffleKey, - fileName) + chunkStreamManager.registerStream(streamId, shuffleKey, fileName) makeStreamHandler(streamId, numChunks = 0) case info: DiskFileInfo if info.isS3 => - chunkStreamManager.registerStream( - streamId, - shuffleKey, - fileName) + chunkStreamManager.registerStream(streamId, shuffleKey, fileName) makeStreamHandler(streamId, numChunks = 0) case info: DiskFileInfo if info.isOSS => - chunkStreamManager.registerStream( - streamId, - shuffleKey, - fileName) + chunkStreamManager.registerStream(streamId, shuffleKey, fileName) makeStreamHandler(streamId, numChunks = 0) case _ => val managedBuffer = fileInfo match { - case df: DiskFileInfo => - new FileChunkBuffers(df, transportConf) - case mf: MemoryFileInfo => - new MemoryChunkBuffers(mf) + case df: DiskFileInfo => new FileChunkBuffers(df, transportConf) + case mf: MemoryFileInfo => new MemoryChunkBuffers(mf) + } + val fetchTimeMetric = fileInfo match { + case info: DiskFileInfo => storageManager.getFetchTimeMetric(info.getFile) + case _ => null } - val fetchTimeMetric = - fileInfo match { - case info: DiskFileInfo => - storageManager.getFetchTimeMetric(info.getFile) - case _ => - null - } chunkStreamManager.registerStream( streamId, shuffleKey, @@ -331,9 +358,7 @@ class FetchHandler( s"StreamId $streamId, fileName $fileName, numChunks ${meta.getNumChunks}, " + s"mapRange [$startIndex-$endIndex]. Received from client channel " + s"${NettyUtils.getRemoteAddress(client.getChannel)}") - makeStreamHandler( - streamId, - meta.getNumChunks) + makeStreamHandler(streamId, meta.getNumChunks) } workerSource.incCounter(WorkerSource.OPEN_STREAM_SUCCESS_COUNT) PbStreamHandlerOpt.newBuilder().setStreamHandler(streamHandler) @@ -349,6 +374,35 @@ class FetchHandler( } } + /** + * Decides whether sorting is needed and either triggers async sort or invokes callback directly. + * Sorting is required when: + * 1. A range openStream request (startIndex..endIndex) is specified. + * 2. A non-range openStream where the raw file has already been deleted by another + * range's sort (addStream returns false). + */ + private def openReduceStreamAsync( + shuffleKey: String, + fileName: String, + fileInfo: FileInfo, + startIndex: Int, + endIndex: Int, + streamId: Long, + callback: FileResolvedCallback): Unit = { + if ((endIndex != Int.MaxValue && endIndex != -1 && endIndex >= startIndex) || + (endIndex == Int.MaxValue && !fileInfo.addStream(streamId))) { + partitionsSorter.getSortedFileInfo( + shuffleKey, + fileName, + fileInfo, + startIndex, + endIndex, + callback) + } else { + callback.onSuccess(fileInfo) + } + } + private def handleOpenStreamInternal( client: TransportClient, shuffleKey: String, @@ -367,23 +421,70 @@ class FetchHandler( client.getChannel.id().toString, rpcRequestId) workerSource.startTimer(WorkerSource.OPEN_STREAM_TIME, requestId) + var asyncReplied = false try { val fileInfo = getRawFileInfo(shuffleKey, fileName) fileInfo.getFileMeta match { case _: ReduceFileMeta => - val pbStreamHandlerOpt = - handleReduceOpenStreamInternal( - client, - shuffleKey, - fileName, - startIndex, - endIndex, - readLocalShuffle) + val streamId = chunkStreamManager.nextStreamId() + openReduceStreamAsync( + shuffleKey, + fileName, + fileInfo, + startIndex, + endIndex, + streamId, + new FileResolvedCallback { + private var timerStopped = false + + override def onSuccess(sortedFileInfo: FileInfo): Unit = { + try { + val pbStreamHandlerOpt = + registerAndHandleStream( + client, + shuffleKey, + fileName, + startIndex, + endIndex, + readLocalShuffle, + sortedFileInfo, + streamId) + if (pbStreamHandlerOpt.getStatus != StatusCode.SUCCESS.getValue) { + throw new CelebornIOException(pbStreamHandlerOpt.getErrorMsg) + } + replyStreamHandler( + client, + rpcRequestId, + pbStreamHandlerOpt.getStreamHandler, + isLegacy) + } catch { + case t: Throwable => + onFailure(t) + } finally { + stopTimerIfNeeded() + } + } - if (pbStreamHandlerOpt.getStatus != StatusCode.SUCCESS.getValue) { - throw new CelebornIOException(pbStreamHandlerOpt.getErrorMsg) - } - replyStreamHandler(client, rpcRequestId, pbStreamHandlerOpt.getStreamHandler, isLegacy) + override def onFailure(e: Throwable): Unit = { + workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT) + handleRpcIOException( + client, + rpcRequestId, + shuffleKey, + fileName, + ExceptionUtils.wrapThrowableToIOException(e), + callback) + stopTimerIfNeeded() + } + + private def stopTimerIfNeeded(): Unit = { + if (!timerStopped) { + workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, requestId) + timerStopped = true + } + } + }) + asyncReplied = true case _: MapFileMeta => val creditStreamHandler = new Consumer[java.lang.Long] { @@ -395,7 +496,6 @@ class FetchHandler( replyStreamHandler(client, rpcRequestId, pbStreamHandler, isLegacy) } } - creditStreamManager.registerStream( creditStreamHandler, client.getChannel, @@ -404,14 +504,16 @@ class FetchHandler( startIndex, endIndex, fileInfo.asInstanceOf[DiskFileInfo]) + workerSource.incCounter(WorkerSource.OPEN_STREAM_SUCCESS_COUNT) } - workerSource.incCounter(WorkerSource.OPEN_STREAM_SUCCESS_COUNT) } catch { case e: IOException => workerSource.incCounter(WorkerSource.OPEN_STREAM_FAIL_COUNT) handleRpcIOException(client, rpcRequestId, shuffleKey, fileName, e, callback) } finally { - workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, requestId) + if (!asyncReplied) { + workerSource.stopTimer(WorkerSource.OPEN_STREAM_TIME, requestId) + } } } diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java index b54a0b99a56..09efed0b2f2 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java @@ -63,10 +63,14 @@ import org.apache.celeborn.common.protocol.PbBufferStreamEnd; import org.apache.celeborn.common.protocol.PbChunkFetchRequest; import org.apache.celeborn.common.protocol.PbOpenStream; +import org.apache.celeborn.common.protocol.PbOpenStreamList; +import org.apache.celeborn.common.protocol.PbOpenStreamListResponse; import org.apache.celeborn.common.protocol.PbStreamChunkSlice; import org.apache.celeborn.common.protocol.PbStreamHandler; +import org.apache.celeborn.common.protocol.PbStreamHandlerOpt; import org.apache.celeborn.common.protocol.StreamType; import org.apache.celeborn.common.protocol.TransportModuleConstants; +import org.apache.celeborn.common.protocol.message.StatusCode; import org.apache.celeborn.common.unsafe.Platform; import org.apache.celeborn.common.util.JavaUtils; import org.apache.celeborn.common.util.Utils; @@ -161,7 +165,7 @@ public static void afterAll() { } @Test - public void testFetchOriginFile() throws IOException { + public void testFetchOriginFile() throws Exception { FileInfo fileInfo = null; try { // total write: 32 * 50 * 256k = 400m @@ -180,7 +184,7 @@ public void testFetchOriginFile() throws IOException { } @Test - public void testFetchSortFile() throws IOException { + public void testFetchSortFile() throws Exception { FileInfo fileInfo = null; try { // total write size: 32 * 50 * 256k = 400m @@ -198,7 +202,7 @@ public void testFetchSortFile() throws IOException { } @Test - public void testLegacyOpenStream() throws IOException { + public void testLegacyOpenStream() throws Exception { FileInfo fileInfo = null; try { // total write size: 32 * 50 * 256k = 400m @@ -215,7 +219,7 @@ public void testLegacyOpenStream() throws IOException { } @Test - public void testWorkerReadSortFileOnceOriginalFileBeDeleted() throws IOException { + public void testWorkerReadSortFileOnceOriginalFileBeDeleted() throws Exception { FileInfo fileInfo = null; try { // total write size: 32 * 50 * 256k = 400m @@ -237,7 +241,7 @@ public void testWorkerReadSortFileOnceOriginalFileBeDeleted() throws IOException } @Test - public void testLocalReadSortFileOnceOriginalFileBeDeleted() throws IOException { + public void testLocalReadSortFileOnceOriginalFileBeDeleted() throws Exception { FileInfo fileInfo = null; try { // total write size: 32 * 50 * 256k = 400m @@ -255,7 +259,7 @@ public void testLocalReadSortFileOnceOriginalFileBeDeleted() throws IOException } @Test - public void testDoNotDeleteOriginalFileWhenNonRangeWorkerReadWorkInProgress() throws IOException { + public void testDoNotDeleteOriginalFileWhenNonRangeWorkerReadWorkInProgress() throws Exception { FileInfo fileInfo = null; try { // total write size: 32 * 50 * 256k = 400m @@ -280,7 +284,7 @@ public void testDoNotDeleteOriginalFileWhenNonRangeWorkerReadWorkInProgress() th } @Test - public void testDoNotDeleteOriginalFileWhenNonRangeLocalReadWorkInProgress() throws IOException { + public void testDoNotDeleteOriginalFileWhenNonRangeLocalReadWorkInProgress() throws Exception { FileInfo fileInfo = null; try { // total write size: 32 * 50 * 256k = 400m @@ -302,6 +306,104 @@ public void testDoNotDeleteOriginalFileWhenNonRangeLocalReadWorkInProgress() thr } } + @Test + public void testBatchOpenStream() throws Exception { + FileInfo fileInfo = null; + try { + fileInfo = prepare(32); + EmbeddedChannel channel = new EmbeddedChannel(); + TransportClient client = new TransportClient(channel, mock(TransportResponseHandler.class)); + FetchHandler fetchHandler = mockFetchHandler(fileInfo); + + PbOpenStreamList.Builder builder = PbOpenStreamList.newBuilder().setShuffleKey(shuffleKey); + int batchSize = 3; + for (int i = 0; i < batchSize; i++) { + builder.addFileName(fileName); + builder.addStartIndex(5); + builder.addEndIndex(10); + builder.addReadLocalShuffle(false); + } + ByteBuffer batchOpenStreamBuffer = + new TransportMessage(MessageType.BATCH_OPEN_STREAM, builder.build().toByteArray()) + .toByteBuffer(); + fetchHandler.receive( + client, + new RpcRequest(dummyRequestId, new NioManagedBuffer(batchOpenStreamBuffer)), + createRpcResponseCallback(channel)); + + RpcResponse result = readOutboundWithTimeout(channel, 30_000); + PbOpenStreamListResponse response = + TransportMessage.fromByteBuffer(result.body().nioByteBuffer()).getParsedPayload(); + assertEquals(batchSize, response.getStreamHandlerOptCount()); + for (int i = 0; i < batchSize; i++) { + PbStreamHandlerOpt opt = response.getStreamHandlerOpt(i); + assertEquals(StatusCode.SUCCESS.getValue(), opt.getStatus()); + assertEquals(10 - 5, opt.getStreamHandler().getNumChunks()); + } + } finally { + cleanup(fileInfo); + } + } + + @Test + public void testBatchOpenStreamPartialFailure() throws Exception { + FileInfo fileInfo = null; + try { + fileInfo = prepare(32); + EmbeddedChannel channel = new EmbeddedChannel(); + TransportClient client = new TransportClient(channel, mock(TransportResponseHandler.class)); + + WorkerSource workerSource = mock(WorkerSource.class); + TransportConf transportConf = + Utils.fromCelebornConf(conf, TransportModuleConstants.FETCH_MODULE, 4); + FetchHandler fetchHandler0 = new FetchHandler(conf, transportConf, workerSource); + Worker worker = mock(Worker.class); + PartitionFilesSorter partitionFilesSorter = + new PartitionFilesSorter(MemoryManager.instance(), conf, workerSource); + StorageManager storageManager = mock(StorageManager.class); + Mockito.doReturn(storageManager).when(worker).storageManager(); + Mockito.doReturn(workerSource).when(worker).workerSource(); + Mockito.doReturn(partitionFilesSorter).when(worker).partitionsSorter(); + fetchHandler0.init(worker); + FetchHandler fetchHandler = spy(fetchHandler0); + + String existingFile = "existingFile"; + String missingFile = "missingFile"; + Mockito.doReturn(fileInfo).when(fetchHandler).getRawFileInfo(shuffleKey, existingFile); + Mockito.doAnswer( + invocation -> { + throw new java.io.FileNotFoundException("Not found"); + }) + .when(fetchHandler) + .getRawFileInfo(shuffleKey, missingFile); + + PbOpenStreamList.Builder builder = PbOpenStreamList.newBuilder().setShuffleKey(shuffleKey); + builder.addFileName(existingFile).addStartIndex(5).addEndIndex(10).addReadLocalShuffle(false); + builder.addFileName(missingFile).addStartIndex(0).addEndIndex(10).addReadLocalShuffle(false); + builder.addFileName(existingFile).addStartIndex(5).addEndIndex(10).addReadLocalShuffle(false); + + ByteBuffer batchBuffer = + new TransportMessage(MessageType.BATCH_OPEN_STREAM, builder.build().toByteArray()) + .toByteBuffer(); + fetchHandler.receive( + client, + new RpcRequest(dummyRequestId, new NioManagedBuffer(batchBuffer)), + createRpcResponseCallback(channel)); + + RpcResponse result = readOutboundWithTimeout(channel, 30_000); + PbOpenStreamListResponse response = + TransportMessage.fromByteBuffer(result.body().nioByteBuffer()).getParsedPayload(); + assertEquals(3, response.getStreamHandlerOptCount()); + + assertEquals(StatusCode.SUCCESS.getValue(), response.getStreamHandlerOpt(0).getStatus()); + assertEquals( + StatusCode.OPEN_STREAM_FAILED.getValue(), response.getStreamHandlerOpt(1).getStatus()); + assertEquals(StatusCode.SUCCESS.getValue(), response.getStreamHandlerOpt(2).getStatus()); + } finally { + cleanup(fileInfo); + } + } + private FetchHandler mockFetchHandler(FileInfo fileInfo) { WorkerSource workerSource = mock(WorkerSource.class); TransportConf transportConf = @@ -325,6 +427,21 @@ private FetchHandler mockFetchHandler(FileInfo fileInfo) { private final String fileName = "dummyFileName"; private final long dummyRequestId = 0; + @SuppressWarnings("unchecked") + private T readOutboundWithTimeout(EmbeddedChannel channel, long timeoutMs) + throws InterruptedException { + long deadline = System.currentTimeMillis() + timeoutMs; + T result; + while ((result = (T) channel.readOutbound()) == null) { + if (System.currentTimeMillis() > deadline) { + fail("Timed out waiting for outbound message"); + } + channel.runPendingTasks(); + Thread.sleep(50); + } + return result; + } + @Deprecated private void legacyOpenStreamAndCheck( TransportClient client, @@ -354,7 +471,7 @@ private PbStreamHandler openStreamAndCheck( FetchHandler fetchHandler, int startIndex, int endIndex) - throws IOException { + throws Exception { return openStreamAndCheck(client, channel, fetchHandler, startIndex, endIndex, false); } @@ -365,7 +482,7 @@ private PbStreamHandler openStreamAndCheck( int startIndex, int endIndex, Boolean readLocalShuffle) - throws IOException { + throws Exception { ByteBuffer openStreamByteBuffer = new TransportMessage( MessageType.OPEN_STREAM, @@ -382,7 +499,7 @@ private PbStreamHandler openStreamAndCheck( client, new RpcRequest(dummyRequestId, new NioManagedBuffer(openStreamByteBuffer)), createRpcResponseCallback(channel)); - RpcResponse result = channel.readOutbound(); + RpcResponse result = readOutboundWithTimeout(channel, 30_000); PbStreamHandler streamHandler = TransportMessage.fromByteBuffer(result.body().nioByteBuffer()).getParsedPayload(); if (endIndex == Integer.MAX_VALUE) { diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskPartitionFilesSorterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskPartitionFilesSorterSuiteJ.java index ce092633db0..ad249db3b43 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskPartitionFilesSorterSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskPartitionFilesSorterSuiteJ.java @@ -27,6 +27,9 @@ import java.util.HashMap; import java.util.Map; import java.util.Random; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Assert; import org.junit.Test; @@ -45,6 +48,7 @@ import org.apache.celeborn.common.util.Utils; import org.apache.celeborn.service.deploy.worker.WorkerSource; import org.apache.celeborn.service.deploy.worker.memory.MemoryManager; +import org.apache.celeborn.service.deploy.worker.storage.FileResolvedCallback; import org.apache.celeborn.service.deploy.worker.storage.PartitionDataWriter; import org.apache.celeborn.service.deploy.worker.storage.PartitionFilesSorter; @@ -136,7 +140,44 @@ public void clean() throws IOException { JavaUtils.deleteRecursively(new File(shuffleFile.getPath() + ".index")); } - private void check(int mapCount, int startMapIndex, int endMapIndex) throws IOException { + private FileInfo getSortedFileInfoAsync( + PartitionFilesSorter sorter, + String shuffleKey, + String fileName, + FileInfo fileInfo, + int startMapIndex, + int endMapIndex) + throws Exception { + CountDownLatch latch = new CountDownLatch(1); + AtomicReference resultRef = new AtomicReference<>(); + AtomicReference errorRef = new AtomicReference<>(); + sorter.getSortedFileInfo( + shuffleKey, + fileName, + fileInfo, + startMapIndex, + endMapIndex, + new FileResolvedCallback() { + @Override + public void onSuccess(FileInfo sortedFileInfo) { + resultRef.set(sortedFileInfo); + latch.countDown(); + } + + @Override + public void onFailure(Throwable e) { + errorRef.set(e); + latch.countDown(); + } + }); + Assert.assertTrue("Sort timed out", latch.await(60, TimeUnit.SECONDS)); + if (errorRef.get() != null) { + throw new IOException("Sort failed", errorRef.get()); + } + return resultRef.get(); + } + + private void check(int mapCount, int startMapIndex, int endMapIndex) throws Exception { try { long[] partitionSize = prepare(mapCount); CelebornConf conf = new CelebornConf(); @@ -144,7 +185,8 @@ private void check(int mapCount, int startMapIndex, int endMapIndex) throws IOEx PartitionFilesSorter partitionFilesSorter = new PartitionFilesSorter(MemoryManager.instance(), conf, new WorkerSource(conf)); FileInfo info = - partitionFilesSorter.getSortedFileInfo( + getSortedFileInfoAsync( + partitionFilesSorter, "application-1", originFileName, partitionDataWriter.getDiskFileInfo(), @@ -168,19 +210,83 @@ private void check(int mapCount, int startMapIndex, int endMapIndex) throws IOEx } @Test - public void testSmallFile() throws IOException { + public void testSmallFile() throws Exception { int startMapIndex = random.nextInt(5); int endMapIndex = startMapIndex + random.nextInt(5) + 5; check(1000, startMapIndex, endMapIndex); } @Test - public void testLargeFile() throws IOException { + public void testLargeFile() throws Exception { int startMapIndex = random.nextInt(5); int endMapIndex = startMapIndex + random.nextInt(5) + 5; check(15000, startMapIndex, endMapIndex); } + @Test + public void testConcurrentSortReaders() throws Exception { + try { + long[] partitionSize = prepare(1000); + CelebornConf conf = new CelebornConf(); + conf.set(CelebornConf.SHUFFLE_CHUNK_SIZE().key(), "8m"); + PartitionFilesSorter sorter = + new PartitionFilesSorter(MemoryManager.instance(), conf, new WorkerSource(conf)); + + int startMapIndex = 3; + int endMapIndex = 8; + int numReaders = 5; + CountDownLatch allDone = new CountDownLatch(numReaders); + @SuppressWarnings("unchecked") + AtomicReference[] results = new AtomicReference[numReaders]; + @SuppressWarnings("unchecked") + AtomicReference[] errors = new AtomicReference[numReaders]; + + for (int i = 0; i < numReaders; i++) { + results[i] = new AtomicReference<>(); + errors[i] = new AtomicReference<>(); + final int idx = i; + sorter.getSortedFileInfo( + "application-1", + originFileName, + partitionDataWriter.getDiskFileInfo(), + startMapIndex, + endMapIndex, + new FileResolvedCallback() { + @Override + public void onSuccess(FileInfo sortedFileInfo) { + results[idx].set(sortedFileInfo); + allDone.countDown(); + } + + @Override + public void onFailure(Throwable e) { + errors[idx].set(e); + allDone.countDown(); + } + }); + } + + Assert.assertTrue("Concurrent readers timed out", allDone.await(60, TimeUnit.SECONDS)); + + long totalSizeToFetch = 0; + for (int i = startMapIndex; i < endMapIndex; i++) { + totalSizeToFetch += partitionSize[i]; + } + + for (int i = 0; i < numReaders; i++) { + Assert.assertNull("Reader " + i + " failed: " + errors[i].get(), errors[i].get()); + FileInfo info = results[i].get(); + Assert.assertNotNull("Reader " + i + " got null result", info); + long actualTotalChunkSize = + ((ReduceFileMeta) info.getFileMeta()).getLastChunkOffset() + - ((ReduceFileMeta) info.getFileMeta()).getChunkOffsets().get(0); + Assert.assertEquals(totalSizeToFetch, actualTotalChunkSize); + } + } finally { + clean(); + } + } + @Test public void testLevelDB() { if (Utils.isMacOnAppleSilicon()) { diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskReducePartitionDataWriterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskReducePartitionDataWriterSuiteJ.java index 536b4aab364..257727cf1fa 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskReducePartitionDataWriterSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskReducePartitionDataWriterSuiteJ.java @@ -165,9 +165,14 @@ public boolean checkRegistered() { } }; PartitionFilesSorter sorter = mock(PartitionFilesSorter.class); - Mockito.doReturn(info) + Mockito.doAnswer( + invocation -> { + FileResolvedCallback callback = invocation.getArgument(5); + callback.onSuccess(info); + return null; + }) .when(sorter) - .getSortedFileInfo(anyString(), anyString(), eq(info), anyInt(), anyInt()); + .getSortedFileInfo(anyString(), anyString(), eq(info), anyInt(), anyInt(), any()); handler.setPartitionsSorter(sorter); transportContext = new TransportContext(transConf, handler); server = transportContext.createServer(); diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryPartitionFilesSorterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryPartitionFilesSorterSuiteJ.java index aa8a4593ed7..ddce90d17cf 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryPartitionFilesSorterSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryPartitionFilesSorterSuiteJ.java @@ -19,11 +19,13 @@ import static org.mockito.Mockito.when; -import java.io.IOException; import java.nio.ByteBuffer; import java.util.HashMap; import java.util.Map; import java.util.Random; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; @@ -41,6 +43,7 @@ import org.apache.celeborn.common.unsafe.Platform; import org.apache.celeborn.service.deploy.worker.WorkerSource; import org.apache.celeborn.service.deploy.worker.memory.MemoryManager; +import org.apache.celeborn.service.deploy.worker.storage.FileResolvedCallback; import org.apache.celeborn.service.deploy.worker.storage.PartitionDataWriter; import org.apache.celeborn.service.deploy.worker.storage.PartitionFilesSorter; import org.apache.celeborn.service.deploy.worker.storage.StorageManager; @@ -123,19 +126,40 @@ public long[] prepare(int mapCount) { return partitionSize; } - private void check(int mapCount, int startMapIndex, int endMapIndex) throws IOException { + private void check(int mapCount, int startMapIndex, int endMapIndex) throws Exception { long[] partitionSize = prepare(mapCount); CelebornConf conf = new CelebornConf(); conf.set(CelebornConf.SHUFFLE_CHUNK_SIZE().key(), "8m"); PartitionFilesSorter partitionFilesSorter = new PartitionFilesSorter(MemoryManager.instance(), conf, new WorkerSource(conf)); - FileInfo info = - partitionFilesSorter.getSortedFileInfo( - "application-1", - "", - partitionDataWriter.getMemoryFileInfo(), - startMapIndex, - endMapIndex); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference resultRef = new AtomicReference<>(); + AtomicReference errorRef = new AtomicReference<>(); + partitionFilesSorter.getSortedFileInfo( + "application-1", + "", + partitionDataWriter.getMemoryFileInfo(), + startMapIndex, + endMapIndex, + new FileResolvedCallback() { + @Override + public void onSuccess(FileInfo sortedFileInfo) { + resultRef.set(sortedFileInfo); + latch.countDown(); + } + + @Override + public void onFailure(Throwable e) { + errorRef.set(e); + latch.countDown(); + } + }); + + Assert.assertTrue("Sort timed out", latch.await(60, TimeUnit.SECONDS)); + Assert.assertNull("Sort failed: " + errorRef.get(), errorRef.get()); + FileInfo info = resultRef.get(); + long totalSizeToFetch = 0; for (int i = startMapIndex; i < endMapIndex; i++) { totalSizeToFetch += partitionSize[i]; @@ -152,7 +176,7 @@ private void check(int mapCount, int startMapIndex, int endMapIndex) throws IOEx } @Test - public void testSortMemoryShuffleFile() throws IOException { + public void testSortMemoryShuffleFile() throws Exception { int startMapIndex = random.nextInt(5); int endMapIndex = startMapIndex + random.nextInt(5) + 5; check(1000, startMapIndex, endMapIndex); diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryReducePartitionDataWriterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryReducePartitionDataWriterSuiteJ.java index 92d0fd5d416..cb30446e17e 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryReducePartitionDataWriterSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryReducePartitionDataWriterSuiteJ.java @@ -192,9 +192,14 @@ public boolean checkRegistered() { } }; PartitionFilesSorter sorter = mock(PartitionFilesSorter.class); - Mockito.doReturn(info) + Mockito.doAnswer( + invocation -> { + FileResolvedCallback callback = invocation.getArgument(5); + callback.onSuccess(info); + return null; + }) .when(sorter) - .getSortedFileInfo(anyString(), anyString(), eq(info), anyInt(), anyInt()); + .getSortedFileInfo(anyString(), anyString(), eq(info), anyInt(), anyInt(), any()); handler.setPartitionsSorter(sorter); TransportContext context = new TransportContext(transConf, handler); server = context.createServer();