diff --git a/.gitignore b/.gitignore index a48d1f65f8..66a44247c8 100644 --- a/.gitignore +++ b/.gitignore @@ -320,4 +320,8 @@ global.json tests/llm/*/ # agent instructions -.github/instructions/ \ No newline at end of file +.github/instructions/ + +# AI +CLAUDE.md +AGENTS.md diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceConvertVisitor.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceConvertVisitor.cs index 3ebc69ab4a..038db85702 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/CSourceConvertVisitor.cs @@ -119,7 +119,7 @@ public void IndWrite(string? value) /// /// convert single prim function to c source. /// -public abstract class CSourceConvertVisitor : ExprFunctor +public class CSourceConvertVisitor : ExprFunctor { protected readonly Dictionary _exprMemo = new(ReferenceEqualityComparer.Instance); diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelCSourceConvertVisitor.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelCSourceConvertVisitor.cs index c3f3c1db12..346e4bffa7 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelCSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelCSourceConvertVisitor.cs @@ -401,18 +401,12 @@ protected override CSymbol VisitCall(Call expr) break; case TIR.NTT.Matmul matmul: { - var dimInfo = IR.NTT.VectorizedMatMul.GetDimInfo(matmul.TransposeA, matmul.TransposeB, args[0].CheckedShape.Rank, args[1].CheckedShape.Rank); WriteWithProfiler( RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Matmul.cshtml", new TypedKernelTemplateModel(matmul) { Arguments = args.Select(x => new KernelArgument { Symbol = VisitBuffer(x, local: true) }).ToArray(), }).Result, "matmul"); - if (args[0] is TIR.Buffer a && a.DistributedType?.AxisPolicies[dimInfo.Lk] is SBPSplit s) - { - var reduceKind = "tar::reduce_kind::" + string.Join("_", Enumerable.Range(0, TargetOptions.HierarchyNames.Length).Select(i => (s.Axes.Contains(i) ? "r" : string.Empty) + TargetOptions.HierarchyNames[i])); - WriteIndWithProfiler($"tac::tensor_reduce_sync({VisitBuffer(args[2], local: true).Name}, {VisitBuffer(args[2], local: true).Name});\n"); - } } break; @@ -428,18 +422,12 @@ protected override CSymbol VisitCall(Call expr) break; case TIR.NTT.PackedMatMul matmul: { - var dimInfo = IR.NTT.VectorizedMatMul.GetDimInfo(false, true, args[0].CheckedShape.Rank, args[1].CheckedShape.Rank); WriteWithProfiler( RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/PackedMatMul.cshtml", new TypedKernelTemplateModel(matmul) { Arguments = args.Select(x => new KernelArgument { Symbol = VisitBuffer(x, local: true) }).ToArray(), }).Result, "packed_matmul"); - if (args[0] is TIR.Buffer a && a.DistributedType?.AxisPolicies[dimInfo.Lk] is SBPSplit s) - { - var reduceKind = "tar::reduce_kind::" + string.Join("_", Enumerable.Range(0, TargetOptions.HierarchyNames.Length).Select(i => (s.Axes.Contains(i) ? "r" : string.Empty) + TargetOptions.HierarchyNames[i])); - WriteIndWithProfiler($"tac::tensor_reduce_sync({VisitBuffer(args[2], local: true).Name}, {VisitBuffer(args[2], local: true).Name});\n"); - } } break; @@ -449,11 +437,6 @@ protected override CSymbol VisitCall(Call expr) case TIR.NTT.Gather gather: { WriteWithProfiler($"gather({VisitBuffer(args[0], local: false).Name}, {VisitBuffer(args[1], local: true).Name}, {VisitBuffer(args[2], local: true).Name}, {gather.Axis}_dim);\n"); - if (args[0] is TIR.Buffer b && b.DistributedType?.AxisPolicies[gather.Axis] is SBPSplit s) - { - var reduceKind = "tar::reduce_kind::" + string.Join("_", Enumerable.Range(0, TargetOptions.HierarchyNames.Length).Select(i => (s.Axes.Contains(i) ? "r" : string.Empty) + TargetOptions.HierarchyNames[i])); - WriteIndWithProfiler($"tac::tensor_reduce_sync({VisitBuffer(args[2], local: true).Name}, {VisitBuffer(args[2], local: true).Name});\n"); - } } break; @@ -534,12 +517,11 @@ protected override CSymbol VisitCall(Call expr) break; case TIR.NTT.GatherReduceScatter grs: { - if (grs.InType.AxisPolicies.Any(s => s is SBPPartial)) + if (grs.InType.Partial is not null) { - // deprecated - var sbpPartial = (SBPPartial)grs.InType.AxisPolicies.Where(s => s is SBPPartial).Distinct().First(); - var reduceKind = "tar::reduce_kind::" + string.Join("_", grs.InType.AxisPolicies.Select((s, i) => (s is SBPPartial ? "r" : string.Empty) + TargetOptions.HierarchyNames[i])); - WriteIndWithProfiler($"tac::tensor_reduce_sync({VisitBuffer(args[0], local: true).Name}, {VisitBuffer(args[1], local: true).Name});\n"); + var sbpPartial = grs.InType.Partial; + var reduceKind = "tar::reduce_kind::" + string.Join("_", Enumerable.Range(0, TargetOptions.HierarchyNames.Length).Select(i => (sbpPartial.Axes.Contains(i) ? "r" : string.Empty) + TargetOptions.HierarchyNames[i])); + WriteIndWithProfiler($"tac::tensor_reduce_sync({VisitBuffer(args[0], local: false).Name}, {VisitBuffer(args[1], local: false).Name});\n"); } else { diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelUtility.cs b/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelUtility.cs index d47e5a2749..664070e994 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelUtility.cs +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/KernelUtility.cs @@ -35,7 +35,7 @@ public static string SBPToC(this SBP value) { if (value is SBPSplit s) { - return $"S<{string.Join(", ", s.Axes)}>()"; + return $"S<{string.Join(", ", s.Axes)}>({new CSourceConvertVisitor().Visit(s.Granularity as BaseExpr ?? None.Default).Name})"; } else { diff --git a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/topo_aware_runtime.cshtml b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/topo_aware_runtime.cshtml index 0ad5b88376..90dea517ad 100644 --- a/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/topo_aware_runtime.cshtml +++ b/modules/Nncase.Modules.NTT/CodeGen/CPU/Templates/topo_aware_runtime.cshtml @@ -58,8 +58,6 @@ enum reduce_kind { }; constexpr std::array Hierarchy = {@(string.Join(", ", hierarchy))}; -auto src_ptr_tensor = nncase::ntt::make_tensor(nncase::ntt::fixed_shape_v<@(string.Join(",", hierarchy))>); -auto dest_ptr_tensor = nncase::ntt::make_tensor(nncase::ntt::fixed_shape_v<@(string.Join(",", hierarchy))>); template static std::byte *get_cache_address() { return reinterpret_cast( @@ -159,8 +157,8 @@ class tensor_reduce_sync_impl { var cur_index = string.Join(", ", Enumerable.Range(0, hierarchy.Length).Select(i => "ntt::distributed::" + hierarchyNames[i] + "id()")); } - template - void reduce_impl(TSliceIn &local, TSliceIn &remote, TSliceOut &dest) { + template + void reduce_impl(TSliceIn1 &local, TSliceIn2 &remote, TSliceOut &dest) { if constexpr (Op == ntt::reduce_op::max) { ntt::binary(local, remote, dest); } else if constexpr (Op == ntt::reduce_op::sum || @@ -177,112 +175,141 @@ class tensor_reduce_sync_impl { // collect all tensors pointer for access tensor from other nodes. using TElem = typename TIn::element_type; using TOutBase = std::decay_t; + static_assert(ShardedTensor, "dest must be sharded tensor"); constexpr size_t Rank = TIn::rank(); constexpr auto group_hierarchy = group_hierarchy_getter::group_hierarchy; auto cur_index = ntt::make_shape(@(cur_index)); auto cur_index_g = index_global2group(cur_index); - tar::src_ptr_tensor(cur_index) = - reinterpret_cast(src.elements().data()); - tar::dest_ptr_tensor(cur_index) = - reinterpret_cast(dest.elements().data()); reduce_group_sync(); // according to the group size split the tensor. // todo should using better split strategy. constexpr auto group_size = ntt::fixed_dim_v; - const auto axis = [&] { - dim_t axis = -1; - loop([&](auto i) { - if (axis == -1 && src.shape()[i] >= group_size) { - axis = i; - } - }); - if (axis == -1) { - axis = 0; - } - return axis; - }(); - - auto remain = src.shape()[axis] % (group_size); - auto frac = src.shape()[axis] / (group_size); - auto node_number_g = ntt::linear_offset(cur_index_g, group_hierarchy); - - // reduce-scatter, communicate (group_size - 1) times - for (auto i = 0; i < group_size - 1; i++) - { - auto new_shape = ntt::generate_shape([&](auto j) { - if (j == axis) { - return ntt::where(node_number_g == group_size - 1, frac + remain, frac); + if (!src.shape().aggregate(true, [&](auto same, auto s, auto i){ return same && std::is_same_v(src.sharding().axis_policies))>, std::decay_t(dest.sharding().axis_policies))>>; })) { + using mesh_type = typename TOutBase::mesh_type; + const auto local_shard_index = mesh_type::local_index(); + auto node_number_g = ntt::linear_offset(cur_index_g, group_hierarchy); + auto new_shape = dest.local().shape(); + auto starts_src = src.sharding().global_offset(src.shape(), local_shard_index); + auto starts_dest = dest.sharding().global_offset(dest.shape(), local_shard_index); + auto starts = ntt::generate_shape([&](auto axis) { return starts_dest[axis] - starts_src[axis]; }); + auto viewed_src1_tensor = src.local().view(starts, new_shape); + auto viewed_dest_tensor = dest.local(); + + // reduce-scatter, communicate (group_size - 1) times + for (auto i = 0; i < group_size - 1; i++) + { + auto next_index_g = ntt::unravel_index((node_number_g + i + 1) % group_size, group_hierarchy); + + // keep the non-reduce axis invariant. + auto next_index = index_group2global(next_index_g, cur_index); + + auto src2_tensor = src.template remote(next_index); + auto viewed_src2_tensor = src2_tensor.view(starts, new_shape); + + if (i == 0) { + reduce_impl(viewed_src1_tensor, viewed_src2_tensor, + viewed_dest_tensor); } else { - return (dim_t)src.shape()[j]; + reduce_impl(viewed_dest_tensor, viewed_src2_tensor, + viewed_dest_tensor); } - }); - auto starts = ntt::generate_shape([&](auto j) { - if (j == axis) { - return node_number_g * frac; - } else { - return (dim_t)0; - } - }); - auto viewed_src1_tensor = src.view(starts, new_shape); - auto viewed_dest_tensor = dest.view(starts, new_shape); - - auto next_index_g = ntt::unravel_index((node_number_g + i + 1) % group_size, group_hierarchy); - - // keep the non-reduce axis invariant. - auto next_index = index_group2global(next_index_g, cur_index); - - auto src2_tensor = ntt::make_tensor_view_from_address( - (TElem *)tar::src_ptr_tensor(next_index), src.shape(), - src.strides()); - auto viewed_src2_tensor = src2_tensor.view(starts, new_shape); - - if (i == 0) { - reduce_impl(viewed_src1_tensor, viewed_src2_tensor, - viewed_dest_tensor); - } else { - reduce_impl(viewed_dest_tensor, viewed_src2_tensor, - viewed_dest_tensor); } - } - - reduce_group_sync(); - // all gather - for (size_t i = 0; i < group_size - 1; i++) { - auto offset = (node_number_g + i + 1) % (group_size); - auto src_index_g = ntt::unravel_index(offset % group_size, group_hierarchy); - auto src_index = index_group2global(src_index_g, cur_index); - - auto src_tensor = ntt::make_tensor_view_from_address( - (TElem *)tar::dest_ptr_tensor(src_index), dest.shape(), - dest.strides()); - auto starts = ntt::generate_shape([&](auto j) { - if (j == axis) { - return offset * frac; - } else { - return (dim_t)0; + ntt::tensor_copy_wait(); + reduce_group_sync(); + } else { + const auto axis = [&] { + dim_t axis = -1; + loop([&](auto i) { + if (axis == -1 && src.local().shape()[i] >= group_size) { + axis = i; + } + }); + if (axis == -1) { + axis = 0; } - }); - auto new_shape = ntt::generate_shape([&](auto j) { - if (j == axis) { - return ntt::where(offset == group_size - 1, frac + remain, frac); + return axis; + }(); + + auto remain = src.local().shape()[axis] % (group_size); + auto frac = src.local().shape()[axis] / (group_size); + + auto node_number_g = ntt::linear_offset(cur_index_g, group_hierarchy); + + // reduce-scatter, communicate (group_size - 1) times + for (auto i = 0; i < group_size - 1; i++) + { + auto new_shape = ntt::generate_shape([&](auto j) { + if (j == axis) { + return ntt::where(node_number_g == group_size - 1, frac + remain, frac); + } else { + return (dim_t)src.local().shape()[j]; + } + }); + auto starts = ntt::generate_shape([&](auto j) { + if (j == axis) { + return node_number_g * frac; + } else { + return (dim_t)0; + } + }); + auto viewed_src1_tensor = src.local().view(starts, new_shape); + auto viewed_dest_tensor = dest.local().view(starts, new_shape); + + auto next_index_g = ntt::unravel_index((node_number_g + i + 1) % group_size, group_hierarchy); + + // keep the non-reduce axis invariant. + auto next_index = index_group2global(next_index_g, cur_index); + + auto src2_tensor = src.template remote(next_index); + auto viewed_src2_tensor = src2_tensor.view(starts, new_shape); + + if (i == 0) { + reduce_impl(viewed_src1_tensor, viewed_src2_tensor, + viewed_dest_tensor); } else { - return (dim_t)src.shape()[j]; + reduce_impl(viewed_dest_tensor, viewed_src2_tensor, + viewed_dest_tensor); } - }); - auto viewed_src_tensor = src_tensor.view(starts, new_shape); - auto viewed_dest_tensor = dest.view(starts, new_shape); - ntt::tensor_copy_async(viewed_src_tensor, viewed_dest_tensor); - } + } - ntt::tensor_copy_wait(); - reduce_group_sync(); + reduce_group_sync(); + + // all gather + for (size_t i = 0; i < group_size - 1; i++) { + auto offset = (node_number_g + i + 1) % (group_size); + auto src_index_g = ntt::unravel_index(offset % group_size, group_hierarchy); + auto src_index = index_group2global(src_index_g, cur_index); + + auto src_tensor = dest.template remote(src_index); + auto starts = ntt::generate_shape([&](auto j) { + if (j == axis) { + return offset * frac; + } else { + return (dim_t)0; + } + }); + auto new_shape = ntt::generate_shape([&](auto j) { + if (j == axis) { + return ntt::where(offset == group_size - 1, frac + remain, frac); + } else { + return (dim_t)src.local().shape()[j]; + } + }); + auto viewed_src_tensor = src_tensor.view(starts, new_shape); + auto viewed_dest_tensor = dest.local().view(starts, new_shape); + ntt::tensor_copy_async(viewed_src_tensor, viewed_dest_tensor); + } - if (Op == ntt::reduce_op::mean) { - auto numerator = (element_or_scalar_t)(size_t)group_size; - ntt::binary(dest, ntt::make_tensor_view_from_address(&numerator, ntt::fixed_shape_v<>), dest); + ntt::tensor_copy_wait(); + reduce_group_sync(); + + if (Op == ntt::reduce_op::mean) { + auto numerator = (element_or_scalar_t)(size_t)group_size; + ntt::binary(dest.local(), ntt::make_tensor_view_from_address(&numerator, ntt::fixed_shape_v<>), dest.local()); + } } } }; diff --git a/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/LayerNorm.cs b/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/LayerNorm.cs index 7f7e01d6e9..5ca68de7ff 100644 --- a/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/LayerNorm.cs +++ b/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/LayerNorm.cs @@ -2,6 +2,7 @@ // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System; +using System.Diagnostics; using System.Linq; using System.Runtime.CompilerServices; using DryIoc.ImTools; @@ -47,7 +48,7 @@ public IRType Visit(ITypeInferenceContext context, LayerNorm target) { return (input, scale, bias) switch { - (DistributedType a, DistributedType b, DistributedType c) => new DistributedType((TensorType)VisitTensorType(target, a.TensorType, b.TensorType, c.TensorType), target.OutSBPs, a.Placement), + (DistributedType a, DistributedType b, DistributedType c) => VisitDistributedType(target, a, b, c), (TensorType a, TensorType b, TensorType c) => VisitTensorType(target, a, b, c), _ => new InvalidType($"{input} {scale} {bias} not support"), }; @@ -68,17 +69,17 @@ private bool CheckCustomSBP(IRType input, IRType scale, IRType bias, LayerNorm l { if (input is DistributedType a && scale is DistributedType b && bias is DistributedType c) { - if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => a.AxisPolicies[i] != layerNorm.InSBPs[i])) + if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(a.AxisPolicies[i], layerNorm.InSBPs[i], false))) { return false; } - if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => b.AxisPolicies[i] != layerNorm.ScaleSBPs[i])) + if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(b.AxisPolicies[i], layerNorm.ScaleSBPs[i], false))) { return false; } - if (Enumerable.Range(0, c.TensorType.Shape.Rank).Any(i => c.AxisPolicies[i] != layerNorm.BiasSBPs[i])) + if (Enumerable.Range(0, c.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(c.AxisPolicies[i], layerNorm.BiasSBPs[i], false))) { return false; } @@ -114,4 +115,24 @@ private IRType VisitTensorType(LayerNorm target, TensorType input, TensorType sc return new TensorType(target.OutputDataType, input.Shape); } } + + private IRType VisitDistributedType(LayerNorm target, DistributedType input, DistributedType scale, DistributedType bias) + { + var tensorType = (TensorType)VisitTensorType(target, input.TensorType, scale.TensorType, bias.TensorType); + + var ndsbps = new SBP[tensorType.Shape.Rank]; + for (var i = 0; i < ndsbps.Length; i++) + { + if (i == target.VectorizedAxes[0] && input.AxisPolicies[i] is SBPSplit split) + { + ndsbps[i] = SBP.S(split.Axes, split.Granularity is null ? null : split.Granularity * ((VectorType)input.TensorType.DType).Lanes[0] / ((VectorType)tensorType.DType).Lanes[0]); + } + else + { + ndsbps[i] = input.AxisPolicies[i]; + } + } + + return new DistributedType(tensorType, ndsbps, input.Placement); + } } diff --git a/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Matmul.cs b/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Matmul.cs index 52fda68dfe..8f639988c6 100644 --- a/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Matmul.cs +++ b/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Matmul.cs @@ -52,7 +52,7 @@ public IRType Visit(ITypeInferenceContext context, MatMul target) { return (lhs, rhs) switch { - (DistributedType a, DistributedType b) => new DistributedType((TensorType)VisitTensorType(target, a.TensorType, b.TensorType, true, dimInfo), target.OutSBPs, a.Placement), + (DistributedType a, DistributedType b) => VisitDistributedType(target, a, b, true, dimInfo), (TensorType a, TensorType b) => VisitTensorType(target, a, b, true, dimInfo), _ => new InvalidType($"{lhs} {rhs} not support"), }; @@ -78,12 +78,12 @@ private bool CheckCustomSBP(IRType lhs, IRType rhs, IRType extra, MatMul matmul) if (lhs is DistributedType a && rhs is DistributedType b) { - if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => a.AxisPolicies[i] != matmul.LhsSBPs[i])) + if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(a.AxisPolicies[i], matmul.LhsSBPs[i], false))) { return false; } - if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => b.AxisPolicies[i] != matmul.RhsSBPs[i])) + if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(b.AxisPolicies[i], matmul.RhsSBPs[i], false))) { return false; } @@ -156,4 +156,25 @@ private IRType VisitTensorType(MatMul target, TensorType lhs, TensorType rhs, bo return new TensorType(dtype, front.Concat(end).ToArray()); } + + private IRType VisitDistributedType(MatMul target, DistributedType lhs, DistributedType rhs, bool vectorizeK, MatMulDimInfo dimInfo) + { + var tensorType = (TensorType)VisitTensorType(target, lhs.TensorType, rhs.TensorType, vectorizeK, dimInfo); + + // FIXME: support rank>=2, and only support vectorize N of output. + var policyN = rhs.AxisPolicies[dimInfo!.Rn]; + if (policyN is SBPSplit split) + { + policyN = SBP.S(split.Axes, split.Granularity is null ? null : split.Granularity / ((VectorType)tensorType.DType).Lanes[0]); + } + + var policyM = lhs.AxisPolicies[dimInfo!.Lm]; + var ndsbps = (target.TransposeA || target.TransposeB) ? new[] { policyN, policyM } : new[] { policyM, policyN }; + if (DistributedUtility.AreSamePolicies(ndsbps, target.OutSBPs, false)) + { + return new DistributedType(tensorType, ndsbps, lhs.Placement); + } + + return new InvalidType("Please Check SBP Scheme."); + } } diff --git a/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/SparseExperts.cs b/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/SparseExperts.cs index 41f14343eb..b77aea008a 100644 --- a/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/SparseExperts.cs +++ b/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/SparseExperts.cs @@ -165,25 +165,25 @@ private bool CheckCustomSBP( { if (q is DistributedType a && gate is DistributedType b && down is DistributedType c && up is DistributedType d) { - if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => a.AxisPolicies[i] != se.QSBPs[i])) + if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(a.AxisPolicies[i], se.QSBPs[i], checkGranularity: false))) { Console.WriteLine($"[SparseExperts] Q SBP not match: {string.Join(",", a.AxisPolicies.Select(p => p.ToString()))} != {string.Join(",", se.QSBPs.Select(p => p.ToString()))}"); return false; } - if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => b.AxisPolicies[i] != se.GateSBPs[i])) + if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(b.AxisPolicies[i], se.GateSBPs[i], checkGranularity: false))) { Console.WriteLine($"[SparseExperts] Gate SBP not match: {string.Join(",", b.AxisPolicies.Select(p => p.ToString()))} != {string.Join(",", se.GateSBPs.Select(p => p.ToString()))}"); return false; } - if (Enumerable.Range(0, c.TensorType.Shape.Rank).Any(i => c.AxisPolicies[i] != se.DownSBPs[i])) + if (Enumerable.Range(0, c.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(c.AxisPolicies[i], se.DownSBPs[i], checkGranularity: false))) { Console.WriteLine($"[SparseExperts] Down SBP not match: {string.Join(",", c.AxisPolicies.Select(p => p.ToString()))} != {string.Join(",", se.DownSBPs.Select(p => p.ToString()))}"); return false; } - if (Enumerable.Range(0, d.TensorType.Shape.Rank).Any(i => d.AxisPolicies[i] != se.UpSBPs[i])) + if (Enumerable.Range(0, d.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(d.AxisPolicies[i], se.UpSBPs[i], checkGranularity: false))) { Console.WriteLine($"[SparseExperts] Up SBP not match: {string.Join(",", d.AxisPolicies.Select(p => p.ToString()))} != {string.Join(",", se.UpSBPs.Select(p => p.ToString()))}"); return false; diff --git a/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Unary.cs b/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Unary.cs index 31f57017fc..bb9f5a1fdb 100644 --- a/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Unary.cs +++ b/modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Unary.cs @@ -6,6 +6,7 @@ using Nncase.CostModel; using Nncase.IR; using Nncase.IR.CustomNTT; +using Nncase.Utilities; using OrtKISharp; namespace Nncase.Evaluator.CustomNTT; @@ -43,7 +44,7 @@ private bool CheckCustomSBP(IRType input, Unary target) { if (input is DistributedType inType) { - if (Enumerable.Range(0, inType.TensorType.Shape.Rank).Any(i => inType.AxisPolicies[i] != target.InSBPs[i])) + if (Enumerable.Range(0, inType.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(inType.AxisPolicies[i], target.InSBPs[i], false))) { return false; } diff --git a/modules/Nncase.Modules.NTT/Evaluator/Distributed/Boxing.cs b/modules/Nncase.Modules.NTT/Evaluator/Distributed/Boxing.cs index 79ab794779..ee6bb6770e 100644 --- a/modules/Nncase.Modules.NTT/Evaluator/Distributed/Boxing.cs +++ b/modules/Nncase.Modules.NTT/Evaluator/Distributed/Boxing.cs @@ -27,32 +27,67 @@ IRType VisitD2D(DistributedType inv, DistributedType outv) if (inv.TensorType != outv.TensorType) { - if (!inv.AxisPolicies.Any(sbp => sbp is SBPPartial)) + if (!inv.AxisPolicies.Any(sbp => sbp is SBPPartial) && inv.Partial is null) { return outv; } + else + { + return new InvalidType("Not Support Partial when shape changes."); + } + } + + if (inv.AxisPolicies.Any(sbp => sbp is SBPPartial) || outv.AxisPolicies.Any(sbp => sbp is SBPPartial)) + { + return new InvalidType("Not Support Partial in Policeis."); } - var ndsbpsA = DistributedUtility.AxisPolicesToNDSBP(inv.AxisPolicies, inv.Placement.Rank); - var ndsbpsB = DistributedUtility.AxisPolicesToNDSBP(outv.AxisPolicies, outv.Placement.Rank); - for (int i = 0; i < ndsbpsA.Count; i++) + var partialDims = new List(); + if (inv.Partial is not null) { - switch (ndsbpsA[i], ndsbpsB[i]) + for (int i = 0; i < inv.AxisPolicies.Count; i++) { - case (SBPPartial, SBPSplit): - return new InvalidType("partial to split"); - case (not SBPPartial, SBPPartial): - return new InvalidType("split/broadcast to partial"); + if (inv.AxisPolicies[i] is SBPSplit && outv.AxisPolicies[i] is SBPBroadCast) + { + return new InvalidType("Not supported input is BroadCast output is Split"); + } + + if (outv.AxisPolicies[i] is SBPSplit s) + { + if (inv.AxisPolicies[i] is SBPSplit splitIn) + { + if (splitIn.Axes.Except(s.Axes).Any()) + { + return new InvalidType("Not Supported Split-> Split."); + } + } + + if (s.Axes.Except(inv.Partial.Axes).ToArray() != s.Axes) + { + partialDims.Add(i); + } + } + } + + var ndspsIn = DistributedUtility.AxisPolicesToNDSBP(inv.AxisPolicies, inv.Placement.Rank); + var ndspsOut = DistributedUtility.AxisPolicesToNDSBP(outv.AxisPolicies, outv.Placement.Rank); + if (Enumerable.Range(0, ndspsIn.Count).Any(i => ndspsIn[i] is SBPSplit si && (ndspsOut[i] is SBPBroadCast || (ndspsOut[i] is SBPSplit so && so.Axes[0] != si.Axes[0])))) + { + return new InvalidType("Not Supported Split-> Broadcast."); } } + if (partialDims.Count > 0 && !Enumerable.Range(0, inv.AxisPolicies.Count).Except(partialDims.ToArray()).All(i => DistributedUtility.IsSamePolicy(inv.AxisPolicies[i], outv.AxisPolicies[i]))) + { + return new InvalidType("Not Supported Partial."); + } + return outv; } IRType VisitD2T(DistributedType inv, TensorType outv) { - var ndsbpsA = DistributedUtility.AxisPolicesToNDSBP(inv.AxisPolicies, inv.Placement.Rank); - if (ndsbpsA.Any(s => s is SBPPartial)) + if (inv.AxisPolicies.Any(s => s is SBPPartial) || inv.Partial is not null) { return new InvalidType("Not supported input is Partial output is Unshard"); } @@ -62,8 +97,7 @@ IRType VisitD2T(DistributedType inv, TensorType outv) IRType VisitT2D(TensorType inv, DistributedType outv) { - var ndsbpsB = DistributedUtility.AxisPolicesToNDSBP(outv.AxisPolicies, outv.Placement.Rank); - if (ndsbpsB.Any(s => s is SBPPartial)) + if (outv.AxisPolicies.Any(s => s is SBPPartial) || outv.Partial is not null) { return new InvalidType("Not supported input is Unshard output is Partial"); } @@ -308,6 +342,12 @@ public Cost Visit(ICostEvaluateContext context, Boxing target) case (SBPBroadCast, SBPSplit splitOut): splitOut.Axes.ToArray().ForEach(s => scatterPart *= hierarchyPenalty[s]); break; + case (SBPPartial, SBPSplit splitOut): + // actually partial to split needs gather. + break; + case (SBPPartial sBPPartial, SBPBroadCast): + sBPPartial.Axes.ToArray().ForEach(s => gatherPart *= hierarchyPenalty[s]); + break; default: throw new NotSupportedException($"{a} to {b}"); } diff --git a/modules/Nncase.Modules.NTT/Evaluator/Distributed/ForceBoxing.cs b/modules/Nncase.Modules.NTT/Evaluator/Distributed/ForceBoxing.cs index 9218bd558c..4a0d4a7383 100644 --- a/modules/Nncase.Modules.NTT/Evaluator/Distributed/ForceBoxing.cs +++ b/modules/Nncase.Modules.NTT/Evaluator/Distributed/ForceBoxing.cs @@ -21,7 +21,7 @@ IRType VisitD2D(DistributedType inv, DistributedType outv) var ndsbpsB = DistributedUtility.AxisPolicesToNDSBP(outv.AxisPolicies, outv.Placement.Rank).ToArray(); // TODO: add more invalid cases - if (ndsbpsA.Distinct().Count() == 1 && ndsbpsB.Distinct().Count() == 1 && ndsbpsA[0] == ndsbpsB[0]) + if (ndsbpsA.Distinct().Count() == 1 && ndsbpsB.Distinct().Count() == 1 && ndsbpsA[0] == ndsbpsB[0] && inv.Partial == outv.Partial) { return new InvalidType("Same NDSBP"); } @@ -90,7 +90,8 @@ public IValue Visit(IEvaluateContext context, ForceBoxing target) var inTenor = context.GetArgumentValueAsTensor(target, ForceBoxing.Input); var input = inTenor.ToOrtTensor(); var output = input - input; - var repeat = target.NewType.AxisPolicies.Select((x, i) => (x is SBPPartial) ? target.NewType.Placement.Hierarchy[i] : 1).Aggregate(1, (x, i) => x * i); + var ndsbps = DistributedUtility.AxisPolicesToNDSBP(target.NewType.AxisPolicies, target.NewType.Placement.Rank).ToArray(); + var repeat = ndsbps.Select((x, i) => (x is SBPPartial) ? target.NewType.Placement.Hierarchy[i] : 1).Aggregate(1, (x, i) => x * i); for (int i = 0; i < repeat; i++) { output += input; diff --git a/modules/Nncase.Modules.NTT/Evaluator/NTT/Im2col.cs b/modules/Nncase.Modules.NTT/Evaluator/NTT/Im2col.cs index 149e881f56..c9c0e8e37b 100644 --- a/modules/Nncase.Modules.NTT/Evaluator/NTT/Im2col.cs +++ b/modules/Nncase.Modules.NTT/Evaluator/NTT/Im2col.cs @@ -106,7 +106,6 @@ private IRType Visit(DistributedType dt, Im2col target) return new InvalidType("im2col typeinfer failed"); } - var outShape = tensorType.Shape.ToArray(); var ndsbp = new SBP[tensorType.Shape.Rank]; for (int i = 0; i < dt.AxisPolicies.Count; i++) @@ -117,7 +116,6 @@ private IRType Visit(DistributedType dt, Im2col target) switch (sbp) { case SBPSplit split: - outShape[1] /= split.Axes.Select(a => dt.Placement.Hierarchy[a]).Aggregate(1, (a, b) => a * b); ndsbp[1] = split; break; case SBPPartial: @@ -132,7 +130,6 @@ private IRType Visit(DistributedType dt, Im2col target) switch (sbp) { case SBPSplit split: - outShape[0] /= split.Axes.Select(a => dt.Placement.Hierarchy[a]).Aggregate(1, (a, b) => a * b); ndsbp[0] = split; break; case SBPPartial: @@ -146,7 +143,7 @@ private IRType Visit(DistributedType dt, Im2col target) { switch (sbp) { - case SBPSplit split: + case SBPSplit: return new InvalidType($"can't split on {i}"); case SBPPartial: return new InvalidType($"can't be partial sum"); diff --git a/modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedCast.cs b/modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedCast.cs index 2fa8111046..753b0afbd4 100644 --- a/modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedCast.cs +++ b/modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedCast.cs @@ -134,7 +134,7 @@ private IRType Visit(VectorizedCast target, DistributedType inType) { var invalid = new InvalidType(inType.ToString()); var outType = Visit(target, inType.TensorType); - var ndsbp = new SBP[inType.TensorType.Shape.Rank]; + var ndsbp = inType.AxisPolicies.ToArray(); var shape = CompilerServices.GetMaxShape(inType.TensorType.Shape); for (int i = 0; i < ndsbp.Length; i++) { @@ -153,10 +153,13 @@ private IRType Visit(VectorizedCast target, DistributedType inType) { return invalid; } + else + { + var scale = 1f * outShape[i] / shape[i]; + ndsbp[i] = SBP.S(split.Axes, split.Granularity is not null ? (scale >= 1 ? split.Granularity * (long)scale : split.Granularity / (long)(1f / scale)) : null); + } } } - - ndsbp[i] = inType.AxisPolicies[i]; } return new DistributedType((TensorType)outType, ndsbp, inType.Placement); diff --git a/modules/Nncase.Modules.NTT/Passes/AffineSelection/MatMul.cs b/modules/Nncase.Modules.NTT/Passes/AffineSelection/MatMul.cs index b2b1b0c8b4..1cead83cce 100644 --- a/modules/Nncase.Modules.NTT/Passes/AffineSelection/MatMul.cs +++ b/modules/Nncase.Modules.NTT/Passes/AffineSelection/MatMul.cs @@ -6,6 +6,7 @@ using Nncase.IR.Affine; using Nncase.TIR; using Nncase.TIR.NTT; +using Nncase.Utilities; namespace Nncase.Passes; @@ -27,8 +28,8 @@ public Expr SelectMatMul(Op op, Call call, Expr output) var dinfo = pmm.GetDimInfo(dta.TensorType.Shape.Rank, dtb.TensorType.Shape.Rank); if (dta.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) && dtb.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) && - dta.AxisPolicies[dinfo.Lk] == dtb.AxisPolicies[dinfo.Rn] && - dta.AxisPolicies[dinfo.Lm] == dtb.AxisPolicies[dinfo.Rk]) + DistributedUtility.IsSamePolicy(dta.AxisPolicies[dinfo.Lk], dtb.AxisPolicies[dinfo.Rn], false) && + DistributedUtility.IsSamePolicy(dta.AxisPolicies[dinfo.Lm], dtb.AxisPolicies[dinfo.Rk], false)) { return call; } @@ -42,8 +43,8 @@ public Expr SelectMatMul(Op op, Call call, Expr output) { if (dta.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) && dtb.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) && - dta.AxisPolicies[^2] == dtb.AxisPolicies[^2] && - dta.AxisPolicies[^1] == dtb.AxisPolicies[^1]) + DistributedUtility.IsSamePolicy(dta.AxisPolicies[^2], dtb.AxisPolicies[^2], false) && + DistributedUtility.IsSamePolicy(dta.AxisPolicies[^1], dtb.AxisPolicies[^1], false)) { return call; } @@ -57,8 +58,8 @@ public Expr SelectMatMul(Op op, Call call, Expr output) { if (dta.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) && dtb.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) && - dta.AxisPolicies[^2] == dtb.AxisPolicies[^1] && - dta.AxisPolicies[^1] == dtb.AxisPolicies[^2]) + DistributedUtility.IsSamePolicy(dta.AxisPolicies[^2], dtb.AxisPolicies[^1], false) && + DistributedUtility.IsSamePolicy(dta.AxisPolicies[^1], dtb.AxisPolicies[^2], false)) { return call; } diff --git a/modules/Nncase.Modules.NTT/Passes/NTTTIRSelectionPass.cs b/modules/Nncase.Modules.NTT/Passes/NTTTIRSelectionPass.cs index 40bf45e71c..552771a4e3 100644 --- a/modules/Nncase.Modules.NTT/Passes/NTTTIRSelectionPass.cs +++ b/modules/Nncase.Modules.NTT/Passes/NTTTIRSelectionPass.cs @@ -65,8 +65,8 @@ protected override Expr SelectCall(Call call, IReadOnlyList arguments, var dinfo = vectorizedMatMul.GetDimInfo(dta.TensorType.Shape.Rank, dtb.TensorType.Shape.Rank); if (dta.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) && dtb.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) && - dta.AxisPolicies[dinfo.Lk] == dtb.AxisPolicies[dinfo.Rn] && - dta.AxisPolicies[dinfo.Lm] == dtb.AxisPolicies[dinfo.Rk]) + DistributedUtility.IsSamePolicy(dta.AxisPolicies[dinfo.Lk], dtb.AxisPolicies[dinfo.Rn], false) && + DistributedUtility.IsSamePolicy(dta.AxisPolicies[dinfo.Lm], dtb.AxisPolicies[dinfo.Rk], false)) { return TIR.F.NTT.SUMMA((Expr)arguments[0], (Expr)arguments[1], output, None.Default, (Expr)call[IR.NTT.VectorizedMatMul.Scale], vectorizedMatMul.LhsVectorizedAxes, vectorizedMatMul.RhsVectorizedAxes, vectorizedMatMul.TransposeA, vectorizedMatMul.TransposeB); } @@ -78,8 +78,8 @@ protected override Expr SelectCall(Call call, IReadOnlyList arguments, case IR.Math.MatMul when GetArgumentType(arguments[0]) is DistributedType dta && GetArgumentType(arguments[1]) is DistributedType dtb: if (dta.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) && dtb.AxisPolicies[^2..].AsValueEnumerable().All(x => x is SBPSplit) && - dta.AxisPolicies[^2] == dtb.AxisPolicies[^2] && - dta.AxisPolicies[^1] == dtb.AxisPolicies[^1]) + DistributedUtility.IsSamePolicy(dta.AxisPolicies[^2], dtb.AxisPolicies[^2], false) && + DistributedUtility.IsSamePolicy(dta.AxisPolicies[^1], dtb.AxisPolicies[^1], false)) { return TIR.F.NTT.SUMMA((Expr)arguments[0], (Expr)arguments[1], output, None.Default, (Expr)call[IR.Math.MatMul.Scale]); } @@ -237,7 +237,7 @@ private Expr GenerateReshape(Expr input, ref Expr output, bool sequeeze = false) var bitcast = outBuffer.DistributedType is DistributedType ? false : true; if (!bitcast) { - if (inBuffer.DistributedType!.AxisPolicies.Where(sbp => sbp is not SBPBroadCast).ToArray().SequenceEqual(outBuffer.DistributedType!.AxisPolicies.Where(sbp => sbp is not SBPBroadCast).ToArray())) + if (DistributedUtility.AreSamePolicies(inBuffer.DistributedType!.AxisPolicies.Where(sbp => sbp is not SBPBroadCast).ToArray(), outBuffer.DistributedType!.AxisPolicies.Where(sbp => sbp is not SBPBroadCast).ToArray(), false)) { bitcast = true; } @@ -246,7 +246,19 @@ private Expr GenerateReshape(Expr input, ref Expr output, bool sequeeze = false) // If the size is not same, we cannot bitcast. if ((inBuffer.MemSpan.Size == outBuffer.MemSpan.Size) && (bitcast || sequeeze)) { - output = inBuffer.With(name: outBuffer.Name, elemType: outBuffer.ElemType, dimensions: outBuffer.Dimensions.ToArray(), strides: outBuffer.Strides.ToArray(), distributedType: outBuffer.DistributedType); + var newStrides = outBuffer.Strides.ToArray(); + if (inBuffer.MemSpan.Buffer.Location == MemoryLocation.BlockLocalData && bitcast) + { + var outType = outBuffer.DistributedType!; + var threadAxis = outType.Placement.Rank - 1; + int splitAxis = outType.AxisPolicies.ToArray().IndexOf(x => x is SBPSplit split && split.Axes.Contains(threadAxis)); + if (splitAxis >= 0) + { + Enumerable.Range(0, splitAxis).ToArray().ForEach(i => newStrides[i] *= outType.Placement.Hierarchy[threadAxis]); + } + } + + output = inBuffer.With(name: outBuffer.Name, elemType: outBuffer.ElemType, dimensions: outBuffer.Dimensions.ToArray(), strides: newStrides, distributedType: outBuffer.DistributedType); return T.Nop(); } else @@ -288,7 +300,7 @@ private Expr GenerateBitcast(Expr input, ref Expr output, DataType newType) } } - var distributedType = inBuffer.DistributedType is DistributedType dt + var distributedType = ((TIR.Buffer)output).DistributedType is DistributedType dt ? dt with { TensorType = new TensorType(newType, newDimensions) } : null; output = inBuffer.With(name: ((TIR.Buffer)output).Name, elemType: newType, dimensions: newDimensions, strides: newStrides, distributedType: distributedType); @@ -325,8 +337,6 @@ private Expr GenerateBoxing(Call call, IR.Distributed.Boxing boxing, IReadOnlyLi private Expr GenerateReshard(Expr input, ref Expr output, DistributedType inType, DistributedType outType) { - // FIXME: re-balance issue. -#if false if (input is TIR.Buffer inBuffer) { if (TryGenerateGatherThreadsReshard(inBuffer, ref output, inType, outType, out var newCall)) @@ -338,19 +348,24 @@ private Expr GenerateReshard(Expr input, ref Expr output, DistributedType inType return newCall; } } -#endif return TIR.F.NTT.GatherReduceScatter(input, output, inType, outType); } private bool TryGenerateGatherThreadsReshard(TIR.Buffer inBuffer, ref Expr output, DistributedType inType, DistributedType outType, [MaybeNullWhen(false)] out Expr newCall) { + if (inType.Partial is not null || inBuffer.Users.Any(u => u is Call { Target: TIR.PrimFunction })) + { + newCall = null; + return false; + } + var threadAxis = inType.Placement.Rank - 1; PhysicalBuffer? oldPhysicalBuffer = null; // S -> B var reducedInPolices = inType.AxisPolicies.Select(sbp => sbp is SBPSplit split && split.Axes.Contains(threadAxis) ? (split.Axes.Count == 1 ? (SBP)SBP.B : SBP.S(split.Axes.Except([threadAxis]).ToArray())) : sbp); - if (reducedInPolices.ToArray().SequenceEqual(outType.AxisPolicies.ToArray())) + if (DistributedUtility.AreSamePolicies(reducedInPolices.ToArray(), outType.AxisPolicies.ToArray(), false)) { oldPhysicalBuffer = inBuffer.MemSpan.Buffer; } @@ -422,7 +437,7 @@ private bool TryGenerateGatherThreadsReshard(TIR.Buffer inBuffer, ref Expr outpu private bool TryGenerateSplitThreadsReshard(TIR.Buffer inBuffer, ref Expr output, DistributedType inType, DistributedType outType, [MaybeNullWhen(false)] out Expr newCall) { // Avoid P -> B -> S - if (inType.Partial) + if (inType.Partial is not null || inBuffer.Users.Any(u => u is Call { Target: TIR.PrimFunction })) { newCall = null; return false; @@ -434,7 +449,7 @@ private bool TryGenerateSplitThreadsReshard(TIR.Buffer inBuffer, ref Expr output int splitAxis = -1; // B -> S - if (reducedOutPolices.ToArray().SequenceEqual(inType.AxisPolicies.ToArray())) + if (DistributedUtility.AreSamePolicies(reducedOutPolices.ToArray(), inType.AxisPolicies.ToArray(), false)) { oldPhysicalBuffer = inBuffer.MemSpan.Buffer; splitAxis = outType.AxisPolicies.ToArray().IndexOf(x => x is SBPSplit split && split.Axes.Contains(threadAxis)); diff --git a/modules/Nncase.Modules.NTT/Passes/Rules/Mutators/ToBlockLocalData.cs b/modules/Nncase.Modules.NTT/Passes/Rules/Mutators/ToBlockLocalData.cs new file mode 100644 index 0000000000..50caa14920 --- /dev/null +++ b/modules/Nncase.Modules.NTT/Passes/Rules/Mutators/ToBlockLocalData.cs @@ -0,0 +1,107 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reactive; +using Nncase.IR; +using Nncase.TIR; + +namespace Nncase.Passes.Mutators; + +/// +/// Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional Load/Store. +/// +public sealed class ToBlockLocalData : ExprRewriter +{ + private readonly Dictionary _bufferMemo = new(ReferenceEqualityComparer.Instance); + + protected override BaseExpr VisitSequential(Sequential expr, Unit context) + { + for (int i = 0; i < expr.Fields.Length - 1; i++) + { + { + if (expr.Fields[i] is Call { Target: TIR.NTT.VectorizedLayerNorm ln } lnCall + && expr.Fields[i + 1] is Call { Target: TIR.NTT.Matmul mm } mmCall + && ln.CSourcePath is not null && mm.CSourcePath is not null) + { + var lnOutput = lnCall.Arguments[TIR.NTT.VectorizedLayerNorm.Output.Index]; + var mmLhs = mmCall.Arguments[TIR.NTT.Matmul.Lhs.Index]; + if (lnOutput is TIR.Buffer a && mmLhs is TIR.Buffer b && a.MemSpan.Buffer == b.MemSpan.Buffer && b.MemSpan.Buffer.Location == MemoryLocation.Data) + { + var newPhysicalBuffer = a.MemSpan.Buffer.With(location: MemoryLocation.BlockLocalData); + _bufferMemo.TryAdd(lnOutput, newPhysicalBuffer); + _bufferMemo.TryAdd(mmLhs, newPhysicalBuffer); + } + } + } + + { + if (expr.Fields[i] is Call { Target: TIR.NTT.Matmul mm } mmCall + && expr.Fields[i + 1] is Call { Target: TIR.NTT.VectorizedLayerNorm ln } lnCall + && ln.CSourcePath is not null && mm.CSourcePath is not null) + { + var lnInput = lnCall.Arguments[TIR.NTT.VectorizedLayerNorm.Input.Index]; + var mmOutput = mmCall.Arguments[TIR.NTT.Matmul.Output.Index]; + if (lnInput is TIR.Buffer a && mmOutput is TIR.Buffer b && a.MemSpan.Buffer == b.MemSpan.Buffer && b.MemSpan.Buffer.Location == MemoryLocation.Data) + { + var newPhysicalBuffer = a.MemSpan.Buffer.With(location: MemoryLocation.BlockLocalData); + _bufferMemo.TryAdd(lnInput, newPhysicalBuffer); + _bufferMemo.TryAdd(mmOutput, newPhysicalBuffer); + } + } + } + + { + if (expr.Fields[i] is Call { Target: TIR.NTT.Matmul mm } mmCall + && expr.Fields[i + 1] is Call { Target: TIR.NTT.SynchronizeThreads } + && expr.Fields[i + 2] is Call { Target: TIR.NTT.UpdatePagedAttentionKVCache } + && expr.Fields[i + 3] is Call { Target: TIR.NTT.UpdatePagedAttentionKVCache } upkvCall2 + && mm.CSourcePath is not null) + { + var mmOutput = mmCall.Arguments[TIR.NTT.Matmul.Output.Index]; + var upkvInput = upkvCall2.Arguments[TIR.NTT.UpdatePagedAttentionKVCache.Slots.Index]; + if (upkvInput is TIR.Buffer a && mmOutput is TIR.Buffer b && a.MemSpan.Buffer == b.MemSpan.Buffer && b.MemSpan.Buffer.Location == MemoryLocation.Data) + { + var newPhysicalBuffer = a.MemSpan.Buffer.With(location: MemoryLocation.BlockLocalData); + _bufferMemo.TryAdd(upkvInput, newPhysicalBuffer); + _bufferMemo.TryAdd(mmOutput, newPhysicalBuffer); + } + } + } + + // FIXME: need to add sync after primfunc. + // { + // if (expr.Fields[i] is Call { Target: TIR.NTT.Matmul mm } mmCall + // && expr.Fields[i + 1] is Call { Target: TIR.PrimFunction } fnCall + // && mm.CSourcePath is not null) + // { + // var mmOutput = mmCall.Arguments[TIR.NTT.Matmul.Output.Index]; + // foreach (var arg in fnCall.Arguments) + // { + // if (arg is TIR.Buffer a && mmOutput is TIR.Buffer b && a.MemSpan.Buffer == b.MemSpan.Buffer && b.MemSpan.Buffer.Location == MemoryLocation.Data) + // { + // var newPhysicalBuffer = a.MemSpan.Buffer.With(location: MemoryLocation.BlockLocalData); + // _bufferMemo.TryAdd(arg, newPhysicalBuffer); + // _bufferMemo.TryAdd(mmOutput, newPhysicalBuffer); + // } + // } + // } + // } + } + + return base.VisitSequential(expr, context); + } + + protected override BaseExpr RewriteLeafBuffer(TIR.Buffer expr) + { + if (_bufferMemo.TryGetValue(expr, out var buffer)) + { + return expr.With(memSpan: expr.MemSpan.With(buffer: buffer)); + } + + return expr; + } +} diff --git a/modules/Nncase.Modules.NTT/Targets/CPUTarget.cs b/modules/Nncase.Modules.NTT/Targets/CPUTarget.cs index 04312ec46d..2bb9d04d09 100644 --- a/modules/Nncase.Modules.NTT/Targets/CPUTarget.cs +++ b/modules/Nncase.Modules.NTT/Targets/CPUTarget.cs @@ -119,6 +119,11 @@ public override void RegisterAutoVectorizeRules(IRulesAddable pass, CompileOptio public override void RegisterTIRSelectionPass(IPassManager passManager, CompileOptions optionsÍ) { passManager.Add(); + passManager.AddWithName("ToBlockLocalData").Configure(p => + { + p.Add(); + p.Add(); + }); } public override void RegisterPostAutoVectorizePass(IPassManager passManager, CompileOptions options) diff --git a/ntt/include/nncase/ntt/arch/cpu/remote_tensor.h b/ntt/include/nncase/ntt/arch/cpu/remote_tensor.h index f339dabd13..392abfe57d 100644 --- a/ntt/include/nncase/ntt/arch/cpu/remote_tensor.h +++ b/ntt/include/nncase/ntt/arch/cpu/remote_tensor.h @@ -20,7 +20,10 @@ namespace nncase::ntt::distributed { namespace detail { extern decltype(nncase::ntt::make_tensor>( - nncase::ntt::distributed::topology_shape)) global_local_data_ptr; + nncase::ntt::distributed::topology_shape)) global_thread_local_data_ptr; + +extern decltype(nncase::ntt::make_tensor>( + nncase::ntt::distributed::topology_shape)) global_block_local_data_ptr; extern decltype(nncase::ntt::make_tensor>( nncase::ntt::distributed::topology_shape)) global_thread_local_rdata_ptr; @@ -37,9 +40,9 @@ template = end) { start = (size_t)global_thread_local_rdata_ptr(local_program_ids)(0_dim); end = (size_t)global_thread_local_rdata_ptr(local_program_ids)(1_dim); @@ -48,8 +51,15 @@ static auto get_remote_address(const TLocalProgramIds &local_program_ids, if ((uintptr_t)local_address < start || (uintptr_t)local_address >= end) { start = (size_t)global_block_local_rdata_ptr(local_program_ids)(0_dim); + end = (size_t)global_block_local_rdata_ptr(local_program_ids)(1_dim); remote_address = (size_t)global_block_local_rdata_ptr(remote_program_ids)(0_dim); + if ((uintptr_t)local_address < start || + (uintptr_t)local_address >= end) { + start = (size_t)global_block_local_data_ptr(local_program_ids)(0_dim); + remote_address = + (size_t)global_block_local_data_ptr(remote_program_ids)(0_dim); + } } } diff --git a/ntt/include/nncase/ntt/arch/xpu/remote_tensor.h b/ntt/include/nncase/ntt/arch/xpu/remote_tensor.h index 354d5abb4c..751483084d 100644 --- a/ntt/include/nncase/ntt/arch/xpu/remote_tensor.h +++ b/ntt/include/nncase/ntt/arch/xpu/remote_tensor.h @@ -25,7 +25,10 @@ namespace nncase::ntt::distributed { namespace detail { #if not defined(SYS_MODE) extern decltype(nncase::ntt::make_tensor>( - nncase::ntt::distributed::topology_shape)) global_local_data_ptr; + nncase::ntt::distributed::topology_shape)) global_thread_local_data_ptr; + +extern decltype(nncase::ntt::make_tensor>( + nncase::ntt::distributed::topology_shape)) global_block_local_data_ptr; extern decltype(nncase::ntt::make_tensor>( nncase::ntt::distributed::topology_shape)) global_thread_local_rdata_ptr; @@ -46,9 +49,9 @@ static auto get_remote_address(const TLocalProgramIds &local_program_ids, const TRemoteProgramIds &remote_program_ids, T *local_address) { #if not defined(SYS_MODE) - auto start = global_local_data_ptr(local_program_ids)(0_dim); - auto end = global_local_data_ptr(local_program_ids)(1_dim); - auto remote_address = global_local_data_ptr(remote_program_ids)(0_dim); + auto start = global_thread_local_data_ptr(local_program_ids)(0_dim); + auto end = global_thread_local_data_ptr(local_program_ids)(1_dim); + auto remote_address = global_thread_local_data_ptr(remote_program_ids)(0_dim); if ((uintptr_t)local_address < start || (uintptr_t)local_address >= end) { start = global_thread_local_rdata_ptr(local_program_ids)(0_dim); end = global_thread_local_rdata_ptr(local_program_ids)(1_dim); @@ -57,8 +60,15 @@ static auto get_remote_address(const TLocalProgramIds &local_program_ids, if ((uintptr_t)local_address < start || (uintptr_t)local_address >= end) { start = global_block_local_rdata_ptr(local_program_ids)(0_dim); + end = global_block_local_rdata_ptr(local_program_ids)(1_dim); remote_address = global_block_local_rdata_ptr(remote_program_ids)(0_dim); + if ((uintptr_t)local_address < start || + (uintptr_t)local_address >= end) { + start = global_block_local_data_ptr(local_program_ids)(0_dim); + remote_address = + global_block_local_data_ptr(remote_program_ids)(0_dim); + } } } diff --git a/ntt/include/nncase/ntt/caching.h b/ntt/include/nncase/ntt/caching.h index 6108f5ed16..cee33abf7e 100644 --- a/ntt/include/nncase/ntt/caching.h +++ b/ntt/include/nncase/ntt/caching.h @@ -145,7 +145,7 @@ template class attention_kv_cache { namespace detail { template constexpr auto -kv_dim(const distributed::shard_policy::split &split) noexcept { +kv_dim(distributed::shard_policy::split split) noexcept { return split.template divider(); } @@ -184,8 +184,8 @@ constexpr auto origin_kv_cache_one_block_shape() noexcept { // num_blocks is not sharded, so we just return return last_shape; } else { - auto dim = - axis_policy_t::template try_shard_dim_without_shard_index< + auto policy = axis_policy_t{}; + auto dim = policy.template try_shard_dim_without_shard_index< Mesh>(last_shape[sharding_axis]); static_assert(dim != -1_dim, "Only uniform shard dim is supported."); diff --git a/ntt/include/nncase/ntt/distributed/sharding.h b/ntt/include/nncase/ntt/distributed/sharding.h index 13919a27b1..55e3327551 100644 --- a/ntt/include/nncase/ntt/distributed/sharding.h +++ b/ntt/include/nncase/ntt/distributed/sharding.h @@ -54,17 +54,17 @@ struct broadcast { inline constexpr broadcast B; // Split -template struct split { +template struct split { + TGranularity granularity; static constexpr auto axes = fixed_shape_v; - template static constexpr auto divider() { + template constexpr auto divider() { return Mesh::shape.select(axes).length(); } template TShardIndex> - static constexpr auto - global_offset(const TDim &global_dim, - const TShardIndex &shard_index) noexcept { + constexpr auto global_offset(const TDim &global_dim, + const TShardIndex &shard_index) noexcept { constexpr auto submesh_shape = fixed_shape_v]...>; auto subshard_index = make_shape(shard_index[fixed_dim_v]...); @@ -76,16 +76,21 @@ template struct split { } template - static constexpr auto + constexpr auto try_shard_dim_without_shard_index(const TDim &global_dim) noexcept { - const auto remainder = global_dim % divider(); - return ntt::where(remainder == dim_zero, global_dim / divider(), - -1_dim); + if constexpr (std::is_same_v) { + const auto remainder = global_dim % divider(); + return ntt::where(remainder == dim_zero, + global_dim / divider(), -1_dim); + } else { + return ntt::where(global_dim == granularity * divider(), + granularity, -1_dim); + } } template TShardIndex> - static constexpr auto shard_dim(const TDim &global_dim, - const TShardIndex &shard_index) noexcept { + constexpr auto shard_dim(const TDim &global_dim, + const TShardIndex &shard_index) noexcept { const auto shard_dim_v = try_shard_dim_without_shard_index(global_dim); return ntt::where(shard_dim_v != -1_dim, shard_dim_v, [&] { @@ -96,20 +101,31 @@ template struct split { } template - static constexpr auto max_shard_dim(const TDim &global_dim) noexcept { - return ntt::ceil_div(global_dim, divider()); + constexpr auto max_shard_dim(const TDim &global_dim) noexcept { + if constexpr (std::is_same_v) { + return ntt::ceil_div(global_dim, divider()); + } else { + return granularity; + } } + + constexpr split() requires(std::is_same_v) + : granularity(nullptr) {} + + constexpr split(TGranularity granularity_) : granularity(granularity_) {} }; -template constexpr auto S() noexcept { - return split{}; +template +constexpr auto S(TGranularity granularity = nullptr) noexcept { + return split{granularity}; } } // namespace shard_policy template struct is_split_shard_policy : std::false_type {}; -template -struct is_split_shard_policy> : std::true_type {}; +template +struct is_split_shard_policy> + : std::true_type {}; template concept SplitShardPolicy = is_split_shard_policy::value; @@ -134,9 +150,9 @@ template struct sharding { global_offset(const GlobalShape &global_shape, const TShardIndex &shard_index) const noexcept { auto get_dim = [&, this] { - return std::get(axis_policies) - .template global_offset(global_shape[fixed_dim_v], - shard_index); + auto policy = std::get(axis_policies); + return policy.template global_offset( + global_shape[fixed_dim_v], shard_index); }; auto get_all_dims = [&](std::index_sequence) { return make_shape(get_dim.template operator()()...); @@ -148,9 +164,9 @@ template struct sharding { constexpr auto shard_shape(const GlobalShape &global_shape, const TShardIndex &shard_index) const noexcept { auto get_dim = [&, this] { - return std::get(axis_policies) - .template shard_dim(global_shape[fixed_dim_v], - shard_index); + auto policy = std::get(axis_policies); + return policy.template shard_dim( + global_shape[fixed_dim_v], shard_index); }; auto get_all_dims = [&](std::index_sequence) { return make_shape(get_dim.template operator()()...); @@ -235,12 +251,14 @@ constexpr auto tensor_axes_of_non_split_shard_policies() noexcept { template constexpr auto local_shard_dim(const TSharding &sharding, const GlobalShape &global_shape) noexcept { - static_assert(GlobalShape::rank() == TSharding::rank(), "Invalid sharding."); + static_assert(GlobalShape::rank() == TSharding::rank(), + "Invalid sharding."); using mesh_type = typename TSharding::mesh_type; const auto local_index = mesh_type::local_index(); - return std::get(sharding.axis_policies) - .template shard_dim(global_shape[fixed_dim_v], local_index); + auto policy = std::get(sharding.axis_policies); + return policy.template shard_dim( + global_shape[fixed_dim_v], local_index); } template diff --git a/ntt/include/nncase/ntt/kernels/paged_attention.h b/ntt/include/nncase/ntt/kernels/paged_attention.h index d109d965bb..da242e48a9 100644 --- a/ntt/include/nncase/ntt/kernels/paged_attention.h +++ b/ntt/include/nncase/ntt/kernels/paged_attention.h @@ -150,7 +150,7 @@ constexpr void update_paged_attention_kv_cache(const TSlots &slots_tensor, .squeeze(local_slots_squeeze); // process kv_head different sharding on slot and kv cache. - const auto kv_head_policy = config_t::template axis_policy< + auto kv_head_policy = config_t::template axis_policy< caching::paged_kvcache_dim_kind::num_kv_heads>(); const auto global_head_id = slots_global_offset[head_index] + local_head_id; diff --git a/ntt/include/nncase/ntt/kernels/reshard.h b/ntt/include/nncase/ntt/kernels/reshard.h index 3123591590..8b47bc6952 100644 --- a/ntt/include/nncase/ntt/kernels/reshard.h +++ b/ntt/include/nncase/ntt/kernels/reshard.h @@ -120,7 +120,7 @@ struct reshard_impl { [&](auto last_acc, auto global_dim, auto axis) { auto [last_global_offset, last_local_offset, last_shape] = last_acc; - const auto policy = + auto policy = std::get(src.sharding().axis_policies); if constexpr (distributed::SplitShardPolicy< std::decay_t>) { @@ -330,16 +330,16 @@ struct reshard_impl { "Cannot reshard between different mesh types."); constexpr void operator()(const SrcTensor &src, DestTensor &dest) noexcept { - if constexpr (std::is_same_v) { - if (src.shape() == dest.shape()) { - // make sure src ready. - distributed::topology_synchronize(); - overlap_aware_reshard(src, dest); - distributed::topology_synchronize(); - return; - } - } + // if constexpr (std::is_same_v) { + // if (src.shape() == dest.shape()) { + // // make sure src ready. + // distributed::topology_synchronize(); + // overlap_aware_reshard(src, dest); + // distributed::topology_synchronize(); + // return; + // } + // } copy_to_global(src); copy_from_global(dest); @@ -388,7 +388,7 @@ struct reshard_impl { std::array counts{}; auto get_coord = [&]() { - const auto policy = std::get(src.sharding().axis_policies); + auto policy = std::get(src.sharding().axis_policies); if constexpr (distributed::SplitShardPolicy>) { size_t num_blocks = 1; constexpr auto policy_rank = policy.axes.rank(); diff --git a/ntt/src/cpu_runtime.cpp b/ntt/src/cpu_runtime.cpp index 90409b5d50..e57f41b008 100644 --- a/ntt/src/cpu_runtime.cpp +++ b/ntt/src/cpu_runtime.cpp @@ -38,7 +38,13 @@ using namespace nncase::ntt::runtime; decltype(nncase::ntt::make_tensor>( nncase::ntt::distributed::topology_shape)) - nncase::ntt::distributed::detail::global_local_data_ptr = + nncase::ntt::distributed::detail::global_thread_local_data_ptr = + nncase::ntt::make_tensor>( + nncase::ntt::distributed::topology_shape); + +decltype(nncase::ntt::make_tensor>( + nncase::ntt::distributed::topology_shape)) + nncase::ntt::distributed::detail::global_block_local_data_ptr = nncase::ntt::make_tensor>( nncase::ntt::distributed::topology_shape); @@ -194,11 +200,16 @@ extern "C" void block_entry(const cpu_block_entry_params_t ¶ms) { } } - ntt::distributed::detail::global_local_data_ptr(program_ids)( + ntt::distributed::detail::global_thread_local_data_ptr(program_ids)( 0_dim) = (uintptr_t)thread_local_data.data(); - ntt::distributed::detail::global_local_data_ptr(program_ids)( + ntt::distributed::detail::global_thread_local_data_ptr(program_ids)( 1_dim) = (uintptr_t)(thread_local_data.data() + thread_local_data.size_bytes()); + ntt::distributed::detail::global_block_local_data_ptr(program_ids)( + 0_dim) = (uintptr_t)block_local_data.data(); + ntt::distributed::detail::global_block_local_data_ptr(program_ids)( + 1_dim) = (uintptr_t)(block_local_data.data() + + block_local_data.size_bytes()); ntt::distributed::detail::global_block_local_rdata_ptr(program_ids)( 0_dim) = (uintptr_t)params.block_local_rdata.data(); ntt::distributed::detail::global_block_local_rdata_ptr(program_ids)( diff --git a/ntt/test/ctest/test_ntt_reshard.cpp b/ntt/test/ctest/test_ntt_reshard.cpp index c0d201b5ad..b3c08119d4 100644 --- a/ntt/test/ctest/test_ntt_reshard.cpp +++ b/ntt/test/ctest/test_ntt_reshard.cpp @@ -51,7 +51,7 @@ using namespace ortki; decltype(nncase::ntt::make_tensor>( nncase::ntt::distributed::topology_shape)) - nncase::ntt::distributed::detail::global_local_data_ptr = + nncase::ntt::distributed::detail::global_thread_local_data_ptr = nncase::ntt::make_tensor>( nncase::ntt::distributed::topology_shape); @@ -145,13 +145,13 @@ TEST(CpuTest, reshard_2D_same_sharding_spec_broadcast) { constexpr size_t num = cdims * bdims * tdims; ntt::apply( - ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) = (uintptr_t)(ntt::runtime::thread_alloc(M * N * sizeof(float), 8)); - ntt::distributed::detail::global_local_data_ptr(index)(1_dim) = - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) + + ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) + M * N * sizeof(float); }); @@ -199,7 +199,7 @@ TEST(CpuTest, reshard_2D_same_sharding_spec_broadcast) { const auto program_ids = make_shape(cid, bid, tid); float *local_data_src = reinterpret_cast( - ((size_t)ntt::distributed::detail::global_local_data_ptr(program_ids)( + ((size_t)ntt::distributed::detail::global_thread_local_data_ptr(program_ids)( 0_dim))); auto local_data_dst = reinterpret_cast( nncase::ntt::runtime::thread_alloc(M * N * sizeof(float), 8)); @@ -238,10 +238,10 @@ TEST(CpuTest, reshard_2D_same_sharding_spec_broadcast) { t.join(); - ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { thread_free( - (void *)(size_t)ntt::distributed::detail::global_local_data_ptr( + (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr( index)(0_dim)); }); } @@ -268,13 +268,13 @@ TEST(CpuTest, reshard_2D_same_sharding_sepc_split) { constexpr size_t num = cdims * bdims * tdims; ntt::apply( - ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) = (uintptr_t)(ntt::runtime::thread_alloc(M * N * sizeof(float), 8)); - ntt::distributed::detail::global_local_data_ptr(index)(1_dim) = - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) + + ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) + M * N * sizeof(float); }); @@ -323,7 +323,7 @@ TEST(CpuTest, reshard_2D_same_sharding_sepc_split) { const auto program_ids = make_shape(cid, bid, tid); float *local_data_src = reinterpret_cast( - ((size_t)ntt::distributed::detail::global_local_data_ptr(program_ids)( + ((size_t)ntt::distributed::detail::global_thread_local_data_ptr(program_ids)( 0_dim))); auto local_data_dst = reinterpret_cast( nncase::ntt::runtime::thread_alloc(M * N * sizeof(float), 8)); @@ -359,10 +359,10 @@ TEST(CpuTest, reshard_2D_same_sharding_sepc_split) { for (auto &t : threads) t.join(); - ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { thread_free( - (void *)(size_t)ntt::distributed::detail::global_local_data_ptr( + (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr( index)(0_dim)); }); } @@ -394,13 +394,13 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_broadcast_split) { constexpr size_t num = cdims * bdims * tdims; ntt::apply( - ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) = (uintptr_t)(ntt::runtime::thread_alloc(M * N * sizeof(float), 8)); - ntt::distributed::detail::global_local_data_ptr(index)(1_dim) = - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) + + ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) + M * N * sizeof(float); }); @@ -450,7 +450,7 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_broadcast_split) { const auto program_ids = make_shape(cid, bid, tid); float *local_data_src = reinterpret_cast( - ((size_t)ntt::distributed::detail::global_local_data_ptr(program_ids)( + ((size_t)ntt::distributed::detail::global_thread_local_data_ptr(program_ids)( 0_dim))); auto local_data_dst = reinterpret_cast( nncase::ntt::runtime::thread_alloc(M * N * sizeof(float), 8)); @@ -492,10 +492,10 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_broadcast_split) { for (auto &t : threads) t.join(); - ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { thread_free( - (void *)(size_t)ntt::distributed::detail::global_local_data_ptr( + (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr( index)(0_dim)); }); } @@ -524,13 +524,13 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_split_broadcast) { constexpr size_t num = cdims * bdims * tdims; ntt::apply( - ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) = (uintptr_t)(ntt::runtime::thread_alloc(M * N * sizeof(float), 8)); - ntt::distributed::detail::global_local_data_ptr(index)(1_dim) = - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) + + ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) + M * N * sizeof(float); }); @@ -578,7 +578,7 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_split_broadcast) { const auto program_ids = make_shape(cid, bid, tid); float *local_data_src = reinterpret_cast( - ((size_t)ntt::distributed::detail::global_local_data_ptr(program_ids)( + ((size_t)ntt::distributed::detail::global_thread_local_data_ptr(program_ids)( 0_dim))); auto local_data_dst = reinterpret_cast( nncase::ntt::runtime::thread_alloc(M * N * sizeof(float), 8)); @@ -617,10 +617,10 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_split_broadcast) { for (auto &t : threads) t.join(); - ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { thread_free( - (void *)(size_t)ntt::distributed::detail::global_local_data_ptr( + (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr( index)(0_dim)); }); } @@ -659,13 +659,13 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_different_split_axis) { constexpr size_t num = cdims * bdims * tdims; ntt::apply( - ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) = (uintptr_t)(ntt::runtime::thread_alloc(M * N * sizeof(float), 8)); - ntt::distributed::detail::global_local_data_ptr(index)(1_dim) = - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) + + ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) + M * N * sizeof(float); }); @@ -716,7 +716,7 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_different_split_axis) { const auto program_ids = make_shape(cid, bid, tid); float *local_data_src = reinterpret_cast( - (size_t)(ntt::distributed::detail::global_local_data_ptr(program_ids)( + (size_t)(ntt::distributed::detail::global_thread_local_data_ptr(program_ids)( 0_dim))); auto local_data_dst = reinterpret_cast( nncase::ntt::runtime::thread_alloc(M * N * sizeof(float), 8)); @@ -759,10 +759,10 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_different_split_axis) { EXPECT_TRUE(NttTest::compare_tensor(ntt_input, ntt_output)); - ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { thread_free( - (void *)(size_t)ntt::distributed::detail::global_local_data_ptr( + (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr( index)(0_dim)); }); } @@ -792,13 +792,13 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_non_divisible_SB2BS) { constexpr size_t num = cdims * bdims * tdims; ntt::apply( - ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) = (uintptr_t)(ntt::runtime::thread_alloc(M * N * sizeof(float), 8)); - ntt::distributed::detail::global_local_data_ptr(index)(1_dim) = - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) + + ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) + M * N * sizeof(float); }); @@ -849,7 +849,7 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_non_divisible_SB2BS) { const auto program_ids = make_shape(cid, bid, tid); float *local_data_src = reinterpret_cast( - (size_t)(ntt::distributed::detail::global_local_data_ptr(program_ids)( + (size_t)(ntt::distributed::detail::global_thread_local_data_ptr(program_ids)( 0_dim))); auto local_data_dst = reinterpret_cast( nncase::ntt::runtime::thread_alloc(M * N * sizeof(float), 8)); @@ -892,10 +892,10 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_non_divisible_SB2BS) { EXPECT_TRUE(NttTest::compare_tensor(ntt_input, ntt_output)); - ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { thread_free( - (void *)(size_t)ntt::distributed::detail::global_local_data_ptr( + (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr( index)(0_dim)); }); } @@ -926,13 +926,13 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_non_divisible_BS2SB) { constexpr size_t num = cdims * bdims * tdims; ntt::apply( - ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) = (uintptr_t)(ntt::runtime::thread_alloc(M * N * sizeof(float), 8)); - ntt::distributed::detail::global_local_data_ptr(index)(1_dim) = - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) + + ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) + M * N * sizeof(float); }); @@ -983,7 +983,7 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_non_divisible_BS2SB) { const auto program_ids = make_shape(cid, bid, tid); float *local_data_src = reinterpret_cast( - (size_t)(ntt::distributed::detail::global_local_data_ptr(program_ids)( + (size_t)(ntt::distributed::detail::global_thread_local_data_ptr(program_ids)( 0_dim))); auto local_data_dst = reinterpret_cast( nncase::ntt::runtime::thread_alloc(M * N * sizeof(float), 8)); @@ -1026,10 +1026,10 @@ TEST(CpuTest, reshard_2D_different_sharding_spec_non_divisible_BS2SB) { EXPECT_TRUE(NttTest::compare_tensor(ntt_input, ntt_output)); - ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { thread_free( - (void *)(size_t)ntt::distributed::detail::global_local_data_ptr( + (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr( index)(0_dim)); }); } @@ -1059,13 +1059,13 @@ TEST(CpuTest, reshard_2D_split_multilple_axes) { constexpr size_t num = cdims * bdims * tdims; ntt::apply( - ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) = (uintptr_t)(ntt::runtime::thread_alloc(M * N * sizeof(float), 8)); - ntt::distributed::detail::global_local_data_ptr(index)(1_dim) = - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) + + ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) + M * N * sizeof(float); }); @@ -1115,7 +1115,7 @@ TEST(CpuTest, reshard_2D_split_multilple_axes) { const auto program_ids = make_shape(cid, bid, tid); float *local_data_src = reinterpret_cast( - (size_t)(ntt::distributed::detail::global_local_data_ptr(program_ids)( + (size_t)(ntt::distributed::detail::global_thread_local_data_ptr(program_ids)( 0_dim))); auto local_data_dst = reinterpret_cast( nncase::ntt::runtime::thread_alloc(M * N * sizeof(float), 8)); @@ -1158,10 +1158,10 @@ TEST(CpuTest, reshard_2D_split_multilple_axes) { EXPECT_TRUE(NttTest::compare_tensor(ntt_input, ntt_output)); - ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { thread_free( - (void *)(size_t)ntt::distributed::detail::global_local_data_ptr( + (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr( index)(0_dim)); }); } @@ -1192,13 +1192,13 @@ TEST(CpuTest, reshard_3D_different_sharding_spec_different_split_axis) { constexpr size_t num = cdims * bdims * tdims; ntt::apply( - ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) = (uintptr_t)(ntt::runtime::thread_alloc(M * N * K * sizeof(float), 8)); - ntt::distributed::detail::global_local_data_ptr(index)(1_dim) = - ntt::distributed::detail::global_local_data_ptr(index)(0_dim) + + ntt::distributed::detail::global_thread_local_data_ptr(index)(1_dim) = + ntt::distributed::detail::global_thread_local_data_ptr(index)(0_dim) + M * N * K *sizeof(float); }); @@ -1251,7 +1251,7 @@ TEST(CpuTest, reshard_3D_different_sharding_spec_different_split_axis) { const auto program_ids = make_shape(cid, bid, tid); float *local_data_src = reinterpret_cast( - (size_t)(ntt::distributed::detail::global_local_data_ptr(program_ids)( + (size_t)(ntt::distributed::detail::global_thread_local_data_ptr(program_ids)( 0_dim))); auto local_data_dst = reinterpret_cast( nncase::ntt::runtime::thread_alloc(M * N * K * sizeof(float), 8)); @@ -1297,10 +1297,10 @@ TEST(CpuTest, reshard_3D_different_sharding_spec_different_split_axis) { EXPECT_TRUE(NttTest::compare_tensor(ntt_input, ntt_output)); - ntt::apply(ntt::distributed::detail::global_local_data_ptr.shape(), + ntt::apply(ntt::distributed::detail::global_thread_local_data_ptr.shape(), [&](auto index) { thread_free( - (void *)(size_t)ntt::distributed::detail::global_local_data_ptr( + (void *)(size_t)ntt::distributed::detail::global_thread_local_data_ptr( index)(0_dim)); }); } @@ -1416,4 +1416,4 @@ TEST(CpuTest, reshard_reshape) { int main(int argc, char *argv[]) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); -} \ No newline at end of file +} diff --git a/src/Native/src/test.cpp b/src/Native/src/test.cpp index c8298c3ee7..7afd028ab7 100644 --- a/src/Native/src/test.cpp +++ b/src/Native/src/test.cpp @@ -189,6 +189,29 @@ void test_sharding() { ntt::distributed::detail::mesh_axes_of_non_split_shard_policies< sharding_type>() == ntt::fixed_shape_v<0, 1>); } + + // Sharing split with S<2> + { + ntt::dim_t seq_length = 122; + [[maybe_unused]] float *xx = new float[seq_length]; + [[maybe_unused]] auto sp = ntt::distributed::shard_policy::S<0>(16_dim); + [[maybe_unused]] auto sb = ntt::distributed::shard_policy::B; + [[maybe_unused]] auto sharding = ntt::distributed::make_sharding< + ntt::distributed::mesh>( + ntt::distributed::shard_policy::S<0>(16_dim), + ntt::distributed::shard_policy::B); + [[maybe_unused]] auto buffer_0 = + ntt::distributed::make_sharded_tensor_view( + span_cast(make_subspan( + std::span((std::byte *)xx + 4096UL, 4096), + 0_dim, 4096_dim)), + ntt::make_shape(seq_length, 64_dim), + ntt::distributed::make_sharding>( + ntt::distributed::shard_policy::S<0>(16_dim), + ntt::distributed::shard_policy::B), + ntt::make_strides(64_dim, 1_dim)); + } } void test_matmul_normal() { @@ -633,8 +656,9 @@ void test_caching() { constexpr auto head_dim_policy = paged_config_t::axis_policy< ntt::caching::paged_kvcache_dim_kind::num_kv_heads>(); static_assert( - std::is_same_v, - ntt::distributed::shard_policy::split<0>>, + std::is_same_v< + std::remove_cv_t, + ntt::distributed::shard_policy::split>, "find failed!"); auto context_lens = ntt::make_tensor(ntt::make_shape(1)); diff --git a/src/Nncase.Core/DistributedType.cs b/src/Nncase.Core/DistributedType.cs index ccb7bb6bb4..93d7d75193 100644 --- a/src/Nncase.Core/DistributedType.cs +++ b/src/Nncase.Core/DistributedType.cs @@ -26,21 +26,19 @@ public abstract record SBP { public static SBPBroadCast B => SBPBroadCast.Instance; - public static SBPPartial P(ReduceOp op = ReduceOp.Sum) => new SBPPartial(op); + public static SBPPartial P(IRArray axes, ReduceOp op = ReduceOp.Sum) => new SBPPartial(axes, op); - public static SBPSplit S(IRArray axes) => new SBPSplit(axes); - - public static SBPSplit S(params int[] axes) => new SBPSplit(axes); + public static SBPSplit S(IRArray axes, Dimension? granularity = null) => new SBPSplit(axes, granularity); } -public sealed record SBPSplit(IRArray Axes) : SBP +public sealed record SBPSplit(IRArray Axes, Dimension? Granularity = null) : SBP { - public override string ToString() => $"S({string.Join(",", Axes)})"; + public override string ToString() => $"S([{string.Join(",", Axes)}], {Granularity})"; } -public sealed record SBPPartial(ReduceOp Op) : SBP +public sealed record SBPPartial(IRArray Axes, ReduceOp Op) : SBP { - public override string ToString() => $"P({Op})"; + public override string ToString() => $"P([{string.Join(",", Axes)}], {Op})"; } public sealed record SBPBroadCast : SBP @@ -87,6 +85,10 @@ public override SBP Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSeri { sbpSplit = new SBPSplit(irAxes); } + else if (typeDiscriminator == "P") + { + sbpPartial = new SBPPartial(irAxes, ReduceOp.Sum); + } else { throw new InvalidDataException("Axes must be used in SBP split"); @@ -95,7 +97,7 @@ public override SBP Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSeri break; case "Op": ReduceOp partialOp = JsonSerializer.Deserialize(ref reader, options); - sbpPartial = new SBPPartial(partialOp); + sbpPartial = new SBPPartial(sbpPartial!.Axes, partialOp); break; default: reader.Skip(); @@ -157,7 +159,7 @@ public sealed record Placement(IRArray Hierarchy, string Name, HierarchyKin public override string ToString() => $"[{string.Join(',', Hierarchy.Zip(Name).Select(t => t.Second.ToString() + ':' + t.First.ToString()))}]"; } -public sealed record DistributedType(TensorType TensorType, IRArray AxisPolicies, Placement Placement, bool Partial = false) : IRType +public sealed record DistributedType(TensorType TensorType, IRArray AxisPolicies, Placement Placement, SBPPartial? Partial = null) : IRType { public override string ToString() => $"{TensorType}, ({string.Join(',', AxisPolicies)}), {Placement}, Partial: {Partial}"; } diff --git a/src/Nncase.Core/Utilities/DistributedUtility.cs b/src/Nncase.Core/Utilities/DistributedUtility.cs index fe77372334..3b18b2d8cf 100644 --- a/src/Nncase.Core/Utilities/DistributedUtility.cs +++ b/src/Nncase.Core/Utilities/DistributedUtility.cs @@ -22,7 +22,7 @@ static DistributedUtility() } } - public delegate bool DivideByDelegate(long input, int divisor); + public delegate bool DivideByDelegate(long input, int divisor, bool isFixed); [Flags] public enum DivideFlags @@ -73,9 +73,9 @@ public static IReadOnlyList> GetLeafCandidatePolicies(TensorType te { var axis = splitsAxes[ti]; var divisor = axis.Select(a => placement.Hierarchy[a]).Aggregate(1, (a, b) => a * b); - if (axis.All(a => placement.Hierarchy[a] > 1) && divisor > 1 && DivideByFunc(maxShape[di], divisor)) + if (axis.All(a => placement.Hierarchy[a] > 1) && divisor > 1 && DivideByFunc(maxShape[di], divisor, tensorType.Shape[di].IsFixed)) { - policy.Add(SBP.S(axis.ToArray())); + policy.Add(SBP.S(axis.ToArray(), (int)MathUtility.CeilDiv(maxShape[di], divisor))); } } @@ -142,7 +142,7 @@ public static bool IsDistributable(TensorType tensorType, ReadOnlySpan poli // 2. All shapes are divisible by the mesh. var maxShape = CompilerServices.GetMaxShape(tensorType.Shape); var divisors = GetDivisors(new DistributedType(tensorType, polices.ToArray(), placement)); - return divisors.Select((d, axis) => (d, axis)).All(p => p.d == 0 ? true : DivideByFunc(maxShape[p.axis], p.d)); + return divisors.Select((d, axis) => (d, axis)).All(p => p.d == 0 ? true : DivideByFunc(maxShape[p.axis], p.d, tensorType.Shape[p.axis].IsFixed)); } public static bool IsDistributable(ReadOnlySpan polices) @@ -213,7 +213,7 @@ public static bool TryGetDividedTensorType(DistributedType distributedType, [May public static IRArray AxisPolicesToNDSBP(IRArray axisPolices, int rank) { - var ndsbp = new SBP[rank]; + var ndsbp = Enumerable.Repeat(SBP.B, rank).Select(p => (SBP)p).ToArray(); for (var i = 0; i < axisPolices.Count; i++) { var policy = axisPolices[i]; @@ -221,27 +221,36 @@ public static IRArray AxisPolicesToNDSBP(IRArray axisPolices, int rank { foreach (var ax in split.Axes) { - ndsbp[ax] = SBP.S([i]); + ndsbp[ax] = SBP.S([i], split.Granularity); + } + } + else if (policy is SBPPartial partial) + { + foreach (var ax in partial.Axes) + { + ndsbp[ax] = SBP.P(ndsbp[ax] is SBPPartial p ? p.Axes.Append(i).ToArray() : [i], partial.Op); } } } - return ndsbp.Select(sbp => sbp is SBPSplit ? sbp : SBP.B).ToArray(); + return ndsbp; } public static IRArray NDSBPToAxisPolices(IRArray ndsbp, int rank) { - var polices = new SBP[rank]; + var polices = Enumerable.Repeat(SBP.B, rank).Select(p => (SBP)p).ToArray(); for (int d = 0; d < polices.Length; d++) { var splitAxes = Enumerable.Range(0, ndsbp.Count).Where(i => ndsbp[i] is SBPSplit split && split.Axes[0] == d).ToArray(); + var partialAxes = Enumerable.Range(0, ndsbp.Count).Where(i => ndsbp[i] is SBPSplit partial && partial.Axes.Contains(d)).ToArray(); if (splitAxes.Any()) { - polices[d] = SBP.S(splitAxes); + polices[d] = SBP.S(splitAxes, ((SBPSplit)ndsbp[splitAxes[0]]).Granularity); } - else + + if (partialAxes.Any()) { - polices[d] = SBP.B; + polices[d] = SBP.P(partialAxes, ((SBPPartial)ndsbp[partialAxes[0]]).Op); } } @@ -293,9 +302,9 @@ from item in array return ret.ToList(); } - public static bool IsDivideBy(long input, int divisor) + public static bool IsDivideBy(long input, int divisor, bool isFixed) { - if (input >= divisor) + if (!isFixed || input >= divisor) { return true; } @@ -303,9 +312,9 @@ public static bool IsDivideBy(long input, int divisor) return false; } - public static bool IsDivideExactly(long input, int divisor) + public static bool IsDivideExactly(long input, int divisor, bool isFixed = true) { - if (input >= divisor && input % divisor == 0) + if (!isFixed || (input >= divisor && input % divisor == 0)) { return true; } @@ -313,6 +322,53 @@ public static bool IsDivideExactly(long input, int divisor) return false; } + public static bool AreSamePolicies(IRArray? a, IRArray? b, bool checkGranularity = true) + { + if (a == null && b == null) + { + return true; + } + + if (a == null || b == null || a.Value.Count != b.Value.Count) + { + return false; + } + + for (int i = 0; i < a.Value.Count; i++) + { + if (!IsSamePolicy(a.Value[i], b.Value[i], checkGranularity)) + { + return false; + } + } + + return true; + } + + public static bool IsSamePolicy(SBP a, SBP b, bool checkGranularity = true) + { + if (a == null || b == null) + { + return false; + } + + if (a is SBPSplit splitA && b is SBPSplit splitB) + { + if (checkGranularity) + { + return a == b; + } + else + { + return splitA.Axes == splitB.Axes; + } + } + else + { + return a == b; + } + } + public static float GetDividedTensorEfficiency(DistributedType distributedType, int burstLength) { var (tiles, shape) = GetDividedTile(distributedType); @@ -358,7 +414,8 @@ public static (long[] Offset, long[] Shape) GetLocalOffsetAndShape(DistributedTy var shape = new long[distributedType.TensorType.Shape.Rank]; for (int axis = 0; axis < offset.Length; axis++) { - var splits = distributedType.AxisPolicies[axis] is SBPSplit s + var policy = distributedType.AxisPolicies[axis]; + var splits = policy is SBPSplit s ? s.Axes.Select(td => (Placement: td, DeviceIndex: shardIndex[td], DeviceDim: distributedType.Placement.Hierarchy[td])).ToArray() : Array.Empty<(int Placement, int DeviceIndex, int DeviceDim)>(); if (splits.Any()) @@ -368,7 +425,7 @@ public static (long[] Offset, long[] Shape) GetLocalOffsetAndShape(DistributedTy var subHierarchySize = (int)TensorUtilities.GetProduct(subHierarchies); var subShardIndex = splits.Select(x => x.DeviceIndex).ToArray(); var linearIndex = TensorUtilities.GetLinearOffset(subHierarchyStrides, subShardIndex); - var localDim = MathUtility.CeilDiv(globalShape[axis], subHierarchySize); + var localDim = ((SBPSplit)policy).Granularity is not null ? (long)((SBPSplit)policy).Granularity!.Metadata.Range!.Value.Max : MathUtility.CeilDiv(globalShape[axis], subHierarchySize); offset[axis] = linearIndex * localDim; shape[axis] = Math.Min(localDim, globalShape[axis] - offset[axis]); } @@ -390,14 +447,21 @@ private static (RankedShape Tile, RankedShape Shape) GetDividedTile(DistributedT { if (distributedType.AxisPolicies.Count > d && distributedType.AxisPolicies[d] is SBPSplit split) { - var divisor = split.Axes.Select(t => distributedType.Placement.Hierarchy[t]).Aggregate(1, (a, b) => a * b); - if (divideFlags.HasFlag(DivideFlags.FloorDiv)) + if (split.Granularity is not null) { - tiles[d] = tiles[d] / divisor; + tiles[d] = split.Granularity; } else { - tiles[d] = Dimension.CeilDiv(tiles[d], divisor); + var divisor = split.Axes.Select(t => distributedType.Placement.Hierarchy[t]).Aggregate(1, (a, b) => a * b); + if (divideFlags.HasFlag(DivideFlags.FloorDiv)) + { + tiles[d] = tiles[d] / divisor; + } + else + { + tiles[d] = Dimension.CeilDiv(tiles[d], divisor); + } } } } diff --git a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs index a40ec71c2d..5b05aa790d 100644 --- a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs +++ b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs @@ -154,9 +154,16 @@ public override string VisitType(DistributedType type) { if (s is SBPSplit split) { - var divisor = split.Axes.Select(a => type.Placement.Hierarchy[a]).Aggregate(1, (a, b) => a * b); - usedCeil[r] = shape[r] % divisor != 0; - shape[r] = (shape[r] + divisor - 1) / divisor; + if (split.Granularity is null) + { + var divisor = split.Axes.Select(a => type.Placement.Hierarchy[a]).Aggregate(1, (a, b) => a * b); + usedCeil[r] = shape[r] % divisor != 0; + shape[r] = (shape[r] + divisor - 1) / divisor; + } + else + { + shape[r] = split.Granularity.FixedValue; + } } } @@ -169,7 +176,7 @@ public override string VisitType(DistributedType type) } } - return $"{{{VisitType(type.TensorType)}, ({string.Join(',', type.AxisPolicies)}), [{string.Join(',', sshape)}]}}"; + return $"{{{VisitType(type.TensorType)}, ({string.Join(',', type.AxisPolicies)}), [{string.Join(',', sshape)}], {type.Partial?.ToString() ?? string.Empty}}}"; } protected override string DispatchVisit(BaseExpr expr) diff --git a/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs index aa44efab09..af4d375a07 100644 --- a/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs +++ b/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs @@ -144,9 +144,16 @@ public override string VisitType(DistributedType type) { if (s is SBPSplit split) { - var divisor = split.Axes.Select(a => type.Placement.Hierarchy[a]).Aggregate(1, (a, b) => a * b); - usedCeil[r] = shape[r] % divisor != 0; - shape[r] = (shape[r] + divisor - 1) / divisor; + if (split.Granularity is null) + { + var divisor = split.Axes.Select(a => type.Placement.Hierarchy[a]).Aggregate(1, (a, b) => a * b); + usedCeil[r] = shape[r] % divisor != 0; + shape[r] = (shape[r] + divisor - 1) / divisor; + } + else + { + shape[r] = split.Granularity.FixedValue; + } } } @@ -159,7 +166,7 @@ public override string VisitType(DistributedType type) } } - return $"Dist({VisitType(type.TensorType)}, ({string.Join(',', type.AxisPolicies)}), [{string.Join(',', sshape)}])"; + return $"Dist({VisitType(type.TensorType)}, ({string.Join(',', type.AxisPolicies)}), [{string.Join(',', sshape)}], {type.Partial?.ToString() ?? string.Empty})"; } /// diff --git a/src/Nncase.Evaluator/Math/MatMul.cs b/src/Nncase.Evaluator/Math/MatMul.cs index 99133c33cf..930b64158d 100644 --- a/src/Nncase.Evaluator/Math/MatMul.cs +++ b/src/Nncase.Evaluator/Math/MatMul.cs @@ -45,7 +45,7 @@ public static IRType VisitDistributedType(DistributedType a, DistributedType b, var (lm, lk, rk, rn) = dimInfo ?? new(aRank - 2, aRank - 1, bRank - 2, bRank - 1); var aMaxShape = CompilerServices.GetMaxShape(a.TensorType.Shape); var bMaxShape = CompilerServices.GetMaxShape(b.TensorType.Shape); - bool isPartial = false; + SBPPartial? partial = null; // TODO: keep summa only if (!a.TensorType.Shape.IsFixed || !b.TensorType.Shape.IsFixed || transB || (a.Placement.HierarchyKind == HierarchyKind.SMT && a.TensorType.DType is VectorType vt && vt.ElemType == DataTypes.Float8E4M3)) @@ -106,11 +106,11 @@ public static IRType VisitDistributedType(DistributedType a, DistributedType b, ndsbp[oRank - 2] = a.AxisPolicies[lm]; ndsbp[oRank - 1] = b.AxisPolicies[rn]; - if (a.AxisPolicies[lk] is SBPSplit || b.AxisPolicies[rk] is SBPSplit) + if (a.AxisPolicies[lk] is SBPSplit sk && b.AxisPolicies[rk] is SBPSplit) { ndsbp[oRank - 2] = ndsbp[oRank - 2]; ndsbp[oRank - 1] = ndsbp[oRank - 1]; - isPartial = true; + partial = SBP.P(sk.Axes); } if (!DistributedUtility.IsDistributable(outType, ndsbp, a.Placement)) @@ -118,7 +118,7 @@ public static IRType VisitDistributedType(DistributedType a, DistributedType b, return new InvalidType("no valid sbp."); } - return new DistributedType(outType, ndsbp, a.Placement, Partial: isPartial); + return new DistributedType(outType, ndsbp, a.Placement, Partial: partial); } else { @@ -149,7 +149,6 @@ public static IRType VisitDistributedType(DistributedType a, DistributedType b, { ndsbp[oRank - 2] = a.AxisPolicies[lm]; ndsbp[oRank - 1] = b.AxisPolicies[rn]; - isPartial = true; } else { @@ -211,7 +210,7 @@ public static IRType VisitDistributedType(DistributedType a, DistributedType b, if (DistributedUtility.IsDistributable(ndsbp)) { - return new DistributedType(outType, ndsbp, a.Placement, isPartial); + return new DistributedType(outType, ndsbp, a.Placement, partial); } return new InvalidType("no valid sbp."); diff --git a/src/Nncase.Evaluator/NN/Conv2D.cs b/src/Nncase.Evaluator/NN/Conv2D.cs index 6309e05883..ba73aaf5a8 100644 --- a/src/Nncase.Evaluator/NN/Conv2D.cs +++ b/src/Nncase.Evaluator/NN/Conv2D.cs @@ -140,7 +140,7 @@ private IRType Visit(ITypeInferenceContext context, Conv2D target, DistributedTy { if (ndsbpBias[i] is SBPBroadCast) { - ndsbp[i] = SBP.P(); + ndsbp[i] = SBP.P([1]); } else { diff --git a/src/Nncase.Evaluator/Tensors/Bitcast.cs b/src/Nncase.Evaluator/Tensors/Bitcast.cs index dc10fcb47a..09b71cf693 100644 --- a/src/Nncase.Evaluator/Tensors/Bitcast.cs +++ b/src/Nncase.Evaluator/Tensors/Bitcast.cs @@ -89,7 +89,14 @@ private IRType Visit(Bitcast target, DistributedType input) return invalid; } - ndsbp[i] = input.AxisPolicies[i]; + if (input.AxisPolicies[i] is SBPSplit split) + { + ndsbp[i] = SBP.S(split.Axes, split.Granularity is null ? null : split.Granularity * outTensorType.Shape[i] / input.TensorType.Shape[i]); + } + else + { + ndsbp[i] = input.AxisPolicies[i]; + } } return new DistributedType(outTensorType, ndsbp, input.Placement); diff --git a/src/Nncase.Evaluator/Tensors/Cast.cs b/src/Nncase.Evaluator/Tensors/Cast.cs index cba956c88e..f8ff0857bf 100644 --- a/src/Nncase.Evaluator/Tensors/Cast.cs +++ b/src/Nncase.Evaluator/Tensors/Cast.cs @@ -73,7 +73,7 @@ private IRType Visit(Cast target, DistributedType inType) { var invalid = new InvalidType(inType.ToString()); var outType = Visit(target, inType.TensorType); - var ndsbp = new SBP[inType.TensorType.Shape.Rank]; + var ndsbp = inType.AxisPolicies.ToArray(); var shape = CompilerServices.GetMaxShape(inType.TensorType.Shape); for (int i = 0; i < ndsbp.Length; i++) { @@ -92,10 +92,13 @@ private IRType Visit(Cast target, DistributedType inType) { return invalid; } + else + { + var scale = 1f * outShape[i] / shape[i]; + ndsbp[i] = SBP.S(split.Axes, split.Granularity is not null ? (scale >= 1 ? split.Granularity * (long)scale : split.Granularity / (long)(1f / scale)) : null); + } } } - - ndsbp[i] = inType.AxisPolicies[i]; } return new DistributedType((TensorType)outType, ndsbp, inType.Placement); diff --git a/src/Nncase.Evaluator/Tensors/Gather.cs b/src/Nncase.Evaluator/Tensors/Gather.cs index cf56e1a562..bc32aeb5ae 100644 --- a/src/Nncase.Evaluator/Tensors/Gather.cs +++ b/src/Nncase.Evaluator/Tensors/Gather.cs @@ -3,6 +3,7 @@ using System; using System.Linq; +using DryIoc.ImTools; using NetFabric.Hyperlinq; using Nncase.CostModel; using Nncase.IR; @@ -49,16 +50,10 @@ public Cost Visit(ICostEvaluateContext context, Gather target) var indexType = context.GetArgumentType(target, Gather.Index); var retType = context.GetReturnType(); - var gatherPart = 1U; - if (inputType is DistributedType d && d.AxisPolicies[target.Axis] is SBPSplit split) - { - gatherPart = split.Axes.Select(a => d.Placement.Hierarchy[a]).Aggregate(1U, (a, b) => (uint)(a * b)); - } - return new() { [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(inputType) + CostUtility.GetMemoryAccess(indexType), - [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(retType) * gatherPart, + [CostFactorNames.MemoryStore] = CostUtility.GetMemoryAccess(retType), }; } @@ -112,12 +107,12 @@ private IRType Visit(DistributedType input, int axis, DistributedType index) return invalid; } - // not support partial + // not support partial in ndsbp if (ndsbp.Any(sbp => sbp is SBPPartial)) { return invalid; } - return new DistributedType(tensorType, ndsbp, input.Placement); + return new DistributedType(tensorType, ndsbp, input.Placement, input.AxisPolicies[axis] is SBPSplit split ? SBP.P(split.Axes) : null); } } diff --git a/src/Nncase.Evaluator/Tensors/Pack.cs b/src/Nncase.Evaluator/Tensors/Pack.cs index 188875b895..7a61418264 100644 --- a/src/Nncase.Evaluator/Tensors/Pack.cs +++ b/src/Nncase.Evaluator/Tensors/Pack.cs @@ -109,12 +109,12 @@ private IRType Visit(ITypeInferenceContext context, Pack target, DistributedType var ndsbp = new SBP[input.TensorType.Shape.Rank]; for (int i = 0; i < input.TensorType.Shape.Rank; i++) { - if (input.AxisPolicies[i] is SBPSplit && target.Axes.Contains(i)) + if (input.AxisPolicies[i] is SBPSplit split && target.Axes.Contains(i)) { var lane = target.Lanes[target.Axes.IndexOf(i)]; if (input.TensorType.Shape[i] is { IsFixed: true, FixedValue: long s } && s / lane % divisor[i] == 0) { - ndsbp[i] = input.AxisPolicies[i]; + ndsbp[i] = SBP.S(split.Axes, split.Granularity is not null ? split.Granularity / lane : null); } else { diff --git a/src/Nncase.Evaluator/Tensors/Reshape.cs b/src/Nncase.Evaluator/Tensors/Reshape.cs index 2f40baa687..82986f06c5 100644 --- a/src/Nncase.Evaluator/Tensors/Reshape.cs +++ b/src/Nncase.Evaluator/Tensors/Reshape.cs @@ -60,9 +60,16 @@ public static IRType VisitDistributedType(DistributedType inType, RankedShape ne return invalidType; } + var granularity = split.Granularity; + if (newDims[newSplitAxis] != inShape[inAxis]) + { + // If the new split axis is not the same as the old split axis, we need to adjust the granularity. + granularity = granularity is null ? null : granularity / newAxes.Except([newAxesOffset + newSplitAxis]).Aggregate((Dimension)1, (a, b) => a * maxNewShape[b]); + } + foreach (var newAxis in newAxes) { - newAxisPolicies[newAxis] = newAxis == (newAxesOffset + newSplitAxis) ? split : SBP.B; + newAxisPolicies[newAxis] = newAxis == (newAxesOffset + newSplitAxis) ? SBP.S(split.Axes, granularity) : SBP.B; } } else @@ -101,7 +108,9 @@ where inPolicy is SBPSplit return invalidType; } - newAxisPolicies[newAxis] = inType.AxisPolicies[firstSplitAxis.Value]; + var split = (SBPSplit)inType.AxisPolicies[firstSplitAxis.Value]; + newAxisPolicies[newAxis] = split.Granularity is null ? split : + SBP.S(split.Axes, split.Granularity! * inAxes.Except([firstSplitAxis.Value]).Aggregate((Dimension)1, (a, b) => a * maxInShape[b])); } else { diff --git a/src/Nncase.Evaluator/Tensors/Unpack.cs b/src/Nncase.Evaluator/Tensors/Unpack.cs index 25723275ed..f8e0615869 100644 --- a/src/Nncase.Evaluator/Tensors/Unpack.cs +++ b/src/Nncase.Evaluator/Tensors/Unpack.cs @@ -104,6 +104,7 @@ private IRType Visit(ITypeInferenceContext context, Unpack target, DistributedTy } // [m]<8>@8@4 -> [m*8]@8@4, when max(m)=256 and runtime m=12, input and output have different local shape. + var newPolicies = input.AxisPolicies.ToArray(); foreach (var (s, r) in input.AxisPolicies.Select((s, r) => (s, r))) { if (s is SBPSplit split && target.Axes.Contains(r)) @@ -114,9 +115,11 @@ private IRType Visit(ITypeInferenceContext context, Unpack target, DistributedTy { return new InvalidType("Not support non-divisible input"); } + + newPolicies[r] = SBP.S(split.Axes, split.Granularity is not null ? split.Granularity * target.Lanes[target.Axes.IndexOf(r)] : null); } } - return new DistributedType(tensorType, input.AxisPolicies, input.Placement); + return new DistributedType(tensorType, newPolicies, input.Placement); } } diff --git a/src/Nncase.Passes/Distributed/AutoDistributed.cs b/src/Nncase.Passes/Distributed/AutoDistributed.cs index 04549c4fb1..71c609dc58 100644 --- a/src/Nncase.Passes/Distributed/AutoDistributed.cs +++ b/src/Nncase.Passes/Distributed/AutoDistributed.cs @@ -321,7 +321,7 @@ public void FilterByScheme(BaseExpr expr, DistributedSearchGraph cluster) { bool Matched(SearchableNode node, (IRArray Policies, Placement Placement) tp) { - return node.IRType is DistributedType dtype && dtype.AxisPolicies == tp.Policies && dtype.Placement == tp.Placement; + return node.IRType is DistributedType dtype && DistributedUtility.AreSamePolicies(dtype.AxisPolicies, tp.Policies, false) && dtype.Placement == tp.Placement; } foreach (var name in expr.Metadata.OutputNames ?? Array.Empty()) @@ -602,6 +602,11 @@ string DescribeSbp(IRType? type) } } + if (expr.Target is not Boxing && ((Call)newExpr).Arguments.AsValueEnumerable().Any(a => a.CheckedType is DistributedType dt && dt.Partial is not null)) + { + continue; + } + if (!newExpr.InferenceType(_inferencer_cache) || newExpr.CheckedType is InvalidType) { continue; @@ -693,7 +698,9 @@ string DescribeSbp(IRType? type) || expr.Users.Any(u => u is Call call && (call.Target.GetType().FullName!.Contains("CustomNTT", StringComparison.Ordinal) || (TargetOptions.HierarchyKind == HierarchyKind.SMT && expr.Target is PagedAttention))) || expr.Target.GetType().FullName!.Contains("CustomNTT", StringComparison.Ordinal) || expr.Target.GetType().FullName!.Contains("VectorizedRoPE", StringComparison.Ordinal) - || (TargetOptions.HierarchyKind == HierarchyKind.SMT && expr.Target is PagedAttention)) + || expr.Target.GetType().FullName!.Contains("Matmul", StringComparison.InvariantCultureIgnoreCase) + || (TargetOptions.HierarchyKind == HierarchyKind.SMT && expr.Target is PagedAttention) + || expr.Target is Gather) { bucket = callCluster.CreateCluster(SearchGraphKind.Bucket); var linked = false; @@ -1023,14 +1030,54 @@ private IRType CheckBoxingType(IRType inType, IRType outType, bool isReshape = f { IRType VisitD2D(DistributedType inv, DistributedType outv) { - if (inv == outv) + if (inv.Partial == outv.Partial && DistributedUtility.AreSamePolicies(inv.AxisPolicies, outv.AxisPolicies)) { return new InvalidType("Same DistributedType"); } - if (inv.AxisPolicies.Any(s => s is SBPPartial) || outv.AxisPolicies.Any(s => s is SBPPartial)) + if (inv.AxisPolicies.Any(sbp => sbp is SBPPartial) || outv.AxisPolicies.Any(sbp => sbp is SBPPartial)) + { + return new InvalidType("Not Support Partial in Policeis."); + } + + var partialDims = new List(); + if (inv.Partial is not null) + { + for (int i = 0; i < inv.AxisPolicies.Count; i++) + { + if (inv.AxisPolicies[i] is SBPSplit && outv.AxisPolicies[i] is SBPBroadCast) + { + return new InvalidType("Not supported input is BroadCast output is Split"); + } + + if (outv.AxisPolicies[i] is SBPSplit s) + { + if (inv.AxisPolicies[i] is SBPSplit splitIn) + { + if (splitIn.Axes.Except(s.Axes).Any()) + { + return new InvalidType("Not Supported Split-> Split."); + } + } + + if (s.Axes.Except(inv.Partial.Axes).ToArray() != s.Axes) + { + partialDims.Add(i); + } + } + } + + var ndspsIn = DistributedUtility.AxisPolicesToNDSBP(inv.AxisPolicies, inv.Placement.Rank); + var ndspsOut = DistributedUtility.AxisPolicesToNDSBP(outv.AxisPolicies, outv.Placement.Rank); + if (Enumerable.Range(0, ndspsIn.Count).Any(i => ndspsIn[i] is SBPSplit si && (ndspsOut[i] is SBPBroadCast || (ndspsOut[i] is SBPSplit so && so.Axes[0] != si.Axes[0])))) + { + return new InvalidType("Not Supported Split-> Broadcast."); + } + } + + if (partialDims.Count > 0 && !Enumerable.Range(0, inv.AxisPolicies.Count).Except(partialDims.ToArray()).All(i => DistributedUtility.IsSamePolicy(inv.AxisPolicies[i], outv.AxisPolicies[i]))) { - return new InvalidType("Not supported input/output is Partial"); + return new InvalidType("Not Supported Partial."); } return outv; @@ -1038,7 +1085,7 @@ IRType VisitD2D(DistributedType inv, DistributedType outv) IRType VisitD2T(DistributedType inv, TensorType outv) { - if (inv.AxisPolicies.Any(s => s is SBPPartial)) + if (inv.AxisPolicies.Any(s => s is SBPPartial) || inv.Partial is not null) { return new InvalidType("Not supported input is Partial output is Unshard"); } @@ -1048,7 +1095,7 @@ IRType VisitD2T(DistributedType inv, TensorType outv) IRType VisitT2D(TensorType inv, DistributedType outv) { - if (outv.AxisPolicies.Any(s => s is SBPPartial)) + if (outv.AxisPolicies.Any(s => s is SBPPartial) || outv.Partial is not null) { return new InvalidType("Not supported input is Unshard output is Partial"); } diff --git a/src/Nncase.Passes/Rules/Distributed/UpdateBoxingTensorType.cs b/src/Nncase.Passes/Rules/Distributed/UpdateBoxingTensorType.cs new file mode 100644 index 0000000000..0e7892bf99 --- /dev/null +++ b/src/Nncase.Passes/Rules/Distributed/UpdateBoxingTensorType.cs @@ -0,0 +1,42 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Nncase.IR; +using Nncase.IR.Distributed; +using Nncase.PatternMatch; +using static Nncase.IR.F.NN; + +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Distributed; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules; + +[RuleGenerator] +public partial class UpdateBoxingTensorType : RewriteRule +{ + /// + public override Pattern Pattern { get; } = IsBoxing( + target_name: "boxing", + _ => true, + IsWildcard("input")); + + private Expr? GetReplace(Boxing boxing, Expr input, RunPassContext context) + { + if (boxing.NewType is DistributedType dt) + { + var ttype = dt.TensorType; + var dtype = dt with { TensorType = ttype with { Shape = ttype.Shape.Select(d => d.Simplify()).ToArray() } }; + var newBoxing = new Call(new IR.Distributed.Boxing(dtype), input); + context.MatchOptions.SuppressPattern(newBoxing, Pattern); + return newBoxing; + } + + return null; + } +} diff --git a/src/Nncase.Schedule/Transforms/AutoTilePass.cs b/src/Nncase.Schedule/Transforms/AutoTilePass.cs index b6e0ccb6ae..590702a877 100644 --- a/src/Nncase.Schedule/Transforms/AutoTilePass.cs +++ b/src/Nncase.Schedule/Transforms/AutoTilePass.cs @@ -11,6 +11,7 @@ using Nncase.IR; using Nncase.IR.Affine; using Nncase.Passes.GraphPartition; +using Nncase.Passes.Rules; using Nncase.Schedule; using QuikGraph; using QuikGraph.Algorithms; @@ -116,6 +117,7 @@ protected override Task RunCoreAsync(BaseFunction input, RunPassCo var constructor = new AutoTileReconstructor(tiler, ModuleKind, CompileOptions, condenseAlgo, dimVars.ToArray()); var post = constructor.Construct(); + post = CompilerServices.Rewrite(post, [new UpdateBoxingTensorType()], new()); return Task.FromResult((BaseFunction)func.With(body: post)); } } diff --git a/src/Nncase.Tests/Core/UnitTestLayout.cs b/src/Nncase.Tests/Core/UnitTestLayout.cs index 836908e239..719823168c 100644 --- a/src/Nncase.Tests/Core/UnitTestLayout.cs +++ b/src/Nncase.Tests/Core/UnitTestLayout.cs @@ -55,7 +55,7 @@ public void TestUnflatten() public void TestDistributedTypeLayout() { var placement = new Placement([1, 2, 8, 4, 4], "cdyxt"); - var distType = new DistributedType(new TensorType(DataTypes.Float32, new[] { 2048, 1024 }), new SBP[] { SBP.S(1, 3), SBP.B }, placement); + var distType = new DistributedType(new TensorType(DataTypes.Float32, new[] { 2048, 1024 }), new SBP[] { SBP.S([1, 3]), SBP.B }, placement); var layout = Layout.From(distType.TensorType); Assert.Equal("Layout((2048, 1024):(1024, 1))", layout.ToString()); @@ -88,8 +88,8 @@ public void TestCoorinateWithDifferentLayout() { var placement = new Placement([1, 2, 8, 4, 4], "cdyxt"); var tensorType = new TensorType(DataTypes.Float32, new[] { 2048, 1024 }); - var distTypeA = new DistributedType(tensorType, new SBP[] { SBP.S(1, 3), SBP.B }, placement); - var distTypeB = new DistributedType(tensorType, new SBP[] { SBP.B, SBP.S(2) }, placement); + var distTypeA = new DistributedType(tensorType, new SBP[] { SBP.S([1, 3]), SBP.B }, placement); + var distTypeB = new DistributedType(tensorType, new SBP[] { SBP.B, SBP.S([2]) }, placement); var shardA = Layout.From(distTypeA); var shardB = Layout.From(distTypeB); diff --git a/src/Nncase.Tests/Core/UnitTestTensor.cs b/src/Nncase.Tests/Core/UnitTestTensor.cs index b60aa2bf70..87c5628ca1 100644 --- a/src/Nncase.Tests/Core/UnitTestTensor.cs +++ b/src/Nncase.Tests/Core/UnitTestTensor.cs @@ -456,7 +456,7 @@ public void TestTensorSerialize() }, new[] { 32 }, new[] { IR.NN.PagedKVCacheDimKind.NumBlocks }, - new[] { SBP.S(0) }); + new[] { SBP.S([0]) }); var obj = new Evaluator.NN.RefPagedAttentionKVCache(cfg, 1, 4, Tensor.From([0L]), Tensor.From([4L]), Tensor.From([0L, 1L, 0L, 2L], [1, 2, 2]), Tensor.From([0L, 1L, 0L, 2L, 0L, 3L, 0L, 4L], [4, 2]), 4, Tensor.Zeros>([1, 1, 2, 3, 4, 5, 6])); var original = Tensor.From(new Reference[] { new(obj) }, []); using (var stream = File.Create(path)) diff --git a/src/Nncase.Tests/Distributed/UnitTestDistributeSchema.cs b/src/Nncase.Tests/Distributed/UnitTestDistributeSchema.cs index ee2d743377..24b51a6a07 100644 --- a/src/Nncase.Tests/Distributed/UnitTestDistributeSchema.cs +++ b/src/Nncase.Tests/Distributed/UnitTestDistributeSchema.cs @@ -94,6 +94,6 @@ public async Task TestLoadScheme() Dumpper.DumpIR(result, "result"); - Assert.True(result is Function { Body: Call { Target: IR.Distributed.Boxing } boxing } && boxing.Arguments[0] is Call { Target: IR.Math.Unary { UnaryOp: UnaryOp.Cos } } unary && unary.CheckedType is DistributedType dt && dt == new DistributedType(new(DataTypes.Float32, new[] { 1, 512, 8192 }), new[] { (SBP)SBP.B, SBP.S(new[] { 0 }), SBP.S(new[] { 1, 2 }) }, new(new[] { 8, 8, 4 }, "cbt"))); + Assert.True(result is Function { Body: Call { Target: IR.Distributed.Boxing } boxing } && boxing.Arguments[0] is Call { Target: IR.Math.Unary { UnaryOp: UnaryOp.Cos } } unary && unary.CheckedType is DistributedType dt && dt == new DistributedType(new(DataTypes.Float32, new[] { 1, 512, 8192 }), new[] { (SBP)SBP.B, SBP.S([0], 64), SBP.S([1, 2], 256) }, new(new[] { 8, 8, 4 }, "cbt"))); } } diff --git a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs index f5d1c1034a..35c1d02b4a 100755 --- a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs +++ b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs @@ -78,7 +78,7 @@ private static readonly (PagedKVCacheDimKind[] Cache, PagedKVCacheDimKind[] Vect private static readonly (PagedKVCacheDimKind[] Sharding, SBPSplit[] Policies, Placement Placement)[] ShardingConfigs = [ (Array.Empty(), Array.Empty(), new Placement(new[] { 1 }, "t")), - (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S(0) }, new Placement(new[] { 1 }, "t")), + (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S([0]) }, new Placement(new[] { 1 }, "t")), ]; private static readonly (AttentionDimKind[] QLayout, AttentionDimKind[] KLayout)[] QKLayoutConfigs = @@ -114,7 +114,7 @@ public PagedAttentionKVCacheTestData() } } - Add(new TestFixture.PagedAttentionKVCacheTestFixture([4], [4], 4, 2, 32, 8, 32, Runtime.TypeCode.Float32, 1, [PagedKVCacheDimKind.NumLayers, PagedKVCacheDimKind.NumBlocks, PagedKVCacheDimKind.KV, PagedKVCacheDimKind.NumKVHeads, PagedKVCacheDimKind.HeadDim, PagedKVCacheDimKind.BlockSize], [PagedKVCacheDimKind.HeadDim], new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S(0) }, [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq]), new Placement(new[] { 1 }, "t")); + Add(new TestFixture.PagedAttentionKVCacheTestFixture([4], [4], 4, 2, 32, 8, 32, Runtime.TypeCode.Float32, 1, [PagedKVCacheDimKind.NumLayers, PagedKVCacheDimKind.NumBlocks, PagedKVCacheDimKind.KV, PagedKVCacheDimKind.NumKVHeads, PagedKVCacheDimKind.HeadDim, PagedKVCacheDimKind.BlockSize], [PagedKVCacheDimKind.HeadDim], new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S([0]) }, [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq]), new Placement(new[] { 1 }, "t")); } } @@ -159,8 +159,8 @@ private static readonly (PagedKVCacheDimKind[] Cache, PagedKVCacheDimKind[] Vect private static readonly (PagedKVCacheDimKind[] Sharding, SBPSplit[] Policies, int[] Hierarchy)[] ShardingConfigs = [ (Array.Empty(), Array.Empty(), Array.Empty()), - (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S(0) }, new[] { 1 }), - (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S(0) }, new[] { 8 }), + (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S([0]) }, new[] { 1 }), + (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S([0]) }, new[] { 8 }), ]; public PagedAttentionSchedulerTestData() diff --git a/src/Nncase.Tests/Simulator/UnitTestInterop.cs b/src/Nncase.Tests/Simulator/UnitTestInterop.cs index c516895155..dac65cf5b7 100644 --- a/src/Nncase.Tests/Simulator/UnitTestInterop.cs +++ b/src/Nncase.Tests/Simulator/UnitTestInterop.cs @@ -271,11 +271,11 @@ public void TestRTAttentionConfig() } { - var config = new IR.NN.PagedAttentionConfig(1, 2, 3, DataTypes.Float16, 4, new[] { IR.NN.PagedKVCacheDimKind.BlockSize, IR.NN.PagedKVCacheDimKind.HeadDim, IR.NN.PagedKVCacheDimKind.KV, IR.NN.PagedKVCacheDimKind.NumBlocks, IR.NN.PagedKVCacheDimKind.NumKVHeads, IR.NN.PagedKVCacheDimKind.NumLayers }, new[] { IR.NN.PagedKVCacheDimKind.HeadDim }, new[] { 32 }, new[] { IR.NN.PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S(1, 2) }); + var config = new IR.NN.PagedAttentionConfig(1, 2, 3, DataTypes.Float16, 4, new[] { IR.NN.PagedKVCacheDimKind.BlockSize, IR.NN.PagedKVCacheDimKind.HeadDim, IR.NN.PagedKVCacheDimKind.KV, IR.NN.PagedKVCacheDimKind.NumBlocks, IR.NN.PagedKVCacheDimKind.NumKVHeads, IR.NN.PagedKVCacheDimKind.NumLayers }, new[] { IR.NN.PagedKVCacheDimKind.HeadDim }, new[] { 32 }, new[] { IR.NN.PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S([1, 2]) }); var rtConfig = RTAttentionConfig.FromConfig(config); Assert.IsType(rtConfig); var rtPagedConfig = (RTPagedAttentionConfig)rtConfig; - Assert.True(rtPagedConfig.AxisPolicies.SequenceEqual([SBP.S(1, 2)])); + Assert.True(rtPagedConfig.AxisPolicies.SequenceEqual([SBP.S([1, 2])])); } } diff --git a/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs b/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs index cff88fe0fa..d4490b4dcf 100644 --- a/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs +++ b/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs @@ -98,7 +98,7 @@ private static readonly (PagedKVCacheDimKind[] Cache, PagedKVCacheDimKind[] Vect private static readonly (PagedKVCacheDimKind[] Sharding, SBPSplit[] Policies, int[] Hierarchy)[] ShardingConfigs = [ - (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S(0) }, [1]), + (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S([0]) }, [1]), ]; private static readonly (AttentionDimKind[] QLayout, AttentionDimKind[] KLayout)[] QKLayoutConfigs = @@ -133,7 +133,7 @@ public TestUpdatePagedAttentionCaseData() } } - Add(new TestFixture.PagedAttentionKVCacheTestFixture([256], [256], 14, 2, 64, 256, 16, Runtime.TypeCode.Float32, 1, [PagedKVCacheDimKind.NumBlocks, PagedKVCacheDimKind.NumLayers, PagedKVCacheDimKind.KV, PagedKVCacheDimKind.NumKVHeads, PagedKVCacheDimKind.HeadDim, PagedKVCacheDimKind.BlockSize], [PagedKVCacheDimKind.HeadDim], [PagedKVCacheDimKind.NumBlocks], [SBP.S(0)], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq]), [1], count++); + Add(new TestFixture.PagedAttentionKVCacheTestFixture([256], [256], 14, 2, 64, 256, 16, Runtime.TypeCode.Float32, 1, [PagedKVCacheDimKind.NumBlocks, PagedKVCacheDimKind.NumLayers, PagedKVCacheDimKind.KV, PagedKVCacheDimKind.NumKVHeads, PagedKVCacheDimKind.HeadDim, PagedKVCacheDimKind.BlockSize], [PagedKVCacheDimKind.HeadDim], [PagedKVCacheDimKind.NumBlocks], [SBP.S([0])], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq]), [1], count++); } } @@ -186,7 +186,7 @@ private static readonly (PagedKVCacheDimKind[] Cache, PagedKVCacheDimKind[] Vect private static readonly (PagedKVCacheDimKind[] Sharding, SBPSplit[] Policies, int[] Hierarchy)[] ShardingConfigs = [ - (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S(0) }, [1]), + (new[] { PagedKVCacheDimKind.NumBlocks }, new[] { SBP.S([0]) }, [1]), ]; private static readonly (AttentionDimKind[] QLayout, AttentionDimKind[] KLayout)[] QKLayoutConfigs = @@ -221,7 +221,7 @@ public TestPagedAttentionCaseData() } } - Add(new TestFixture.PagedAttentionKVCacheTestFixture([4], [4], 14, 2, 64, 256, 16, Runtime.TypeCode.Float32, 1, [PagedKVCacheDimKind.NumBlocks, PagedKVCacheDimKind.NumLayers, PagedKVCacheDimKind.KV, PagedKVCacheDimKind.NumKVHeads, PagedKVCacheDimKind.HeadDim, PagedKVCacheDimKind.BlockSize], [PagedKVCacheDimKind.HeadDim], [PagedKVCacheDimKind.NumBlocks], [SBP.S(0)], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq]), [1], count++); + Add(new TestFixture.PagedAttentionKVCacheTestFixture([4], [4], 14, 2, 64, 256, 16, Runtime.TypeCode.Float32, 1, [PagedKVCacheDimKind.NumBlocks, PagedKVCacheDimKind.NumLayers, PagedKVCacheDimKind.KV, PagedKVCacheDimKind.NumKVHeads, PagedKVCacheDimKind.HeadDim, PagedKVCacheDimKind.BlockSize], [PagedKVCacheDimKind.HeadDim], [PagedKVCacheDimKind.NumBlocks], [SBP.S([0])], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq], [AttentionDimKind.Head, AttentionDimKind.Dim, AttentionDimKind.Seq]), [1], count++); } } @@ -509,25 +509,25 @@ public async Task TestGatherReduceScatter(long[] shape, int[] hierarchy, int cou }; var placement = new Placement(hierarchy, targetOptions.HierarchyNames); - var ndsbp = Enumerable.Repeat(SBP.B, hierarchy.Length).ToArray(); + var ndsbp = Enumerable.Repeat(SBP.B, shape.Length).ToArray(); var posts = new List(); var broadcast = IR.F.Distributed.Boxing(input, new DistributedType(inputType, ndsbp, placement)); - foreach (var comb in LinqUtility.Combination(hierarchy.Length)) + foreach (var comb in LinqUtility.Combination(shape.Length)) { var newsbp = ndsbp.ToArray(); foreach (var axis in comb) { - newsbp[axis] = SBP.P(); + newsbp[axis] = SBP.B; } - var partial = IR.F.Distributed.ForceBoxing(broadcast, new DistributedType(inputType, newsbp, placement)); + var partial = IR.F.Distributed.ForceBoxing(broadcast, new DistributedType(inputType, newsbp, placement, SBP.P(Enumerable.Range(0, hierarchy.Length).ToArray()))); var sumed = IR.F.Distributed.Boxing(partial, new DistributedType(inputType, ndsbp, placement)); var post = IR.F.Distributed.Boxing(sumed, inputType); post.Metadata = new Passes.Distributed.AutoDistributedMetaData() { Skip = true }; posts.Add(post); } - await RunCases($"Theory{count}", feedDict, posts); + await RunCases($"Theory{count}", feedDict, posts, null, false); } [Fact] @@ -1204,7 +1204,7 @@ public async Task TestDynamicGetPositionIds(long[] queryLens, long[] seqLens, in Metadata = new() { Range = new(1, MathUtility.AlignUp(queryLens.Sum(), 128)) }, }; - var fixture = new PagedAttentionKVCacheTestFixture(queryLens, seqLens, 2, 2, 64, 64, (int)MathUtility.CeilDiv(seqLens.Select(seq_len => MathUtility.CeilDiv(seq_len, 64)).Sum(), hierarchy.Max()) * hierarchy.Max(), Runtime.TypeCode.Float32, 1, [PagedKVCacheDimKind.NumBlocks, PagedKVCacheDimKind.NumLayers, PagedKVCacheDimKind.KV, PagedKVCacheDimKind.NumKVHeads, PagedKVCacheDimKind.HeadDim, PagedKVCacheDimKind.BlockSize], [PagedKVCacheDimKind.HeadDim], [PagedKVCacheDimKind.NumBlocks], [SBP.S(0)], [AttentionDimKind.Seq, AttentionDimKind.Dim, AttentionDimKind.Head], [AttentionDimKind.Seq, AttentionDimKind.Dim, AttentionDimKind.Head]); + var fixture = new PagedAttentionKVCacheTestFixture(queryLens, seqLens, 2, 2, 64, 64, (int)MathUtility.CeilDiv(seqLens.Select(seq_len => MathUtility.CeilDiv(seq_len, 64)).Sum(), hierarchy.Max()) * hierarchy.Max(), Runtime.TypeCode.Float32, 1, [PagedKVCacheDimKind.NumBlocks, PagedKVCacheDimKind.NumLayers, PagedKVCacheDimKind.KV, PagedKVCacheDimKind.NumKVHeads, PagedKVCacheDimKind.HeadDim, PagedKVCacheDimKind.BlockSize], [PagedKVCacheDimKind.HeadDim], [PagedKVCacheDimKind.NumBlocks], [SBP.S([0])], [AttentionDimKind.Seq, AttentionDimKind.Dim, AttentionDimKind.Head], [AttentionDimKind.Seq, AttentionDimKind.Dim, AttentionDimKind.Head]); var placement = new Placement(hierarchy, targetOptions.HierarchyNames); var dataGeneratorOptions = new PagedAttentionKVCacheTestFixture.DataGeneratorOptions(Random: true, IncreaseBy: [AttentionDimKind.Head], ResetForKV: true);