diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index 9ce400de56e4..d63928527df5 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -108,7 +108,7 @@ object VeloxRuleApi { offloads)) // Legacy: Post-transform rules. - injector.injectPostTransform(_ => AppendBatchResizeForShuffleInputAndOutput()) + injector.injectPostTransform(c => AppendBatchResizeForShuffleInputAndOutput(c.caller.isAqe())) injector.injectPostTransform(_ => GpuBufferBatchResizeForShuffleInputOutput()) injector.injectPostTransform(_ => UnionTransformerRule()) injector.injectPostTransform(_ => PartialFallbackRules()) diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/AppendBatchResizeForShuffleInputAndOutput.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/AppendBatchResizeForShuffleInputAndOutput.scala index fcce64d65222..a7309c341a9b 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/AppendBatchResizeForShuffleInputAndOutput.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/AppendBatchResizeForShuffleInputAndOutput.scala @@ -22,17 +22,18 @@ import org.apache.gluten.execution.VeloxResizeBatchesExec import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} -import org.apache.spark.sql.execution.exchange.ReusedExchangeExec /** * Try to append [[VeloxResizeBatchesExec]] for shuffle input and output to make the batch sizes in * good shape. */ -case class AppendBatchResizeForShuffleInputAndOutput() extends Rule[SparkPlan] { +case class AppendBatchResizeForShuffleInputAndOutput(isAdaptiveContext: Boolean) + extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = { if (VeloxConfig.get.enableColumnarCudf) { return plan } + val resizeBatchesShuffleInputEnabled = VeloxConfig.get.veloxResizeBatchesShuffleInput val resizeBatchesShuffleOutputEnabled = VeloxConfig.get.veloxResizeBatchesShuffleOutput if (!resizeBatchesShuffleInputEnabled && !resizeBatchesShuffleOutputEnabled) { @@ -41,65 +42,59 @@ case class AppendBatchResizeForShuffleInputAndOutput() extends Rule[SparkPlan] { val range = VeloxConfig.get.veloxResizeBatchesShuffleInputOutputRange val preferredBatchBytes = VeloxConfig.get.veloxPreferredBatchBytes + + val newPlan = if (resizeBatchesShuffleInputEnabled) { + addResizeBatchesForShuffleInput(plan, range.min, range.max, preferredBatchBytes) + } else { + plan + } + + val resultPlan = if (resizeBatchesShuffleOutputEnabled) { + addResizeBatchesForShuffleOutput(newPlan, range.min, range.max, preferredBatchBytes) + } else { + newPlan + } + + resultPlan + } + + private def addResizeBatchesForShuffleInput( + plan: SparkPlan, + min: Int, + max: Int, + preferredBatchBytes: Long): SparkPlan = { plan.transformUp { case shuffle: ColumnarShuffleExchangeExec - if resizeBatchesShuffleInputEnabled && - shuffle.shuffleWriterType.requiresResizingShuffleInput => + if shuffle.shuffleWriterType.requiresResizingShuffleInput => val appendBatches = - VeloxResizeBatchesExec(shuffle.child, range.min, range.max, preferredBatchBytes) + VeloxResizeBatchesExec(shuffle.child, min, max, preferredBatchBytes) shuffle.withNewChildren(Seq(appendBatches)) - case a @ AQEShuffleReadExec( - ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec, _), - _) - if resizeBatchesShuffleOutputEnabled && - shuffle.shuffleWriterType.requiresResizingShuffleOutput => - VeloxResizeBatchesExec(a, range.min, range.max, preferredBatchBytes) - case a @ AQEShuffleReadExec( - ShuffleQueryStageExec( - _, - ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec), - _), - _) - if resizeBatchesShuffleOutputEnabled && - shuffle.shuffleWriterType.requiresResizingShuffleOutput => - VeloxResizeBatchesExec(a, range.min, range.max, preferredBatchBytes) - // Since it's transformed in a bottom to up order, so we may first encounter - // ShuffeQueryStageExec, which is transformed to VeloxResizeBatchesExec(ShuffeQueryStageExec), - // then we see AQEShuffleReadExec - case a @ AQEShuffleReadExec( - VeloxResizeBatchesExec( - s @ ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec, _), - _, - _, - _), - _) - if resizeBatchesShuffleOutputEnabled && - shuffle.shuffleWriterType.requiresResizingShuffleOutput => - VeloxResizeBatchesExec(a.copy(child = s), range.min, range.max, preferredBatchBytes) - case a @ AQEShuffleReadExec( - VeloxResizeBatchesExec( - s @ ShuffleQueryStageExec( - _, - ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec), - _), - _, - _, - _), - _) - if resizeBatchesShuffleOutputEnabled && - shuffle.shuffleWriterType.requiresResizingShuffleOutput => - VeloxResizeBatchesExec(a.copy(child = s), range.min, range.max, preferredBatchBytes) - case s @ ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec, _) - if resizeBatchesShuffleOutputEnabled && - shuffle.shuffleWriterType.requiresResizingShuffleOutput => - VeloxResizeBatchesExec(s, range.min, range.max, preferredBatchBytes) - case s @ ShuffleQueryStageExec( - _, - ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec), - _) - if resizeBatchesShuffleOutputEnabled && - shuffle.shuffleWriterType.requiresResizingShuffleOutput => - VeloxResizeBatchesExec(s, range.min, range.max, preferredBatchBytes) + } + } + + private def addResizeBatchesForShuffleOutput( + plan: SparkPlan, + min: Int, + max: Int, + preferredBatchBytes: Long): SparkPlan = { + plan match { + case s: ShuffleQueryStageExec if requiresResizingShuffleOutput(s) => + VeloxResizeBatchesExec(s, min, max, preferredBatchBytes) + case a @ AQEShuffleReadExec(s @ ShuffleQueryStageExec(_, _, _), _) + if requiresResizingShuffleOutput(s) => + VeloxResizeBatchesExec(a, min, max, preferredBatchBytes) + case other => + other.withNewChildren(other.children.map( + p => addResizeBatchesForShuffleOutput(p, min, max, preferredBatchBytes))) + } + } + + private def requiresResizingShuffleOutput(s: ShuffleQueryStageExec): Boolean = { + s.shuffle match { + case c: ColumnarShuffleExchangeExec + if c.shuffleWriterType.requiresResizingShuffleOutput => + true + case _ => false } } }