Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, context.partitionId),
nativeBufferSize,
GlutenConfig.get.columnarShuffleReallocThreshold,
partitionWriterHandle
partitionWriterHandle,
false
)
case SortShuffleWriterType =>
shuffleWriterJniWrapper.createSortShuffleWriter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ protected void writeImpl(Iterator<Product2<K, V>> records) {
columnarDep.nativePartitioning(), partitionId),
nativeBufferSize,
reallocThreshold,
partitionWriterHandle);
partitionWriterHandle,
false);
}

runtime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class ColumnarShuffleWriter[K, V](

private val blockManager = SparkEnv.get.blockManager

private val rowBasedChecksumEnabled: Boolean = GlutenMapStatusUtil.isRowBasedChecksumEnabled

// Are we in the process of stopping? Because map tasks can call stop() with success = true
// and then call stop() with success = false if they get an exception, we want to make sure
// we don't try deleting files, etc twice.
Expand Down Expand Up @@ -192,7 +194,8 @@ class ColumnarShuffleWriter[K, V](
taskContext.partitionId),
nativeBufferSize,
reallocThreshold,
partitionWriterHandle
partitionWriterHandle,
rowBasedChecksumEnabled
)
}

Expand Down Expand Up @@ -277,11 +280,16 @@ class ColumnarShuffleWriter[K, V](
}
}

// The partitionLength is much more than vanilla spark partitionLengths,
// almost 3 times than vanilla spark partitionLengths
// This value is sensitive in rules such as AQE rule OptimizeSkewedJoin DynamicJoinSelection
// May affect the final plan
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
// Use native row-based checksums (order-independent, per-row XXH64) for MapStatus.
val rowChecksums = splitResult.getRowBasedChecksums
val aggregatedChecksum = if (rowChecksums != null && rowChecksums.nonEmpty) {
rowChecksums.foldLeft(0L)((acc, c) => acc * 31L + c)
} else 0L
mapStatus = GlutenMapStatusUtil.createMapStatus(
blockManager.shuffleServerId,
partitionLengths,
mapId,
aggregatedChecksum)
}

private def handleEmptyInput(): Unit = {
Expand All @@ -292,7 +300,11 @@ class ColumnarShuffleWriter[K, V](
partitionLengths,
Array[Long](),
null)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
mapStatus = GlutenMapStatusUtil.createMapStatus(
blockManager.shuffleServerId,
partitionLengths,
mapId,
0L)
}

@throws[IOException]
Expand Down
16 changes: 13 additions & 3 deletions cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
jniByteInputStreamClose = getMethodIdOrError(env, jniByteInputStreamClass, "close", "()V");

splitResultClass = createGlobalClassReferenceOrError(env, "Lorg/apache/gluten/vectorized/GlutenSplitResult;");
splitResultConstructor = getMethodIdOrError(env, splitResultClass, "<init>", "(JJJJJJJJJJDJ[J[J)V");
splitResultConstructor = getMethodIdOrError(env, splitResultClass, "<init>", "(JJJJJJJJJJDJ[J[J[J)V");

metricsBuilderClass = createGlobalClassReferenceOrError(env, "Lorg/apache/gluten/metrics/Metrics;");

Expand Down Expand Up @@ -990,7 +990,8 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
jint startPartitionId,
jint splitBufferSize,
jdouble splitBufferReallocThreshold,
jlong partitionWriterHandle) {
jlong partitionWriterHandle,
jboolean rowBasedChecksumEnabled) {
JNI_METHOD_START
const auto ctx = getRuntime(env, wrapper);

Expand All @@ -1005,6 +1006,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
startPartitionId,
splitBufferSize,
splitBufferReallocThreshold);
shuffleWriterOptions->rowBasedChecksumEnabled = rowBasedChecksumEnabled;

return ctx->saveObject(ctx->createShuffleWriter(numPartitions, partitionWriter, shuffleWriterOptions));
JNI_METHOD_END(kInvalidObjectHandle)
Expand Down Expand Up @@ -1159,6 +1161,13 @@ JNIEXPORT jobject JNICALL Java_org_apache_gluten_vectorized_ShuffleWriterJniWrap
auto rawSrc = reinterpret_cast<const jlong*>(rawPartitionLengths.data());
env->SetLongArrayRegion(rawPartitionLengthArr, 0, rawPartitionLengths.size(), rawSrc);

