diff --git a/.gitignore b/.gitignore
index a48d1f65f8..66a44247c8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -320,4 +320,8 @@ global.json
tests/llm/*/
# agent instructions
-.github/instructions/
\ No newline at end of file
+.github/instructions/
+
+# AI
+CLAUDE.md
+AGENTS.md
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceConvertVisitor.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceConvertVisitor.cs
index 3ebc69ab4a..038db85702 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceConvertVisitor.cs
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceConvertVisitor.cs
@@ -119,7 +119,7 @@ public void IndWrite(string? value)
///
/// convert single prim function to c source.
///
-public abstract class CSourceConvertVisitor : ExprFunctor
+public class CSourceConvertVisitor : ExprFunctor
{
protected readonly Dictionary _exprMemo = new(ReferenceEqualityComparer.Instance);
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelCSourceConvertVisitor.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelCSourceConvertVisitor.cs
index c3f3c1db12..346e4bffa7 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelCSourceConvertVisitor.cs
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelCSourceConvertVisitor.cs
@@ -401,18 +401,12 @@ protected override CSymbol VisitCall(Call expr)
break;
case TIR.NTT.Matmul matmul:
{
- var dimInfo = IR.NTT.VectorizedMatMul.GetDimInfo(matmul.TransposeA, matmul.TransposeB, args[0].CheckedShape.Rank, args[1].CheckedShape.Rank);
WriteWithProfiler(
RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Matmul.cshtml", new TypedKernelTemplateModel(matmul)
{
Arguments = args.Select(x => new KernelArgument { Symbol = VisitBuffer(x, local: true) }).ToArray(),
}).Result,
"matmul");
- if (args[0] is TIR.Buffer a && a.DistributedType?.AxisPolicies[dimInfo.Lk] is SBPSplit s)
- {
- var reduceKind = "tar::reduce_kind::" + string.Join("_", Enumerable.Range(0, TargetOptions.HierarchyNames.Length).Select(i => (s.Axes.Contains(i) ? "r" : string.Empty) + TargetOptions.HierarchyNames[i]));
- WriteIndWithProfiler($"tac::tensor_reduce_sync({VisitBuffer(args[2], local: true).Name}, {VisitBuffer(args[2], local: true).Name});\n");
- }
}
break;
@@ -428,18 +422,12 @@ protected override CSymbol VisitCall(Call expr)
break;
case TIR.NTT.PackedMatMul matmul:
{
- var dimInfo = IR.NTT.VectorizedMatMul.GetDimInfo(false, true, args[0].CheckedShape.Rank, args[1].CheckedShape.Rank);
WriteWithProfiler(
RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/PackedMatMul.cshtml", new TypedKernelTemplateModel(matmul)
{
Arguments = args.Select(x => new KernelArgument { Symbol = VisitBuffer(x, local: true) }).ToArray(),
}).Result,
"packed_matmul");
- if (args[0] is TIR.Buffer a && a.DistributedType?.AxisPolicies[dimInfo.Lk] is SBPSplit s)
- {
- var reduceKind = "tar::reduce_kind::" + string.Join("_", Enumerable.Range(0, TargetOptions.HierarchyNames.Length).Select(i => (s.Axes.Contains(i) ? "r" : string.Empty) + TargetOptions.HierarchyNames[i]));
- WriteIndWithProfiler($"tac::tensor_reduce_sync({VisitBuffer(args[2], local: true).Name}, {VisitBuffer(args[2], local: true).Name});\n");
- }
}
break;
@@ -449,11 +437,6 @@ protected override CSymbol VisitCall(Call expr)
case TIR.NTT.Gather gather:
{
WriteWithProfiler($"gather({VisitBuffer(args[0], local: false).Name}, {VisitBuffer(args[1], local: true).Name}, {VisitBuffer(args[2], local: true).Name}, {gather.Axis}_dim);\n");
- if (args[0] is TIR.Buffer b && b.DistributedType?.AxisPolicies[gather.Axis] is SBPSplit s)
- {
- var reduceKind = "tar::reduce_kind::" + string.Join("_", Enumerable.Range(0, TargetOptions.HierarchyNames.Length).Select(i => (s.Axes.Contains(i) ? "r" : string.Empty) + TargetOptions.HierarchyNames[i]));
- WriteIndWithProfiler($"tac::tensor_reduce_sync({VisitBuffer(args[2], local: true).Name}, {VisitBuffer(args[2], local: true).Name});\n");
- }
}
break;
@@ -534,12 +517,11 @@ protected override CSymbol VisitCall(Call expr)
break;
case TIR.NTT.GatherReduceScatter grs:
{
- if (grs.InType.AxisPolicies.Any(s => s is SBPPartial))
+ if (grs.InType.Partial is not null)
{
- // deprecated
- var sbpPartial = (SBPPartial)grs.InType.AxisPolicies.Where(s => s is SBPPartial).Distinct().First();
- var reduceKind = "tar::reduce_kind::" + string.Join("_", grs.InType.AxisPolicies.Select((s, i) => (s is SBPPartial ? "r" : string.Empty) + TargetOptions.HierarchyNames[i]));
- WriteIndWithProfiler($"tac::tensor_reduce_sync({VisitBuffer(args[0], local: true).Name}, {VisitBuffer(args[1], local: true).Name});\n");
+ var sbpPartial = grs.InType.Partial;
+ var reduceKind = "tar::reduce_kind::" + string.Join("_", Enumerable.Range(0, TargetOptions.HierarchyNames.Length).Select(i => (sbpPartial.Axes.Contains(i) ? "r" : string.Empty) + TargetOptions.HierarchyNames[i]));
+ WriteIndWithProfiler($"tac::tensor_reduce_sync({VisitBuffer(args[0], local: false).Name}, {VisitBuffer(args[1], local: false).Name});\n");
}
else
{
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelUtility.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelUtility.cs
index d47e5a2749..664070e994 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelUtility.cs
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelUtility.cs
@@ -35,7 +35,7 @@ public static string SBPToC(this SBP value)
{
if (value is SBPSplit s)
{
- return $"S<{string.Join(", ", s.Axes)}>()";
+ return $"S<{string.Join(", ", s.Axes)}>({new CSourceConvertVisitor().Visit(s.Granularity as BaseExpr ?? None.Default).Name})";
}
else
{
diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/topo_aware_runtime.cshtml b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/topo_aware_runtime.cshtml
index 0ad5b88376..90dea517ad 100644
--- a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/topo_aware_runtime.cshtml
+++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/topo_aware_runtime.cshtml
@@ -58,8 +58,6 @@ enum reduce_kind {
};
constexpr std::array Hierarchy = {@(string.Join(", ", hierarchy))};
-auto src_ptr_tensor = nncase::ntt::make_tensor(nncase::ntt::fixed_shape_v<@(string.Join(",", hierarchy))>);
-auto dest_ptr_tensor = nncase::ntt::make_tensor(nncase::ntt::fixed_shape_v<@(string.Join(",", hierarchy))>);
template static std::byte *get_cache_address() {
return reinterpret_cast(
@@ -159,8 +157,8 @@ class tensor_reduce_sync_impl {
var cur_index = string.Join(", ", Enumerable.Range(0, hierarchy.Length).Select(i => "ntt::distributed::" + hierarchyNames[i] + "id()"));
}
- template
- void reduce_impl(TSliceIn &local, TSliceIn &remote, TSliceOut &dest) {
+ template
+ void reduce_impl(TSliceIn1 &local, TSliceIn2 &remote, TSliceOut &dest) {
if constexpr (Op == ntt::reduce_op::max) {
ntt::binary(local, remote, dest);
} else if constexpr (Op == ntt::reduce_op::sum ||
@@ -177,112 +175,141 @@ class tensor_reduce_sync_impl {
// collect all tensors pointer for access tensor from other nodes.
using TElem = typename TIn::element_type;
using TOutBase = std::decay_t;
+ static_assert(ShardedTensor, "dest must be sharded tensor");
constexpr size_t Rank = TIn::rank();
constexpr auto group_hierarchy = group_hierarchy_getter::group_hierarchy;
auto cur_index = ntt::make_shape(@(cur_index));
auto cur_index_g = index_global2group(cur_index);
- tar::src_ptr_tensor(cur_index) =
- reinterpret_cast(src.elements().data());
- tar::dest_ptr_tensor(cur_index) =
- reinterpret_cast(dest.elements().data());
reduce_group_sync();
// according to the group size split the tensor.
// todo should using better split strategy.
constexpr auto group_size = ntt::fixed_dim_v;
- const auto axis = [&] {
- dim_t axis = -1;
- loop([&](auto i) {
- if (axis == -1 && src.shape()[i] >= group_size) {
- axis = i;
- }
- });
- if (axis == -1) {
- axis = 0;
- }
- return axis;
- }();
-
- auto remain = src.shape()[axis] % (group_size);
- auto frac = src.shape()[axis] / (group_size);
- auto node_number_g = ntt::linear_offset(cur_index_g, group_hierarchy);
-
- // reduce-scatter, communicate (group_size - 1) times
- for (auto i = 0; i < group_size - 1; i++)
- {
- auto new_shape = ntt::generate_shape([&](auto j) {
- if (j == axis) {
- return ntt::where(node_number_g == group_size - 1, frac + remain, frac);
+ if (!src.shape().aggregate(true, [&](auto same, auto s, auto i){ return same && std::is_same_v(src.sharding().axis_policies))>, std::decay_t(dest.sharding().axis_policies))>>; })) {
+ using mesh_type = typename TOutBase::mesh_type;
+ const auto local_shard_index = mesh_type::local_index();
+ auto node_number_g = ntt::linear_offset(cur_index_g, group_hierarchy);
+ auto new_shape = dest.local().shape();
+ auto starts_src = src.sharding().global_offset(src.shape(), local_shard_index);
+ auto starts_dest = dest.sharding().global_offset(dest.shape(), local_shard_index);
+ auto starts = ntt::generate_shape([&](auto axis) { return starts_dest[axis] - starts_src[axis]; });
+ auto viewed_src1_tensor = src.local().view(starts, new_shape);
+ auto viewed_dest_tensor = dest.local();
+
+ // reduce-scatter, communicate (group_size - 1) times
+ for (auto i = 0; i < group_size - 1; i++)
+ {
+ auto next_index_g = ntt::unravel_index((node_number_g + i + 1) % group_size, group_hierarchy);
+
+ // keep the non-reduce axis invariant.
+ auto next_index = index_group2global(next_index_g, cur_index);
+
+ auto src2_tensor = src.template remote(next_index);
+ auto viewed_src2_tensor = src2_tensor.view(starts, new_shape);
+
+ if (i == 0) {
+ reduce_impl(viewed_src1_tensor, viewed_src2_tensor,
+ viewed_dest_tensor);
} else {
- return (dim_t)src.shape()[j];
+ reduce_impl(viewed_dest_tensor, viewed_src2_tensor,
+ viewed_dest_tensor);
}
- });
- auto starts = ntt::generate_shape([&](auto j) {
- if (j == axis) {
- return node_number_g * frac;
- } else {
- return (dim_t)0;
- }
- });
- auto viewed_src1_tensor = src.view(starts, new_shape);
- auto viewed_dest_tensor = dest.view(starts, new_shape);
-
- auto next_index_g = ntt::unravel_index((node_number_g + i + 1) % group_size, group_hierarchy);
-
- // keep the non-reduce axis invariant.
- auto next_index = index_group2global(next_index_g, cur_index);
-
- auto src2_tensor = ntt::make_tensor_view_from_address(
- (TElem *)tar::src_ptr_tensor(next_index), src.shape(),
- src.strides());
- auto viewed_src2_tensor = src2_tensor.view(starts, new_shape);
-
- if (i == 0) {
- reduce_impl(viewed_src1_tensor, viewed_src2_tensor,
- viewed_dest_tensor);
- } else {
- reduce_impl(viewed_dest_tensor, viewed_src2_tensor,
- viewed_dest_tensor);
}
- }
-
- reduce_group_sync();
- // all gather
- for (size_t i = 0; i < group_size - 1; i++) {
- auto offset = (node_number_g + i + 1) % (group_size);
- auto src_index_g = ntt::unravel_index(offset % group_size, group_hierarchy);
- auto src_index = index_group2global(src_index_g, cur_index);
-
- auto src_tensor = ntt::make_tensor_view_from_address(
- (TElem *)tar::dest_ptr_tensor(src_index), dest.shape(),
- dest.strides());
- auto starts = ntt::generate_shape([&](auto j) {
- if (j == axis) {
- return offset * frac;
- } else {
- return (dim_t)0;
+ ntt::tensor_copy_wait();
+ reduce_group_sync();
+ } else {
+ const auto axis = [&] {
+ dim_t axis = -1;
+ loop([&](auto i) {
+ if (axis == -1 && src.local().shape()[i] >= group_size) {
+ axis = i;
+ }
+ });
+ if (axis == -1) {
+ axis = 0;
}
- });
- auto new_shape = ntt::generate_shape([&](auto j) {
- if (j == axis) {
- return ntt::where(offset == group_size - 1, frac + remain, frac);
+ return axis;
+ }();
+
+ auto remain = src.local().shape()[axis] % (group_size);
+ auto frac = src.local().shape()[axis] / (group_size);
+
+ auto node_number_g = ntt::linear_offset(cur_index_g, group_hierarchy);
+
+ // reduce-scatter, communicate (group_size - 1) times
+ for (auto i = 0; i < group_size - 1; i++)
+ {
+ auto new_shape = ntt::generate_shape([&](auto j) {
+ if (j == axis) {
+ return ntt::where(node_number_g == group_size - 1, frac + remain, frac);
+ } else {
+ return (dim_t)src.local().shape()[j];
+ }
+ });
+ auto starts = ntt::generate_shape([&](auto j) {
+ if (j == axis) {
+ return node_number_g * frac;
+ } else {
+ return (dim_t)0;
+ }
+ });
+ auto viewed_src1_tensor = src.local().view(starts, new_shape);
+ auto viewed_dest_tensor = dest.local().view(starts, new_shape);
+
+ auto next_index_g = ntt::unravel_index((node_number_g + i + 1) % group_size, group_hierarchy);
+
+ // keep the non-reduce axis invariant.
+ auto next_index = index_group2global(next_index_g, cur_index);
+
+ auto src2_tensor = src.template remote(next_index);
+ auto viewed_src2_tensor = src2_tensor.view(starts, new_shape);
+
+ if (i == 0) {
+ reduce_impl(viewed_src1_tensor, viewed_src2_tensor,
+ viewed_dest_tensor);
} else {
- return (dim_t)src.shape()[j];
+ reduce_impl(viewed_dest_tensor, viewed_src2_tensor,
+ viewed_dest_tensor);
}
- });
- auto viewed_src_tensor = src_tensor.view(starts, new_shape);
- auto viewed_dest_tensor = dest.view(starts, new_shape);
- ntt::tensor_copy_async(viewed_src_tensor, viewed_dest_tensor);
- }
+ }
- ntt::tensor_copy_wait();
- reduce_group_sync();
+ reduce_group_sync();
+
+ // all gather
+ for (size_t i = 0; i < group_size - 1; i++) {
+ auto offset = (node_number_g + i + 1) % (group_size);
+ auto src_index_g = ntt::unravel_index(offset % group_size, group_hierarchy);
+ auto src_index = index_group2global(src_index_g, cur_index);
+
+ auto src_tensor = dest.template remote(src_index);
+ auto starts = ntt::generate_shape([&](auto j) {
+ if (j == axis) {
+ return offset * frac;
+ } else {
+ return (dim_t)0;
+ }
+ });
+ auto new_shape = ntt::generate_shape([&](auto j) {
+ if (j == axis) {
+ return ntt::where(offset == group_size - 1, frac + remain, frac);
+ } else {
+ return (dim_t)src.local().shape()[j];
+ }
+ });
+ auto viewed_src_tensor = src_tensor.view(starts, new_shape);
+ auto viewed_dest_tensor = dest.local().view(starts, new_shape);
+ ntt::tensor_copy_async(viewed_src_tensor, viewed_dest_tensor);
+ }
- if (Op == ntt::reduce_op::mean) {
- auto numerator = (element_or_scalar_t)(size_t)group_size;
- ntt::binary(dest, ntt::make_tensor_view_from_address(&numerator, ntt::fixed_shape_v<>), dest);
+ ntt::tensor_copy_wait();
+ reduce_group_sync();
+
+ if (Op == ntt::reduce_op::mean) {
+ auto numerator = (element_or_scalar_t)(size_t)group_size;
+ ntt::binary(dest.local(), ntt::make_tensor_view_from_address(&numerator, ntt::fixed_shape_v<>), dest.local());
+ }
}
}
};
diff --git a/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/LayerNorm.cs b/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/LayerNorm.cs
index 7f7e01d6e9..5ca68de7ff 100644
--- a/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/LayerNorm.cs
+++ b/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/LayerNorm.cs
@@ -2,6 +2,7 @@
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System;
+using System.Diagnostics;
using System.Linq;
using System.Runtime.CompilerServices;
using DryIoc.ImTools;
@@ -47,7 +48,7 @@ public IRType Visit(ITypeInferenceContext context, LayerNorm target)
{
return (input, scale, bias) switch
{
- (DistributedType a, DistributedType b, DistributedType c) => new DistributedType((TensorType)VisitTensorType(target, a.TensorType, b.TensorType, c.TensorType), target.OutSBPs, a.Placement),
+ (DistributedType a, DistributedType b, DistributedType c) => VisitDistributedType(target, a, b, c),
(TensorType a, TensorType b, TensorType c) => VisitTensorType(target, a, b, c),
_ => new InvalidType($"{input} {scale} {bias} not support"),
};
@@ -68,17 +69,17 @@ private bool CheckCustomSBP(IRType input, IRType scale, IRType bias, LayerNorm l
{
if (input is DistributedType a && scale is DistributedType b && bias is DistributedType c)
{
- if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => a.AxisPolicies[i] != layerNorm.InSBPs[i]))
+ if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(a.AxisPolicies[i], layerNorm.InSBPs[i], false)))
{
return false;
}
- if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => b.AxisPolicies[i] != layerNorm.ScaleSBPs[i]))
+ if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(b.AxisPolicies[i], layerNorm.ScaleSBPs[i], false)))
{
return false;
}
- if (Enumerable.Range(0, c.TensorType.Shape.Rank).Any(i => c.AxisPolicies[i] != layerNorm.BiasSBPs[i]))
+ if (Enumerable.Range(0, c.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(c.AxisPolicies[i], layerNorm.BiasSBPs[i], false)))
{
return false;
}
@@ -114,4 +115,24 @@ private IRType VisitTensorType(LayerNorm target, TensorType input, TensorType sc
return new TensorType(target.OutputDataType, input.Shape);
}
}
+
+ private IRType VisitDistributedType(LayerNorm target, DistributedType input, DistributedType scale, DistributedType bias)
+ {
+ var tensorType = (TensorType)VisitTensorType(target, input.TensorType, scale.TensorType, bias.TensorType);
+
+ var ndsbps = new SBP[tensorType.Shape.Rank];
+ for (var i = 0; i < ndsbps.Length; i++)
+ {
+ if (i == target.VectorizedAxes[0] && input.AxisPolicies[i] is SBPSplit split)
+ {
+ ndsbps[i] = SBP.S(split.Axes, split.Granularity is null ? null : split.Granularity * ((VectorType)input.TensorType.DType).Lanes[0] / ((VectorType)tensorType.DType).Lanes[0]);
+ }
+ else
+ {
+ ndsbps[i] = input.AxisPolicies[i];
+ }
+ }
+
+ return new DistributedType(tensorType, ndsbps, input.Placement);
+ }
}
diff --git a/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Matmul.cs b/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Matmul.cs
index 52fda68dfe..8f639988c6 100644
--- a/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Matmul.cs
+++ b/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Matmul.cs
@@ -52,7 +52,7 @@ public IRType Visit(ITypeInferenceContext context, MatMul target)
{
return (lhs, rhs) switch
{
- (DistributedType a, DistributedType b) => new DistributedType((TensorType)VisitTensorType(target, a.TensorType, b.TensorType, true, dimInfo), target.OutSBPs, a.Placement),
+ (DistributedType a, DistributedType b) => VisitDistributedType(target, a, b, true, dimInfo),
(TensorType a, TensorType b) => VisitTensorType(target, a, b, true, dimInfo),
_ => new InvalidType($"{lhs} {rhs} not support"),
};
@@ -78,12 +78,12 @@ private bool CheckCustomSBP(IRType lhs, IRType rhs, IRType extra, MatMul matmul)
if (lhs is DistributedType a && rhs is DistributedType b)
{
- if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => a.AxisPolicies[i] != matmul.LhsSBPs[i]))
+ if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(a.AxisPolicies[i], matmul.LhsSBPs[i], false)))
{
return false;
}
- if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => b.AxisPolicies[i] != matmul.RhsSBPs[i]))
+ if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(b.AxisPolicies[i], matmul.RhsSBPs[i], false)))
{
return false;
}
@@ -156,4 +156,25 @@ private IRType VisitTensorType(MatMul target, TensorType lhs, TensorType rhs, bo
return new TensorType(dtype, front.Concat(end).ToArray());
}
+
+ private IRType VisitDistributedType(MatMul target, DistributedType lhs, DistributedType rhs, bool vectorizeK, MatMulDimInfo dimInfo)
+ {
+ var tensorType = (TensorType)VisitTensorType(target, lhs.TensorType, rhs.TensorType, vectorizeK, dimInfo);
+
+ // FIXME: support rank>=2, and only support vectorize N of output.
+ var policyN = rhs.AxisPolicies[dimInfo!.Rn];
+ if (policyN is SBPSplit split)
+ {
+ policyN = SBP.S(split.Axes, split.Granularity is null ? null : split.Granularity / ((VectorType)tensorType.DType).Lanes[0]);
+ }
+
+ var policyM = lhs.AxisPolicies[dimInfo!.Lm];
+ var ndsbps = (target.TransposeA || target.TransposeB) ? new[] { policyN, policyM } : new[] { policyM, policyN };
+ if (DistributedUtility.AreSamePolicies(ndsbps, target.OutSBPs, false))
+ {
+ return new DistributedType(tensorType, ndsbps, lhs.Placement);
+ }
+
+ return new InvalidType("Please Check SBP Scheme.");
+ }
}
diff --git a/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/SparseExperts.cs b/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/SparseExperts.cs
index 41f14343eb..b77aea008a 100644
--- a/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/SparseExperts.cs
+++ b/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/SparseExperts.cs
@@ -165,25 +165,25 @@ private bool CheckCustomSBP(
{
if (q is DistributedType a && gate is DistributedType b && down is DistributedType c && up is DistributedType d)
{
- if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => a.AxisPolicies[i] != se.QSBPs[i]))
+ if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(a.AxisPolicies[i], se.QSBPs[i], checkGranularity: false)))
{
Console.WriteLine($"[SparseExperts] Q SBP not match: {string.Join(",", a.AxisPolicies.Select(p => p.ToString()))} != {string.Join(",", se.QSBPs.Select(p => p.ToString()))}");
return false;
}
- if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => b.AxisPolicies[i] != se.GateSBPs[i]))
+ if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(b.AxisPolicies[i], se.GateSBPs[i], checkGranularity: false)))
{
Console.WriteLine($"[SparseExperts] Gate SBP not match: {string.Join(",", b.AxisPolicies.Select(p => p.ToString()))} != {string.Join(",", se.GateSBPs.Select(p => p.ToString()))}");
return false;
}
- if (Enumerable.Range(0, c.TensorType.Shape.Rank).Any(i => c.AxisPolicies[i] != se.DownSBPs[i]))
+ if (Enumerable.Range(0, c.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(c.AxisPolicies[i], se.DownSBPs[i], checkGranularity: false)))
{
Console.WriteLine($"[SparseExperts] Down SBP not match: {string.Join(",", c.AxisPolicies.Select(p => p.ToString()))} != {string.Join(",", se.DownSBPs.Select(p => p.ToString()))}");
return false;
}
- if (Enumerable.Range(0, d.TensorType.Shape.Rank).Any(i => d.AxisPolicies[i] != se.UpSBPs[i]))
+ if (Enumerable.Range(0, d.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(d.AxisPolicies[i], se.UpSBPs[i], checkGranularity: false)))
{
Console.WriteLine($"[SparseExperts] Up SBP not match: {string.Join(",", d.AxisPolicies.Select(p => p.ToString()))} != {string.Join(",", se.UpSBPs.Select(p => p.ToString()))}");
return false;
diff --git a/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Unary.cs b/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Unary.cs
index 31f57017fc..bb9f5a1fdb 100644
--- a/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Unary.cs
+++ b/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Unary.cs
@@ -6,6 +6,7 @@
using Nncase.CostModel;
using Nncase.IR;
using Nncase.IR.CustomNTT;
+using Nncase.Utilities;
using OrtKISharp;
namespace Nncase.Evaluator.CustomNTT;
@@ -43,7 +44,7 @@ private bool CheckCustomSBP(IRType input, Unary target)
{
if (input is DistributedType inType)
{
- if (Enumerable.Range(0, inType.TensorType.Shape.Rank).Any(i => inType.AxisPolicies[i] != target.InSBPs[i]))
+ if (Enumerable.Range(0, inType.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(inType.AxisPolicies[i], target.InSBPs[i], false)))
{
return false;
}
diff --git a/modules/Nncase.Modules.NTT/Evaluator/Distributed/Boxing.cs b/modules/Nncase.Modules.NTT/Evaluator/Distributed/Boxing.cs
index 79ab794779..ee6bb6770e 100644
--- a/modules/Nncase.Modules.NTT/Evaluator/Distributed/Boxing.cs
+++ b/modules/Nncase.Modules.NTT/Evaluator/Distributed/Boxing.cs
@@ -27,32 +27,67 @@ IRType VisitD2D(DistributedType inv, DistributedType outv)
if (inv.TensorType != outv.TensorType)
{
- if (!inv.AxisPolicies.Any(sbp => sbp is SBPPartial))
+ if (!inv.AxisPolicies.Any(sbp => sbp is SBPPartial) && inv.Partial is null)
{
return outv;
}
+ else
+ {
+ return new InvalidType("Not Support Partial when shape changes.");
+ }
+ }
+
+ if (inv.AxisPolicies.Any(sbp => sbp is SBPPartial) || outv.AxisPolicies.Any(sbp => sbp is SBPPartial))
+ {
+ return new InvalidType("Not Support Partial in Policeis.");
}
- var ndsbpsA = DistributedUtility.AxisPolicesToNDSBP(inv.AxisPolicies, inv.Placement.Rank);
- var ndsbpsB = DistributedUtility.AxisPolicesToNDSBP(outv.AxisPolicies, outv.Placement.Rank);
- for (int i = 0; i < ndsbpsA.Count; i++)
+ var partialDims = new List();
+ if (inv.Partial is not null)
{
- switch (ndsbpsA[i], ndsbpsB[i])
+ for (int i = 0; i < inv.AxisPolicies.Count; i++)
{
- case (SBPPartial, SBPSplit):
- return new InvalidType("partial to split");
- case (not SBPPartial, SBPPartial):
- return new InvalidType("split/broadcast to partial");
+ if (inv.AxisPolicies[i] is SBPSplit && outv.AxisPolicies[i] is SBPBroadCast)
+ {
+ return new InvalidType("Not supported input is BroadCast output is Split");
+ }
+
+ if (outv.AxisPolicies[i] is SBPSplit s)
+ {
+ if (inv.AxisPolicies[i] is SBPSplit splitIn)
+ {
+ if (splitIn.Axes.Except(s.Axes).Any())
+ {
+ return new InvalidType("Not Supported Split-> Split.");
+ }
+ }
+
+ if (s.Axes.Except(inv.Partial.Axes).ToArray() != s.Axes)
+ {
+ partialDims.Add(i);
+ }
+ }
+ }
+
+ var ndspsIn = DistributedUtility.AxisPolicesToNDSBP(inv.AxisPolicies, inv.Placement.Rank);
+ var ndspsOut = DistributedUtility.AxisPolicesToNDSBP(outv.AxisPolicies, outv.Placement.Rank);
+ if (Enumerable.Range(0, ndspsIn.Count).Any(i => ndspsIn[i] is SBPSplit si && (ndspsOut[i] is SBPBroadCast || (ndspsOut[i] is SBPSplit so && so.Axes[0] != si.Axes[0]))))
+ {
+ return new InvalidType("Not Supported Split-> Broadcast.");
}
}
+ if (partialDims.Count > 0 && !Enumerable.Range(0, inv.AxisPolicies.Count).Except(partialDims.ToArray()).All(i => DistributedUtility.IsSamePolicy(inv.AxisPolicies[i], outv.AxisPolicies[i])))
+ {
+ return new InvalidType("Not Supported Partial.");
+ }
+
return outv;
}
IRType VisitD2T(DistributedType inv, TensorType outv)
{
- var ndsbpsA = DistributedUtility.AxisPolicesToNDSBP(inv.AxisPolicies, inv.Placement.Rank);
- if (ndsbpsA.Any(s => s is SBPPartial))
+ if (inv.AxisPolicies.Any(s => s is SBPPartial) || inv.Partial is not null)
{
return new InvalidType("Not supported input is Partial output is Unshard");
}
@@ -62,8 +97,7 @@ IRType VisitD2T(DistributedType inv, TensorType outv)
IRType VisitT2D(TensorType inv, DistributedType outv)
{
- var ndsbpsB = DistributedUtility.AxisPolicesToNDSBP(outv.AxisPolicies, outv.Placement.Rank);
- if (ndsbpsB.Any(s => s is SBPPartial))
+ if (outv.AxisPolicies.Any(s => s is SBPPartial) || outv.Partial is not null)
{
return new InvalidType("Not supported input is Unshard output is Partial");
}
@@ -308,6 +342,12 @@ public Cost Visit(ICostEvaluateContext context, Boxing target)
case (SBPBroadCast, SBPSplit splitOut):
splitOut.Axes.ToArray().ForEach(s => scatterPart *= hierarchyPenalty[s]);
break;
+ case (SBPPartial, SBPSplit splitOut):
+ // actually partial to split needs gather.
+ break;
+ case (SBPPartial sBPPartial, SBPBroadCast):
+ sBPPartial.Axes.ToArray().ForEach(s => gatherPart *= hierarchyPenalty[s]);
+ break;
default:
throw new NotSupportedException($"{a} to {b}");
}
diff --git a/modules/Nncase.Modules.NTT/Evaluator/Distributed/ForceBoxing.cs b/modules/Nncase.Modules.NTT/Evaluator/Distributed/ForceBoxing.cs
index 9218bd558c..4a0d4a7383 100644
--- a/modules/Nncase.Modules.NTT/Evaluator/Distributed/ForceBoxing.cs
+++ b/modules/Nncase.Modules.NTT/Evaluator/Distributed/ForceBoxing.cs
@@ -21,7 +21,7 @@ IRType VisitD2D(DistributedType inv, DistributedType outv)
var ndsbpsB = DistributedUtility.AxisPolicesToNDSBP(outv.AxisPolicies, outv.Placement.Rank).ToArray();
// TODO: add more invalid cases
- if (ndsbpsA.Distinct().Count() == 1 && ndsbpsB.Distinct().Count() == 1 && ndsbpsA[0] == ndsbpsB[0])
+ if (ndsbpsA.Distinct().Count() == 1 && ndsbpsB.Distinct().Count() == 1 && ndsbpsA[0] == ndsbpsB[0] && inv.Partial == outv.Partial)
{
return new InvalidType("Same NDSBP");
}
@@ -90,7 +90,8 @@ public IValue Visit(IEvaluateContext context, ForceBoxing target)
var inTenor = context.GetArgumentValueAsTensor(target, ForceBoxing.Input);
var input = inTenor.ToOrtTensor();
var output = input - input;
- var repeat = target.NewType.AxisPolicies.Select((x, i) => (x is SBPPartial) ? target.NewType.Placement.Hierarchy[i] : 1).Aggregate(1, (x, i) => x * i);
+ var ndsbps = DistributedUtility.AxisPolicesToNDSBP(target.NewType.AxisPolicies, target.NewType.Placement.Rank).ToArray();
+ var repeat = ndsbps.Select((x, i) => (x is SBPPartial) ? target.NewType.Placement.Hierarchy[i] : 1).Aggregate(1, (x, i) => x * i);
for (int i = 0; i < repeat; i++)
{
output += input;
diff --git a/modules/Nncase.Modules.NTT/Evaluator/NTT/Im2col.cs b/modules/Nncase.Modules.NTT/Evaluator/NTT/Im2col.cs
index 149e881f56..c9c0e8e37b 100644
--- a/modules/Nncase.Modules.NTT/Evaluator/NTT/Im2col.cs
+++ b/modules/Nncase.Modules.NTT/Evaluator/NTT/Im2col.cs
@@ -106,7 +106,6 @@ private IRType Visit(DistributedType dt, Im2col target)
return new InvalidType("im2col typeinfer failed");
}
- var outShape = tensorType.Shape.ToArray();
var ndsbp = new SBP[tensorType.Shape.Rank];
for (int i = 0; i < dt.AxisPolicies.Count; i++)
@@ -117,7 +116,6 @@ private IRType Visit(DistributedType dt, Im2col target)
switch (sbp)
{
case SBPSplit split:
- outShape[1] /= split.Axes.Select(a => dt.Placement.Hierarchy[a]).Aggregate(1, (a, b) => a * b);
ndsbp[1] = split;
break;
case SBPPartial:
@@ -132,7 +130,6 @@ private IRType Visit(DistributedType dt, Im2col target)
switch (sbp)
{
case SBPSplit split:
- outShape[0] /= split.Axes.Select(a => dt.Placement.Hierarchy[a]).Aggregate(1, (a, b) => a * b);
ndsbp[0] = split;
break;
case SBPPartial:
@@ -146,7 +143,7 @@ private IRType Visit(DistributedType dt, Im2col target)
{
switch (sbp)
{
- case SBPSplit split:
+ case SBPSplit:
return new InvalidType($"can't split on {i}");
case SBPPartial:
return new InvalidType($"can't be partial sum");
diff --git a/modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedCast.cs b/modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedCast.cs
index 2fa8111046..753b0afbd4 100644
--- a/modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedCast.cs
+++ b/modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedCast.cs
@@ -134,7 +134,7 @@ private IRType Visit(VectorizedCast target, DistributedType inType)
{
var invalid = new InvalidType(inType.ToString());
var outType = Visit(target, inType.TensorType);
- var ndsbp = new SBP[inType.TensorType.Shape.Rank];
+ var ndsbp = inType.AxisPolicies.ToArray();
var shape = CompilerServices.GetMaxShape(inType.TensorType.Shape);
for (int i = 0; i < ndsbp.Length; i++)
{
@@ -153,10 +153,13 @@ private IRType Visit(VectorizedCast target, DistributedType inType)
{
return invalid;
}
+ else
+ {
+ var scale = 1f * outShape[i] / shape[i];
+ ndsbp[i] = SBP.S(split.Axes, split.Granularity is not null ? (scale >= 1 ? split.Granularity * (long)scale : split.Granularity / (long)(1f / scale)) : null);
+ }
}
}
-
- ndsbp[i] = inType.AxisPolicies[i];
}
return new DistributedType((TensorType)outType, ndsbp, inType.Placement);
diff --git a/modules/Nncase.Modules.NTT/Passes/AffineSelection/MatMul.cs b/modules/Nncase.Modules.NTT/Passes/AffineSelection/MatMul.cs
index b2b1b0c8b4..1cead83cce 100644
--- a/modules/Nncase.Modules.NTT/Passes/AffineSelection/MatMul.cs
+++ b/modules/Nncase.Modules.NTT/Passes/AffineSelection/MatMul.cs
@@ -6,6 +6,7 @@
using Nncase.IR.Affine;
using Nncase.TIR;
using Nncase.TIR.NTT;
+using Nncase.Utilities;
namespace Nncase.Passes;
@@ -27,8 +28,8 @@ public Expr SelectMatMul(Op op, Call call, Expr output)
var dinfo = pmm.GetDimInfo(dta.TensorType.Shape.Rank, dtb.TensorType.Shape.Rank);
if (dta.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) &&
dtb.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) &&
- dta.AxisPolicies[dinfo.Lk] == dtb.AxisPolicies[dinfo.Rn] &&
- dta.AxisPolicies[dinfo.Lm] == dtb.AxisPolicies[dinfo.Rk])
+ DistributedUtility.IsSamePolicy(dta.AxisPolicies[dinfo.Lk], dtb.AxisPolicies[dinfo.Rn], false) &&
+ DistributedUtility.IsSamePolicy(dta.AxisPolicies[dinfo.Lm], dtb.AxisPolicies[dinfo.Rk], false))
{
return call;
}
@@ -42,8 +43,8 @@ public Expr SelectMatMul(Op op, Call call, Expr output)
{
if (dta.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) &&
dtb.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) &&
- dta.AxisPolicies[^2] == dtb.AxisPolicies[^2] &&
- dta.AxisPolicies[^1] == dtb.AxisPolicies[^1])
+ DistributedUtility.IsSamePolicy(dta.AxisPolicies[^2], dtb.AxisPolicies[^2], false) &&
+ DistributedUtility.IsSamePolicy(dta.AxisPolicies[^1], dtb.AxisPolicies[^1], false))
{
return call;
}
@@ -57,8 +58,8 @@ public Expr SelectMatMul(Op op, Call call, Expr output)
{
if (dta.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) &&
dtb.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) &&
- dta.AxisPolicies[^2] == dtb.AxisPolicies[^1] &&
- dta.AxisPolicies[^1] == dtb.AxisPolicies[^2])
+ DistributedUtility.IsSamePolicy(dta.AxisPolicies[^2], dtb.AxisPolicies[^1], false) &&
+ DistributedUtility.IsSamePolicy(dta.AxisPolicies[^1], dtb.AxisPolicies[^2], false))
{
return call;
}
diff --git a/modules/Nncase.Modules.NTT/Passes/NTTTIRSelectionPass.cs b/modules/Nncase.Modules.NTT/Passes/NTTTIRSelectionPass.cs
index 40bf45e71c..552771a4e3 100644
--- a/modules/Nncase.Modules.NTT/Passes/NTTTIRSelectionPass.cs
+++ b/modules/Nncase.Modules.NTT/Passes/NTTTIRSelectionPass.cs
@@ -65,8 +65,8 @@ protected override Expr SelectCall(Call call, IReadOnlyList arguments,
var dinfo = vectorizedMatMul.GetDimInfo(dta.TensorType.Shape.Rank, dtb.TensorType.Shape.Rank);
if (dta.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) &&
dtb.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) &&
- dta.AxisPolicies[dinfo.Lk] == dtb.AxisPolicies[dinfo.Rn] &&
- dta.AxisPolicies[dinfo.Lm] == dtb.AxisPolicies[dinfo.Rk])
+ DistributedUtility.IsSamePolicy(dta.AxisPolicies[dinfo.Lk], dtb.AxisPolicies[dinfo.Rn], false) &&
+ DistributedUtility.IsSamePolicy(dta.AxisPolicies[dinfo.Lm], dtb.AxisPolicies[dinfo.Rk], false))
{
return TIR.F.NTT.SUMMA((Expr)arguments[0], (Expr)arguments[1], output, None.Default, (Expr)call[IR.NTT.VectorizedMatMul.Scale], vectorizedMatMul.LhsVectorizedAxes, vectorizedMatMul.RhsVectorizedAxes, vectorizedMatMul.TransposeA, vectorizedMatMul.TransposeB);
}
@@ -78,8 +78,8 @@ protected override Expr SelectCall(Call call, IReadOnlyList arguments,
case IR.Math.MatMul when GetArgumentType(arguments[0]) is DistributedType dta && GetArgumentType(arguments[1]) is DistributedType dtb:
if (dta.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) &&
dtb.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) &&
- dta.AxisPolicies[^2] == dtb.AxisPolicies[^2] &&
- dta.AxisPolicies[^1] == dtb.AxisPolicies[^1])
+ DistributedUtility.IsSamePolicy(dta.AxisPolicies[^2], dtb.AxisPolicies[^2], false) &&
+ DistributedUtility.IsSamePolicy(dta.AxisPolicies[^1], dtb.AxisPolicies[^1], false))
{
return TIR.F.NTT.SUMMA((Expr)arguments[0], (Expr)arguments[1], output, None.Default, (Expr)call[IR.Math.MatMul.Scale]);
}
@@ -237,7 +237,7 @@ private Expr GenerateReshape(Expr input, ref Expr output, bool sequeeze = false)
var bitcast = outBuffer.DistributedType is DistributedType ? false : true;
if (!bitcast)
{
- if (inBuffer.DistributedType!.AxisPolicies.Where(sbp => sbp is not SBPBroadCast).ToArray().SequenceEqual(outBuffer.DistributedType!.AxisPolicies.Where(sbp => sbp is not SBPBroadCast).ToArray()))
+ if (DistributedUtility.AreSamePolicies(inBuffer.DistributedType!.AxisPolicies.Where(sbp => sbp is not SBPBroadCast).ToArray(), outBuffer.DistributedType!.AxisPolicies.Where(sbp => sbp is not SBPBroadCast).ToArray(), false))
{
bitcast = true;
}
@@ -246,7 +246,19 @@ private Expr GenerateReshape(Expr input, ref Expr output, bool sequeeze = false)
// If the size is not same, we cannot bitcast.
if ((inBuffer.MemSpan.Size == outBuffer.MemSpan.Size) && (bitcast || sequeeze))
{
- output = inBuffer.With(name: outBuffer.Name, elemType: outBuffer.ElemType, dimensions: outBuffer.Dimensions.ToArray(), strides: outBuffer.Strides.ToArray(), distributedType: outBuffer.DistributedType);
+ var newStrides = outBuffer.Strides.ToArray();
+ if (inBuffer.MemSpan.Buffer.Location == MemoryLocation.BlockLocalData && bitcast)
+ {
+ var outType = outBuffer.DistributedType!;
+ var threadAxis = outType.Placement.Rank - 1;
+ int splitAxis = outType.AxisPolicies.ToArray().IndexOf(x => x is SBPSplit split && split.Axes.Contains(threadAxis));
+ if (splitAxis >= 0)
+ {
+ Enumerable.Range(0, splitAxis).ToArray().ForEach(i => newStrides[i] *= outType.Placement.Hierarchy[threadAxis]);
+ }
+ }
+
+ output = inBuffer.With(name: outBuffer.Name, elemType: outBuffer.ElemType, dimensions: outBuffer.Dimensions.ToArray(), strides: newStrides, distributedType: outBuffer.DistributedType);
return T.Nop();
}
else
@@ -288,7 +300,7 @@ private Expr GenerateBitcast(Expr input, ref Expr output, DataType newType)
}
}
- var distributedType = inBuffer.DistributedType is DistributedType dt
+ var distributedType = ((TIR.Buffer)output).DistributedType is DistributedType dt
? dt with { TensorType = new TensorType(newType, newDimensions) }
: null;
output = inBuffer.With(name: ((TIR.Buffer)output).Name, elemType: newType, dimensions: newDimensions, strides: newStrides, distributedType: distributedType);
@@ -325,8 +337,6 @@ private Expr GenerateBoxing(Call call, IR.Distributed.Boxing boxing, IReadOnlyLi
private Expr GenerateReshard(Expr input, ref Expr output, DistributedType inType, DistributedType outType)
{
- // FIXME: re-balance issue.
-#if false
if (input is TIR.Buffer inBuffer)
{
if (TryGenerateGatherThreadsReshard(inBuffer, ref output, inType, outType, out var newCall))
@@ -338,19 +348,24 @@ private Expr GenerateReshard(Expr input, ref Expr output, DistributedType inType
return newCall;
}
}
-#endif
return TIR.F.NTT.GatherReduceScatter(input, output, inType, outType);
}
private bool TryGenerateGatherThreadsReshard(TIR.Buffer inBuffer, ref Expr output, DistributedType inType, DistributedType outType, [MaybeNullWhen(false)] out Expr newCall)
{
+ if (inType.Partial is not null || inBuffer.Users.Any(u => u is Call { Target: TIR.PrimFunction }))
+ {
+ newCall = null;
+ return false;
+ }
+
var threadAxis = inType.Placement.Rank - 1;
PhysicalBuffer? oldPhysicalBuffer = null;
// S -> B
var reducedInPolices = inType.AxisPolicies.Select(sbp => sbp is SBPSplit split && split.Axes.Contains(threadAxis) ? (split.Axes.Count == 1 ? (SBP)SBP.B : SBP.S(split.Axes.Except([threadAxis]).ToArray())) : sbp);
- if (reducedInPolices.ToArray().SequenceEqual(outType.AxisPolicies.ToArray()))
+ if (DistributedUtility.AreSamePolicies(reducedInPolices.ToArray(), outType.AxisPolicies.ToArray(), false))
{
oldPhysicalBuffer = inBuffer.MemSpan.Buffer;
}
@@ -422,7 +437,7 @@ private bool TryGenerateGatherThreadsReshard(TIR.Buffer inBuffer, ref Expr outpu
private bool TryGenerateSplitThreadsReshard(TIR.Buffer inBuffer, ref Expr output, DistributedType inType, DistributedType outType, [MaybeNullWhen(false)] out Expr newCall)
{
// Avoid P -> B -> S
- if (inType.Partial)
+ if (inType.Partial is not null || inBuffer.Users.Any(u => u is Call { Target: TIR.PrimFunction }))
{
newCall = null;
return false;
@@ -434,7 +449,7 @@ private bool TryGenerateSplitThreadsReshard(TIR.Buffer inBuffer, ref Expr output
int splitAxis = -1;
// B -> S
- if (reducedOutPolices.ToArray().SequenceEqual(inType.AxisPolicies.ToArray()))
+ if (DistributedUtility.AreSamePolicies(reducedOutPolices.ToArray(), inType.AxisPolicies.ToArray(), false))
{
oldPhysicalBuffer = inBuffer.MemSpan.Buffer;
splitAxis = outType.AxisPolicies.ToArray().IndexOf(x => x is SBPSplit split && split.Axes.Contains(threadAxis));
diff --git a/modules/Nncase.Modules.NTT/Passes/Rules/Mutators/ToBlockLocalData.cs b/modules/Nncase.Modules.NTT/Passes/Rules/Mutators/ToBlockLocalData.cs
new file mode 100644
index 0000000000..50caa14920
--- /dev/null
+++ b/modules/Nncase.Modules.NTT/Passes/Rules/Mutators/ToBlockLocalData.cs
@@ -0,0 +1,107 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Reactive;
+using Nncase.IR;
+using Nncase.TIR;
+
+namespace Nncase.Passes.Mutators;
+
+///
+/// Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional Load/Store.
+///
+public sealed class ToBlockLocalData : ExprRewriter
+{
+ private readonly Dictionary _bufferMemo = new(ReferenceEqualityComparer.Instance);
+
+ protected override BaseExpr VisitSequential(Sequential expr, Unit context)
+ {
+ for (int i = 0; i < expr.Fields.Length - 1; i++)
+ {
+ {
+ if (expr.Fields[i] is Call { Target: TIR.NTT.VectorizedLayerNorm ln } lnCall
+ && expr.Fields[i + 1] is Call { Target: TIR.NTT.Matmul mm } mmCall
+ && ln.CSourcePath is not null && mm.CSourcePath is not null)
+ {
+ var lnOutput = lnCall.Arguments[TIR.NTT.VectorizedLayerNorm.Output.Index];
+ var mmLhs = mmCall.Arguments[TIR.NTT.Matmul.Lhs.Index];
+ if (lnOutput is TIR.Buffer a && mmLhs is TIR.Buffer b && a.MemSpan.Buffer == b.MemSpan.Buffer && b.MemSpan.Buffer.Location == MemoryLocation.Data)
+ {
+ var newPhysicalBuffer = a.MemSpan.Buffer.With(location: MemoryLocation.BlockLocalData);
+ _bufferMemo.TryAdd(lnOutput, newPhysicalBuffer);
+ _bufferMemo.TryAdd(mmLhs, newPhysicalBuffer);
+ }
+ }
+ }
+
+ {
+ if (expr.Fields[i] is Call { Target: TIR.NTT.Matmul mm } mmCall
+ && expr.Fields[i + 1] is Call { Target: TIR.NTT.VectorizedLayerNorm ln } lnCall
+ && ln.CSourcePath is not null && mm.CSourcePath is not null)
+ {
+ var lnInput = lnCall.Arguments[TIR.NTT.VectorizedLayerNorm.Input.Index];
+ var mmOutput = mmCall.Arguments[TIR.NTT.Matmul.Output.Index];
+ if (lnInput is TIR.Buffer a && mmOutput is TIR.Buffer b && a.MemSpan.Buffer == b.MemSpan.Buffer && b.MemSpan.Buffer.Location == MemoryLocation.Data)
+ {
+ var newPhysicalBuffer = a.MemSpan.Buffer.With(location: MemoryLocation.BlockLocalData);
+ _bufferMemo.TryAdd(lnInput, newPhysicalBuffer);
+ _bufferMemo.TryAdd(mmOutput, newPhysicalBuffer);
+ }
+ }
+ }
+
+ {
+ if (expr.Fields[i] is Call { Target: TIR.NTT.Matmul mm } mmCall
+ && expr.Fields[i + 1] is Call { Target: TIR.NTT.SynchronizeThreads }
+ && expr.Fields[i + 2] is Call { Target: TIR.NTT.UpdatePagedAttentionKVCache }
+ && expr.Fields[i + 3] is Call { Target: TIR.NTT.UpdatePagedAttentionKVCache } upkvCall2
+ && mm.CSourcePath is not null)
+ {
+ var mmOutput = mmCall.Arguments[TIR.NTT.Matmul.Output.Index];
+ var upkvInput = upkvCall2.Arguments[TIR.NTT.UpdatePagedAttentionKVCache.Slots.Index];
+ if (upkvInput is TIR.Buffer a && mmOutput is TIR.Buffer b && a.MemSpan.Buffer == b.MemSpan.Buffer && b.MemSpan.Buffer.Location == MemoryLocation.Data)
+ {
+ var newPhysicalBuffer = a.MemSpan.Buffer.With(location: MemoryLocation.BlockLocalData);
+ _bufferMemo.TryAdd(upkvInput, newPhysicalBuffer);
+ _bufferMemo.TryAdd(mmOutput, newPhysicalBuffer);
+ }
+ }
+ }
+
+ // FIXME: need to add sync after primfunc.
+ // {
+ // if (expr.Fields[i] is Call { Target: TIR.NTT.Matmul mm } mmCall
+ // && expr.Fields[i + 1] is Call { Target: TIR.PrimFunction } fnCall
+ // && mm.CSourcePath is not null)
+ // {
+ // var mmOutput = mmCall.Arguments[TIR.NTT.Matmul.Output.Index];
+ // foreach (var arg in fnCall.Arguments)
+ // {
+ // if (arg is TIR.Buffer a && mmOutput is TIR.Buffer b && a.MemSpan.Buffer == b.MemSpan.Buffer && b.MemSpan.Buffer.Location == MemoryLocation.Data)
+ // {
+ // var newPhysicalBuffer = a.MemSpan.Buffer.With(location: MemoryLocation.BlockLocalData);
+ // _bufferMemo.TryAdd(arg, newPhysicalBuffer);
+ // _bufferMemo.TryAdd(mmOutput, newPhysicalBuffer);
+ // }
+ // }
+ // }
+ // }
+ }
+
+ return base.VisitSequential(expr, context);
+ }
+
+ protected override BaseExpr RewriteLeafBuffer(TIR.Buffer expr)
+ {
+ if (_bufferMemo.TryGetValue(expr, out var buffer))
+ {
+ return expr.With(memSpan: expr.MemSpan.With(buffer: buffer));
+ }
+
+ return expr;
+ }
+}
diff --git a/modules/Nncase.Modules.NTT/Targets/CPUTarget.cs b/modules/Nncase.Modules.NTT/Targets/CPUTarget.cs
index 04312ec46d..2bb9d04d09 100644
--- a/modules/Nncase.Modules.NTT/Targets/CPUTarget.cs
+++ b/modules/Nncase.Modules.NTT/Targets/CPUTarget.cs
@@ -119,6 +119,11 @@ public override void RegisterAutoVectorizeRules(IRulesAddable pass, CompileOptio
public override void RegisterTIRSelectionPass(IPassManager passManager, CompileOptions optionsÍ)
{
passManager.Add();
+ passManager.AddWithName("ToBlockLocalData").Configure(p =>
+ {
+ p.Add();
+ p.Add();
+ });
}
public override void RegisterPostAutoVectorizePass(IPassManager passManager, CompileOptions options)
diff --git a/ntt/include/nncase/ntt/arch/cpu/remote_tensor.h b/ntt/include/nncase/ntt/arch/cpu/remote_tensor.h
index f339dabd13..392abfe57d 100644
--- a/ntt/include/nncase/ntt/arch/cpu/remote_tensor.h
+++ b/ntt/include/nncase/ntt/arch/cpu/remote_tensor.h
@@ -20,7 +20,10 @@
namespace nncase::ntt::distributed {
namespace detail {
extern decltype(nncase::ntt::make_tensor>(
- nncase::ntt::distributed::topology_shape)) global_local_data_ptr;
+ nncase::ntt::distributed::topology_shape)) global_thread_local_data_ptr;
+
+extern decltype(nncase::ntt::make_tensor>(
+ nncase::ntt::distributed::topology_shape)) global_block_local_data_ptr;
extern decltype(nncase::ntt::make_tensor>(
nncase::ntt::distributed::topology_shape)) global_thread_local_rdata_ptr;
@@ -37,9 +40,9 @@ template = end) {
start = (size_t)global_thread_local_rdata_ptr(local_program_ids)(0_dim);
end = (size_t)global_thread_local_rdata_ptr(local_program_ids)(1_dim);
@@ -48,8 +51,15 @@ static auto get_remote_address(const TLocalProgramIds &local_program_ids,
if ((uintptr_t)local_address < start ||
(uintptr_t)local_address >= end) {
start = (size_t)global_block_local_rdata_ptr(local_program_ids)(0_dim);
+ end = (size_t)global_block_local_rdata_ptr(local_program_ids)(1_dim);
remote_address =
(size_t)global_block_local_rdata_ptr(remote_program_ids)(0_dim);
+ if ((uintptr_t)local_address < start ||
+ (uintptr_t)local_address >= end) {
+ start = (size_t)global_block_local_data_ptr(local_program_ids)(0_dim);
+ remote_address =
+ (size_t)global_block_local_data_ptr(remote_program_ids)(0_dim);
+ }
}
}
diff --git a/ntt/include/nncase/ntt/arch/xpu/remote_tensor.h b/ntt/include/nncase/ntt/arch/xpu/remote_tensor.h
index 354d5abb4c..751483084d 100644
--- a/ntt/include/nncase/ntt/arch/xpu/remote_tensor.h
+++ b/ntt/include/nncase/ntt/arch/xpu/remote_tensor.h
@@ -25,7 +25,10 @@ namespace nncase::ntt::distributed {
namespace detail {
#if not defined(SYS_MODE)
extern decltype(nncase::ntt::make_tensor>(
- nncase::ntt::distributed::topology_shape)) global_local_data_ptr;
+ nncase::ntt::distributed::topology_shape)) global_thread_local_data_ptr;
+
+extern decltype(nncase::ntt::make_tensor>(
+ nncase::ntt::distributed::topology_shape)) global_block_local_data_ptr;
extern decltype(nncase::ntt::make_tensor>(
nncase::ntt::distributed::topology_shape)) global_thread_local_rdata_ptr;
@@ -46,9 +49,9 @@ static auto get_remote_address(const TLocalProgramIds &local_program_ids,
const TRemoteProgramIds &remote_program_ids,
T *local_address) {
#if not defined(SYS_MODE)
- auto start = global_local_data_ptr(local_program_ids)(0_dim);
- auto end = global_local_data_ptr(local_program_ids)(1_dim);
- auto remote_address = global_local_data_ptr(remote_program_ids)(0_dim);
+ auto start = global_thread_local_data_ptr(local_program_ids)(0_dim);
+ auto end = global_thread_local_data_ptr(local_program_ids)(1_dim);
+ auto remote_address = global_thread_local_data_ptr(remote_program_ids)(0_dim);
if ((uintptr_t)local_address < start || (uintptr_t)local_address >= end) {
start = global_thread_local_rdata_ptr(local_program_ids)(0_dim);
end = global_thread_local_rdata_ptr(local_program_ids)(1_dim);
@@ -57,8 +60,15 @@ static auto get_remote_address(const TLocalProgramIds &local_program_ids,
if ((uintptr_t)local_address < start ||
(uintptr_t)local_address >= end) {
start = global_block_local_rdata_ptr(local_program_ids)(0_dim);
+ end = global_block_local_rdata_ptr(local_program_ids)(1_dim);
remote_address =
global_block_local_rdata_ptr(remote_program_ids)(0_dim);
+ if ((uintptr_t)local_address < start ||
+ (uintptr_t)local_address >= end) {
+ start = global_block_local_data_ptr(local_program_ids)(0_dim);
+ remote_address =
+ global_block_local_data_ptr(remote_program_ids)(0_dim);
+ }
}
}
diff --git a/ntt/include/nncase/ntt/caching.h b/ntt/include/nncase/ntt/caching.h
index 6108f5ed16..cee33abf7e 100644
--- a/ntt/include/nncase/ntt/caching.h
+++ b/ntt/include/nncase/ntt/caching.h
@@ -145,7 +145,7 @@ template class attention_kv_cache {
namespace detail {
template
constexpr auto
-kv_dim(const distributed::shard_policy::split &split) noexcept {
+kv_dim(distributed::shard_policy::split split) noexcept {
return split.template divider();
}
@@ -184,8 +184,8 @@ constexpr auto origin_kv_cache_one_block_shape() noexcept {
// num_blocks is not sharded, so we just return
return last_shape;
} else {
- auto dim =
- axis_policy_t::template try_shard_dim_without_shard_index<
+ auto policy = axis_policy_t{};
+ auto dim = policy.template try_shard_dim_without_shard_index<
Mesh>(last_shape[sharding_axis]);
static_assert(dim != -1_dim,
"Only uniform shard dim is supported.");
diff --git a/ntt/include/nncase/ntt/distributed/sharding.h b/ntt/include/nncase/ntt/distributed/sharding.h
index 13919a27b1..55e3327551 100644
--- a/ntt/include/nncase/ntt/distributed/sharding.h
+++ b/ntt/include/nncase/ntt/distributed/sharding.h
@@ -54,17 +54,17 @@ struct broadcast {
inline constexpr broadcast B;
// Split
-template struct split {
+template struct split {
+ TGranularity granularity;
static constexpr auto axes = fixed_shape_v;
- template static constexpr auto divider() {
+ template constexpr auto divider() {
return Mesh::shape.select(axes).length();
}
template TShardIndex>
- static constexpr auto
- global_offset(const TDim &global_dim,
- const TShardIndex &shard_index) noexcept {
+ constexpr auto global_offset(const TDim &global_dim,
+ const TShardIndex &shard_index) noexcept {
constexpr auto submesh_shape =
fixed_shape_v]...>;
auto subshard_index = make_shape(shard_index[fixed_dim_v]...);
@@ -76,16 +76,21 @@ template struct split {
}
template
- static constexpr auto
+ constexpr auto
try_shard_dim_without_shard_index(const TDim &global_dim) noexcept {
- const auto remainder = global_dim % divider();
- return ntt::where(remainder == dim_zero, global_dim / divider(),
- -1_dim);
+ if constexpr (std::is_same_v) {
+ const auto remainder = global_dim % divider();
+ return ntt::where(remainder == dim_zero,
+ global_dim / divider(), -1_dim);
+ } else {
+ return ntt::where(global_dim == granularity * divider(),
+ granularity, -1_dim);
+ }
}
template TShardIndex>
- static constexpr auto shard_dim(const TDim &global_dim,
- const TShardIndex &shard_index) noexcept {
+ constexpr auto shard_dim(const TDim &global_dim,
+ const TShardIndex &shard_index) noexcept {
const auto shard_dim_v =
try_shard_dim_without_shard_index(global_dim);
return ntt::where(shard_dim_v != -1_dim, shard_dim_v, [&] {
@@ -96,20 +101,31 @@ template struct split {
}
template
- static constexpr auto max_shard_dim(const TDim &global_dim) noexcept {
- return ntt::ceil_div(global_dim, divider());
+ constexpr auto max_shard_dim(const TDim &global_dim) noexcept {
+ if constexpr (std::is_same_v) {
+ return ntt::ceil_div(global_dim, divider());
+ } else {
+ return granularity;
+ }
}
+
+ constexpr split() requires(std::is_same_v)
+ : granularity(nullptr) {}
+
+ constexpr split(TGranularity granularity_) : granularity(granularity_) {}
};
-template constexpr auto S() noexcept {
- return split{};
+template
+constexpr auto S(TGranularity granularity = nullptr) noexcept {
+ return split{granularity};
}
} // namespace shard_policy
template struct is_split_shard_policy : std::false_type {};
-template
-struct is_split_shard_policy> : std::true_type {};
+template
+struct is_split_shard_policy>
+ : std::true_type {};
template
concept SplitShardPolicy = is_split_shard_policy::value;
@@ -134,9 +150,9 @@ template struct sharding {
global_offset(const GlobalShape &global_shape,
const TShardIndex &shard_index) const noexcept {
auto get_dim = [&, this] {
- return std::get(axis_policies)
- .template global_offset(global_shape[fixed_dim_v],
- shard_index);
+ auto policy = std::get(axis_policies);
+ return policy.template global_offset(
+ global_shape[fixed_dim_v], shard_index);
};
auto get_all_dims = [&](std::index_sequence) {
return make_shape(get_dim.template operator()()...);
@@ -148,9 +164,9 @@ template struct sharding {
constexpr auto shard_shape(const GlobalShape &global_shape,
const TShardIndex &shard_index) const noexcept {
auto get_dim = [&, this] {
- return std::get(axis_policies)
- .template shard_dim(global_shape[fixed_dim_v],
- shard_index);
+ auto policy = std::get(axis_policies);
+ return policy.template shard_dim(
+ global_shape[fixed_dim_v], shard_index);
};
auto get_all_dims = [&](std::index_sequence) {
return make_shape(get_dim.template operator()()...);
@@ -235,12 +251,14 @@ constexpr auto tensor_axes_of_non_split_shard_policies() noexcept {
template
constexpr auto local_shard_dim(const TSharding &sharding,
const GlobalShape &global_shape) noexcept {
- static_assert(GlobalShape::rank() == TSharding::rank(), "Invalid sharding.");
+ static_assert(GlobalShape::rank() == TSharding::rank(),
+ "Invalid sharding.");
using mesh_type = typename TSharding::mesh_type;
const auto local_index = mesh_type::local_index();
- return std::get(sharding.axis_policies)
- .template shard_dim(global_shape[fixed_dim_v], local_index);
+ auto policy = std::get(sharding.axis_policies);
+ return policy.template shard_dim(
+ global_shape[fixed_dim_v], local_index);
}
template
diff --git a/ntt/include/nncase/ntt/kernels/paged_attention.h b/ntt/include/nncase/ntt/kernels/paged_attention.h
index d109d965bb..da242e48a9 100644
--- a/ntt/include/nncase/ntt/kernels/paged_attention.h
+++ b/ntt/include/nncase/ntt/kernels/paged_attention.h
@@ -150,7 +150,7 @@ constexpr void update_paged_attention_kv_cache(const TSlots &slots_tensor,
.squeeze(local_slots_squeeze);
// process kv_head different sharding on slot and kv cache.
- const auto kv_head_policy = config_t::template axis_policy<
+ auto kv_head_policy = config_t::template axis_policy<
caching::paged_kvcache_dim_kind::num_kv_heads>();
const auto global_head_id =
slots_global_offset[head_index] + local_head_id;
diff --git a/ntt/include/nncase/ntt/kernels/reshard.h b/ntt/include/nncase/ntt/kernels/reshard.h
index 3123591590..8b47bc6952 100644
--- a/ntt/include/nncase/ntt/kernels/reshard.h
+++ b/ntt/include/nncase/ntt/kernels/reshard.h
@@ -120,7 +120,7 @@ struct reshard_impl {
[&](auto last_acc, auto global_dim, auto axis) {
auto [last_global_offset, last_local_offset, last_shape] =
last_acc;
- const auto policy =
+ auto policy =
std::get(src.sharding().axis_policies);
if constexpr (distributed::SplitShardPolicy<
std::decay_t>) {
@@ -330,16 +330,16 @@ struct reshard_impl {
"Cannot reshard between different mesh types.");
constexpr void operator()(const SrcTensor &src, DestTensor &dest) noexcept {
- if constexpr (std::is_same_v) {
- if (src.shape() == dest.shape()) {
- // make sure src ready.
- distributed::topology_synchronize();
- overlap_aware_reshard(src, dest);
- distributed::topology_synchronize();
- return;
- }
- }
+ // if constexpr (std::is_same_v) {
+ // if (src.shape() == dest.shape()) {
+ // // make sure src ready.
+ // distributed::topology_synchronize();
+ // overlap_aware_reshard(src, dest);
+ // distributed::topology_synchronize();
+ // return;
+ // }
+ // }
copy_to_global(src);
copy_from_global(dest);
@@ -388,7 +388,7 @@ struct reshard_impl {
std::array counts{};
auto get_coord = [&]() {
- const auto policy = std::get(src.sharding().axis_policies);
+ auto policy = std::get(src.sharding().axis_policies);
if constexpr (distributed::SplitShardPolicy>) {
size_t num_blocks = 1;
constexpr auto policy_rank = policy.axes.rank();
diff --git a/ntt/src/cpu_runtime.cpp b/ntt/src/cpu_runtime.cpp
index 90409b5d50..e57f41b008 100644
--- a/ntt/src/cpu_runtime.cpp
+++ b/ntt/src/cpu_runtime.cpp
@@ -38,7 +38,13 @@ using namespace nncase::ntt::runtime;
decltype(nncase::ntt::make_tensor>(
nncase::ntt::distributed::topology_shape))
- nncase::ntt::distributed::detail::global_local_data_ptr =
+ nncase::ntt::distributed::detail::global_thread_local_data_ptr =
+ nncase::ntt::make_tensor>(
+ nncase::ntt::distributed::topology_shape);
+
+decltype(nncase::ntt::make_tensor>(
+ nncase::ntt::distributed::topology_shape))
+ nncase::ntt::distributed::detail::global_block_local_data_ptr =
nncase::ntt::make_tensor>(
nncase::ntt::distributed::topology_shape);
@@ -194,11 +200,16 @@ extern "C" void block_entry(const cpu_block_entry_params_t ¶ms) {
}
}
- ntt::distributed::detail::global_local_data_ptr(program_ids)(
+ ntt::distributed::detail::global_thread_local_data_ptr(program_ids)(
0_dim) = (uintptr_t)thread_local_data.data();
- ntt::distributed::detail::global_local_data_ptr(program_ids)(
+ ntt::distributed::detail::global_thread_local_data_ptr(program_ids)(
1_dim) = (uintptr_t)(thread_local_data.data() +
thread_local_data.size_bytes());
+ ntt::distributed::detail::global_block_local_data_ptr(program_ids)(
+ 0_dim) = (uintptr_t)block_local_data.data();
+ ntt::distributed::detail::global_block_local_data_ptr(program_ids)(
+ 1_dim) = (uintptr_t)(block_local_data.data() +
+ block_local_data.size_bytes());
ntt::distributed::detail::global_block_local_rdata_ptr(program_ids)(
0_dim) = (uintptr_t)params.block_local_rdata.data();
ntt::distributed::detail::global_block_local_rdata_ptr(program_ids)(
diff --git a/ntt/test/ctest/test_ntt_reshard.cpp b/ntt/test/ctest/test_ntt_reshard.cpp
index c0d201b5ad..b3c08119d4 100644
--- a/ntt/test/ctest/test_ntt_reshard.cpp
+++ b/ntt/test/ctest/test_ntt_reshard.cpp
@@ -51,7 +51,7 @@ using namespace ortki;
decltype(nncase::ntt::make_tensor>(
nncase::ntt::distributed::topology_shape))
- nncase::ntt::distributed::detail::global_local_data_ptr =
+ nncase::ntt::distributed::detail::global_thread_local_data_ptr =
nncase::ntt::make_tensor>(
nncase::ntt::distributed::topology_shape);
@@ -145,13 +145,13 @@ TEST(CpuTest, reshard_2D_same_sharding_spec_broadcast) {
constexpr size_t num = cdims * bdims * tdims;
ntt::apply(
- ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) =
(uintptr_t)(ntt::runtime::thread_alloc(M * N * sizeof(float),
8));
- ntt::distributed::detail::global_local_data_ptr(index)(1_dim) =
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) +
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) +
M * N * sizeof(float);
});
@@ -199,7 +199,7 @@ TEST(CpuTest, reshard_2D_same_sharding_spec_broadcast) {
const auto program_ids = make_shape(cid, bid, tid);
float *local_data_src = reinterpret_cast(
- ((size_t)ntt::distributed::detail::global_local_data_ptr(program_ids)(
+ ((size_t)ntt::distributed::detail::global_thread_local_data_ptr(program_ids)(
0_dim)));
auto local_data_dst = reinterpret_cast(
nncase::ntt::runtime::thread_alloc(M * N * sizeof(float), 8));
@@ -238,10 +238,10 @@ TEST(CpuTest, reshard_2D_same_sharding_spec_broadcast) {
t.join();
- ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
thread_free(
- (void *)(size_t)ntt::distributed::detail::global_local_data_ptr(
+ (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr(
index)(0_dim));
});
}
@@ -268,13 +268,13 @@ TEST(CpuTest, reshard_2D_same_sharding_sepc_split) {
constexpr size_t num = cdims * bdims * tdims;
ntt::apply(
- ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) =
(uintptr_t)(ntt::runtime::thread_alloc(M * N * sizeof(float),
8));
- ntt::distributed::detail::global_local_data_ptr(index)(1_dim) =
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) +
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) +
M * N * sizeof(float);
});
@@ -323,7 +323,7 @@ TEST(CpuTest, reshard_2D_same_sharding_sepc_split) {
const auto program_ids = make_shape(cid, bid, tid);
float *local_data_src = reinterpret_cast(
- ((size_t)ntt::distributed::detail::global_local_data_ptr(program_ids)(
+ ((size_t)ntt::distributed::detail::global_thread_local_data_ptr(program_ids)(
0_dim)));
auto local_data_dst = reinterpret_cast(
nncase::ntt::runtime::thread_alloc(M * N * sizeof(float), 8));
@@ -359,10 +359,10 @@ TEST(CpuTest, reshard_2D_same_sharding_sepc_split) {
for (auto &t : threads)
t.join();
- ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
thread_free(
- (void *)(size_t)ntt::distributed::detail::global_local_data_ptr(
+ (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr(
index)(0_dim));
});
}
@@ -394,13 +394,13 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_broadcast_split) {
constexpr size_t num = cdims * bdims * tdims;
ntt::apply(
- ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) =
(uintptr_t)(ntt::runtime::thread_alloc(M * N * sizeof(float),
8));
- ntt::distributed::detail::global_local_data_ptr(index)(1_dim) =
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) +
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) +
M * N * sizeof(float);
});
@@ -450,7 +450,7 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_broadcast_split) {
const auto program_ids = make_shape(cid, bid, tid);
float *local_data_src = reinterpret_cast(
- ((size_t)ntt::distributed::detail::global_local_data_ptr(program_ids)(
+ ((size_t)ntt::distributed::detail::global_thread_local_data_ptr(program_ids)(
0_dim)));
auto local_data_dst = reinterpret_cast(
nncase::ntt::runtime::thread_alloc(M * N * sizeof(float), 8));
@@ -492,10 +492,10 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_broadcast_split) {
for (auto &t : threads)
t.join();
- ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
thread_free(
- (void *)(size_t)ntt::distributed::detail::global_local_data_ptr(
+ (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr(
index)(0_dim));
});
}
@@ -524,13 +524,13 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_split_broadcast) {
constexpr size_t num = cdims * bdims * tdims;
ntt::apply(
- ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) =
(uintptr_t)(ntt::runtime::thread_alloc(M * N * sizeof(float),
8));
- ntt::distributed::detail::global_local_data_ptr(index)(1_dim) =
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) +
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) +
M * N * sizeof(float);
});
@@ -578,7 +578,7 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_split_broadcast) {
const auto program_ids = make_shape(cid, bid, tid);
float *local_data_src = reinterpret_cast(
- ((size_t)ntt::distributed::detail::global_local_data_ptr(program_ids)(
+ ((size_t)ntt::distributed::detail::global_thread_local_data_ptr(program_ids)(
0_dim)));
auto local_data_dst = reinterpret_cast(
nncase::ntt::runtime::thread_alloc(M * N * sizeof(float), 8));
@@ -617,10 +617,10 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_split_broadcast) {
for (auto &t : threads)
t.join();
- ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
thread_free(
- (void *)(size_t)ntt::distributed::detail::global_local_data_ptr(
+ (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr(
index)(0_dim));
});
}
@@ -659,13 +659,13 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_different_split_axis) {
constexpr size_t num = cdims * bdims * tdims;
ntt::apply(
- ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) =
(uintptr_t)(ntt::runtime::thread_alloc(M * N * sizeof(float),
8));
- ntt::distributed::detail::global_local_data_ptr(index)(1_dim) =
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) +
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) +
M * N * sizeof(float);
});
@@ -716,7 +716,7 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_different_split_axis) {
const auto program_ids = make_shape(cid, bid, tid);
float *local_data_src = reinterpret_cast(
- (size_t)(ntt::distributed::detail::global_local_data_ptr(program_ids)(
+ (size_t)(ntt::distributed::detail::global_thread_local_data_ptr(program_ids)(
0_dim)));
auto local_data_dst = reinterpret_cast(
nncase::ntt::runtime::thread_alloc(M * N * sizeof(float), 8));
@@ -759,10 +759,10 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_different_split_axis) {
EXPECT_TRUE(NttTest::compare_tensor(ntt_input, ntt_output));
- ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
thread_free(
- (void *)(size_t)ntt::distributed::detail::global_local_data_ptr(
+ (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr(
index)(0_dim));
});
}
@@ -792,13 +792,13 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_non_divisible_SB2BS) {
constexpr size_t num = cdims * bdims * tdims;
ntt::apply(
- ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) =
(uintptr_t)(ntt::runtime::thread_alloc(M * N * sizeof(float),
8));
- ntt::distributed::detail::global_local_data_ptr(index)(1_dim) =
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) +
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) +
M * N * sizeof(float);
});
@@ -849,7 +849,7 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_non_divisible_SB2BS) {
const auto program_ids = make_shape(cid, bid, tid);
float *local_data_src = reinterpret_cast(
- (size_t)(ntt::distributed::detail::global_local_data_ptr(program_ids)(
+ (size_t)(ntt::distributed::detail::global_thread_local_data_ptr(program_ids)(
0_dim)));
auto local_data_dst = reinterpret_cast(
nncase::ntt::runtime::thread_alloc(M * N * sizeof(float), 8));
@@ -892,10 +892,10 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_non_divisible_SB2BS) {
EXPECT_TRUE(NttTest::compare_tensor(ntt_input, ntt_output));
- ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
thread_free(
- (void *)(size_t)ntt::distributed::detail::global_local_data_ptr(
+ (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr(
index)(0_dim));
});
}
@@ -926,13 +926,13 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_non_divisible_BS2SB) {
constexpr size_t num = cdims * bdims * tdims;
ntt::apply(
- ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) =
(uintptr_t)(ntt::runtime::thread_alloc(M * N * sizeof(float),
8));
- ntt::distributed::detail::global_local_data_ptr(index)(1_dim) =
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) +
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) +
M * N * sizeof(float);
});
@@ -983,7 +983,7 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_non_divisible_BS2SB) {
const auto program_ids = make_shape(cid, bid, tid);
float *local_data_src = reinterpret_cast(
- (size_t)(ntt::distributed::detail::global_local_data_ptr(program_ids)(
+ (size_t)(ntt::distributed::detail::global_thread_local_data_ptr(program_ids)(
0_dim)));
auto local_data_dst = reinterpret_cast(
nncase::ntt::runtime::thread_alloc(M * N * sizeof(float), 8));
@@ -1026,10 +1026,10 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_non_divisible_BS2SB) {
EXPECT_TRUE(NttTest::compare_tensor(ntt_input, ntt_output));
- ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
thread_free(
- (void *)(size_t)ntt::distributed::detail::global_local_data_ptr(
+ (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr(
index)(0_dim));
});
}
@@ -1059,13 +1059,13 @@ TEST(CpuTest, reshard_2D_split_multilple_axes) {
constexpr size_t num = cdims * bdims * tdims;
ntt::apply(
- ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) =
(uintptr_t)(ntt::runtime::thread_alloc(M * N * sizeof(float),
8));
- ntt::distributed::detail::global_local_data_ptr(index)(1_dim) =
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) +
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) +
M * N * sizeof(float);
});
@@ -1115,7 +1115,7 @@ TEST(CpuTest, reshard_2D_split_multilple_axes) {
const auto program_ids = make_shape(cid, bid, tid);
float *local_data_src = reinterpret_cast(
- (size_t)(ntt::distributed::detail::global_local_data_ptr(program_ids)(
+ (size_t)(ntt::distributed::detail::global_thread_local_data_ptr(program_ids)(
0_dim)));
auto local_data_dst = reinterpret_cast(
nncase::ntt::runtime::thread_alloc(M * N * sizeof(float), 8));
@@ -1158,10 +1158,10 @@ TEST(CpuTest, reshard_2D_split_multilple_axes) {
EXPECT_TRUE(NttTest::compare_tensor(ntt_input, ntt_output));
- ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
thread_free(
- (void *)(size_t)ntt::distributed::detail::global_local_data_ptr(
+ (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr(
index)(0_dim));
});
}
@@ -1192,13 +1192,13 @@ TEST(CpuTest, reshard_3D_different_sharding_spec_different_split_axis) {
constexpr size_t num = cdims * bdims * tdims;
ntt::apply(
- ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) =
(uintptr_t)(ntt::runtime::thread_alloc(M * N * K * sizeof(float),
8));
- ntt::distributed::detail::global_local_data_ptr(index)(1_dim) =
- ntt::distributed::detail::global_local_data_ptr(index)(0_dim) +
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) =
+ ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) +
M * N * K *sizeof(float);
});
@@ -1251,7 +1251,7 @@ TEST(CpuTest, reshard_3D_different_sharding_spec_different_split_axis) {
const auto program_ids = make_shape(cid, bid, tid);
float *local_data_src = reinterpret_cast(
- (size_t)(ntt::distributed::detail::global_local_data_ptr(program_ids)(
+ (size_t)(ntt::distributed::detail::global_thread_local_data_ptr(program_ids)(
0_dim)));
auto local_data_dst = reinterpret_cast(
nncase::ntt::runtime::thread_alloc(M * N * K * sizeof(float), 8));
@@ -1297,10 +1297,10 @@ TEST(CpuTest, reshard_3D_different_sharding_spec_different_split_axis) {
EXPECT_TRUE(NttTest::compare_tensor(ntt_input, ntt_output));
- ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(),
+ ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(),
[&](auto index) {
thread_free(
- (void *)(size_t)ntt::distributed::detail::global_local_data_ptr(
+ (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr(
index)(0_dim));
});
}
@@ -1416,4 +1416,4 @@ TEST(CpuTest, reshard_reshape) {
int main(int argc, char *argv[]) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
-}
\ No newline at end of file
+}
diff --git a/src/Native/src/test.cpp b/src/Native/src/test.cpp
index c8298c3ee7..7afd028ab7 100644
--- a/src/Native/src/test.cpp
+++ b/src/Native/src/test.cpp
@@ -189,6 +189,29 @@ void test_sharding() {
ntt::distributed::detail::mesh_axes_of_non_split_shard_policies<
sharding_type>() == ntt::fixed_shape_v<0, 1>);
}
+
+ // Sharing split with S<2>
+ {
+ ntt::dim_t seq_length = 122;
+ [[maybe_unused]] float *xx = new float[seq_length];
+ [[maybe_unused]] auto sp = ntt::distributed::shard_policy::S<0>(16_dim);
+ [[maybe_unused]] auto sb = ntt::distributed::shard_policy::B;
+ [[maybe_unused]] auto sharding = ntt::distributed::make_sharding<
+ ntt::distributed::mesh>(
+ ntt::distributed::shard_policy::S<0>(16_dim),
+ ntt::distributed::shard_policy::B);
+ [[maybe_unused]] auto buffer_0 =
+ ntt::distributed::make_sharded_tensor_view(
+ span_cast(make_subspan(
+ std::span((std::byte *)xx + 4096UL, 4096),
+ 0_dim, 4096_dim)),
+ ntt::make_shape(seq_length, 64_dim),
+ ntt::distributed::make_sharding>(
+ ntt::distributed::shard_policy::S<0>(16_dim),
+ ntt::distributed::shard_policy::B),
+ ntt::make_strides(64_dim, 1_dim));
+ }
}
void test_matmul_normal() {
@@ -633,8 +656,9 @@ void test_caching() {
constexpr auto head_dim_policy = paged_config_t::axis_policy<
ntt::caching::paged_kvcache_dim_kind::num_kv_heads>();
static_assert(
- std::is_same_v,
- ntt::distributed::shard_policy::split<0>>,
+ std::is_same_v<
+ std::remove_cv_t,
+ ntt::distributed::shard_policy::split>,
"find failed!");
auto context_lens = ntt::make_tensor(ntt::make_shape(1));
diff --git a/src/Nncase.Core/DistributedType.cs b/src/Nncase.Core/DistributedType.cs
index ccb7bb6bb4..93d7d75193 100644
--- a/src/Nncase.Core/DistributedType.cs
+++ b/src/Nncase.Core/DistributedType.cs
@@ -26,21 +26,19 @@ public abstract record SBP
{
public static SBPBroadCast B => SBPBroadCast.Instance;
- public static SBPPartial P(ReduceOp op = ReduceOp.Sum) => new SBPPartial(op);
+ public static SBPPartial P(IRArray axes, ReduceOp op = ReduceOp.Sum) => new SBPPartial(axes, op);
- public static SBPSplit S(IRArray axes) => new SBPSplit(axes);
-
- public static SBPSplit S(params int[] axes) => new SBPSplit(axes);
+ public static SBPSplit S(IRArray axes, Dimension? granularity = null) => new SBPSplit(axes, granularity);
}
-public sealed record SBPSplit(IRArray Axes) : SBP
+public sealed record SBPSplit(IRArray Axes, Dimension? Granularity = null) : SBP
{
- public override string ToString() => $"S({string.Join(",", Axes)})";
+ public override string ToString() => $"S([{string.Join(",", Axes)}], {Granularity})";
}
-public sealed record SBPPartial(ReduceOp Op) : SBP
+public sealed record SBPPartial(IRArray Axes, ReduceOp Op) : SBP
{
- public override string ToString() => $"P({Op})";
+ public override string ToString() => $"P([{string.Join(",", Axes)}], {Op})";
}
public sealed record SBPBroadCast : SBP
@@ -87,6 +85,10 @@ public override SBP Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSeri
{
sbpSplit = new SBPSplit(irAxes);
}
+ else if (typeDiscriminator == "P")
+ {
+ sbpPartial = new SBPPartial(irAxes, ReduceOp.Sum);
+ }
else
{
throw new InvalidDataException("Axes must be used in SBP split");
@@ -95,7 +97,7 @@ public override SBP Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSeri
break;
case "Op":
ReduceOp partialOp = JsonSerializer.Deserialize(ref reader, options);
- sbpPartial = new SBPPartial(partialOp);
+ sbpPartial = new SBPPartial(sbpPartial!.Axes, partialOp);
break;
default:
reader.Skip();
@@ -157,7 +159,7 @@ public sealed record Placement(IRArray Hierarchy, string Name, HierarchyKin
public override string ToString() => $"[{string.Join(',', Hierarchy.Zip(Name).Select(t => t.Second.ToString() + ':' + t.First.ToString()))}]";
}
-public sealed record DistributedType(TensorType TensorType, IRArray AxisPolicies, Placement Placement, bool Partial = false) : IRType
+public sealed record DistributedType(TensorType TensorType, IRArray AxisPolicies, Placement Placement, SBPPartial? Partial = null) : IRType
{
public override string ToString() => $"{TensorType}, ({string.Join(',', AxisPolicies)}), {Placement}, Partial: {Partial}";
}
diff --git a/src/Nncase.Core/Utilities/DistributedUtility.cs b/src/Nncase.Core/Utilities/DistributedUtility.cs
index fe77372334..3b18b2d8cf 100644
--- a/src/Nncase.Core/Utilities/DistributedUtility.cs
+++ b/src/Nncase.Core/Utilities/DistributedUtility.cs
@@ -22,7 +22,7 @@ static DistributedUtility()
}
}
- public delegate bool DivideByDelegate(long input, int divisor);
+ public delegate bool DivideByDelegate(long input, int divisor, bool isFixed);
[Flags]
public enum DivideFlags
@@ -73,9 +73,9 @@ public static IReadOnlyList> GetLeafCandidatePolicies(TensorType te
{
var axis = splitsAxes[ti];
var divisor = axis.Select(a => placement.Hierarchy[a]).Aggregate(1, (a, b) => a * b);
- if (axis.All(a => placement.Hierarchy[a] > 1) && divisor > 1 && DivideByFunc(maxShape[di], divisor))
+ if (axis.All(a => placement.Hierarchy[a] > 1) && divisor > 1 && DivideByFunc(maxShape[di], divisor, tensorType.Shape[di].IsFixed))
{
- policy.Add(SBP.S(axis.ToArray()));
+ policy.Add(SBP.S(axis.ToArray(), (int)MathUtility.CeilDiv(maxShape[di], divisor)));
}
}
@@ -142,7 +142,7 @@ public static bool IsDistributable(TensorType tensorType, ReadOnlySpan poli
// 2. All shapes are divisible by the mesh.
var maxShape = CompilerServices.GetMaxShape(tensorType.Shape);
var divisors = GetDivisors(new DistributedType(tensorType, polices.ToArray(), placement));
- return divisors.Select((d, axis) => (d, axis)).All(p => p.d == 0 ? true : DivideByFunc(maxShape[p.axis], p.d));
+ return divisors.Select((d, axis) => (d, axis)).All(p => p.d == 0 ? true : DivideByFunc(maxShape[p.axis], p.d, tensorType.Shape[p.axis].IsFixed));
}
public static bool IsDistributable(ReadOnlySpan polices)
@@ -213,7 +213,7 @@ public static bool TryGetDividedTensorType(DistributedType distributedType, [May
public static IRArray AxisPolicesToNDSBP(IRArray axisPolices, int rank)
{
- var ndsbp = new SBP[rank];
+ var ndsbp = Enumerable.Repeat(SBP.B, rank).Select(p => (SBP)p).ToArray();
for (var i = 0; i < axisPolices.Count; i++)
{
var policy = axisPolices[i];
@@ -221,27 +221,36 @@ public static IRArray AxisPolicesToNDSBP(IRArray axisPolices, int rank
{
foreach (var ax in split.Axes)
{
- ndsbp[ax] = SBP.S([i]);
+ ndsbp[ax] = SBP.S([i], split.Granularity);
+ }
+ }
+ else if (policy is SBPPartial partial)
+ {
+ foreach (var ax in partial.Axes)
+ {
+ ndsbp[ax] = SBP.P(ndsbp[ax] is SBPPartial p ? p.Axes.Append(i).ToArray() : [i], partial.Op);
}
}
}
- return ndsbp.Select(sbp => sbp is SBPSplit ? sbp : SBP.B).ToArray();
+ return ndsbp;
}
public static IRArray NDSBPToAxisPolices(IRArray ndsbp, int rank)
{
- var polices = new SBP[rank];
+ var polices = Enumerable.Repeat(SBP.B, rank).Select(p => (SBP)p).ToArray();
for (int d = 0; d < polices.Length; d++)
{
var splitAxes = Enumerable.Range(0, ndsbp.Count).Where(i => ndsbp[i] is SBPSplit split && split.Axes[0] == d).ToArray();
+ var partialAxes = Enumerable.Range(0, ndsbp.Count).Where(i => ndsbp[i] is SBPSplit partial && partial.Axes.Contains(d)).ToArray();
if (splitAxes.Any())
{
- polices[d] = SBP.S(splitAxes);
+ polices[d] = SBP.S(splitAxes, ((SBPSplit)ndsbp[splitAxes[0]]).Granularity);
}
- else
+
+ if (partialAxes.Any())
{
- polices[d] = SBP.B;
+ polices[d] = SBP.P(partialAxes, ((SBPPartial)ndsbp[partialAxes[0]]).Op);
}
}
@@ -293,9 +302,9 @@ from item in array
return ret.ToList();
}
- public static bool IsDivideBy(long input, int divisor)
+ public static bool IsDivideBy(long input, int divisor, bool isFixed)
{
- if (input >= divisor)
+ if (!isFixed || input >= divisor)
{
return true;
}
@@ -303,9 +312,9 @@ public static bool IsDivideBy(long input, int divisor)
return false;
}
- public static bool IsDivideExactly(long input, int divisor)
+ public static bool IsDivideExactly(long input, int divisor, bool isFixed = true)
{
- if (input >= divisor && input % divisor == 0)
+ if (!isFixed || (input >= divisor && input % divisor == 0))
{
return true;
}
@@ -313,6 +322,53 @@ public static bool IsDivideExactly(long input, int divisor)
return false;
}
+ public static bool AreSamePolicies(IRArray? a, IRArray? b, bool checkGranularity = true)
+ {
+ if (a == null && b == null)
+ {
+ return true;
+ }
+
+ if (a == null || b == null || a.Value.Count != b.Value.Count)
+ {
+ return false;
+ }
+
+ for (int i = 0; i < a.Value.Count; i++)
+ {
+ if (!IsSamePolicy(a.Value[i], b.Value[i], checkGranularity))
+ {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ public static bool IsSamePolicy(SBP a, SBP b, bool checkGranularity = true)
+ {
+ if (a == null || b == null)
+ {
+ return false;
+ }
+
+ if (a is SBPSplit splitA && b is SBPSplit splitB)
+ {
+ if (checkGranularity)
+ {
+ return a == b;
+ }
+ else
+ {
+ return splitA.Axes == splitB.Axes;
+ }
+ }
+ else
+ {
+ return a == b;
+ }
+ }
+
public static float GetDividedTensorEfficiency(DistributedType distributedType, int burstLength)
{
var (tiles, shape) = GetDividedTile(distributedType);
@@ -358,7 +414,8 @@ public static (long[] Offset, long[] Shape) GetLocalOffsetAndShape(DistributedTy
var shape = new long[distributedType.TensorType.Shape.Rank];
for (int axis = 0; axis < offset.Length; axis++)
{
- var splits = distributedType.AxisPolicies[axis] is SBPSplit s
+ var policy = distributedType.AxisPolicies[axis];
+ var splits = policy is SBPSplit s
? s.Axes.Select(td => (Placement: td, DeviceIndex: shardIndex[td], DeviceDim: distributedType.Placement.Hierarchy[td])).ToArray()
: Array.Empty<(int Placement, int DeviceIndex, int DeviceDim)>();
if (splits.Any())
@@ -368,7 +425,7 @@ public static (long[] Offset, long[] Shape) GetLocalOffsetAndShape(DistributedTy
var subHierarchySize = (int)TensorUtilities.GetProduct(subHierarchies);
var subShardIndex = splits.Select(x => x.DeviceIndex).ToArray();
var linearIndex = TensorUtilities.GetLinearOffset(subHierarchyStrides, subShardIndex);
- var localDim = MathUtility.CeilDiv(globalShape[axis], subHierarchySize);
+ var localDim = ((SBPSplit)policy).Granularity is not null ? (long)((SBPSplit)policy).Granularity!.Metadata.Range!.Value.Max : MathUtility.CeilDiv(globalShape[axis], subHierarchySize);
offset[axis] = linearIndex * localDim;
shape[axis] = Math.Min(localDim, globalShape[axis] - offset[axis]);
}
@@ -390,14 +447,21 @@ private static (RankedShape Tile, RankedShape Shape) GetDividedTile(DistributedT
{
if (distributedType.AxisPolicies.Count > d && distributedType.AxisPolicies[d] is SBPSplit split)
{
- var divisor = split.Axes.Select(t => distributedType.Placement.Hierarchy[t]).Aggregate(1, (a, b) => a * b);
- if (divideFlags.HasFlag(DivideFlags.FloorDiv))
+ if (split.Granularity is not null)
{
- tiles[d] = tiles[d] / divisor;
+ tiles[d] = split.Granularity;
}
else
{
- tiles[d] = Dimension.CeilDiv(tiles[d], divisor);
+ var divisor = split.Axes.Select(t => distributedType.Placement.Hierarchy[t]).Aggregate(1, (a, b) => a * b);
+ if (divideFlags.HasFlag(DivideFlags.FloorDiv))
+ {
+ tiles[d] = tiles[d] / divisor;
+ }
+ else
+ {
+ tiles[d] = Dimension.CeilDiv(tiles[d], divisor);
+ }
}
}
}
diff --git a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs
index a40ec71c2d..5b05aa790d 100644
--- a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs
+++ b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs
@@ -154,9 +154,16 @@ public override string VisitType(DistributedType type)
{
if (s is SBPSplit split)
{
- var divisor = split.Axes.Select(a => type.Placement.Hierarchy[a]).Aggregate(1, (a, b) => a * b);
- usedCeil[r] = shape[r] % divisor != 0;
- shape[r] = (shape[r] + divisor - 1) / divisor;
+ if (split.Granularity is null)
+ {
+ var divisor = split.Axes.Select(a => type.Placement.Hierarchy[a]).Aggregate(1, (a, b) => a * b);
+ usedCeil[r] = shape[r] % divisor != 0;
+ shape[r] = (shape[r] + divisor - 1) / divisor;
+ }
+ else
+ {
+ shape[r] = split.Granularity.FixedValue;
+ }
}
}
@@ -169,7 +176,7 @@ public override string VisitType(DistributedType type)
}
}
- return $"{{{VisitType(type.TensorType)}, ({string.Join(',', type.AxisPolicies)}), [{string.Join(',', sshape)}]}}";
+ return $"{{{VisitType(type.TensorType)}, ({string.Join(',', type.AxisPolicies)}), [{string.Join(',', sshape)}], {type.Partial?.ToString() ?? string.Empty}}}";
}
protected override string DispatchVisit(BaseExpr expr)
diff --git a/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs
index aa44efab09..af4d375a07 100644
--- a/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs
+++ b/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs
@@ -144,9 +144,16 @@ public override string VisitType(DistributedType type)
{
if (s is SBPSplit split)
{
- var divisor = split.Axes.Select(a => type.Placement.Hierarchy[a]).Aggregate(1, (a, b) => a * b);
- usedCeil[r] = shape[r] % divisor != 0;
- shape[r] = (shape[r] + divisor - 1) / divisor;
+ if (split.Granularity is null)
+ {
+ var divisor = split.Axes.Select(a => type.Placement.Hierarchy[a]).Aggregate(1, (a, b) => a * b);
+ usedCeil[r] = shape[r] % divisor != 0;
+ shape[r] = (shape[r] + divisor - 1) / divisor;
+ }
+ else
+ {
+ shape[r] = split.Granularity.FixedValue;
+ }
}
}
@@ -159,7 +166,7 @@ public override string VisitType(DistributedType type)
}
}
- return $"Dist({VisitType(type.TensorType)}, ({string.Join(',', type.AxisPolicies)}), [{string.Join(',', sshape)}])";
+ return $"Dist({VisitType(type.TensorType)}, ({string.Join(',', type.AxisPolicies)}), [{string.Join(',', sshape)}], {type.Partial?.ToString() ?? string.Empty})";
}
///
diff --git a/src/Nncase.Evaluator/Math/MatMul.cs b/src/Nncase.Evaluator/Math/MatMul.cs
index 99133c33cf..930b64158d 100644
--- a/src/Nncase.Evaluator/Math/MatMul.cs
+++ b/src/Nncase.Evaluator/Math/MatMul.cs
@@ -45,7 +45,7 @@ public static IRType VisitDistributedType(DistributedType a, DistributedType b,
var (lm, lk, rk, rn) = dimInfo ?? new(aRank - 2, aRank - 1, bRank - 2, bRank - 1);
var aMaxShape = CompilerServices.GetMaxShape(a.TensorType.Shape);
var bMaxShape = CompilerServices.GetMaxShape(b.TensorType.Shape);
- bool isPartial = false;
+ SBPPartial? partial = null;
// TODO: keep summa only
if (!a.TensorType.Shape.IsFixed || !b.TensorType.Shape.IsFixed || transB || (a.Placement.HierarchyKind == HierarchyKind.SMT && a.TensorType.DType is VectorType vt && vt.ElemType == DataTypes.Float8E4M3))
@@ -106,11 +106,11 @@ public static IRType VisitDistributedType(DistributedType a, DistributedType b,
ndsbp[oRank - 2] = a.AxisPolicies[lm];
ndsbp[oRank - 1] = b.AxisPolicies[rn];
- if (a.AxisPolicies[lk] is SBPSplit || b.AxisPolicies[rk] is SBPSplit)
+ if (a.AxisPolicies[lk] is SBPSplit sk && b.AxisPolicies[rk] is SBPSplit)
{
ndsbp[oRank - 2] = ndsbp[oRank - 2];
ndsbp[oRank - 1] = ndsbp[oRank - 1];
- isPartial = true;
+ partial = SBP.P(sk.Axes);
}
if (!DistributedUtility.IsDistributable(outType, ndsbp, a.Placement))
@@ -118,7 +118,7 @@ public static IRType VisitDistributedType(DistributedType a, DistributedType b,
return new InvalidType("no valid sbp.");
}
- return new DistributedType(outType, ndsbp, a.Placement, Partial: isPartial);
+ return new DistributedType(outType, ndsbp, a.Placement, Partial: partial);
}
else
{
@@ -149,7 +149,6 @@ public static IRType VisitDistributedType(DistributedType a, DistributedType b,
{
ndsbp[oRank - 2] = a.AxisPolicies[lm];
ndsbp[oRank - 1] = b.AxisPolicies[rn];
- isPartial = true;
}
else
{
@@ -211,7 +210,7 @@ public static IRType VisitDistributedType(DistributedType a, DistributedType b,
if (DistributedUtility.IsDistributable(ndsbp))
{
- return new DistributedType(outType, ndsbp, a.Placement, isPartial);
+ return new DistributedType(outType, ndsbp, a.Placement, partial);
}
return new InvalidType("no valid sbp.");
diff --git a/src/Nncase.Evaluator/NN/Conv2D.cs b/src/Nncase.Evaluator/NN/Conv2D.cs
index 6309e05883..ba73aaf5a8 100644
--- a/src/Nncase.Evaluator/NN/Conv2D.cs
+++ b/src/Nncase.Evaluator/NN/Conv2D.cs
@@ -140,7 +140,7 @@ private IRType Visit(ITypeInferenceContext context, Conv2D target, DistributedTy
{
if (ndsbpBias[i] is SBPBroadCast)
{
- ndsbp[i] = SBP.P();
+ ndsbp[i] = SBP.P([1]);
}
else
{
diff --git a/src/Nncase.Evaluator/Tensors/Bitcast.cs b/src/Nncase.Evaluator/Tensors/Bitcast.cs
index dc10fcb47a..09b71cf693 100644
--- a/src/Nncase.Evaluator/Tensors/Bitcast.cs
+++ b/src/Nncase.Evaluator/Tensors/Bitcast.cs
@@ -89,7 +89,14 @@ private IRType Visit(Bitcast target, DistributedType input)
return invalid;
}
- ndsbp[i] = input.AxisPolicies[i];
+ if (input.AxisPolicies[i] is SBPSplit split)
+ {
+ ndsbp[i] = SBP.S(split.Axes, split.Granularity is null ? null : split.Granularity * outTensorType.Shape[i] / input.TensorType.Shape[i]);
+ }
+ else
+ {
+ ndsbp[i] = input.AxisPolicies[i];
+ }
}
return new DistributedType(outTensorType, ndsbp, input.Placement);
diff --git a/src/Nncase.Evaluator/Tensors/Cast.cs b/src/Nncase.Evaluator/Tensors/Cast.cs
index cba956c88e..f8ff0857bf 100644
--- a/src/Nncase.Evaluator/Tensors/Cast.cs
+++ b/src/Nncase.Evaluator/Tensors/Cast.cs
@@ -73,7 +73,7 @@ private IRType Visit(Cast target, DistributedType inType)
{
var invalid = new InvalidType(inType.ToString());
var outType = Visit(target, inType.TensorType);
- var ndsbp = new SBP[inType.TensorType.Shape.Rank];
+ var ndsbp = inType.AxisPolicies.ToArray();
var shape = CompilerServices.GetMaxShape(inType.TensorType.Shape);
for (int i = 0; i < ndsbp.Length; i++)
{
@@ -92,10 +92,13 @@ private IRType Visit(Cast target, DistributedType inType)
{
return invalid;
}
+ else
+ {
+ var scale = 1f * outShape[i] / shape[i];
+ ndsbp[i] = SBP.S(split.Axes, split.Granularity is not null ? (scale >= 1 ? split.Granularity * (long)scale : split.Granularity / (long)(1f / scale)) : null);
+ }
}
}
-
- ndsbp[i] = inType.AxisPolicies[i];
}
return new DistributedType((TensorType)outType, ndsbp, inType.Placement);
diff --git a/src/Nncase.Evaluator/Tensors/Gather.cs b/src/Nncase.Evaluator/Tensors/Gather.cs
index cf56e1a562..bc32aeb5ae 100644
--- a/src/Nncase.Evaluator/Tensors/Gather.cs
+++ b/src/Nncase.Evaluator/Tensors/Gather.cs
@@ -3,6 +3,7 @@
using System;
using System.Linq;
+using DryIoc.ImTools;
using NetFabric.Hyperlinq;
using Nncase.CostModel;
using Nncase.IR;
@@ -49,16 +50,10 @@ public Cost Visit(ICostEvaluateContext context, Gather target)
var indexType = context.GetArgumentType(target, Gather.Index);
var retType = context.GetReturnType();
- var gatherPart = 1U;
- if (inputType is DistributedType d && d.AxisPolicies[target.Axis] is SBPSplit split)
- {
- gatherPart = split.Axes.Select(a => d.Placement.Hierarchy[a]).Aggregate(1U, (a, b) => (uint)(a * b));
- }
-
return new()
{
[CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType) + CostUtility.GetMemoryAccess(indexType),
- [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(retType) * gatherPart,
+ [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(retType),
};
}
@@ -112,12 +107,12 @@ private IRType Visit(DistributedType input, int axis, DistributedType index)
return invalid;
}
- // not support partial
+ // not support partial in ndsbp
if (ndsbp.Any(sbp => sbp is SBPPartial))
{
return invalid;
}
- return new DistributedType(tensorType, ndsbp, input.Placement);
+ return new DistributedType(tensorType, ndsbp, input.Placement, input.AxisPolicies[axis] is SBPSplit split ? SBP.P(split.Axes) : null);
}
}
diff --git a/src/Nncase.Evaluator/Tensors/Pack.cs b/src/Nncase.Evaluator/Tensors/Pack.cs
index 188875b895..7a61418264 100644
--- a/src/Nncase.Evaluator/Tensors/Pack.cs
+++ b/src/Nncase.Evaluator/Tensors/Pack.cs
@@ -109,12 +109,12 @@ private IRType Visit(ITypeInferenceContext context, Pack target, DistributedType
var ndsbp = new SBP[input.TensorType.Shape.Rank];
for (int i = 0; i < input.TensorType.Shape.Rank; i++)
{
- if (input.AxisPolicies[i] is SBPSplit && target.Axes.Contains(i))
+ if (input.AxisPolicies[i] is SBPSplit split && target.Axes.Contains(i))
{
var lane = target.Lanes[target.Axes.IndexOf(i)];
if (input.TensorType.Shape[i] is { IsFixed: true, FixedValue: long s } && s / lane % divisor[i] == 0)
{
- ndsbp[i] = input.AxisPolicies[i];
+ ndsbp[i] = SBP.S(split.Axes, split.Granularity is not null ? split.Granularity / lane : null);
}
else
{
diff --git a/src/Nncase.Evaluator/Tensors/Reshape.cs b/src/Nncase.Evaluator/Tensors/Reshape.cs
index 2f40baa687..82986f06c5 100644
--- a/src/Nncase.Evaluator/Tensors/Reshape.cs
+++ b/src/Nncase.Evaluator/Tensors/Reshape.cs
@@ -60,9 +60,16 @@ public static IRType VisitDistributedType(DistributedType inType, RankedShape ne
return invalidType;
}
+ var granularity = split.Granularity;
+ if (newDims[newSplitAxis] != inShape[inAxis])
+ {
+ // If the new split axis is not the same as the old split axis, we need to adjust the granularity.
+ granularity = granularity is null ? null : granularity / newAxes.Except([newAxesOffset + newSplitAxis]).Aggregate((Dimension)1, (a, b) => a * maxNewShape[b]);
+ }
+
foreach (var newAxis in newAxes)
{
- newAxisPolicies[newAxis] = newAxis == (newAxesOffset + newSplitAxis) ? split : SBP.B;
+ newAxisPolicies[newAxis] = newAxis == (newAxesOffset + newSplitAxis) ? SBP.S(split.Axes, granularity) : SBP.B;
}
}
else
@@ -101,7 +108,9 @@ where inPolicy is SBPSplit
return invalidType;
}
- newAxisPolicies[newAxis] = inType.AxisPolicies[firstSplitAxis.Value];
+ var split = (SBPSplit)inType.AxisPolicies[firstSplitAxis.Value];
+ newAxisPolicies[newAxis] = split.Granularity is null ? split :
+ SBP.S(split.Axes, split.Granularity! * inAxes.Except([firstSplitAxis.Value]).Aggregate((Dimension)1, (a, b) => a * maxInShape[b]));
}
else
{
diff --git a/src/Nncase.Evaluator/Tensors/Unpack.cs b/src/Nncase.Evaluator/Tensors/Unpack.cs
index 25723275ed..f8e0615869 100644
--- a/src/Nncase.Evaluator/Tensors/Unpack.cs
+++ b/src/Nncase.Evaluator/Tensors/Unpack.cs
@@ -104,6 +104,7 @@ private IRType Visit(ITypeInferenceContext context, Unpack target, DistributedTy
}
// [m]<8>@8@4 -> [m*8]@8@4, when max(m)=256 and runtime m=12, input and output have different local shape.
+ var newPolicies = input.AxisPolicies.ToArray();
foreach (var (s, r) in input.AxisPolicies.Select((s, r) => (s, r)))
{
if (s is SBPSplit split && target.Axes.Contains(r))
@@ -114,9 +115,11 @@ private IRType Visit(ITypeInferenceContext context, Unpack target, DistributedTy
{
return new InvalidType("Not support non-divisible input");
}
+
+ newPolicies[r] = SBP.S(split.Axes, split.Granularity is not null ? split.Granularity * target.Lanes[target.Axes.IndexOf(r)] : null);
}
}
- return new DistributedType(tensorType, input.AxisPolicies, input.Placement);
+ return new DistributedType(tensorType, newPolicies, input.Placement);
}
}
diff --git a/src/Nncase.Passes/Distributed/AutoDistributed.cs b/src/Nncase.Passes/Distributed/AutoDistributed.cs
index 04549c4fb1..71c609dc58 100644
--- a/src/Nncase.Passes/Distributed/AutoDistributed.cs
+++ b/src/Nncase.Passes/Distributed/AutoDistributed.cs
@@ -321,7 +321,7 @@ public void FilterByScheme(BaseExpr expr, DistributedSearchGraph cluster)
{
bool Matched(SearchableNode node, (IRArray Policies, Placement Placement) tp)
{
- return node.IRType is DistributedType dtype && dtype.AxisPolicies == tp.Policies && dtype.Placement == tp.Placement;
+ return node.IRType is DistributedType dtype && DistributedUtility.AreSamePolicies(dtype.AxisPolicies, tp.Policies, false) && dtype.Placement == tp.Placement;
}
foreach (var name in expr.Metadata.OutputNames ?? Array.Empty())
@@ -602,6 +602,11 @@ string DescribeSbp(IRType? type)
}
}
+ if (expr.Target is not Boxing && ((Call)newExpr).Arguments.AsValueEnumerable().Any(a => a.CheckedType is DistributedType dt && dt.Partial is not null))
+ {
+ continue;
+ }
+
if (!newExpr.InferenceType(_inferencer_cache) || newExpr.CheckedType is InvalidType)
{
continue;
@@ -693,7 +698,9 @@ string DescribeSbp(IRType? type)
|| expr.Users.Any(u => u is Call call && (call.Target.GetType().FullName!.Contains("CustomNTT", StringComparison.Ordinal) || (TargetOptions.HierarchyKind == HierarchyKind.SMT && expr.Target is PagedAttention)))
|| expr.Target.GetType().FullName!.Contains("CustomNTT", StringComparison.Ordinal)
|| expr.Target.GetType().FullName!.Contains("VectorizedRoPE", StringComparison.Ordinal)
- || (TargetOptions.HierarchyKind == HierarchyKind.SMT && expr.Target is PagedAttention))
+ || expr.Target.GetType().FullName!.Contains("Matmul", StringComparison.InvariantCultureIgnoreCase)
+ || (TargetOptions.HierarchyKind == HierarchyKind.SMT && expr.Target is PagedAttention)
+ || expr.Target is Gather)
{
bucket = callCluster.CreateCluster(SearchGraphKind.Bucket);
var linked = false;
@@ -1023,14 +1030,54 @@ private IRType CheckBoxingType(IRType inType, IRType outType, bool isReshape = f
{
IRType VisitD2D(DistributedType inv, DistributedType outv)
{
- if (inv == outv)
+ if (inv.Partial == outv.Partial && DistributedUtility.AreSamePolicies(inv.AxisPolicies, outv.AxisPolicies))
{
return new InvalidType("Same DistributedType");
}
- if (inv.AxisPolicies.Any(s => s is SBPPartial) || outv.AxisPolicies.Any(s => s is SBPPartial))
+ if (inv.AxisPolicies.Any(sbp => sbp is SBPPartial) || outv.AxisPolicies.Any(sbp => sbp is SBPPartial))
+ {
+ return new InvalidType("Not Support Partial in Policeis.");
+ }
+
+ var partialDims = new List();
+ if (inv.Partial is not null)
+ {
+ for (int i = 0; i < inv.AxisPolicies.Count; i++)
+ {
+ if (inv.AxisPolicies[i] is SBPSplit && outv.AxisPolicies[i] is SBPBroadCast)
+ {
+ return new InvalidType("Not supported input is BroadCast output is Split");
+ }
+
+ if (outv.AxisPolicies[i] is SBPSplit s)
+ {
+ if (inv.AxisPolicies[i] is SBPSplit splitIn)
+ {
+ if (splitIn.Axes.Except(s.Axes).Any())
+ {
+ return new InvalidType("Not Supported Split-> Split.");
+ }
+ }
+
+ if (s.Axes.Except(inv.Partial.Axes).ToArray() != s.Axes)
+ {
+ partialDims.Add(i);
+ }
+ }
+ }
+
+ var ndspsIn = DistributedUtility.AxisPolicesToNDSBP(inv.AxisPolicies, inv.Placement.Rank);
+ var ndspsOut = DistributedUtility.AxisPolicesToNDSBP(outv.AxisPolicies, outv.Placement.Rank);
+ if (Enumerable.Range(0, ndspsIn.Count).Any(i => ndspsIn[i] is SBPSplit si && (ndspsOut[i] is SBPBroadCast || (ndspsOut[i] is SBPSplit so && so.Axes[0] != si.Axes[0]))))
+ {
+ return new InvalidType("Not Supported Split-> Broadcast.");
+ }
+ }
+
+ if (partialDims.Count > 0 && !Enumerable.Range(0, inv.AxisPolicies.Count).Except(partialDims.ToArray()).All(i => DistributedUtility.IsSamePolicy(inv.AxisPolicies[i], outv.AxisPolicies[i])))
{
- return new InvalidType("Not supported input/output is Partial");
+ return new InvalidType("Not Supported Partial.");
}
return outv;
@@ -1038,7 +1085,7 @@ IRType VisitD2D(DistributedType inv, DistributedType outv)
IRType VisitD2T(DistributedType inv, TensorType outv)
{
- if (inv.AxisPolicies.Any(s => s is SBPPartial))
+ if (inv.AxisPolicies.Any(s => s is SBPPartial) || inv.Partial is not null)
{
return new InvalidType("Not supported input is Partial output is Unshard");
}
@@ -1048,7 +1095,7 @@ IRType VisitD2T(DistributedType inv, TensorType outv)
IRType VisitT2D(TensorType inv, DistributedType outv)
{
- if (outv.AxisPolicies.Any(s => s is SBPPartial))
+ if (outv.AxisPolicies.Any(s => s is SBPPartial) || outv.Partial is not null)
{
return new InvalidType("Not supported input is Unshard output is Partial");
}
diff --git a/src/Nncase.Passes/Rules/Distributed/UpdateBoxingTensorType.cs b/src/Nncase.Passes/Rules/Distributed/UpdateBoxingTensorType.cs
new file mode 100644
index 0000000000..0e7892bf99
--- /dev/null
+++ b/src/Nncase.Passes/Rules/Distributed/UpdateBoxingTensorType.cs
@@ -0,0 +1,42 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Nncase.IR;
+using Nncase.IR.Distributed;
+using Nncase.PatternMatch;
+using static Nncase.IR.F.NN;
+
+using static Nncase.IR.TypePatternUtility;
+using static Nncase.PatternMatch.F.Distributed;
+using static Nncase.PatternMatch.Utility;
+
+namespace Nncase.Passes.Rules;
+
+[RuleGenerator]
+public partial class UpdateBoxingTensorType : RewriteRule
+{
+ ///
+ public override Pattern Pattern { get; } = IsBoxing(
+ target_name: "boxing",
+ _ => true,
+ IsWildcard("input"));
+
+ private Expr? GetReplace(Boxing boxing, Expr input, RunPassContext context)
+ {
+ if (boxing.NewType is DistributedType dt)
+ {
+ var ttype = dt.TensorType;
+ var dtype = dt with { TensorType = ttype with { Shape = ttype.Shape.Select(d => d.Simplify()).ToArray() } };
+ var newBoxing = new Call(new IR.Distributed.Boxing(dtype), input);
+ context.MatchOptions.SuppressPattern(newBoxing, Pattern);
+ return newBoxing;
+ }
+
+ return null;
+ }
+}
diff --git a/src/Nncase.Schedule/Transforms/AutoTilePass.cs b/src/Nncase.Schedule/Transforms/AutoTilePass.cs
index b6e0ccb6ae..590702a877 100644
--- a/src/Nncase.Schedule/Transforms/AutoTilePass.cs
+++ b/src/Nncase.Schedule/Transforms/AutoTilePass.cs
@@ -11,6 +11,7 @@
using Nncase.IR;
using Nncase.IR.Affine;
using Nncase.Passes.GraphPartition;
+using Nncase.Passes.Rules;
using Nncase.Schedule;
using QuikGraph;
using QuikGraph.Algorithms;
@@ -116,6 +117,7 @@ protected override Task RunCoreAsync(BaseFunction input, RunPassCo
var constructor = new AutoTileReconstructor(tiler, ModuleKind, CompileOptions, condenseAlgo, dimVars.ToArray());
var post = constructor.Construct();
+ post = CompilerServices.Rewrite(post, [new UpdateBoxingTensorType()], new());
return Task.FromResult((BaseFunction)func.With(body: post));
}
}
diff --git a/src/Nncase.Tests/Core/UnitTestLayout.cs b/src/Nncase.Tests/Core/UnitTestLayout.cs
index 836908e239..719823168c 100644
--- a/src/Nncase.Tests/Core/UnitTestLayout.cs
+++ b/src/Nncase.Tests/Core/UnitTestLayout.cs
@@ -55,7 +55,7 @@ public void TestUnflatten()
public void TestDistributedTypeLayout()
{
var placement = new Placement([1, 2, 8, 4, 4], "cdyxt");
- var distType = new DistributedType(new TensorType(DataTypes.Float32, new[] { 2048, 1024 }), new SBP[] { SBP.S(1, 3), SBP.B }, placement);
+ var distType = new DistributedType(new TensorType(DataTypes.Float32, new[] { 2048, 1024 }), new SBP[] { SBP.S([1, 3]), SBP.B }, placement);
var layout = Layout.From(distType.TensorType);
Assert.Equal("Layout((2048, 1024):(1024, 1))", layout.ToString());
@@ -88,8 +88,8 @@ public void TestCoorinateWithDifferentLayout()
{
var placement = new Placement([1, 2, 8, 4, 4], "cdyxt");
var tensorType = new TensorType(DataTypes.Float32, new[] { 2048, 1024 });
- var distTypeA = new DistributedType(tensorType, new SBP[] { SBP.S(1, 3), SBP.B }, placement);
- var distTypeB = new DistributedType(tensorType, new SBP[] { SBP.B, SBP.S(2) }, placement);
+ var distTypeA = new DistributedType(tensorType, new SBP[] { SBP.S([1, 3]), SBP.B }, placement);
+ var distTypeB = new DistributedType(tensorType, new SBP[] { SBP.B, SBP.S([2]) }, placement);
var shardA = Layout.From(distTypeA);
var shardB = Layout.From(distTypeB);
diff --git a/src/Nncase.Tests/Core/UnitTestTensor.cs b/src/Nncase.Tests/Core/UnitTestTensor.cs
index b60aa2bf70..87c5628ca1 100644
--- a/src/Nncase.Tests/Core/UnitTestTensor.cs
+++ b/src/Nncase.Tests/Core/UnitTestTensor.cs
@@ -456,7 +456,7 @@ public void TestTensorSerialize()
},
new[] { 32 },
new[] { IR.NN.PagedKVCacheDimKind.NumBlocks },
- new[] { SBP.S(0) });
+ new[] { SBP.S([0]) });
var obj = new Evaluator.NN.RefPagedAttentionKVCache(cfg, 1, 4, Tensor.From([0L]), Tensor.From([4L]), Tensor.From([0L, 1L, 0L, 2L], [1, 2, 2]), Tensor.From([0L, 1L, 0L, 2L, 0L, 3L, 0L, 4L], [4, 2]), 4, Tensor.Zeros>([1, 1, 2, 3, 4, 5, 6]));
var original = Tensor.From(new Reference[] { new(obj) }, []);
using (var stream = File.Create(path))
diff --git a/src/Nncase.Tests/Distributed/UnitTestDistributeSchema.cs b/src/Nncase.Tests/Distributed/UnitTestDistributeSchema.cs
index ee2d743377..24b51a6a07 100644
--- a/src/Nncase.Tests/Distributed/UnitTestDistributeSchema.cs
+++ b/src/Nncase.Tests/Distributed/UnitTestDistributeSchema.cs
@@ -94,6 +94,6 @@ public async Task TestLoadScheme()
Dumpper.DumpIR(result, "result");
- Assert.True(result is Function { Body: Call { Target: IR.Distributed.Boxing } boxing } && boxing.Arguments[0] is Call { Target: IR.Math.Unary { UnaryOp: UnaryOp.Cos } } unary && unary.CheckedType is DistributedType dt && dt == new DistributedType(new(DataTypes.Float32, new[] { 1, 512, 8192 }), new[] { (SBP)SBP.B, SBP.S(new[] { 0 }), SBP.S(new[] { 1, 2 }) }, new(new[] { 8, 8, 4 }, "cbt")));
+ Assert.True(result is Function { Body: Call { Target: IR.Distributed.Boxing } boxing } && boxing.Arguments[0] is Call { Target: IR.Math.Unary { UnaryOp: UnaryOp.Cos } } unary && unary.CheckedType is DistributedType dt && dt == new DistributedType(new(DataTypes.Float32, new[] { 1, 512, 8192 }), new[] { (SBP)SBP.B, SBP.S([0], 64), SBP.S([1, 2], 256) }, new(new[] { 8, 8, 4 }, "cbt")));
}
}
diff --git a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs
index f5d1c1034a..35c1d02b4a 100755
--- a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs
+++ b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs
@@ -78,7 +78,7 @@ private static readonly (PagedKVCacheDimKind[] Cache, PagedKVCacheDimKind[] Vect
private static readonly (PagedKVCacheDimKind[] Sharding, SBPSplit[] Policies, Placement Placement)[] ShardingConfigs =
[
(Array.Empty(), Array.Empty(), new Placement(new[] { 1 }, "t")),
- (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S(0) }, new Placement(new[] { 1 }, "t")),
+ (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S([0]) }, new Placement(new[] { 1 }, "t")),
];
private static readonly (AttentionDimKind[] QLayout, AttentionDimKind[] KLayout)[] QKLayoutConfigs =
@@ -114,7 +114,7 @@ public PagedAttentionKVCacheTestData()
}
}
- Add(new TestFixture.PagedAttentionKVCacheTestFixture([4], [4], 4, 2, 32, 8, 32, Runtime.TypeCode.Float32, 1, [PagedKVCacheDimKind.NumLayers, PagedKVCacheDimKind.NumBlocks, PagedKVCacheDimKind.KV, PagedKVCacheDimKind.NumKVHeads, PagedKVCacheDimKind.HeadDim, PagedKVCacheDimKind.BlockSize], [PagedKVCacheDimKind.HeadDim], new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S(0) }, [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq]), new Placement(new[] { 1 }, "t"));
+ Add(new TestFixture.PagedAttentionKVCacheTestFixture([4], [4], 4, 2, 32, 8, 32, Runtime.TypeCode.Float32, 1, [PagedKVCacheDimKind.NumLayers, PagedKVCacheDimKind.NumBlocks, PagedKVCacheDimKind.KV, PagedKVCacheDimKind.NumKVHeads, PagedKVCacheDimKind.HeadDim, PagedKVCacheDimKind.BlockSize], [PagedKVCacheDimKind.HeadDim], new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S([0]) }, [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq]), new Placement(new[] { 1 }, "t"));
}
}
@@ -159,8 +159,8 @@ private static readonly (PagedKVCacheDimKind[] Cache, PagedKVCacheDimKind[] Vect
private static readonly (PagedKVCacheDimKind[] Sharding, SBPSplit[] Policies, int[] Hierarchy)[] ShardingConfigs =
[
(Array.Empty(), Array.Empty(), Array.Empty()),
- (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S(0) }, new[] { 1 }),
- (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S(0) }, new[] { 8 }),
+ (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S([0]) }, new[] { 1 }),
+ (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S([0]) }, new[] { 8 }),
];
public PagedAttentionSchedulerTestData()
diff --git a/src/Nncase.Tests/Simulator/UnitTestInterop.cs b/src/Nncase.Tests/Simulator/UnitTestInterop.cs
index c516895155..dac65cf5b7 100644
--- a/src/Nncase.Tests/Simulator/UnitTestInterop.cs
+++ b/src/Nncase.Tests/Simulator/UnitTestInterop.cs
@@ -271,11 +271,11 @@ public void TestRTAttentionConfig()
}
{
- var config = new IR.NN.PagedAttentionConfig(1, 2, 3, DataTypes.Float16, 4, new[] { IR.NN.PagedKVCacheDimKind.BlockSize, IR.NN.PagedKVCacheDimKind.HeadDim, IR.NN.PagedKVCacheDimKind.KV, IR.NN.PagedKVCacheDimKind.NumBlocks, IR.NN.PagedKVCacheDimKind.NumKVHeads, IR.NN.PagedKVCacheDimKind.NumLayers }, new[] { IR.NN.PagedKVCacheDimKind.HeadDim }, new[] { 32 }, new[] { IR.NN.PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S(1, 2) });
+ var config = new IR.NN.PagedAttentionConfig(1, 2, 3, DataTypes.Float16, 4, new[] { IR.NN.PagedKVCacheDimKind.BlockSize, IR.NN.PagedKVCacheDimKind.HeadDim, IR.NN.PagedKVCacheDimKind.KV, IR.NN.PagedKVCacheDimKind.NumBlocks, IR.NN.PagedKVCacheDimKind.NumKVHeads, IR.NN.PagedKVCacheDimKind.NumLayers }, new[] { IR.NN.PagedKVCacheDimKind.HeadDim }, new[] { 32 }, new[] { IR.NN.PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S([1, 2]) });
var rtConfig = RTAttentionConfig.FromConfig(config);
Assert.IsType(rtConfig);
var rtPagedConfig = (RTPagedAttentionConfig)rtConfig;
- Assert.True(rtPagedConfig.AxisPolicies.SequenceEqual([SBP.S(1, 2)]));
+ Assert.True(rtPagedConfig.AxisPolicies.SequenceEqual([SBP.S([1, 2])]));
}
}
diff --git a/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs b/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs
index cff88fe0fa..d4490b4dcf 100644
--- a/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs
+++ b/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs
@@ -98,7 +98,7 @@ private static readonly (PagedKVCacheDimKind[] Cache, PagedKVCacheDimKind[] Vect
private static readonly (PagedKVCacheDimKind[] Sharding, SBPSplit[] Policies, int[] Hierarchy)[] ShardingConfigs =
[
- (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S(0) }, [1]),
+ (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S([0]) }, [1]),
];
private static readonly (AttentionDimKind[] QLayout, AttentionDimKind[] KLayout)[] QKLayoutConfigs =
@@ -133,7 +133,7 @@ public TestUpdatePagedAttentionCaseData()
}
}
- Add(new TestFixture.PagedAttentionKVCacheTestFixture([256], [256], 14, 2, 64, 256, 16, Runtime.TypeCode.Float32, 1, [PagedKVCacheDimKind.NumBlocks, PagedKVCacheDimKind.NumLayers, PagedKVCacheDimKind.KV, PagedKVCacheDimKind.NumKVHeads, PagedKVCacheDimKind.HeadDim, PagedKVCacheDimKind.BlockSize], [PagedKVCacheDimKind.HeadDim], [PagedKVCacheDimKind.NumBlocks], [SBP.S(0)], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq]), [1], count++);
+ Add(new TestFixture.PagedAttentionKVCacheTestFixture([256], [256], 14, 2, 64, 256, 16, Runtime.TypeCode.Float32, 1, [PagedKVCacheDimKind.NumBlocks, PagedKVCacheDimKind.NumLayers, PagedKVCacheDimKind.KV, PagedKVCacheDimKind.NumKVHeads, PagedKVCacheDimKind.HeadDim, PagedKVCacheDimKind.BlockSize], [PagedKVCacheDimKind.HeadDim], [PagedKVCacheDimKind.NumBlocks], [SBP.S([0])], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq]), [1], count++);
}
}
@@ -186,7 +186,7 @@ private static readonly (PagedKVCacheDimKind[] Cache, PagedKVCacheDimKind[] Vect
private static readonly (PagedKVCacheDimKind[] Sharding, SBPSplit[] Policies, int[] Hierarchy)[] ShardingConfigs =
[
- (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S(0) }, [1]),
+ (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S([0]) }, [1]),
];
private static readonly (AttentionDimKind[] QLayout, AttentionDimKind[] KLayout)[] QKLayoutConfigs =
@@ -221,7 +221,7 @@ public TestPagedAttentionCaseData()
}
}
- Add(new TestFixture.PagedAttentionKVCacheTestFixture([4], [4], 14, 2, 64, 256, 16, Runtime.TypeCode.Float32, 1, [PagedKVCacheDimKind.NumBlocks, PagedKVCacheDimKind.NumLayers, PagedKVCacheDimKind.KV, PagedKVCacheDimKind.NumKVHeads, PagedKVCacheDimKind.HeadDim, PagedKVCacheDimKind.BlockSize], [PagedKVCacheDimKind.HeadDim], [PagedKVCacheDimKind.NumBlocks], [SBP.S(0)], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq]), [1], count++);
+ Add(new TestFixture.PagedAttentionKVCacheTestFixture([4], [4], 14, 2, 64, 256, 16, Runtime.TypeCode.Float32, 1, [PagedKVCacheDimKind.NumBlocks, PagedKVCacheDimKind.NumLayers, PagedKVCacheDimKind.KV, PagedKVCacheDimKind.NumKVHeads, PagedKVCacheDimKind.HeadDim, PagedKVCacheDimKind.BlockSize], [PagedKVCacheDimKind.HeadDim], [PagedKVCacheDimKind.NumBlocks], [SBP.S([0])], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq]), [1], count++);
}
}
@@ -509,25 +509,25 @@ public async Task TestGatherReduceScatter(long[] shape, int[] hierarchy, int cou
};
var placement = new Placement(hierarchy, targetOptions.HierarchyNames);
- var ndsbp = Enumerable.Repeat(SBP.B, hierarchy.Length).ToArray();
+ var ndsbp = Enumerable.Repeat(SBP.B, shape.Length).ToArray();
var posts = new List();
var broadcast = IR.F.Distributed.Boxing(input, new DistributedType(inputType, ndsbp, placement));
- foreach (var comb in LinqUtility.Combination(hierarchy.Length))
+ foreach (var comb in LinqUtility.Combination(shape.Length))
{
var newsbp = ndsbp.ToArray();
foreach (var axis in comb)
{
- newsbp[axis] = SBP.P();
+ newsbp[axis] = SBP.B;
}
- var partial = IR.F.Distributed.ForceBoxing(broadcast, new DistributedType(inputType, newsbp, placement));
+ var partial = IR.F.Distributed.ForceBoxing(broadcast, new DistributedType(inputType, newsbp, placement, SBP.P(Enumerable.Range(0, hierarchy.Length).ToArray())));
var sumed = IR.F.Distributed.Boxing(partial, new DistributedType(inputType, ndsbp, placement));
var post = IR.F.Distributed.Boxing(sumed, inputType);
post.Metadata = new Passes.Distributed.AutoDistributedMetaData() { Skip = true };
posts.Add(post);
}
- await RunCases($"Theory{count}", feedDict, posts);
+ await RunCases($"Theory{count}", feedDict, posts, null, false);
}
[Fact]
@@ -1204,7 +1204,7 @@ public async Task TestDynamicGetPositionIds(long[] queryLens, long[] seqLens, in
Metadata = new() { Range = new(1, MathUtility.AlignUp(queryLens.Sum(), 128)) },
};
- var fixture = new PagedAttentionKVCacheTestFixture(queryLens, seqLens, 2, 2, 64, 64, (int)MathUtility.CeilDiv(seqLens.Select(seq_len => MathUtility.CeilDiv(seq_len, 64)).Sum(), hierarchy.Max()) * hierarchy.Max(), Runtime.TypeCode.Float32, 1, [PagedKVCacheDimKind.NumBlocks, PagedKVCacheDimKind.NumLayers, PagedKVCacheDimKind.KV, PagedKVCacheDimKind.NumKVHeads, PagedKVCacheDimKind.HeadDim, PagedKVCacheDimKind.BlockSize], [PagedKVCacheDimKind.HeadDim], [PagedKVCacheDimKind.NumBlocks], [SBP.S(0)], [AttentionDimKind.Seq, AttentionDimKind.Dim, AttentionDimKind.Head], [AttentionDimKind.Seq, AttentionDimKind.Dim, AttentionDimKind.Head]);
+ var fixture = new PagedAttentionKVCacheTestFixture(queryLens, seqLens, 2, 2, 64, 64, (int)MathUtility.CeilDiv(seqLens.Select(seq_len => MathUtility.CeilDiv(seq_len, 64)).Sum(), hierarchy.Max()) * hierarchy.Max(), Runtime.TypeCode.Float32, 1, [PagedKVCacheDimKind.NumBlocks, PagedKVCacheDimKind.NumLayers, PagedKVCacheDimKind.KV, PagedKVCacheDimKind.NumKVHeads, PagedKVCacheDimKind.HeadDim, PagedKVCacheDimKind.BlockSize], [PagedKVCacheDimKind.HeadDim], [PagedKVCacheDimKind.NumBlocks], [SBP.S([0])], [AttentionDimKind.Seq, AttentionDimKind.Dim, AttentionDimKind.Head], [AttentionDimKind.Seq, AttentionDimKind.Dim, AttentionDimKind.Head]);
var placement = new Placement(hierarchy, targetOptions.HierarchyNames);
var dataGeneratorOptions = new PagedAttentionKVCacheTestFixture.DataGeneratorOptions(Random: true, IncreaseBy: [AttentionDimKind.Head], ResetForKV: true);