Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.runtime.Runtimes
import org.apache.gluten.vectorized.{ArrowWritableColumnVector, NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper, NativePartitioning}

import org.apache.spark.{Partitioner, RangePartitioner, ShuffleDependency}
import org.apache.spark.{KeyGroupedPartitioner, Partitioner, RangePartitioner, ShuffleDependency}
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ColumnarShuffleDependency, GlutenShuffleUtils}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, BoundReference, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.SQLExecution
Expand Down Expand Up @@ -98,8 +98,8 @@ object ExecUtil {
executionId,
metrics("numPartitions") :: Nil)
// scalastyle:on argcount
// only used for fallback range partitioning
val rangePartitioner: Option[Partitioner] = newPartitioning match {
// Partitioner for JVM-side partition ID computation (Range and KeyGrouped)
val fallbackPartitioner: Option[Partitioner] = newPartitioning match {
case RangePartitioning(sortingExpressions, numPartitions) =>
// Extract only fields used for sorting to avoid collecting large fields that does not
// affect sorting result when deciding partition bounds in RangePartitioner
Expand Down Expand Up @@ -128,10 +128,22 @@ object ExecUtil {
ascending = true,
samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
Some(part)
case k @ KeyGroupedPartitioning(_, n, _, _) =>
// Build a lookup map from partition key values to partition indices.
// KeyGroupedPartitioner.getPartition uses getOrElseUpdate: if a key is not found
// in the map, it falls back to hash-based assignment (nonNegativeMod). This is a
// best-effort fallback - all expected keys should be present in uniquePartitionValues.
val dataTypes = k.expressions.map(_.dataType)
val valueMap = scala.collection.mutable.Map.empty[Seq[Any], Int]
k.uniquePartitionValues.zipWithIndex.foreach {
case (partition, index) =>
valueMap.update(partition.toSeq(dataTypes), index)
}
Some(new KeyGroupedPartitioner(valueMap, n))
case _ => None
}

// only used for fallback range partitioning
// Used for JVM-side partition ID computation (Range and KeyGrouped)
def computeAndAddPartitionId(
cbIter: Iterator[ColumnarBatch],
partitionKeyExtractor: InternalRow => Any): Iterator[(Int, ColumnarBatch)] = {
Expand All @@ -146,7 +158,8 @@ object ExecUtil {
.head
convertColumnarToRow(cb).zipWithIndex.foreach {
case (row, i) =>
val pid = rangePartitioner.get.getPartition(partitionKeyExtractor(row))
val pid = fallbackPartitioner.get
.getPartition(partitionKeyExtractor(row))
pidVec.putInt(i, pid)
}
val pidBatch = VeloxColumnarBatches.toVeloxBatch(
Expand All @@ -172,6 +185,11 @@ object ExecUtil {
// range partitioning fall back to row-based partition id computation
case RangePartitioning(orders, n) =>
new NativePartitioning(GlutenShuffleUtils.RangePartitioningShortName, n)
// Key grouped partitioning reuses RangePartitioning's native short name because both
// use the same JVM-side partition ID computation pattern: the pid column is prepended
// on the JVM side and the native shuffle writer reads it rather than computing IDs natively.
case KeyGroupedPartitioning(_, n, _, _) =>
new NativePartitioning(GlutenShuffleUtils.RangePartitioningShortName, n)
}

val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] &&
Expand All @@ -197,6 +215,26 @@ object ExecUtil {
},
isOrderSensitive = isOrderSensitive
)
case KeyGroupedPartitioning(expressions, _, _, _) =>
rdd.mapPartitionsWithIndexInternal(
(_, cbIter) => {
val partitionKeyExtractor: InternalRow => Any = {
val boundExprs = BindReferences
.bindReferences(expressions, outputAttributes)
row => {
val keyValues = new Array[Any](boundExprs.length)
var i = 0
while (i < boundExprs.length) {
keyValues(i) = boundExprs(i).eval(row)
i += 1
}
keyValues.toImmutableArraySeq
}
}
computeAndAddPartitionId(cbIter, partitionKeyExtractor)
},
isOrderSensitive = isOrderSensitive
)
case _ =>
rdd.mapPartitionsWithIndexInternal(
(_, cbIter) => cbIter.map(cb => (0, cb)),
Expand Down
Loading
Loading