const auto& rowBasedChecksums = shuffleWriter->rowBasedChecksums();
auto rowBasedChecksumArr = env->NewLongArray(rowBasedChecksums.size());
if (!rowBasedChecksums.empty()) {
auto checksumSrc = reinterpret_cast<const jlong*>(rowBasedChecksums.data());
env->SetLongArrayRegion(rowBasedChecksumArr, 0, rowBasedChecksums.size(), checksumSrc);
}

jobject splitResult = env->NewObject(
splitResultClass,
splitResultConstructor,
Expand All @@ -1175,7 +1184,8 @@ JNIEXPORT jobject JNICALL Java_org_apache_gluten_vectorized_ShuffleWriterJniWrap
shuffleWriter->avgDictionaryFields(),
shuffleWriter->dictionarySize(),
partitionLengthArr,
rawPartitionLengthArr);
rawPartitionLengthArr,
rowBasedChecksumArr);

return splitResult;
JNI_METHOD_END(nullptr)
Expand Down
2 changes: 2 additions & 0 deletions cpp/core/shuffle/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct ShuffleWriterOptions {
ShuffleWriterType shuffleWriterType;
Partitioning partitioning = Partitioning::kRoundRobin;
int32_t startPartitionId = 0;
bool rowBasedChecksumEnabled = false;

ShuffleWriterOptions(ShuffleWriterType shuffleWriterType) : shuffleWriterType(shuffleWriterType) {}

Expand Down Expand Up @@ -224,5 +225,6 @@ struct ShuffleWriterMetrics {
int64_t dictionarySize{0};
std::vector<int64_t> partitionLengths{};
std::vector<int64_t> rawPartitionLengths{}; // Uncompressed size.
std::vector<int64_t> rowBasedChecksums{}; // Per-partition row-based checksums.
};
} // namespace gluten
4 changes: 4 additions & 0 deletions cpp/core/shuffle/ShuffleWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ const std::vector<int64_t>& ShuffleWriter::rawPartitionLengths() const {
return metrics_.rawPartitionLengths;
}

const std::vector<int64_t>& ShuffleWriter::rowBasedChecksums() const {
return metrics_.rowBasedChecksums;
}

