-
Notifications
You must be signed in to change notification settings - Fork 436
[CELEBORN-2230] SparkUtils#shouldReportShuffleFetchFailure method should retrieve the number of task failures from TaskSetManager #3650
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -383,6 +383,11 @@ public static void cancelShuffle(int shuffleId, String reason) { | |
| .hiddenImpl(TaskSetManager.class, "taskInfos") | ||
| .defaultAlwaysNull() | ||
| .build(); | ||
| private static final DynFields.UnboundField<int[]> TASK_FAILURES = | ||
| DynFields.builder() | ||
| .hiddenImpl(TaskSetManager.class, "numFailures") | ||
| .defaultAlwaysNull() | ||
| .build(); | ||
|
|
||
| /** | ||
| * TaskSetManager - it is not designed to be used outside the spark scheduler. Please be careful. | ||
|
|
@@ -420,6 +425,39 @@ protected static Tuple2<TaskInfo, List<TaskInfo>> getTaskAttempts( | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Gets the number of task attempts that have already failed for the given task index. Note: This | ||
| * count does NOT include the current failure. To get the total failure count including the | ||
| * current attempt, you need to add 1 to the returned value. | ||
| * | ||
| * @param taskSetManager the TaskSetManager to query | ||
| * @param index the task index | ||
| * @return the number of previous failed attempts, or -1 if an error occurs | ||
| */ | ||
| @VisibleForTesting | ||
| protected static int getTaskFailureCount(TaskSetManager taskSetManager, int index) { | ||
| if (taskSetManager == null) { | ||
| LOG.error("TaskSetManager is null for task index: {}", index); | ||
| return -1; | ||
| } | ||
|
|
||
| int[] numFailures = TASK_FAILURES.bind(taskSetManager).get(); | ||
| if (numFailures == null) { | ||
| LOG.error("Failed to get numFailures array from TaskSetManager for task index: {}", index); | ||
| return -1; | ||
| } | ||
|
|
||
| if (index < 0 || index >= numFailures.length) { | ||
| LOG.error( | ||
| "Task index {} is out of bounds for numFailures array (length: {})", | ||
| index, | ||
| numFailures.length); | ||
| return -1; | ||
| } | ||
|
|
||
| return numFailures[index]; | ||
| } | ||
|
|
||
| protected static Map<String, Set<Long>> reportedStageShuffleFetchFailureTaskIds = | ||
| JavaUtils.newConcurrentHashMap(); | ||
|
|
||
|
|
@@ -480,7 +518,6 @@ public static boolean shouldReportShuffleFetchFailure(long taskId) { | |
| if (taskAttempts == null) return true; | ||
|
|
||
| TaskInfo taskInfo = taskAttempts._1(); | ||
| int failedTaskAttempts = 1; | ||
| boolean hasRunningAttempt = false; | ||
| for (TaskInfo ti : taskAttempts._2()) { | ||
| if (ti.taskId() != taskId) { | ||
|
|
@@ -492,7 +529,6 @@ public static boolean shouldReportShuffleFetchFailure(long taskId) { | |
| taskId, | ||
| taskInfo.attemptNumber(), | ||
| ti.attemptNumber()); | ||
| failedTaskAttempts += 1; | ||
| } else if (ti.successful()) { | ||
| LOG.info( | ||
| "StageId={} index={} taskId={} attempt={} another attempt {} is successful.", | ||
|
|
@@ -511,36 +547,32 @@ public static boolean shouldReportShuffleFetchFailure(long taskId) { | |
| taskInfo.attemptNumber(), | ||
| ti.attemptNumber()); | ||
| hasRunningAttempt = true; | ||
| } else if ("FAILED".equals(ti.status()) || "UNKNOWN".equals(ti.status())) { | ||
| // For KILLED state task, Spark does not count the number of failures | ||
| // For UNKNOWN state task, Spark does count the number of failures | ||
| // For FAILED state task, Spark decides whether to count the failure based on the | ||
| // different failure reasons. Since we cannot obtain the failure | ||
| // reason here, we will count all tasks in FAILED state. | ||
| LOG.info( | ||
| "StageId={} index={} taskId={} attempt={} another attempt {} status={}.", | ||
| stageId, | ||
| taskInfo.index(), | ||
| taskId, | ||
| taskInfo.attemptNumber(), | ||
| ti.attemptNumber(), | ||
| ti.status()); | ||
| failedTaskAttempts += 1; | ||
| } | ||
| } | ||
| } | ||
| // The following situations should trigger a FetchFailed exception: | ||
| // 1. If failedTaskAttempts >= maxTaskFails | ||
| // 2. If no other taskAttempts are running | ||
| if (failedTaskAttempts >= maxTaskFails || !hasRunningAttempt) { | ||
| // 1. If total failures (previous failures + current failure) >= maxTaskFails | ||
| // 2. If no other taskAttempts are running, trigger a FetchFailed exception | ||
| // to keep the same behavior as Spark. | ||
| // Note: previousFailureCount does NOT include the current failure, | ||
| // so (previousFailureCount + 1) represents the total failure count. | ||
| int previousFailureCount = getTaskFailureCount(taskSetManager, taskInfo.index()); | ||
| // Fail-safe: if failure count cannot be determined, conservatively trigger | ||
| // FetchFailed to avoid silently swallowing the error. | ||
| if (previousFailureCount < 0) { | ||
| return true; | ||
| } | ||
|
Comment on lines
+559
to
+587
|
||
| if (previousFailureCount + 1 >= maxTaskFails || !hasRunningAttempt) { | ||
| LOG.warn( | ||
| "StageId={}, index={}, taskId={}, attemptNumber={}: Task failure count {} reached " | ||
| + "maximum allowed failures {} or no running attempt exists.", | ||
| "StageId={}, index={}, taskId={}, attemptNumber={}: Previous failure count {} " | ||
| + "(total with current: {}) reached maximum allowed failures {} " | ||
| + "or no running attempt exists.", | ||
| stageId, | ||
| taskInfo.index(), | ||
| taskId, | ||
| taskInfo.attemptNumber(), | ||
| failedTaskAttempts, | ||
| previousFailureCount, | ||
| previousFailureCount + 1, | ||
| maxTaskFails); | ||
| return true; | ||
| } else { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ package org.apache.spark.shuffle.celeborn | |
|
|
||
| import scala.collection.JavaConverters._ | ||
|
|
||
| import org.apache.spark.SparkConf | ||
| import org.apache.spark.{SparkConf, TaskContext} | ||
| import org.apache.spark.scheduler.{TaskSchedulerImpl, TaskSetManager} | ||
| import org.apache.spark.sql.SparkSession | ||
| import org.scalatest.BeforeAndAfterEach | ||
|
|
@@ -216,6 +216,107 @@ class SparkUtilsSuite extends AnyFunSuite | |
| } | ||
| } | ||
|
|
||
| test("getTaskFailureCount") { | ||
| assert(SparkUtils.getTaskFailureCount(null, 0) == -1) | ||
|
|
||
| if (Spark3OrNewer) { | ||
| val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]") | ||
| val sparkSession = SparkSession.builder() | ||
| .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) | ||
| .config("spark.sql.shuffle.partitions", 2) | ||
| .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) | ||
| .config("spark.celeborn.client.spark.stageRerun.enabled", "true") | ||
| .config( | ||
| "spark.shuffle.manager", | ||
| "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") | ||
| .getOrCreate() | ||
|
|
||
| try { | ||
| val sc = sparkSession.sparkContext | ||
| val jobThread = new Thread { | ||
| override def run(): Unit = { | ||
| try { | ||
| sc.parallelize(1 to 100, 2) | ||
| .repartition(1) | ||
| .mapPartitions { iter => | ||
| Thread.sleep(3000) | ||
| iter | ||
| }.collect() | ||
| } catch { | ||
| case _: InterruptedException => | ||
| } | ||
| } | ||
| } | ||
| jobThread.start() | ||
|
|
||
| val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] | ||
| eventually(timeout(3.seconds), interval(100.milliseconds)) { | ||
| val taskId = 0 | ||
| val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, taskId) | ||
| assert(taskSetManager != null) | ||
| assert(SparkUtils.getTaskFailureCount(taskSetManager, 0) == 0) | ||
| assert(SparkUtils.getTaskFailureCount(taskSetManager, -1) == -1) | ||
| assert(SparkUtils.getTaskFailureCount(taskSetManager, Int.MaxValue) == -1) | ||
| } | ||
|
|
||
| sparkSession.sparkContext.cancelAllJobs() | ||
| jobThread.interrupt() | ||
| } finally { | ||
| sparkSession.stop() | ||
| } | ||
| } | ||
| } | ||
|
|
||
| test("getTaskFailureCount after real task failures") { | ||
| if (Spark3OrNewer) { | ||
| // local[1,4]: 1 core (sequential execution), max 4 task failures before stage abort | ||
| val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[1,4]") | ||
| val sparkSession = SparkSession.builder() | ||
| .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) | ||
| .config("spark.sql.shuffle.partitions", 2) | ||
| .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) | ||
| .config("spark.celeborn.client.spark.stageRerun.enabled", "true") | ||
| .config( | ||
| "spark.shuffle.manager", | ||
| "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") | ||
| .getOrCreate() | ||
|
|
||
| try { | ||
| val sc = sparkSession.sparkContext | ||
|
|
||
| val jobThread = new Thread { | ||
| override def run(): Unit = { | ||
| try { | ||
| sc.parallelize(1 to 10, 1).mapPartitions { iter => | ||
| if (TaskContext.get().attemptNumber() < 2) { | ||
| throw new RuntimeException("Simulated task failure") | ||
| } | ||
| Thread.sleep(10000) | ||
| iter | ||
| }.collect() | ||
| } catch { | ||
| case _: Exception => | ||
| } | ||
| } | ||
| } | ||
| jobThread.start() | ||
|
|
||
| val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] | ||
| eventually(timeout(10.seconds), interval(100.milliseconds)) { | ||
| // taskId 0,1 failed and removed; taskId 2 is the surviving 3rd attempt | ||
| val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, 2) | ||
| assert(taskSetManager != null) | ||
| assert(SparkUtils.getTaskFailureCount(taskSetManager, 0) == 2) | ||
|
Comment on lines
+304
to
+311
|
||
| } | ||
|
|
||
| sparkSession.sparkContext.cancelAllJobs() | ||
| jobThread.interrupt() | ||
| } finally { | ||
| sparkSession.stop() | ||
| } | ||
| } | ||
| } | ||
|
|
||
| test("serialize/deserialize GetReducerFileGroupResponse with broadcast") { | ||
| val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]") | ||
| val sparkSession = SparkSession.builder() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When getTaskFailureCount() fails (returns < 0), shouldReportShuffleFetchFailure() immediately returns true. This makes the pre-check aggressively report FetchFailed even if there are other running attempts and the retry limit has not been reached, which can reintroduce premature stage reruns if reflective access to TaskSetManager.numFailures fails on some Spark builds. Consider falling back to the previous attempt/status-based counting (or at least gating on !hasRunningAttempt) rather than unconditional true, and avoid logging this as an error on every call if the field is unavailable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is unreasonable. It shouldn't check
!hasRunningAttempt, but should directly returntrue. Otherwise,FetchFailedwon't be triggered and app will fail.