diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 798099bcbd2..7649af01612 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -1371,6 +1371,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se get(WORKER_GRACEFUL_SHUTDOWN_SAVE_COMMITTED_FILEINFO_SYNC) def workerGracefulShutdownDbDeleteFailurePolicy: String = get(WORKER_GRACEFUL_SHUTDOWN_DB_DELETE_FAILURE_POLICY) + def workerGracefulShutdownCommitUncommittedPartitionsEnabled: Boolean = + get(WORKER_GRACEFUL_SHUTDOWN_COMMIT_UNCOMMITTED_PARTITIONS_ENABLED) // ////////////////////////////////////////////////////// // Flusher // @@ -4003,6 +4005,15 @@ object CelebornConf extends Logging { .checkValues(Set("THROW", "EXIT", "IGNORE")) .createWithDefault("IGNORE") + val WORKER_GRACEFUL_SHUTDOWN_COMMIT_UNCOMMITTED_PARTITIONS_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.worker.graceful.shutdown.commitUncommittedPartitions.enabled") + .categories("worker") + .doc("When true, during graceful shutdown the worker commits uncommitted " + + "partitions instead of waiting for LifecycleManager to send CommitFiles RPCs.") + .version("0.7.0") + .booleanConf + .createWithDefault(false) + val WORKER_DISKTIME_SLIDINGWINDOW_SIZE: ConfigEntry[Int] = buildConf("celeborn.worker.flusher.diskTime.slidingWindow.size") .withAlternative("celeborn.worker.flusher.avgFlushTime.slidingWindow.size") diff --git a/common/src/main/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfo.scala b/common/src/main/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfo.scala index 373f3656583..84f11446e41 100644 --- a/common/src/main/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfo.scala +++ b/common/src/main/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfo.scala @@ -173,6 +173,24 @@ class WorkerPartitionLocationInfo extends Logging { } else null } + /** + * Snapshot uncommitted partition unique IDs grouped by shuffle key. + * The returned snapshot is a best-effort view because ConcurrentHashMap + * iteration is weakly consistent — concurrent mutations may or may not + * be visible. + * + * @return (primaryIds, replicaIds) — each a Map[shuffleKey, List[uniqueId]] + */ + def snapshotUncommittedUniqueIds + : (Map[String, util.List[String]], Map[String, util.List[String]]) = + (snapshotIds(primaryPartitionLocations), snapshotIds(replicaPartitionLocations)) + + private def snapshotIds(partInfo: PartitionInfo): Map[String, util.List[String]] = + partInfo.asScala.collect { + case (shuffleKey, partMap) if !partMap.isEmpty => + shuffleKey -> new util.ArrayList[String](partMap.keySet()) + }.toMap + def isEmpty: Boolean = { (primaryPartitionLocations.isEmpty || primaryPartitionLocations.asScala.values.forall(_.isEmpty)) && diff --git a/common/src/test/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfoSuite.scala b/common/src/test/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfoSuite.scala index b4825cba2b0..49734331ed7 100644 --- a/common/src/test/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfoSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/meta/WorkerPartitionLocationInfoSuite.scala @@ -64,6 +64,68 @@ class WorkerPartitionLocationInfoSuite extends CelebornFunSuite { assertEquals(workerPartitionLocationInfo.isEmpty, true) } + test("snapshotUncommittedUniqueIds - empty info returns empty maps") { + val info = new WorkerPartitionLocationInfo + val (primary, replica) = info.snapshotUncommittedUniqueIds + assert(primary.isEmpty) + assert(replica.isEmpty) + } + + test("snapshotUncommittedUniqueIds - captures correct IDs across shuffles") { + val info = new WorkerPartitionLocationInfo + val shuffle1 = "app1-0" + val shuffle2 = "app2-1" + val locs1 = new util.ArrayList[PartitionLocation]() + locs1.add(mockPartition(0, 0)) + locs1.add(mockPartition(1, 0)) + info.addPrimaryPartitions(shuffle1, locs1) + val locs2 = new util.ArrayList[PartitionLocation]() + locs2.add(mockPartition(2, 0)) + info.addPrimaryPartitions(shuffle2, locs2) + val replicaLocs = new util.ArrayList[PartitionLocation]() + replicaLocs.add(mockPartition(3, 0)) + info.addReplicaPartitions(shuffle1, replicaLocs) + val (primary, replica) = info.snapshotUncommittedUniqueIds + assert(primary.size == 2) + assert(primary(shuffle1).size() == 2) + assert(primary(shuffle1).contains("0-0")) + assert(primary(shuffle1).contains("1-0")) + assert(primary(shuffle2).size() == 1) + assert(primary(shuffle2).contains("2-0")) + assert(replica.size == 1) + assert(replica(shuffle1).size() == 1) + assert(replica(shuffle1).contains("3-0")) + } + + test("snapshotUncommittedUniqueIds - filters empty shuffle keys") { + val info = new WorkerPartitionLocationInfo + val shuffleKey = "app1-0" + val locs = new util.ArrayList[PartitionLocation]() + locs.add(mockPartition(0, 0)) + locs.add(mockPartition(1, 0)) + info.addPrimaryPartitions(shuffleKey, locs) + info.removePrimaryPartitions(shuffleKey, locs.asScala.map(_.getUniqueId).asJava) + val (primary, _) = info.snapshotUncommittedUniqueIds + assert(!primary.contains(shuffleKey)) + } + + test("snapshotUncommittedUniqueIds - snapshot is a point-in-time copy") { + val info = new WorkerPartitionLocationInfo + val shuffleKey = "app1-0" + val locs = new util.ArrayList[PartitionLocation]() + locs.add(mockPartition(0, 0)) + info.addPrimaryPartitions(shuffleKey, locs) + val (primary, _) = info.snapshotUncommittedUniqueIds + assert(primary(shuffleKey).size() == 1) + // Add more partitions after snapshot + val moreLocs = new util.ArrayList[PartitionLocation]() + moreLocs.add(mockPartition(1, 0)) + moreLocs.add(mockPartition(2, 0)) + info.addPrimaryPartitions(shuffleKey, moreLocs) + // Snapshot remains unchanged + assert(primary(shuffleKey).size() == 1) + } + private def mockPartition(partitionId: Int, epoch: Int): PartitionLocation = { new PartitionLocation( partitionId, diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md index bb2cec89bc3..704ed235b65 100644 --- a/docs/configuration/worker.md +++ b/docs/configuration/worker.md @@ -102,6 +102,7 @@ license: | | celeborn.worker.flusher.threads | 16 | false | Flusher's thread count per disk for unknown-type disks. | 0.2.0 | | | celeborn.worker.graceful.shutdown.checkSlotsFinished.interval | 1s | false | The wait interval of checking whether all released slots to be committed or destroyed during worker graceful shutdown | 0.2.0 | | | celeborn.worker.graceful.shutdown.checkSlotsFinished.timeout | 480s | false | The wait time of waiting for the released slots to be committed or destroyed during worker graceful shutdown. | 0.2.0 | | +| celeborn.worker.graceful.shutdown.commitUncommittedPartitions.enabled | false | false | When true, during graceful shutdown the worker commits uncommitted partitions instead of waiting for LifecycleManager to send CommitFiles RPCs. | 0.7.0 | | | celeborn.worker.graceful.shutdown.dbDeleteFailurePolicy | IGNORE | false | Policy for handling DB delete failures during graceful shutdown. THROW: throw exception, EXIT: trigger graceful shutdown, IGNORE: log error and continue (default). | 0.7.0 | | | celeborn.worker.graceful.shutdown.enabled | false | false | When true, during worker shutdown, the worker will wait for all released slots to be committed or destroyed. | 0.2.0 | | | celeborn.worker.graceful.shutdown.partitionSorter.shutdownTimeout | 120s | false | The wait time of waiting for sorting partition files during worker graceful shutdown. | 0.2.0 | | diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala index ee959e4d6da..35b7a96c00f 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala @@ -459,6 +459,103 @@ private[deploy] class Controller( } } + /** + * Proactively commits all uncommitted partitions during graceful shutdown. + * + *
Commit results are tracked per-shuffle because uniqueId ({@code partitionId-epoch}) + * is not namespaced by shuffleKey — different shuffles can share the same uniqueId. + * + *
Only successfully committed or empty-file partitions are removed and their slots + * released. Failed or in-flight (timed-out) partitions are retained for the passive + * LifecycleManager CommitFiles retry path. + */ + private[worker] def commitUncommittedPartitions(): Unit = { + val (primarySnapshot, replicaSnapshot) = partitionLocationInfo.snapshotUncommittedUniqueIds + if (primarySnapshot.isEmpty && replicaSnapshot.isEmpty) { + logInfo("No uncommitted partitions.") + return + } + val shuffleKeys = primarySnapshot.keySet ++ replicaSnapshot.keySet + val primaryTotal = primarySnapshot.values.map(_.size()).sum + val replicaTotal = replicaSnapshot.values.map(_.size()).sum + logInfo(s"Committing uncommitted partitions across ${shuffleKeys.size} shuffles ($primaryTotal primary, $replicaTotal replica).") + val emptyIds = java.util.Collections.emptyList[String]() + val futures = ArrayBuffer[CompletableFuture[Void]]() + val tasks = ArrayBuffer[CompletableFuture[Void]]() + val committedPerShuffle = JavaUtils.newConcurrentHashMap[String, jSet[String]]() + val emptyPerShuffle = JavaUtils.newConcurrentHashMap[String, jSet[String]]() + for (shuffleKey <- shuffleKeys) { + val committedIds = ConcurrentHashMap.newKeySet[String]() + val emptyFileIds = ConcurrentHashMap.newKeySet[String]() + val failedIds = ConcurrentHashMap.newKeySet[String]() + val storageInfos = JavaUtils.newConcurrentHashMap[String, StorageInfo]() + val mapIdBitMap = JavaUtils.newConcurrentHashMap[String, RoaringBitmap]() + val partitionSizes = new LinkedBlockingQueue[Long]() + committedPerShuffle.put(shuffleKey, committedIds) + emptyPerShuffle.put(shuffleKey, emptyFileIds) + val primaryIds = primarySnapshot.getOrElse(shuffleKey, emptyIds) + val replicaIds = replicaSnapshot.getOrElse(shuffleKey, emptyIds) + val (primaryFuture, primaryTasks) = commitFiles( + shuffleKey, + primaryIds, + committedIds, + emptyFileIds, + failedIds, + storageInfos, + mapIdBitMap, + partitionSizes) + val (replicaFuture, replicaTasks) = commitFiles( + shuffleKey, + replicaIds, + committedIds, + emptyFileIds, + failedIds, + storageInfos, + mapIdBitMap, + partitionSizes, + isPrimary = false) + if (primaryFuture != null) { futures += primaryFuture } + if (replicaFuture != null) { futures += replicaFuture } + tasks ++= primaryTasks + tasks ++= replicaTasks + } + if (futures.nonEmpty) { + try { + CompletableFuture.allOf(futures.toArray: _*).get( + shuffleCommitTimeout, + TimeUnit.MILLISECONDS) + } catch { + case e: Exception => + futures.foreach(_.cancel(true)) + tasks.foreach(_.cancel(true)) + logWarning( + s"Commit timed out after ${shuffleCommitTimeout}ms across ${shuffleKeys.size} shuffles: ${shuffleKeys.mkString(", ")}", + e) + } + } + var primaryCommitted = 0 + var replicaCommitted = 0 + for (shuffleKey <- shuffleKeys) { + val committed = committedPerShuffle.get(shuffleKey) + val empty = emptyPerShuffle.get(shuffleKey) + def isCommitted(id: String): Boolean = committed.contains(id) || empty.contains(id) + val primaryToRemove = primarySnapshot.getOrElse(shuffleKey, emptyIds) + .asScala.filter(isCommitted).asJava + val replicaToRemove = replicaSnapshot.getOrElse(shuffleKey, emptyIds) + .asScala.filter(isCommitted).asJava + val (primarySlots, _) = + partitionLocationInfo.removePrimaryPartitions(shuffleKey, primaryToRemove) + val (replicaSlots, _) = + partitionLocationInfo.removeReplicaPartitions(shuffleKey, replicaToRemove) + workerInfo.releaseSlots(shuffleKey, primarySlots) + workerInfo.releaseSlots(shuffleKey, replicaSlots) + primaryCommitted += primaryToRemove.size() + replicaCommitted += replicaToRemove.size() + } + logInfo( + s"Committed ${primaryCommitted + replicaCommitted} partitions ($primaryCommitted primary, $replicaCommitted replica) across ${shuffleKeys.size} shuffles.") + } + private def handleCommitFiles( context: RpcCallContext, shuffleKey: String, diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala index da2cab1c349..fb369c66140 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala @@ -985,6 +985,17 @@ private[celeborn] class Worker( e) } shutdown.set(true) + + if (conf.workerGracefulShutdownCommitUncommittedPartitionsEnabled) { + // Commit uncommitted partitions instead of waiting for LifecycleManager to send CommitFiles RPCs. + try { + controller.commitUncommittedPartitions() + } catch { + case e: Throwable => + logError("Failed to commit uncommitted partitions during graceful shutdown", e) + } + } + val interval = conf.workerGracefulShutdownCheckSlotsFinishedInterval val timeout = conf.workerGracefulShutdownCheckSlotsFinishedTimeoutMs var waitTimes = 0 diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala index 2e13ef1d6bf..34946eb8b2a 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala @@ -17,7 +17,7 @@ package org.apache.celeborn.service.deploy.worker -import java.io.File +import java.io.{File, IOException} import java.nio.file.{Files, Paths} import java.util import java.util.{HashSet => JHashSet} @@ -33,7 +33,7 @@ import org.scalatest.funsuite.AnyFunSuite import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.identity.UserIdentifier -import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType} +import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType, StorageInfo} import org.apache.celeborn.common.protocol.message.ControlMessages.CommitFilesResponse import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.quota.ResourceConsumption @@ -303,4 +303,169 @@ class WorkerSuite extends AnyFunSuite with BeforeAndAfterEach { assert(shuffleCommitTime.get(shuffleKey).get(epoch2) == null) assert(epochCommitMap.get(epoch2).response.status == StatusCode.SUCCESS) } + + test("commitUncommittedPartitions - commits primary and replica partitions") { + val controller = initController() + val shuffleKey = "app1-0" + val writer1 = mockWriter(100L) + val writer2 = mockWriter(200L) + val writer3 = mockWriter(50L) + val primaryLocs = new util.ArrayList[PartitionLocation]() + primaryLocs.add(mockWorkingPartition(0, writer1)) + primaryLocs.add(mockWorkingPartition(1, writer2)) + worker.partitionLocationInfo.addPrimaryPartitions(shuffleKey, primaryLocs) + val replicaLocs = new util.ArrayList[PartitionLocation]() + replicaLocs.add(mockWorkingPartition(2, writer3, PartitionLocation.Mode.REPLICA)) + worker.partitionLocationInfo.addReplicaPartitions(shuffleKey, replicaLocs) + assert(!worker.partitionLocationInfo.isEmpty) + controller.commitUncommittedPartitions() + verify(writer1).close() + verify(writer2).close() + verify(writer3).close() + assert(worker.partitionLocationInfo.isEmpty) + } + + test("commitUncommittedPartitions - no-op when no partitions") { + val controller = initController() + assert(worker.partitionLocationInfo.isEmpty) + controller.commitUncommittedPartitions() + assert(worker.partitionLocationInfo.isEmpty) + } + + test("commitUncommittedPartitions - idempotent on double call") { + val controller = initController() + val shuffleKey = "app1-0" + val writer = mockWriter(100L) + val locs = new util.ArrayList[PartitionLocation]() + locs.add(mockWorkingPartition(0, writer)) + worker.partitionLocationInfo.addPrimaryPartitions(shuffleKey, locs) + controller.commitUncommittedPartitions() + assert(worker.partitionLocationInfo.isEmpty) + // Second call — no partitions remain, verify close only called once + controller.commitUncommittedPartitions() + assert(worker.partitionLocationInfo.isEmpty) + verify(writer, times(1)).close() + } + + test("commitUncommittedPartitions - retains failed partitions for passive wait") { + val controller = initController() + val shuffleKey = "app1-0" + val successWriter = mockWriter(100L) + val failWriter = mock[PartitionDataWriter] + when(failWriter.close()).thenThrow(new IOException("disk error")) + when(failWriter.getStorageInfo).thenReturn(new StorageInfo("/tmp", StorageInfo.Type.HDD, 1)) + when(failWriter.getMapIdBitMap).thenReturn(null) + when(failWriter.getMetaHandler).thenReturn(null) + val locs = new util.ArrayList[PartitionLocation]() + locs.add(mockWorkingPartition(0, successWriter)) + locs.add(mockWorkingPartition(1, failWriter)) + worker.partitionLocationInfo.addPrimaryPartitions(shuffleKey, locs) + controller.commitUncommittedPartitions() + // Successful partition (0-0) removed, failed partition (1-0) retained for LifecycleManager retry + assert(worker.partitionLocationInfo.getPrimaryLocation(shuffleKey, "1-0") != null) + assert(worker.partitionLocationInfo.getPrimaryLocation(shuffleKey, "0-0") == null) + } + + test("commitUncommittedPartitions - commits across multiple shuffle keys") { + val controller = initController() + val shuffle1 = "app1-0" + val shuffle2 = "app2-1" + val writer1 = mockWriter(100L) + val writer2 = mockWriter(200L) + val writer3 = mockWriter(50L) + val locs1 = new util.ArrayList[PartitionLocation]() + locs1.add(mockWorkingPartition(0, writer1)) + worker.partitionLocationInfo.addPrimaryPartitions(shuffle1, locs1) + val locs2 = new util.ArrayList[PartitionLocation]() + locs2.add(mockWorkingPartition(1, writer2)) + worker.partitionLocationInfo.addPrimaryPartitions(shuffle2, locs2) + val replicaLocs = new util.ArrayList[PartitionLocation]() + replicaLocs.add(mockWorkingPartition(2, writer3, PartitionLocation.Mode.REPLICA)) + worker.partitionLocationInfo.addReplicaPartitions(shuffle1, replicaLocs) + assert(!worker.partitionLocationInfo.isEmpty) + controller.commitUncommittedPartitions() + verify(writer1).close() + verify(writer2).close() + verify(writer3).close() + assert(worker.partitionLocationInfo.isEmpty) + } + + test("commitUncommittedPartitions - no cross-shuffle uniqueId collision") { + val controller = initController() + val shuffle1 = "app1-0" + val shuffle2 = "app2-1" + // Both shuffles have partition 0 (uniqueId "0-0") + val writer1 = mockWriter(100L) + val writer2 = mockWriter(200L) + val locs1 = new util.ArrayList[PartitionLocation]() + locs1.add(mockWorkingPartition(0, writer1)) + worker.partitionLocationInfo.addPrimaryPartitions(shuffle1, locs1) + val locs2 = new util.ArrayList[PartitionLocation]() + locs2.add(mockWorkingPartition(0, writer2)) + worker.partitionLocationInfo.addPrimaryPartitions(shuffle2, locs2) + controller.commitUncommittedPartitions() + verify(writer1).close() + verify(writer2).close() + // Both shuffles' partitions should be removed independently + assert(worker.partitionLocationInfo.isEmpty) + assert(worker.partitionLocationInfo.getPrimaryLocation(shuffle1, "0-0") == null) + assert(worker.partitionLocationInfo.getPrimaryLocation(shuffle2, "0-0") == null) + } + + test("commitUncommittedPartitions - cross-shuffle collision with partial failure") { + val controller = initController() + val shuffle1 = "app1-0" + val shuffle2 = "app2-1" + // Both shuffles have partition 0 (uniqueId "0-0") + val successWriter = mockWriter(100L) + val failWriter = mock[PartitionDataWriter] + when(failWriter.close()).thenThrow(new IOException("disk error")) + when(failWriter.getStorageInfo).thenReturn(new StorageInfo("/tmp", StorageInfo.Type.HDD, 1)) + when(failWriter.getMapIdBitMap).thenReturn(null) + when(failWriter.getMetaHandler).thenReturn(null) + val locs1 = new util.ArrayList[PartitionLocation]() + locs1.add(mockWorkingPartition(0, successWriter)) + worker.partitionLocationInfo.addPrimaryPartitions(shuffle1, locs1) + val locs2 = new util.ArrayList[PartitionLocation]() + locs2.add(mockWorkingPartition(0, failWriter)) + worker.partitionLocationInfo.addPrimaryPartitions(shuffle2, locs2) + controller.commitUncommittedPartitions() + // shuffle1's 0-0 succeeded — should be removed + assert(worker.partitionLocationInfo.getPrimaryLocation(shuffle1, "0-0") == null) + // shuffle2's 0-0 failed — should be retained for LifecycleManager retry + assert(worker.partitionLocationInfo.getPrimaryLocation(shuffle2, "0-0") != null) + } + + private def mockWriter(bytesOnClose: Long): PartitionDataWriter = { + val writer = mock[PartitionDataWriter] + when(writer.close()).thenReturn(bytesOnClose) + when(writer.getStorageInfo).thenReturn(new StorageInfo("/tmp", StorageInfo.Type.HDD, 1)) + when(writer.getMapIdBitMap).thenReturn(null) + when(writer.getMetaHandler).thenReturn(null) + writer + } + + private def mockWorkingPartition( + partitionId: Int, + writer: PartitionDataWriter, + mode: PartitionLocation.Mode = PartitionLocation.Mode.PRIMARY): WorkingPartition = { + val location = new PartitionLocation( + partitionId, + 0, + "host", + 0, + 0, + 0, + 0, + mode) + new WorkingPartition(location, writer) + } + + private def initController(): Controller = { + conf.set(CelebornConf.WORKER_STORAGE_DIRS.key, "/tmp") + worker = new Worker(conf, workerArgs) + val controller = worker.controller + controller.init(worker) + controller + } }