ShuffleWriter::ShuffleWriter(int32_t numPartitions, Partitioning partitioning)
: numPartitions_(numPartitions), partitioning_(partitioning) {}
} // namespace gluten
2 changes: 2 additions & 0 deletions cpp/core/shuffle/ShuffleWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class ShuffleWriter : public Reclaimable {

const std::vector<int64_t>& rawPartitionLengths() const;

const std::vector<int64_t>& rowBasedChecksums() const;

protected:
ShuffleWriter(int32_t numPartitions, Partitioning partitioning);

Expand Down
64 changes: 64 additions & 0 deletions cpp/velox/shuffle/VeloxHashShuffleWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include "utils/VeloxArrowUtils.h"
#include "velox/buffer/Buffer.h"
#include "velox/common/base/Nulls.h"
#include "velox/external/xxhash/xxhash.h"
#include "velox/row/UnsafeRowFast.h"
#include "velox/type/HugeInt.h"
#include "velox/type/Timestamp.h"
#include "velox/type/Type.h"
Expand Down Expand Up @@ -182,6 +184,11 @@ arrow::Status VeloxHashShuffleWriter::init() {

partitionBufferBase_.resize(numPartitions_);

if (rowBasedChecksumEnabled_) {
checksumXor_.resize(numPartitions_, 0);
checksumSum_.resize(numPartitions_, 0);
}

return arrow::Status::OK();
}

Expand Down Expand Up @@ -361,6 +368,17 @@ arrow::Status VeloxHashShuffleWriter::stop() {

stat();

// Populate row-based checksums into metrics.
if (rowBasedChecksumEnabled_) {
metrics_.rowBasedChecksums.resize(numPartitions_);
for (auto pid = 0; pid < numPartitions_; ++pid) {
int64_t xorVal = checksumXor_[pid];
int64_t sumVal = checksumSum_[pid];
int64_t rotated = (static_cast<uint64_t>(sumVal) << 27) | (static_cast<uint64_t>(sumVal) >> 37);
metrics_.rowBasedChecksums[pid] = xorVal ^ rotated;
}
}

return arrow::Status::OK();
}

Expand Down Expand Up @@ -422,6 +440,7 @@ void VeloxHashShuffleWriter::setSplitState(SplitState state) {
arrow::Status VeloxHashShuffleWriter::doSplit(const facebook::velox::RowVector& rv, int64_t memLimit) {
auto rowNum = rv.size();
RETURN_NOT_OK(buildPartition2Row(rowNum));
computeRowBasedChecksums(rv);
RETURN_NOT_OK(updateInputHasNull(rv));

{
Expand Down Expand Up @@ -1503,4 +1522,49 @@ bool VeloxHashShuffleWriter::isExtremelyLargeBatch(facebook::velox::RowVectorPtr
return (rv->size() > maxBatchSize_ && maxBatchSize_ > 0);
}

void VeloxHashShuffleWriter::computeRowBasedChecksums(const facebook::velox::RowVector& rv) {
if (!rowBasedChecksumEnabled_) {
return;
}

auto numRows = rv.size();
VELOX_DCHECK(rv.nulls() == nullptr, "RowVector with top-level nulls not supported for checksum");
// Get the RowVector to serialize (strip pid column if present).
facebook::velox::RowVectorPtr dataVector;
if (partitioner_->hasPid()) {
// Strip the first column (partition id).
auto rowType = std::dynamic_pointer_cast<const facebook::velox::RowType>(rv.type());
std::vector<std::string> names(rowType->names().begin() + 1, rowType->names().end());
std::vector<facebook::velox::TypePtr> types(rowType->children().begin() + 1, rowType->children().end());
std::vector<facebook::velox::VectorPtr> children(rv.children().begin() + 1, rv.children().end());
auto dataType = facebook::velox::ROW(std::move(names), std::move(types));
dataVector =
std::make_shared<facebook::velox::RowVector>(rv.pool(), dataType, nullptr, numRows, std::move(children));
} else {
auto rowType = std::dynamic_pointer_cast<const facebook::velox::RowType>(rv.type());
dataVector = std::make_shared<facebook::velox::RowVector>(rv.pool(), rowType, nullptr, numRows, rv.children());
}

facebook::velox::row::UnsafeRowFast fast(dataVector);
auto dataType = std::dynamic_pointer_cast<const facebook::velox::RowType>(dataVector->type());
auto fixedSize = facebook::velox::row::UnsafeRowFast::fixedRowSize(dataType);
int32_t bufSize = fixedSize.value_or(1024);
if (checksumBuffer_.size() < static_cast<size_t>(bufSize)) {
checksumBuffer_.resize(bufSize);
}

for (uint32_t row = 0; row < numRows; ++row) {
auto pid = row2Partition_[row];
auto size = fast.rowSize(row);
if (size > static_cast<int32_t>(checksumBuffer_.size())) {
checksumBuffer_.resize(size);
}
fast.serialize(row, checksumBuffer_.data());

auto hash = static_cast<int64_t>(XXH64(checksumBuffer_.data(), size, 0));
checksumXor_[pid] ^= hash;
checksumSum_[pid] += hash;
}
}

} // namespace gluten
11 changes: 10 additions & 1 deletion cpp/velox/shuffle/VeloxHashShuffleWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ class VeloxHashShuffleWriter : public VeloxShuffleWriter {
MemoryManager* memoryManager)
: VeloxShuffleWriter(numPartitions, partitionWriter, options, memoryManager),
splitBufferSize_(options->splitBufferSize),
splitBufferReallocThreshold_(options->splitBufferReallocThreshold) {
splitBufferReallocThreshold_(options->splitBufferReallocThreshold),
rowBasedChecksumEnabled_(options->rowBasedChecksumEnabled) {
arenas_.resize(numPartitions);
}

Expand Down Expand Up @@ -436,6 +437,14 @@ class VeloxHashShuffleWriter : public VeloxShuffleWriter {
std::optional<uint32_t> partitionBufferInUse_{std::nullopt};

std::vector<std::unique_ptr<facebook::velox::StreamArena>> arenas_;

// Row-based checksum state (per-partition XOR + SUM aggregation).
bool rowBasedChecksumEnabled_{false};
std::vector<int64_t> checksumXor_;
std::vector<int64_t> checksumSum_;
std::vector<char> checksumBuffer_;

void computeRowBasedChecksums(const facebook::velox::RowVector& rv);
}; // class VeloxHashBasedShuffleWriter

} // namespace gluten
1 change: 1 addition & 0 deletions cpp/velox/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ add_velox_test(spark_functions_test SOURCES SparkFunctionTest.cc
add_velox_test(runtime_test SOURCES RuntimeTest.cc)
add_velox_test(velox_memory_test SOURCES MemoryManagerTest.cc)
add_velox_test(buffer_outputstream_test SOURCES BufferOutputStreamTest.cc)
add_velox_test(row_based_checksum_test SOURCES RowBasedChecksumTest.cc)
if(BUILD_EXAMPLES)
add_velox_test(my_udf_test SOURCES MyUdfTest.cc)
endif()
Expand Down
Loading
Loading