Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,11 @@ public static void cancelShuffle(int shuffleId, String reason) {
.hiddenImpl(TaskSetManager.class, "taskInfos")
.defaultAlwaysNull()
.build();
private static final DynFields.UnboundField<int[]> TASK_FAILURES =
DynFields.builder()
.hiddenImpl(TaskSetManager.class, "numFailures")
.defaultAlwaysNull()
.build();

/**
* TaskSetManager - it is not designed to be used outside the spark scheduler. Please be careful.
Expand Down Expand Up @@ -284,6 +289,39 @@ protected static Tuple2<TaskInfo, List<TaskInfo>> getTaskAttempts(
}
}

/**
* Gets the number of task attempts that have already failed for the given task index. Note: This
* count does NOT include the current failure. To get the total failure count including the
* current attempt, you need to add 1 to the returned value.
*
* @param taskSetManager the TaskSetManager to query
* @param index the task index
* @return the number of previous failed attempts, or -1 if an error occurs
*/
@VisibleForTesting
protected static int getTaskFailureCount(TaskSetManager taskSetManager, int index) {
if (taskSetManager == null) {
logger.error("TaskSetManager is null for task index: {}", index);
return -1;
}

int[] numFailures = TASK_FAILURES.bind(taskSetManager).get();
if (numFailures == null) {
logger.error("Failed to get numFailures array from TaskSetManager for task index: {}", index);
return -1;
}

if (index < 0 || index >= numFailures.length) {
logger.error(
"Task index {} is out of bounds for numFailures array (length: {})",
index,
numFailures.length);
return -1;
}

return numFailures[index];
}

protected static Map<String, Set<Long>> reportedStageShuffleFetchFailureTaskIds =
JavaUtils.newConcurrentHashMap();

Expand Down Expand Up @@ -344,7 +382,6 @@ public static boolean shouldReportShuffleFetchFailure(long taskId) {
if (taskAttempts == null) return true;

TaskInfo taskInfo = taskAttempts._1();
int failedTaskAttempts = 1;
boolean hasRunningAttempt = false;
for (TaskInfo ti : taskAttempts._2()) {
if (ti.taskId() != taskId) {
Expand All @@ -356,7 +393,6 @@ public static boolean shouldReportShuffleFetchFailure(long taskId) {
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
failedTaskAttempts += 1;
} else if (ti.successful()) {
logger.info(
"StageId={} index={} taskId={} attempt={} another attempt {} is successful.",
Expand All @@ -375,36 +411,32 @@ public static boolean shouldReportShuffleFetchFailure(long taskId) {
taskInfo.attemptNumber(),
ti.attemptNumber());
hasRunningAttempt = true;
} else if ("FAILED".equals(ti.status()) || "UNKNOWN".equals(ti.status())) {
// For KILLED state task, Spark does not count the number of failures
// For UNKNOWN state task, Spark does count the number of failures
// For FAILED state task, Spark decides whether to count the failure based on the
// different failure reasons. Since we cannot obtain the failure
// reason here, we will count all tasks in FAILED state.
logger.info(
"StageId={} index={} taskId={} attempt={} another attempt {} status={}.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber(),
ti.status());
failedTaskAttempts += 1;
}
}
}
// The following situations should trigger a FetchFailed exception:
// 1. If failedTaskAttempts >= maxTaskFails
// 2. If no other taskAttempts are running
if (failedTaskAttempts >= maxTaskFails || !hasRunningAttempt) {
// 1. If total failures (previous failures + current failure) >= maxTaskFails
// 2. If no other taskAttempts are running, trigger a FetchFailed exception
// to keep the same behavior as Spark.
// Note: previousFailureCount does NOT include the current failure,
// so (previousFailureCount + 1) represents the total failure count.
int previousFailureCount = getTaskFailureCount(taskSetManager, taskInfo.index());
// Fail-safe: if failure count cannot be determined, conservatively trigger
// FetchFailed to avoid silently swallowing the error.
if (previousFailureCount < 0) {
return true;
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When getTaskFailureCount() fails (returns < 0), shouldReportShuffleFetchFailure() immediately returns true. This makes the pre-check aggressively report FetchFailed even if there are other running attempts and the retry limit has not been reached, which can reintroduce premature stage reruns if reflective access to TaskSetManager.numFailures fails on some Spark builds. Consider falling back to the previous attempt/status-based counting (or at least gating on !hasRunningAttempt) rather than unconditional true, and avoid logging this as an error on every call if the field is unavailable.

Suggested change
// Fail-safe: if failure count cannot be determined, conservatively trigger
// FetchFailed to avoid silently swallowing the error.
if (previousFailureCount < 0) {
return true;
// If failure count cannot be determined, fall back to attempt status based
// behavior instead of aggressively reporting FetchFailed. This avoids
// premature stage reruns when reflective access to failure counts is
// unavailable, while still reporting the failure when no other attempt is
// running.
if (previousFailureCount < 0) {
if (!hasRunningAttempt) {
logger.warn(
"StageId={}, index={}, taskId={}, attemptNumber={}: Unable to determine "
+ "previous failure count, and no other running attempt exists. "
+ "Reporting shuffle fetch failure.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber());
return true;
} else {
logger.warn(
"StageId={}, index={}, taskId={}, attemptNumber={}: Unable to determine "
+ "previous failure count, but another attempt is still running. "
+ "Deferring shuffle fetch failure report.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber());
return false;
}

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is unreasonable. It shouldn't check !hasRunningAttempt, but should directly return true. Otherwise, FetchFailed won't be triggered and app will fail.

}
if (previousFailureCount + 1 >= maxTaskFails || !hasRunningAttempt) {
logger.warn(
"StageId={}, index={}, taskId={}, attemptNumber={}: Task failure count {} reached "
+ "maximum allowed failures {} or no running attempt exists.",
"StageId={}, index={}, taskId={}, attemptNumber={}: Previous failure count {} "
+ "(total with current: {}) reached maximum allowed failures {} "
+ "or no running attempt exists.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber(),
failedTaskAttempts,
previousFailureCount,
previousFailureCount + 1,
maxTaskFails);
return true;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,11 @@ public static void cancelShuffle(int shuffleId, String reason) {
.hiddenImpl(TaskSetManager.class, "taskInfos")
.defaultAlwaysNull()
.build();
private static final DynFields.UnboundField<int[]> TASK_FAILURES =
DynFields.builder()
.hiddenImpl(TaskSetManager.class, "numFailures")
.defaultAlwaysNull()
.build();

/**
* TaskSetManager - it is not designed to be used outside the spark scheduler. Please be careful.
Expand Down Expand Up @@ -420,6 +425,39 @@ protected static Tuple2<TaskInfo, List<TaskInfo>> getTaskAttempts(
}
}

/**
* Gets the number of task attempts that have already failed for the given task index. Note: This
* count does NOT include the current failure. To get the total failure count including the
* current attempt, you need to add 1 to the returned value.
*
* @param taskSetManager the TaskSetManager to query
* @param index the task index
* @return the number of previous failed attempts, or -1 if an error occurs
*/
@VisibleForTesting
protected static int getTaskFailureCount(TaskSetManager taskSetManager, int index) {
if (taskSetManager == null) {
LOG.error("TaskSetManager is null for task index: {}", index);
return -1;
}

int[] numFailures = TASK_FAILURES.bind(taskSetManager).get();
if (numFailures == null) {
LOG.error("Failed to get numFailures array from TaskSetManager for task index: {}", index);
return -1;
}

if (index < 0 || index >= numFailures.length) {
LOG.error(
"Task index {} is out of bounds for numFailures array (length: {})",
index,
numFailures.length);
return -1;
}

return numFailures[index];
}

protected static Map<String, Set<Long>> reportedStageShuffleFetchFailureTaskIds =
JavaUtils.newConcurrentHashMap();

Expand Down Expand Up @@ -480,7 +518,6 @@ public static boolean shouldReportShuffleFetchFailure(long taskId) {
if (taskAttempts == null) return true;

TaskInfo taskInfo = taskAttempts._1();
int failedTaskAttempts = 1;
boolean hasRunningAttempt = false;
for (TaskInfo ti : taskAttempts._2()) {
if (ti.taskId() != taskId) {
Expand All @@ -492,7 +529,6 @@ public static boolean shouldReportShuffleFetchFailure(long taskId) {
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
failedTaskAttempts += 1;
} else if (ti.successful()) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} is successful.",
Expand All @@ -511,36 +547,32 @@ public static boolean shouldReportShuffleFetchFailure(long taskId) {
taskInfo.attemptNumber(),
ti.attemptNumber());
hasRunningAttempt = true;
} else if ("FAILED".equals(ti.status()) || "UNKNOWN".equals(ti.status())) {
// For KILLED state task, Spark does not count the number of failures
// For UNKNOWN state task, Spark does count the number of failures
// For FAILED state task, Spark decides whether to count the failure based on the
// different failure reasons. Since we cannot obtain the failure
// reason here, we will count all tasks in FAILED state.
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} status={}.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber(),
ti.status());
failedTaskAttempts += 1;
}
}
}
// The following situations should trigger a FetchFailed exception:
// 1. If failedTaskAttempts >= maxTaskFails
// 2. If no other taskAttempts are running
if (failedTaskAttempts >= maxTaskFails || !hasRunningAttempt) {
// 1. If total failures (previous failures + current failure) >= maxTaskFails
// 2. If no other taskAttempts are running, trigger a FetchFailed exception
// to keep the same behavior as Spark.
// Note: previousFailureCount does NOT include the current failure,
// so (previousFailureCount + 1) represents the total failure count.
int previousFailureCount = getTaskFailureCount(taskSetManager, taskInfo.index());
// Fail-safe: if failure count cannot be determined, conservatively trigger
// FetchFailed to avoid silently swallowing the error.
if (previousFailureCount < 0) {
return true;
}
Comment on lines +559 to +587
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When getTaskFailureCount() fails (returns < 0), shouldReportShuffleFetchFailure() immediately returns true. This makes the pre-check aggressively report FetchFailed even if there are other running attempts and the retry limit has not been reached, which can reintroduce premature stage reruns in exactly the scenarios this change is trying to avoid (e.g., if reflective access to TaskSetManager.numFailures breaks on some Spark builds). Consider falling back to the previous attempt/status-based counting (or at least gating on !hasRunningAttempt) instead of unconditional true, and log at WARN once to avoid error spam if the field is unavailable.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is unreasonable. It shouldn't check !hasRunningAttempt, but should directly return true. Otherwise, FetchFailed won't be triggered and app will fail.

if (previousFailureCount + 1 >= maxTaskFails || !hasRunningAttempt) {
LOG.warn(
"StageId={}, index={}, taskId={}, attemptNumber={}: Task failure count {} reached "
+ "maximum allowed failures {} or no running attempt exists.",
"StageId={}, index={}, taskId={}, attemptNumber={}: Previous failure count {} "
+ "(total with current: {}) reached maximum allowed failures {} "
+ "or no running attempt exists.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber(),
failedTaskAttempts,
previousFailureCount,
previousFailureCount + 1,
maxTaskFails);
return true;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.shuffle.celeborn

import scala.collection.JavaConverters._

import org.apache.spark.SparkConf
import org.apache.spark.{SparkConf, TaskContext}
import org.apache.spark.scheduler.{TaskSchedulerImpl, TaskSetManager}
import org.apache.spark.sql.SparkSession
import org.scalatest.BeforeAndAfterEach
Expand Down Expand Up @@ -216,6 +216,107 @@ class SparkUtilsSuite extends AnyFunSuite
}
}

test("getTaskFailureCount") {
assert(SparkUtils.getTaskFailureCount(null, 0) == -1)

if (Spark3OrNewer) {
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
val sparkSession = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
.config("spark.sql.shuffle.partitions", 2)
.config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
.config("spark.celeborn.client.spark.stageRerun.enabled", "true")
.config(
"spark.shuffle.manager",
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
.getOrCreate()

try {
val sc = sparkSession.sparkContext
val jobThread = new Thread {
override def run(): Unit = {
try {
sc.parallelize(1 to 100, 2)
.repartition(1)
.mapPartitions { iter =>
Thread.sleep(3000)
iter
}.collect()
} catch {
case _: InterruptedException =>
}
}
}
jobThread.start()

val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
eventually(timeout(3.seconds), interval(100.milliseconds)) {
val taskId = 0
val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, taskId)
assert(taskSetManager != null)
assert(SparkUtils.getTaskFailureCount(taskSetManager, 0) == 0)
assert(SparkUtils.getTaskFailureCount(taskSetManager, -1) == -1)
assert(SparkUtils.getTaskFailureCount(taskSetManager, Int.MaxValue) == -1)
}

sparkSession.sparkContext.cancelAllJobs()
jobThread.interrupt()
} finally {
sparkSession.stop()
}
}
}

test("getTaskFailureCount after real task failures") {
if (Spark3OrNewer) {
// local[1,4]: 1 core (sequential execution), max 4 task failures before stage abort
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[1,4]")
val sparkSession = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
.config("spark.sql.shuffle.partitions", 2)
.config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
.config("spark.celeborn.client.spark.stageRerun.enabled", "true")
.config(
"spark.shuffle.manager",
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
.getOrCreate()

try {
val sc = sparkSession.sparkContext

val jobThread = new Thread {
override def run(): Unit = {
try {
sc.parallelize(1 to 10, 1).mapPartitions { iter =>
if (TaskContext.get().attemptNumber() < 2) {
throw new RuntimeException("Simulated task failure")
}
Thread.sleep(10000)
iter
}.collect()
} catch {
case _: Exception =>
}
}
}
jobThread.start()

val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
eventually(timeout(10.seconds), interval(100.milliseconds)) {
// taskId 0,1 failed and removed; taskId 2 is the surviving 3rd attempt
val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, 2)
assert(taskSetManager != null)
assert(SparkUtils.getTaskFailureCount(taskSetManager, 0) == 2)
Comment on lines +304 to +311
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test assumes the third attempt’s taskId will be exactly 2 (after two failures). Spark task IDs are globally assigned within a SparkContext and aren’t guaranteed to align with attempt count if any other tasks/stages run (including internal ones), which can make the test brittle across Spark versions/configs. Consider deriving the taskId dynamically (e.g., capturing TaskContext.taskAttemptId() via an accumulator/Promise, or scanning taskScheduler’s taskIdToTaskSetManager for the active TaskSetManager) instead of hardcoding 2.

Copilot uses AI. Check for mistakes.
}

sparkSession.sparkContext.cancelAllJobs()
jobThread.interrupt()
} finally {
sparkSession.stop()
}
}
}

test("serialize/deserialize GetReducerFileGroupResponse with broadcast") {
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
val sparkSession = SparkSession.builder()
Expand Down
Loading