diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index d04f4b0eb16..9cc1b41774c 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -247,6 +247,11 @@ public static void cancelShuffle(int shuffleId, String reason) { .hiddenImpl(TaskSetManager.class, "taskInfos") .defaultAlwaysNull() .build(); + private static final DynFields.UnboundField TASK_FAILURES = + DynFields.builder() + .hiddenImpl(TaskSetManager.class, "numFailures") + .defaultAlwaysNull() + .build(); /** * TaskSetManager - it is not designed to be used outside the spark scheduler. Please be careful. @@ -284,6 +289,39 @@ protected static Tuple2> getTaskAttempts( } } + /** + * Gets the number of task attempts that have already failed for the given task index. Note: This + * count does NOT include the current failure. To get the total failure count including the + * current attempt, you need to add 1 to the returned value. + * + * @param taskSetManager the TaskSetManager to query + * @param index the task index + * @return the number of previous failed attempts, or -1 if an error occurs + */ + @VisibleForTesting + protected static int getTaskFailureCount(TaskSetManager taskSetManager, int index) { + if (taskSetManager == null) { + logger.error("TaskSetManager is null for task index: {}", index); + return -1; + } + + int[] numFailures = TASK_FAILURES.bind(taskSetManager).get(); + if (numFailures == null) { + logger.error("Failed to get numFailures array from TaskSetManager for task index: {}", index); + return -1; + } + + if (index < 0 || index >= numFailures.length) { + logger.error( + "Task index {} is out of bounds for numFailures array (length: {})", + index, + numFailures.length); + return -1; + } + + return numFailures[index]; + } + protected static Map> reportedStageShuffleFetchFailureTaskIds = JavaUtils.newConcurrentHashMap(); @@ -344,7 +382,6 @@ public static boolean shouldReportShuffleFetchFailure(long taskId) { if (taskAttempts == null) return true; TaskInfo taskInfo = taskAttempts._1(); - int failedTaskAttempts = 1; boolean hasRunningAttempt = false; for (TaskInfo ti : taskAttempts._2()) { if (ti.taskId() != taskId) { @@ -356,7 +393,6 @@ public static boolean shouldReportShuffleFetchFailure(long taskId) { taskId, taskInfo.attemptNumber(), ti.attemptNumber()); - failedTaskAttempts += 1; } else if (ti.successful()) { logger.info( "StageId={} index={} taskId={} attempt={} another attempt {} is successful.", @@ -375,36 +411,55 @@ public static boolean shouldReportShuffleFetchFailure(long taskId) { taskInfo.attemptNumber(), ti.attemptNumber()); hasRunningAttempt = true; - } else if ("FAILED".equals(ti.status()) || "UNKNOWN".equals(ti.status())) { - // For KILLED state task, Spark does not count the number of failures - // For UNKNOWN state task, Spark does count the number of failures - // For FAILED state task, Spark decides whether to count the failure based on the - // different failure reasons. Since we cannot obtain the failure - // reason here, we will count all tasks in FAILED state. - logger.info( - "StageId={} index={} taskId={} attempt={} another attempt {} status={}.", - stageId, - taskInfo.index(), - taskId, - taskInfo.attemptNumber(), - ti.attemptNumber(), - ti.status()); - failedTaskAttempts += 1; } } } // The following situations should trigger a FetchFailed exception: - // 1. If failedTaskAttempts >= maxTaskFails - // 2. If no other taskAttempts are running - if (failedTaskAttempts >= maxTaskFails || !hasRunningAttempt) { + // 1. If total failures (previous failures + current failure) >= maxTaskFails + // 2. If no other taskAttempts are running, trigger a FetchFailed exception + // to keep the same behavior as Spark. + // Note: previousFailureCount does NOT include the current failure, + // so (previousFailureCount + 1) represents the total failure count. + int previousFailureCount = getTaskFailureCount(taskSetManager, taskInfo.index()); + // If failure count cannot be determined, fall back to attempt status based + // behavior instead of aggressively reporting FetchFailed. This avoids + // premature stage reruns when reflective access to failure counts is + // unavailable, while still reporting the failure when no other attempt is + // running. + if (previousFailureCount < 0) { + if (!hasRunningAttempt) { + LOG.warn( + "StageId={}, index={}, taskId={}, attemptNumber={}: Unable to determine " + + "previous failure count, and no other running attempt exists. " + + "Reporting shuffle fetch failure.", + stageId, + taskInfo.index(), + taskId, + taskInfo.attemptNumber()); + return true; + } else { + LOG.warn( + "StageId={}, index={}, taskId={}, attemptNumber={}: Unable to determine " + + "previous failure count, but another attempt is still running. " + + "Deferring shuffle fetch failure report.", + stageId, + taskInfo.index(), + taskId, + taskInfo.attemptNumber()); + return false; + } + } + if (previousFailureCount + 1 >= maxTaskFails || !hasRunningAttempt) { logger.warn( - "StageId={}, index={}, taskId={}, attemptNumber={}: Task failure count {} reached " - + "maximum allowed failures {} or no running attempt exists.", + "StageId={}, index={}, taskId={}, attemptNumber={}: Previous failure count {} " + + "(total with current: {}) reached maximum allowed failures {} " + + "or no running attempt exists.", stageId, taskInfo.index(), taskId, taskInfo.attemptNumber(), - failedTaskAttempts, + previousFailureCount, + previousFailureCount + 1, maxTaskFails); return true; } else { diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 0e68a4b46f0..0b4e2d90ac5 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -383,6 +383,11 @@ public static void cancelShuffle(int shuffleId, String reason) { .hiddenImpl(TaskSetManager.class, "taskInfos") .defaultAlwaysNull() .build(); + private static final DynFields.UnboundField TASK_FAILURES = + DynFields.builder() + .hiddenImpl(TaskSetManager.class, "numFailures") + .defaultAlwaysNull() + .build(); /** * TaskSetManager - it is not designed to be used outside the spark scheduler. Please be careful. @@ -420,6 +425,39 @@ protected static Tuple2> getTaskAttempts( } } + /** + * Gets the number of task attempts that have already failed for the given task index. Note: This + * count does NOT include the current failure. To get the total failure count including the + * current attempt, you need to add 1 to the returned value. + * + * @param taskSetManager the TaskSetManager to query + * @param index the task index + * @return the number of previous failed attempts, or -1 if an error occurs + */ + @VisibleForTesting + protected static int getTaskFailureCount(TaskSetManager taskSetManager, int index) { + if (taskSetManager == null) { + LOG.error("TaskSetManager is null for task index: {}", index); + return -1; + } + + int[] numFailures = TASK_FAILURES.bind(taskSetManager).get(); + if (numFailures == null) { + LOG.error("Failed to get numFailures array from TaskSetManager for task index: {}", index); + return -1; + } + + if (index < 0 || index >= numFailures.length) { + LOG.error( + "Task index {} is out of bounds for numFailures array (length: {})", + index, + numFailures.length); + return -1; + } + + return numFailures[index]; + } + protected static Map> reportedStageShuffleFetchFailureTaskIds = JavaUtils.newConcurrentHashMap(); @@ -480,7 +518,6 @@ public static boolean shouldReportShuffleFetchFailure(long taskId) { if (taskAttempts == null) return true; TaskInfo taskInfo = taskAttempts._1(); - int failedTaskAttempts = 1; boolean hasRunningAttempt = false; for (TaskInfo ti : taskAttempts._2()) { if (ti.taskId() != taskId) { @@ -492,7 +529,6 @@ public static boolean shouldReportShuffleFetchFailure(long taskId) { taskId, taskInfo.attemptNumber(), ti.attemptNumber()); - failedTaskAttempts += 1; } else if (ti.successful()) { LOG.info( "StageId={} index={} taskId={} attempt={} another attempt {} is successful.", @@ -511,36 +547,55 @@ public static boolean shouldReportShuffleFetchFailure(long taskId) { taskInfo.attemptNumber(), ti.attemptNumber()); hasRunningAttempt = true; - } else if ("FAILED".equals(ti.status()) || "UNKNOWN".equals(ti.status())) { - // For KILLED state task, Spark does not count the number of failures - // For UNKNOWN state task, Spark does count the number of failures - // For FAILED state task, Spark decides whether to count the failure based on the - // different failure reasons. Since we cannot obtain the failure - // reason here, we will count all tasks in FAILED state. - LOG.info( - "StageId={} index={} taskId={} attempt={} another attempt {} status={}.", - stageId, - taskInfo.index(), - taskId, - taskInfo.attemptNumber(), - ti.attemptNumber(), - ti.status()); - failedTaskAttempts += 1; } } } // The following situations should trigger a FetchFailed exception: - // 1. If failedTaskAttempts >= maxTaskFails - // 2. If no other taskAttempts are running - if (failedTaskAttempts >= maxTaskFails || !hasRunningAttempt) { + // 1. If total failures (previous failures + current failure) >= maxTaskFails + // 2. If no other taskAttempts are running, trigger a FetchFailed exception + // to keep the same behavior as Spark. + // Note: previousFailureCount does NOT include the current failure, + // so (previousFailureCount + 1) represents the total failure count. + int previousFailureCount = getTaskFailureCount(taskSetManager, taskInfo.index()); + // If failure count cannot be determined, fall back to attempt status based + // behavior instead of aggressively reporting FetchFailed. This avoids + // premature stage reruns when reflective access to failure counts is + // unavailable, while still reporting the failure when no other attempt is + // running. + if (previousFailureCount < 0) { + if (!hasRunningAttempt) { + LOG.warn( + "StageId={}, index={}, taskId={}, attemptNumber={}: Unable to determine " + + "previous failure count, and no other running attempt exists. " + + "Reporting shuffle fetch failure.", + stageId, + taskInfo.index(), + taskId, + taskInfo.attemptNumber()); + return true; + } else { + LOG.warn( + "StageId={}, index={}, taskId={}, attemptNumber={}: Unable to determine " + + "previous failure count, but another attempt is still running. " + + "Deferring shuffle fetch failure report.", + stageId, + taskInfo.index(), + taskId, + taskInfo.attemptNumber()); + return false; + } + } + if (previousFailureCount + 1 >= maxTaskFails || !hasRunningAttempt) { LOG.warn( - "StageId={}, index={}, taskId={}, attemptNumber={}: Task failure count {} reached " - + "maximum allowed failures {} or no running attempt exists.", + "StageId={}, index={}, taskId={}, attemptNumber={}: Previous failure count {} " + + "(total with current: {}) reached maximum allowed failures {} " + + "or no running attempt exists.", stageId, taskInfo.index(), taskId, taskInfo.attemptNumber(), - failedTaskAttempts, + previousFailureCount, + previousFailureCount + 1, maxTaskFails); return true; } else { diff --git a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala index 85a1d815fc5..5bdfa1c2a1a 100644 --- a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.shuffle.celeborn import scala.collection.JavaConverters._ -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.scheduler.{TaskSchedulerImpl, TaskSetManager} import org.apache.spark.sql.SparkSession import org.scalatest.BeforeAndAfterEach @@ -216,6 +216,109 @@ class SparkUtilsSuite extends AnyFunSuite } } + test("getTaskFailureCount") { + assert(SparkUtils.getTaskFailureCount(null, 0) == -1) + + if (Spark3OrNewer) { + val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]") + val sparkSession = SparkSession.builder() + .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) + .config("spark.sql.shuffle.partitions", 2) + .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) + .config("spark.celeborn.client.spark.stageRerun.enabled", "true") + .config( + "spark.shuffle.manager", + "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") + .getOrCreate() + + try { + val sc = sparkSession.sparkContext + val jobThread = new Thread { + override def run(): Unit = { + try { + sc.parallelize(1 to 100, 2) + .repartition(1) + .mapPartitions { iter => + Thread.sleep(3000) + iter + }.collect() + } catch { + case _: InterruptedException => + } + } + } + jobThread.start() + + val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] + eventually(timeout(3.seconds), interval(100.milliseconds)) { + val taskId = 0 + val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, taskId) + assert(taskSetManager != null) + assert(SparkUtils.getTaskFailureCount(taskSetManager, 0) == 0) + assert(SparkUtils.getTaskFailureCount(taskSetManager, -1) == -1) + assert(SparkUtils.getTaskFailureCount(taskSetManager, Int.MaxValue) == -1) + } + + sparkSession.sparkContext.cancelAllJobs() + jobThread.interrupt() + } finally { + sparkSession.stop() + } + } + } + + test("getTaskFailureCount after real task failures") { + if (Spark3OrNewer) { + // local[1,4]: 1 core (sequential execution), max 4 task failures before stage abort + val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[1,4]") + val sparkSession = SparkSession.builder() + .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) + .config("spark.sql.shuffle.partitions", 2) + .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false) + .config("spark.celeborn.client.spark.stageRerun.enabled", "true") + .config( + "spark.shuffle.manager", + "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") + .getOrCreate() + + try { + val sc = sparkSession.sparkContext + + val jobThread = new Thread { + override def run(): Unit = { + try { + sc.parallelize(1 to 10, 1).mapPartitions { iter => + if (TaskContext.get().attemptNumber() < 2) { + throw new RuntimeException("Simulated task failure") + } + Thread.sleep(10000) + iter + }.collect() + } catch { + case _: Exception => + } + } + } + jobThread.start() + + val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] + eventually(timeout(10.seconds), interval(100.milliseconds)) { + // Task IDs are globally assigned; find the active TaskSetManager dynamically + // rather than assuming a specific taskId. + val taskSetManager = (0L to 10L).map(id => + SparkUtils.getTaskSetManager(taskScheduler, id)).find(_ != null).orNull + assert(taskSetManager != null) + assert(SparkUtils.getTaskFailureCount(taskSetManager, 0) == 2) + } + + sparkSession.sparkContext.cancelAllJobs() + jobThread.interrupt() + } finally { + sparkSession.stop() + } + } + } + test("serialize/deserialize GetReducerFileGroupResponse with broadcast") { val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]") val sparkSession = SparkSession.builder()