diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala index dcfa0ee525c0..d3e71d7aa650 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala @@ -24,12 +24,12 @@ import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.gluten.runtime.Runtimes import org.apache.gluten.vectorized.{ArrowWritableColumnVector, NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper, NativePartitioning} -import org.apache.spark.{Partitioner, RangePartitioner, ShuffleDependency} +import org.apache.spark.{KeyGroupedPartitioner, Partitioner, RangePartitioner, ShuffleDependency} import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ColumnarShuffleDependency, GlutenShuffleUtils} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, BoundReference, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.SQLExecution @@ -98,8 +98,8 @@ object ExecUtil { executionId, metrics("numPartitions") :: Nil) // scalastyle:on argcount - // only used for fallback range partitioning - val rangePartitioner: Option[Partitioner] = newPartitioning match { + // Partitioner for JVM-side partition ID computation (Range and KeyGrouped) + val fallbackPartitioner: Option[Partitioner] = newPartitioning match { case RangePartitioning(sortingExpressions, numPartitions) => // Extract only fields used for sorting to avoid collecting large fields that does not // affect sorting result when deciding partition bounds in RangePartitioner @@ -128,10 +128,22 @@ object ExecUtil { ascending = true, samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) Some(part) + case k @ KeyGroupedPartitioning(_, n, _, _) => + // Build a lookup map from partition key values to partition indices. + // KeyGroupedPartitioner.getPartition uses getOrElseUpdate: if a key is not found + // in the map, it falls back to hash-based assignment (nonNegativeMod). This is a + // best-effort fallback - all expected keys should be present in uniquePartitionValues. + val dataTypes = k.expressions.map(_.dataType) + val valueMap = scala.collection.mutable.Map.empty[Seq[Any], Int] + k.uniquePartitionValues.zipWithIndex.foreach { + case (partition, index) => + valueMap.update(partition.toSeq(dataTypes), index) + } + Some(new KeyGroupedPartitioner(valueMap, n)) case _ => None } - // only used for fallback range partitioning + // Used for JVM-side partition ID computation (Range and KeyGrouped) def computeAndAddPartitionId( cbIter: Iterator[ColumnarBatch], partitionKeyExtractor: InternalRow => Any): Iterator[(Int, ColumnarBatch)] = { @@ -146,7 +158,8 @@ object ExecUtil { .head convertColumnarToRow(cb).zipWithIndex.foreach { case (row, i) => - val pid = rangePartitioner.get.getPartition(partitionKeyExtractor(row)) + val pid = fallbackPartitioner.get + .getPartition(partitionKeyExtractor(row)) pidVec.putInt(i, pid) } val pidBatch = VeloxColumnarBatches.toVeloxBatch( @@ -172,6 +185,11 @@ object ExecUtil { // range partitioning fall back to row-based partition id computation case RangePartitioning(orders, n) => new NativePartitioning(GlutenShuffleUtils.RangePartitioningShortName, n) + // Key grouped partitioning reuses RangePartitioning's native short name because both + // use the same JVM-side partition ID computation pattern: the pid column is prepended + // on the JVM side and the native shuffle writer reads it rather than computing IDs natively. + case KeyGroupedPartitioning(_, n, _, _) => + new NativePartitioning(GlutenShuffleUtils.RangePartitioningShortName, n) } val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] && @@ -197,6 +215,26 @@ object ExecUtil { }, isOrderSensitive = isOrderSensitive ) + case KeyGroupedPartitioning(expressions, _, _, _) => + rdd.mapPartitionsWithIndexInternal( + (_, cbIter) => { + val partitionKeyExtractor: InternalRow => Any = { + val boundExprs = BindReferences + .bindReferences(expressions, outputAttributes) + row => { + val keyValues = new Array[Any](boundExprs.length) + var i = 0 + while (i < boundExprs.length) { + keyValues(i) = boundExprs(i).eval(row) + i += 1 + } + keyValues.toImmutableArraySeq + } + } + computeAndAddPartitionId(cbIter, partitionKeyExtractor) + }, + isOrderSensitive = isOrderSensitive + ) case _ => rdd.mapPartitionsWithIndexInternal( (_, cbIter) => cbIter.map(cb => (0, cb)), diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxShufflePartitioningSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxShufflePartitioningSuite.scala new file mode 100644 index 000000000000..0c88a00bd593 --- /dev/null +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxShufflePartitioningSuite.scala @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.spark.{KeyGroupedPartitioner, SparkConf} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BindReferences, GenericInternalRow} +import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * SxS test suite for columnar shuffle exchange partitioning. + * + * Tests cover hash, range, round-robin, and single partitioning via columnar shuffle. + * KeyGroupedPartitioning is validated at the unit level (key extraction, partitioner construction) + * but cannot be exercised end-to-end without V2 data source connectors (Iceberg/Paimon) which + * require test infrastructure not available in this module. + */ +class VeloxShufflePartitioningSuite extends VeloxWholeStageTransformerSuite { + + override protected val resourcePath: String = "/tpch-data-parquet" + override protected val fileFormat: String = "parquet" + + override protected def sparkConf: SparkConf = { + super.sparkConf + .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") + .set("spark.sql.shuffle.partitions", "4") + .set("spark.memory.offHeap.size", "2g") + .set("spark.sql.ansi.enabled", "false") + } + + override def beforeAll(): Unit = { + super.beforeAll() + createTPCHNotNullTables() + } + + // === A: Basic Correctness (Hash Partitioning) === + + test("A1: SxS hash shuffle via single-key GROUP BY") { + runQueryAndCompare("""SELECT l_returnflag, count(*) as cnt + |FROM lineitem + |GROUP BY l_returnflag""".stripMargin) { _ => } + } + + test("A2: SxS hash shuffle via multi-key GROUP BY") { + runQueryAndCompare("""SELECT l_returnflag, l_linestatus, + | sum(l_quantity) as total_qty + |FROM lineitem + |GROUP BY l_returnflag, l_linestatus""".stripMargin) { _ => } + } + + test("A3: SxS hash shuffle via column repartition") { + compareDfResultsAgainstVanillaSpark( + () => + spark + .sql("SELECT * FROM lineitem") + .repartition( + 4, + spark + .sql("SELECT * FROM lineitem") + .col("l_orderkey")), + compareResult = true, + customCheck = { _ => }, + noFallBack = true + ) + } + + test("A4: SxS hash shuffle via equi-JOIN") { + runQueryAndCompare("""SELECT l.l_orderkey, l.l_partkey + |FROM lineitem l + |JOIN orders o ON l.l_orderkey = o.o_orderkey + |WHERE o.o_orderpriority = '1-URGENT'""".stripMargin) { _ => } + } + + // === B: Range Partitioning === + + test("B1: SxS range shuffle via ORDER BY single column") { + runQueryAndCompare("""SELECT l_orderkey, l_quantity + |FROM lineitem + |ORDER BY l_quantity, l_orderkey""".stripMargin) { _ => } + } + + test("B2: SxS range shuffle via ORDER BY multiple columns") { + runQueryAndCompare("""SELECT l_returnflag, l_linestatus, + | l_quantity, l_orderkey + |FROM lineitem + |ORDER BY l_returnflag, l_linestatus, + | l_quantity, l_orderkey""".stripMargin) { _ => } + } + + test("B3: SxS range shuffle ORDER BY with LIMIT") { + runQueryAndCompare("""SELECT l_orderkey, l_extendedprice + |FROM lineitem + |ORDER BY l_extendedprice DESC, l_orderkey + |LIMIT 50""".stripMargin) { _ => } + } + + // === C: Round-Robin Partitioning === + + test("C1: SxS round-robin repartition") { + compareDfResultsAgainstVanillaSpark( + () => + spark + .sql("""SELECT l_orderkey, l_partkey, l_quantity + |FROM lineitem""".stripMargin) + .repartition(3), + compareResult = true, + customCheck = { _ => }, + noFallBack = true + ) + } + + test("C2: SxS round-robin with different partition count") { + compareDfResultsAgainstVanillaSpark( + () => + spark + .sql("SELECT l_orderkey FROM lineitem") + .repartition(7), + compareResult = true, + customCheck = { _ => }, + noFallBack = true) + } + + // === D: Single Partitioning === + + test("D1: SxS single partition via coalesce(1)") { + compareDfResultsAgainstVanillaSpark( + () => + spark + .sql("""SELECT l_orderkey, l_quantity + |FROM lineitem""".stripMargin) + .coalesce(1), + compareResult = true, + customCheck = { _ => }, + noFallBack = true + ) + } + + test("D2: SxS single partition via global aggregation") { + runQueryAndCompare("""SELECT count(*) as cnt, + | sum(l_quantity) as total_qty, + | avg(l_extendedprice) as avg_price + |FROM lineitem""".stripMargin) { _ => } + } + + // === E: Null Semantics === + + test("E1: SxS hash shuffle GROUP BY with NULLs") { + runQueryAndCompare("""SELECT l_comment, count(*) as cnt + |FROM lineitem + |GROUP BY l_comment""".stripMargin) { _ => } + } + + test("E2: SxS hash shuffle JOIN with NULL keys") { + runQueryAndCompare("""SELECT l.l_orderkey, l.l_partkey + |FROM lineitem l + |LEFT JOIN orders o + | ON l.l_orderkey = o.o_orderkey + |ORDER BY l.l_orderkey, l.l_partkey""".stripMargin) { _ => } + } + + // === F: Data Type Coverage === + + test("F1: SxS hash shuffle with decimal types") { + runQueryAndCompare("""SELECT l_extendedprice, + | sum(l_extendedprice * l_discount) as revenue + |FROM lineitem + |GROUP BY l_extendedprice""".stripMargin) { _ => } + } + + test("F2: SxS range shuffle with string ordering") { + runQueryAndCompare("""SELECT l_returnflag, l_linestatus, l_orderkey + |FROM lineitem + |ORDER BY l_returnflag, l_linestatus, + | l_orderkey""".stripMargin) { _ => } + } + + // === G: Boundary Cases === + + test("G1: SxS hash shuffle single partition") { + withSQLConf("spark.sql.shuffle.partitions" -> "1") { + runQueryAndCompare("""SELECT l_returnflag, count(*) as cnt + |FROM lineitem + |GROUP BY l_returnflag""".stripMargin) { _ => } + } + } + + test("G2: SxS hash shuffle many partitions") { + withSQLConf("spark.sql.shuffle.partitions" -> "32") { + runQueryAndCompare("""SELECT l_returnflag, count(*) as cnt + |FROM lineitem + |GROUP BY l_returnflag""".stripMargin) { _ => } + } + } + + // === H: KeyGroupedPartitioning Unit Tests === + // These test the key extraction logic used when + // KeyGroupedPartitioning is triggered by V2 connectors. + + test("H1: key extractor single-column integer") { + val keyAttr = AttributeReference("key", IntegerType)() + val valAttr = AttributeReference("val", StringType)() + val outputAttrs = Seq(keyAttr, valAttr) + val boundExprs = BindReferences.bindReferences( + Seq(keyAttr.asInstanceOf[org.apache.spark.sql.catalyst.expressions.Expression]), + outputAttrs) + val row = new GenericInternalRow(Array[Any](42, UTF8String.fromString("hello"))) + val extracted = boundExprs.map(_.eval(row)) + assert(extracted == Seq(42)) + } + + test("H2: key extractor multi-column composite key") { + val k1 = AttributeReference("k1", IntegerType)() + val k2 = AttributeReference("k2", StringType)() + val valAttr = AttributeReference("val", IntegerType)() + val outputAttrs = Seq(k1, k2, valAttr) + val boundExprs = BindReferences.bindReferences( + Seq( + k1.asInstanceOf[org.apache.spark.sql.catalyst.expressions.Expression], + k2.asInstanceOf[org.apache.spark.sql.catalyst.expressions.Expression]), + outputAttrs + ) + val row = new GenericInternalRow(Array[Any](10, UTF8String.fromString("abc"), 999)) + val extracted = boundExprs.map(_.eval(row)) + assert(extracted == Seq(10, UTF8String.fromString("abc"))) + } + + test("H3: key extractor with null values") { + val k1 = AttributeReference("k1", IntegerType, nullable = true)() + val valAttr = AttributeReference("val", StringType)() + val outputAttrs = Seq(k1, valAttr) + val boundExprs = BindReferences.bindReferences( + Seq(k1.asInstanceOf[org.apache.spark.sql.catalyst.expressions.Expression]), + outputAttrs) + val row = new GenericInternalRow(Array[Any](null, UTF8String.fromString("hello"))) + val extracted = boundExprs.map(_.eval(row)) + assert(extracted == Seq(null)) + } + + test("H4: key extractor with long and double types") { + val k1 = AttributeReference("k1", LongType)() + val k2 = AttributeReference("k2", DoubleType)() + val valAttr = AttributeReference("val", IntegerType)() + val outputAttrs = Seq(k1, k2, valAttr) + val boundExprs = BindReferences.bindReferences( + Seq( + k1.asInstanceOf[org.apache.spark.sql.catalyst.expressions.Expression], + k2.asInstanceOf[org.apache.spark.sql.catalyst.expressions.Expression]), + outputAttrs + ) + val row = new GenericInternalRow(Array[Any](Long.MaxValue, 3.14d, 1)) + val extracted = boundExprs.map(_.eval(row)) + assert(extracted == Seq(Long.MaxValue, 3.14d)) + } + + test("H5: KeyGroupedPartitioner maps keys to correct partition IDs") { + val valueMap = scala.collection.mutable.Map.empty[Seq[Any], Int] + valueMap.update(Seq(1, "a"), 0) + valueMap.update(Seq(2, "b"), 1) + valueMap.update(Seq(3, "c"), 2) + val partitioner = new KeyGroupedPartitioner(valueMap, 3) + + assert(partitioner.getPartition(Seq(1, "a")) == 0) + assert(partitioner.getPartition(Seq(2, "b")) == 1) + assert(partitioner.getPartition(Seq(3, "c")) == 2) + assert(partitioner.numPartitions == 3) + } + + test("H6: KeyGroupedPartitioner end-to-end with key extraction") { + val k1 = AttributeReference("k1", IntegerType)() + val k2 = AttributeReference("k2", StringType)() + val valAttr = AttributeReference("val", IntegerType)() + val outputAttrs = Seq(k1, k2, valAttr) + val boundExprs = BindReferences.bindReferences( + Seq( + k1.asInstanceOf[org.apache.spark.sql.catalyst.expressions.Expression], + k2.asInstanceOf[org.apache.spark.sql.catalyst.expressions.Expression]), + outputAttrs + ) + + // Build partitioner with known keys + val valueMap = scala.collection.mutable.Map.empty[Seq[Any], Int] + valueMap.update(Seq(10, UTF8String.fromString("abc")), 0) + valueMap.update(Seq(20, UTF8String.fromString("def")), 1) + val partitioner = new KeyGroupedPartitioner(valueMap, 2) + + // Extract key from row and look up partition + val row = new GenericInternalRow(Array[Any](10, UTF8String.fromString("abc"), 999)) + val key = boundExprs.map(_.eval(row)).toSeq + val pid = partitioner.getPartition(key) + assert(pid == 0) + + val row2 = new GenericInternalRow(Array[Any](20, UTF8String.fromString("def"), 123)) + val key2 = boundExprs.map(_.eval(row2)).toSeq + val pid2 = partitioner.getPartition(key2) + assert(pid2 == 1) + } + + // === Helpers === + // SxS tests use runQueryAndCompare / compareDfResultsAgainstVanillaSpark + // from GlutenQueryComparisonTest which automatically verifies: + // 1. Result correctness (checkAnswer against vanilla Spark) + // 2. No fallback (FallbackUtil.hasFallback check on executed plan) +} diff --git a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecBase.scala b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecBase.scala index 24db757720be..391aaff991f1 100644 --- a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecBase.scala +++ b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExecBase.scala @@ -104,6 +104,7 @@ abstract class ColumnarShuffleExchangeExecBase( case _: RangePartitioning => ValidationResult.succeeded case SinglePartition => ValidationResult.succeeded case _: RoundRobinPartitioning => ValidationResult.succeeded + case _: KeyGroupedPartitioning => ValidationResult.succeeded case _ => ValidationResult.failed( s"Unsupported partitioning ${outputPartitioning.getClass.getSimpleName}")