diff --git a/common/src/main/java/org/apache/celeborn/common/meta/DiskFileInfo.java b/common/src/main/java/org/apache/celeborn/common/meta/DiskFileInfo.java index d4571fa4bbe..637a1458636 100644 --- a/common/src/main/java/org/apache/celeborn/common/meta/DiskFileInfo.java +++ b/common/src/main/java/org/apache/celeborn/common/meta/DiskFileInfo.java @@ -39,6 +39,7 @@ public class DiskFileInfo extends FileInfo { private static final Logger logger = LoggerFactory.getLogger(DiskFileInfo.class); private final String filePath; private final StorageInfo.Type storageType; + private final boolean isSortedDiskFileInfo; public DiskFileInfo( UserIdentifier userIdentifier, @@ -49,6 +50,7 @@ public DiskFileInfo( super(userIdentifier, partitionSplitEnabled, fileMeta); this.filePath = filePath; this.storageType = storageType; + this.isSortedDiskFileInfo = false; } // only called when restore from pb or in UT @@ -67,6 +69,7 @@ public DiskFileInfo( this.storageType = StorageInfo.Type.HDD; } this.bytesFlushed = bytesFlushed; + this.isSortedDiskFileInfo = false; } @VisibleForTesting @@ -79,10 +82,11 @@ public DiskFileInfo(File file, UserIdentifier userIdentifier, CelebornConf conf) StorageInfo.Type.HDD); } - public DiskFileInfo(UserIdentifier userIdentifier, FileMeta fileMeta, String filePath) { + public DiskFileInfo(UserIdentifier userIdentifier, FileMeta fileMeta, String filePath, boolean isSortedDiskFileInfo) { super(userIdentifier, true, fileMeta); this.filePath = filePath; this.storageType = StorageInfo.Type.HDD; + this.isSortedDiskFileInfo = isSortedDiskFileInfo; } public File getFile() { @@ -175,4 +179,8 @@ public boolean isDFS() { public StorageInfo.Type getStorageType() { return storageType; } + + public boolean isSortedDiskFileInfo() { + return isSortedDiskFileInfo; + } } diff --git a/common/src/main/java/org/apache/celeborn/common/network/buffer/FileChunkBuffers.java b/common/src/main/java/org/apache/celeborn/common/network/buffer/FileChunkBuffers.java index ed4969d3ad0..d5fedf1e25d 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/buffer/FileChunkBuffers.java +++ b/common/src/main/java/org/apache/celeborn/common/network/buffer/FileChunkBuffers.java @@ -26,11 +26,13 @@ public class FileChunkBuffers extends ChunkBuffers { private final File file; + private final boolean isSortedFileInfo; private final TransportConf conf; public FileChunkBuffers(DiskFileInfo fileInfo, TransportConf conf) { super(fileInfo.getReduceFileMeta()); file = fileInfo.getFile(); + isSortedFileInfo = fileInfo.isSortedDiskFileInfo(); this.conf = conf; } @@ -39,4 +41,8 @@ public ManagedBuffer chunk(int chunkIndex, int offset, int len) { Tuple2 offsetLen = getChunkOffsetLength(chunkIndex, offset, len); return new FileSegmentManagedBuffer(conf, file, offsetLen._1, offsetLen._2); } + + public boolean isSortedFileInfo() { + return isSortedFileInfo; + } } diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index a813a9e5015..b3bc01e922b 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -643,6 +643,11 @@ message PbSortedShuffleFileSet { repeated string files = 1; } +message PbRegisteredStream { + string shuffleKey = 1; + string fileName = 2; +} + message PbStoreVersion { int32 major = 1; int32 minor = 2; diff --git a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala index e9c407ce80e..5743a6ed732 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala @@ -56,6 +56,18 @@ object PbSerDeUtils { .build .toByteArray + def toPbRegisteredStream(shuffleKey: String, fileName: String, isBufferBacked: Boolean): Array[Byte] = + PbRegisteredStream.newBuilder + .setShuffleKey(shuffleKey) + .setFileName(fileName) + .setIsBufferBacked(isBufferBacked) + .build.toByteArray + + @throws[InvalidProtocolBufferException] + def fromPbRegisteredStream(data: Array[Byte]): PbRegisteredStream = { + val pbRegisteredStream = PbRegisteredStream.parseFrom(data) + } + @throws[InvalidProtocolBufferException] def fromPbStoreVersion(data: Array[Byte]): util.ArrayList[Integer] = { val pbStoreVersion = PbStoreVersion.parseFrom(data) diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/ChunkStreamManager.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/ChunkStreamManager.java index 86ecfbe6a5e..db3e96c49a7 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/ChunkStreamManager.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/ChunkStreamManager.java @@ -17,12 +17,24 @@ package org.apache.celeborn.service.deploy.worker.storage; +import java.io.File; +import java.nio.ByteBuffer; import java.util.Random; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; import com.google.common.annotations.VisibleForTesting; +import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.meta.DiskFileInfo; +import org.apache.celeborn.common.network.buffer.FileChunkBuffers; +import org.apache.celeborn.common.network.util.TransportConf; +import org.apache.celeborn.common.util.CelebornExitKind; +import org.apache.celeborn.common.util.PbSerDeUtils; +import org.apache.celeborn.service.deploy.worker.shuffledb.DB; +import org.apache.celeborn.service.deploy.worker.shuffledb.DBBackend; +import org.apache.celeborn.service.deploy.worker.shuffledb.DBProvider; +import org.apache.celeborn.service.deploy.worker.shuffledb.StoreVersion; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,12 +49,17 @@ */ public class ChunkStreamManager { private static final Logger logger = LoggerFactory.getLogger(ChunkStreamManager.class); + private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0); + private static final String RECOVERY_REGISTERED_STREAMS = "registeredStreams"; private final AtomicLong nextStreamId; // StreamId -> StreamState protected final ConcurrentHashMap streams; // ShuffleKey -> StreamId protected final ConcurrentHashMap> shuffleStreamIds; + private final CelebornConf conf; + private File recoverFile; + private DB registeredStreamsDb; /** State of a single stream. */ public static class StreamState { @@ -62,12 +79,63 @@ public static class StreamState { } } - public ChunkStreamManager() { + public ChunkStreamManager(CelebornConf conf) { // For debugging purposes, start with a random stream id to help identifying different streams. // This does not need to be globally unique, only unique to this class. nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); streams = JavaUtils.newConcurrentHashMap(); shuffleStreamIds = JavaUtils.newConcurrentHashMap(); + this.conf = conf; + boolean gracefulShutdown = conf.workerGracefulShutdown(); + if (gracefulShutdown) { + try { + String recoverPath = conf.workerGracefulShutdownRecoverPath(); + DBBackend dbBackend = DBBackend.byName(conf.workerGracefulShutdownRecoverDbBackend()); + String recoveryRegisteredStreams = dbBackend.fileName(RECOVERY_REGISTERED_STREAMS); + this.recoverFile = new File(recoverPath, recoveryRegisteredStreams); + this.registeredStreamsDb = DBProvider.initDB(dbBackend, recoverFile, CURRENT_VERSION); + } catch (Exception e) { + throw new IllegalStateException( + "Failed to reload DB for sorted shuffle files from: " + recoverFile, e); + } + } + } + + public void init(StorageManager storageManager, TransportConf transportConf) { + reloadRegisteredStreams(storageManager, transportConf); + } + + private void reloadRegisteredStreams(StorageManager storageManager, TransportConf transportConf) { + registeredStreamsDb.iterator().forEachRemaining( + entry -> { + try { + long streamId = ByteBuffer.wrap(entry.getKey()).getLong(); + PbRegisteredStream pbRegisteredStream = + PbSerDeUtils.fromPbRegisteredStream(entry.getValue()); + String shuffleKey = pbRegisteredStream.getShuffleKey(); + String fileName = pbRegisteredStream.getFileName(); + Boolean isBufferBacked = pbRegisteredStream.getIsBufferBacked(); + + ChunkBuffers buffers = null; + TimeWindow fetchTimeMetric = null; + Boolean isValidRestore = true; + if (isBufferBacked) { + DiskFileInfo diskFileInfo = (DiskFileInfo) storageManager.getFileInfo(shuffleKey, fileName); + if (diskFileInfo != null) { + buffers = new FileChunkBuffers(diskFileInfo, transportConf); + fetchTimeMetric = storageManager.getFetchTimeMetric(diskFileInfo.getFile()); + } else { + isValidRestore = false; + } + } + + if (isValidRestore) { + registerStream(streamId, shuffleKey, buffers, fileName, fetchTimeMetric); + } + } catch (Exception e) { + logger.error("Failed to reload registered stream from DB entry: " + entry, e); + } + }); } public ManagedBuffer getChunk(long streamId, int chunkIndex, int offset, int len) { @@ -171,7 +239,12 @@ public long registerStream( } public long nextStreamId() { - return nextStreamId.getAndIncrement(); + long currentId = nextStreamId.getAndIncrement(); + while (streams.containsKey(currentId)) { + currentId = nextStreamId.getAndIncrement();; + } + + return currentId; } public void cleanupExpiredShuffleKey(Set expiredShuffleKeys) { @@ -208,4 +281,42 @@ public int getStreamsCount() { public long numShuffleSteams() { return shuffleStreamIds.values().stream().mapToLong(Set::size).sum(); } + + private void persisteRegisteredStreams() { + streams.forEach( + (streamId, streamState) -> { + // Only need to persist FileChunkBuffers since MemoryChunkBuffers aren't restored anyways + // Skipping sorted file info since these are not restored + if (streamState.buffers == null || (streamState.buffers instanceof FileChunkBuffers && + !((FileChunkBuffers) streamState.buffers).isSortedFileInfo())) { + ByteBuffer keyBuffer = ByteBuffer.allocate(Long.BYTES); + keyBuffer.putLong(0, streamId); + keyBuffer.flip(); + try { + registeredStreamsDb.put(keyBuffer.array(), + PbSerDeUtils.toPbRegisteredStream(streamState.shuffleKey, streamState.fileName, + streamState.buffers != null)); + } catch (Exception e) { + logger.error("Failed to persist stream state for streamId: " + streamId, e); + } + } + }); + } + + public void close(int exitKind) { + logger.info("Closing {}", this.getClass().getSimpleName()); + if (exitKind == CelebornExitKind.WORKER_GRACEFUL_SHUTDOWN() && registeredStreamsDb != null) { + try { + persisteRegisteredStreams(); + } catch (Exception e) { + logger.error("Failed to persist registered streams to DB: " + recoverFile, e); + } finally { + try { + registeredStreamsDb.close(); + } catch (Exception e) { + logger.error("Failed to close DB for registered streams: " + recoverFile, e); + } + } + } + } } diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java index cdfb95c1e59..5ccbba6b42d 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java @@ -664,7 +664,7 @@ public DiskFileInfo resolve( ShuffleBlockInfoUtils.getChunkOffsetsFromShuffleBlockInfos( startMapIndex, endMapIndex, shuffleChunkSize, indexMap, false), shuffleChunkSize); - return new DiskFileInfo(userIdentifier, reduceFileMeta, sortedFilePath); + return new DiskFileInfo(userIdentifier, reduceFileMeta, sortedFilePath, true); } class FileSorter { diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala index 7ad990e2bf4..b652e51114c 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala @@ -61,6 +61,7 @@ class FetchHandler( var registered: Option[AtomicBoolean] = None def init(worker: Worker): Unit = { + chunkStreamManager.init(worker.storageManager, transportConf) workerSource.addGauge(WorkerSource.ACTIVE_CHUNK_STREAM_COUNT) { () => chunkStreamManager.getStreamsCount } @@ -79,7 +80,7 @@ class FetchHandler( } def getChunkStreamManager: ChunkStreamManager = { - new ChunkStreamManager() + new ChunkStreamManager(conf) } def getRawFileInfo( @@ -668,4 +669,8 @@ class FetchHandler( def setPartitionsSorter(partitionFilesSorter: PartitionFilesSorter): Unit = { this.partitionsSorter = partitionFilesSorter } + + def close(exitKind: Int): Unit = { + chunkStreamManager.close(exitKind) + } } diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala index da2cab1c349..5b34c09b4b0 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala @@ -665,6 +665,7 @@ private[celeborn] class Worker( partitionsSorter.close(exitKind) storageManager.close(exitKind) memoryManager.close() + fetchHandler.close(exitKind) Option(CongestionController.instance()).foreach(_.close()) masterClient.close() diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java index d9494c42978..262ee054d06 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java @@ -215,7 +215,7 @@ public boolean checkRegistered() { public void furtherRequestsDelay() throws Exception { final byte[] response = new byte[16]; final ChunkStreamManager manager = - new ChunkStreamManager() { + new ChunkStreamManager(new CelebornConf()) { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex, int offset, int len) { Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS); diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java index f4eaf618ecb..e7cd2b71369 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java @@ -100,7 +100,7 @@ static void initialize(CelebornConf celebornConf) throws Exception { fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); chunkStreamManager = - new ChunkStreamManager() { + new ChunkStreamManager(new CelebornConf()) { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex, int offset, int len) { assertEquals(STREAM_ID, streamId); diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkStreamManagerSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkStreamManagerSuiteJ.java index 5a055b7b46c..d48279d3419 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkStreamManagerSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkStreamManagerSuiteJ.java @@ -20,6 +20,7 @@ import java.util.Arrays; import java.util.HashSet; +import org.apache.celeborn.common.CelebornConf; import org.junit.Assert; import org.junit.Test; import org.mockito.Mockito; @@ -29,7 +30,7 @@ public class ChunkStreamManagerSuiteJ { @Test public void testStreamRegisterAndCleanup() { - ChunkStreamManager manager = new ChunkStreamManager(); + ChunkStreamManager manager = new ChunkStreamManager(new CelebornConf()); @SuppressWarnings("unchecked") FileChunkBuffers buffers = Mockito.mock(FileChunkBuffers.class); diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/network/NettyTransportBenchmark.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/network/NettyTransportBenchmark.scala index 8567e251044..3781ee07836 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/network/NettyTransportBenchmark.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/network/NettyTransportBenchmark.scala @@ -213,7 +213,7 @@ object NettyTransportBenchmark extends BenchmarkBase { conf: TransportConf, streamId: Long, files: Seq[File]): ChunkStreamManager = { - val streamManager = new ChunkStreamManager() { + val streamManager = new ChunkStreamManager(new CelebornConf()) { override def getChunk( streamId: Long, chunkIndex: Int,