Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class DiskFileInfo extends FileInfo {
private static final Logger logger = LoggerFactory.getLogger(DiskFileInfo.class);
private final String filePath;
private final StorageInfo.Type storageType;
private final boolean isSortedDiskFileInfo;

public DiskFileInfo(
UserIdentifier userIdentifier,
Expand All @@ -49,6 +50,7 @@ public DiskFileInfo(
super(userIdentifier, partitionSplitEnabled, fileMeta);
this.filePath = filePath;
this.storageType = storageType;
this.isSortedDiskFileInfo = false;
}

// only called when restore from pb or in UT
Expand All @@ -67,6 +69,7 @@ public DiskFileInfo(
this.storageType = StorageInfo.Type.HDD;
}
this.bytesFlushed = bytesFlushed;
this.isSortedDiskFileInfo = false;
}

@VisibleForTesting
Expand All @@ -79,10 +82,11 @@ public DiskFileInfo(File file, UserIdentifier userIdentifier, CelebornConf conf)
StorageInfo.Type.HDD);
}

public DiskFileInfo(UserIdentifier userIdentifier, FileMeta fileMeta, String filePath) {
public DiskFileInfo(UserIdentifier userIdentifier, FileMeta fileMeta, String filePath, boolean isSortedDiskFileInfo) {
super(userIdentifier, true, fileMeta);
this.filePath = filePath;
this.storageType = StorageInfo.Type.HDD;
this.isSortedDiskFileInfo = isSortedDiskFileInfo;
}

public File getFile() {
Expand Down Expand Up @@ -175,4 +179,8 @@ public boolean isDFS() {
public StorageInfo.Type getStorageType() {
return storageType;
}

public boolean isSortedDiskFileInfo() {
return isSortedDiskFileInfo;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@

public class FileChunkBuffers extends ChunkBuffers {
private final File file;
private final boolean isSortedFileInfo;
private final TransportConf conf;

public FileChunkBuffers(DiskFileInfo fileInfo, TransportConf conf) {
super(fileInfo.getReduceFileMeta());
file = fileInfo.getFile();
isSortedFileInfo = fileInfo.isSortedDiskFileInfo();
this.conf = conf;
}

Expand All @@ -39,4 +41,8 @@ public ManagedBuffer chunk(int chunkIndex, int offset, int len) {
Tuple2<Long, Long> offsetLen = getChunkOffsetLength(chunkIndex, offset, len);
return new FileSegmentManagedBuffer(conf, file, offsetLen._1, offsetLen._2);
}

public boolean isSortedFileInfo() {
return isSortedFileInfo;
}
}
5 changes: 5 additions & 0 deletions common/src/main/proto/TransportMessages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,11 @@ message PbSortedShuffleFileSet {
repeated string files = 1;
}

message PbRegisteredStream {
string shuffleKey = 1;
string fileName = 2;
}

message PbStoreVersion {
int32 major = 1;
int32 minor = 2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ object PbSerDeUtils {
.build
.toByteArray

def toPbRegisteredStream(shuffleKey: String, fileName: String, isBufferBacked: Boolean): Array[Byte] =
PbRegisteredStream.newBuilder
.setShuffleKey(shuffleKey)
.setFileName(fileName)
.setIsBufferBacked(isBufferBacked)
.build.toByteArray

@throws[InvalidProtocolBufferException]
def fromPbRegisteredStream(data: Array[Byte]): PbRegisteredStream = {
val pbRegisteredStream = PbRegisteredStream.parseFrom(data)
}

@throws[InvalidProtocolBufferException]
def fromPbStoreVersion(data: Array[Byte]): util.ArrayList[Integer] = {
val pbStoreVersion = PbStoreVersion.parseFrom(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,24 @@

package org.apache.celeborn.service.deploy.worker.storage;

import java.io.File;
import java.nio.ByteBuffer;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

import com.google.common.annotations.VisibleForTesting;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.meta.DiskFileInfo;
import org.apache.celeborn.common.network.buffer.FileChunkBuffers;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.util.CelebornExitKind;
import org.apache.celeborn.common.util.PbSerDeUtils;
import org.apache.celeborn.service.deploy.worker.shuffledb.DB;
import org.apache.celeborn.service.deploy.worker.shuffledb.DBBackend;
import org.apache.celeborn.service.deploy.worker.shuffledb.DBProvider;
import org.apache.celeborn.service.deploy.worker.shuffledb.StoreVersion;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -37,12 +49,17 @@
*/
public class ChunkStreamManager {
private static final Logger logger = LoggerFactory.getLogger(ChunkStreamManager.class);
private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0);
private static final String RECOVERY_REGISTERED_STREAMS = "registeredStreams";

private final AtomicLong nextStreamId;
// StreamId -> StreamState
protected final ConcurrentHashMap<Long, StreamState> streams;
// ShuffleKey -> StreamId
protected final ConcurrentHashMap<String, Set<Long>> shuffleStreamIds;
private final CelebornConf conf;
private File recoverFile;
private DB registeredStreamsDb;

/** State of a single stream. */
public static class StreamState {
Expand All @@ -62,12 +79,63 @@ public static class StreamState {
}
}

public ChunkStreamManager() {
public ChunkStreamManager(CelebornConf conf) {
// For debugging purposes, start with a random stream id to help identifying different streams.
// This does not need to be globally unique, only unique to this class.
nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000);
streams = JavaUtils.newConcurrentHashMap();
shuffleStreamIds = JavaUtils.newConcurrentHashMap();
this.conf = conf;
boolean gracefulShutdown = conf.workerGracefulShutdown();
if (gracefulShutdown) {
try {
String recoverPath = conf.workerGracefulShutdownRecoverPath();
DBBackend dbBackend = DBBackend.byName(conf.workerGracefulShutdownRecoverDbBackend());
String recoveryRegisteredStreams = dbBackend.fileName(RECOVERY_REGISTERED_STREAMS);
this.recoverFile = new File(recoverPath, recoveryRegisteredStreams);
this.registeredStreamsDb = DBProvider.initDB(dbBackend, recoverFile, CURRENT_VERSION);
} catch (Exception e) {
throw new IllegalStateException(
"Failed to reload DB for sorted shuffle files from: " + recoverFile, e);
}
}
}

public void init(StorageManager storageManager, TransportConf transportConf) {
reloadRegisteredStreams(storageManager, transportConf);
}

private void reloadRegisteredStreams(StorageManager storageManager, TransportConf transportConf) {
registeredStreamsDb.iterator().forEachRemaining(
entry -> {
try {
long streamId = ByteBuffer.wrap(entry.getKey()).getLong();
PbRegisteredStream pbRegisteredStream =
PbSerDeUtils.fromPbRegisteredStream(entry.getValue());
String shuffleKey = pbRegisteredStream.getShuffleKey();
String fileName = pbRegisteredStream.getFileName();
Boolean isBufferBacked = pbRegisteredStream.getIsBufferBacked();

ChunkBuffers buffers = null;
TimeWindow fetchTimeMetric = null;
Boolean isValidRestore = true;
if (isBufferBacked) {
DiskFileInfo diskFileInfo = (DiskFileInfo) storageManager.getFileInfo(shuffleKey, fileName);
if (diskFileInfo != null) {
buffers = new FileChunkBuffers(diskFileInfo, transportConf);
fetchTimeMetric = storageManager.getFetchTimeMetric(diskFileInfo.getFile());
} else {
isValidRestore = false;
}
}

if (isValidRestore) {
registerStream(streamId, shuffleKey, buffers, fileName, fetchTimeMetric);
}
} catch (Exception e) {
logger.error("Failed to reload registered stream from DB entry: " + entry, e);
}
});
}

public ManagedBuffer getChunk(long streamId, int chunkIndex, int offset, int len) {
Expand Down Expand Up @@ -171,7 +239,12 @@ public long registerStream(
}

public long nextStreamId() {
return nextStreamId.getAndIncrement();
long currentId = nextStreamId.getAndIncrement();
while (streams.containsKey(currentId)) {
currentId = nextStreamId.getAndIncrement();;
}

return currentId;
}

public void cleanupExpiredShuffleKey(Set<String> expiredShuffleKeys) {
Expand Down Expand Up @@ -208,4 +281,42 @@ public int getStreamsCount() {
public long numShuffleSteams() {
return shuffleStreamIds.values().stream().mapToLong(Set::size).sum();
}

private void persisteRegisteredStreams() {
streams.forEach(
(streamId, streamState) -> {
// Only need to persist FileChunkBuffers since MemoryChunkBuffers aren't restored anyways
// Skipping sorted file info since these are not restored
if (streamState.buffers == null || (streamState.buffers instanceof FileChunkBuffers &&
!((FileChunkBuffers) streamState.buffers).isSortedFileInfo())) {
ByteBuffer keyBuffer = ByteBuffer.allocate(Long.BYTES);
keyBuffer.putLong(0, streamId);
keyBuffer.flip();
try {
registeredStreamsDb.put(keyBuffer.array(),
PbSerDeUtils.toPbRegisteredStream(streamState.shuffleKey, streamState.fileName,
streamState.buffers != null));
} catch (Exception e) {
logger.error("Failed to persist stream state for streamId: " + streamId, e);
}
}
});
}

public void close(int exitKind) {
logger.info("Closing {}", this.getClass().getSimpleName());
if (exitKind == CelebornExitKind.WORKER_GRACEFUL_SHUTDOWN() && registeredStreamsDb != null) {
try {
persisteRegisteredStreams();
} catch (Exception e) {
logger.error("Failed to persist registered streams to DB: " + recoverFile, e);
} finally {
try {
registeredStreamsDb.close();
} catch (Exception e) {
logger.error("Failed to close DB for registered streams: " + recoverFile, e);
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ public DiskFileInfo resolve(
ShuffleBlockInfoUtils.getChunkOffsetsFromShuffleBlockInfos(
startMapIndex, endMapIndex, shuffleChunkSize, indexMap, false),
shuffleChunkSize);
return new DiskFileInfo(userIdentifier, reduceFileMeta, sortedFilePath);
return new DiskFileInfo(userIdentifier, reduceFileMeta, sortedFilePath, true);
}

class FileSorter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class FetchHandler(
var registered: Option[AtomicBoolean] = None

def init(worker: Worker): Unit = {
chunkStreamManager.init(worker.storageManager, transportConf)
workerSource.addGauge(WorkerSource.ACTIVE_CHUNK_STREAM_COUNT) { () =>
chunkStreamManager.getStreamsCount
}
Expand All @@ -79,7 +80,7 @@ class FetchHandler(
}

def getChunkStreamManager: ChunkStreamManager = {
new ChunkStreamManager()
new ChunkStreamManager(conf)
}

def getRawFileInfo(
Expand Down Expand Up @@ -668,4 +669,8 @@ class FetchHandler(
def setPartitionsSorter(partitionFilesSorter: PartitionFilesSorter): Unit = {
this.partitionsSorter = partitionFilesSorter
}

def close(exitKind: Int): Unit = {
chunkStreamManager.close(exitKind)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ private[celeborn] class Worker(
partitionsSorter.close(exitKind)
storageManager.close(exitKind)
memoryManager.close()
fetchHandler.close(exitKind)
Option(CongestionController.instance()).foreach(_.close())

masterClient.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ public boolean checkRegistered() {
public void furtherRequestsDelay() throws Exception {
final byte[] response = new byte[16];
final ChunkStreamManager manager =
new ChunkStreamManager() {
new ChunkStreamManager(new CelebornConf()) {
@Override
public ManagedBuffer getChunk(long streamId, int chunkIndex, int offset, int len) {
Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ static void initialize(CelebornConf celebornConf) throws Exception {
fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25);

chunkStreamManager =
new ChunkStreamManager() {
new ChunkStreamManager(new CelebornConf()) {
@Override
public ManagedBuffer getChunk(long streamId, int chunkIndex, int offset, int len) {
assertEquals(STREAM_ID, streamId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Arrays;
import java.util.HashSet;

import org.apache.celeborn.common.CelebornConf;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;
Expand All @@ -29,7 +30,7 @@
public class ChunkStreamManagerSuiteJ {
@Test
public void testStreamRegisterAndCleanup() {
ChunkStreamManager manager = new ChunkStreamManager();
ChunkStreamManager manager = new ChunkStreamManager(new CelebornConf());

@SuppressWarnings("unchecked")
FileChunkBuffers buffers = Mockito.mock(FileChunkBuffers.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ object NettyTransportBenchmark extends BenchmarkBase {
conf: TransportConf,
streamId: Long,
files: Seq[File]): ChunkStreamManager = {
val streamManager = new ChunkStreamManager() {
val streamManager = new ChunkStreamManager(new CelebornConf()) {
override def getChunk(
streamId: Long,
chunkIndex: Int,
Expand Down
Loading