Skip to content
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
053ff6d
add granularity for split
xhuohai Oct 14, 2025
40d992b
fix warnings
xhuohai Oct 14, 2025
92b8d4f
fix split with granularity
xhuohai Oct 15, 2025
c02a8cb
Merge branch 'dev/3.0' into feature/split_granularity
xhuohai Oct 15, 2025
23f2ace
Apply code-format changes
xhuohai Oct 15, 2025
85692f6
fix dist scheme
xhuohai Oct 15, 2025
a3ee274
fix sharding apple-clang compile
zhen8838 Oct 15, 2025
97fbf56
Merge branch 'feature/split_granularity' of github.com:kendryte/nncas…
xhuohai Oct 15, 2025
5dc015c
Apply code-format changes
xhuohai Oct 15, 2025
d8d5e62
fix bug
zhen8838 Oct 15, 2025
c5dde13
fix auto-dist of matmul
xhuohai Oct 15, 2025
e1df1ca
update type infer of custom op
xhuohai Oct 15, 2025
b6f2a56
fix build
xhuohai Oct 15, 2025
a7a27f3
fix type infer of vectorize cast with granularity
xhuohai Oct 16, 2025
fca5bd0
fix bitcast to tir
xhuohai Oct 16, 2025
bc941cd
update to_string of split
xhuohai Oct 16, 2025
baae5fe
fix type infer of bitcast and reshape
xhuohai Oct 16, 2025
b4b5bba
fix SBP check of custom op
xhuohai Oct 16, 2025
292c931
fix build
xhuohai Oct 16, 2025
56281dc
fix auto-dist of reduce
xhuohai Oct 16, 2025
f4be4f4
fix type infer of reshape
xhuohai Oct 16, 2025
32b17b8
fix sbp of custom gemm
xhuohai Oct 16, 2025
7e3ac0e
close overlap_aware_reshard due to conflict with split granularity
xhuohai Oct 16, 2025
411342c
fix type infer of reshape
xhuohai Oct 17, 2025
d4b19c7
fix type infer of reshape
xhuohai Oct 17, 2025
a6c66a0
fix custom layernorm type infer
xhuohai Oct 17, 2025
1903ad2
fix TIR GenerateReshape/Boxing
xhuohai Oct 20, 2025
25fc6c1
add ToBlockLocalData
xhuohai Oct 20, 2025
863fb89
support Partial and open for Gather
xhuohai Oct 21, 2025
cf9fecd
update AreSamePolicies according to reviews
xhuohai Oct 21, 2025
662020a
fix build
xhuohai Oct 21, 2025
f116765
fix topo aware
xhuohai Oct 21, 2025
45a3a8f
fix reduce with dynamic shape
xhuohai Oct 21, 2025
8c0b840
fix codegen of matmul
xhuohai Oct 21, 2025
c4b630f
fix eval of ForceBoxing
xhuohai Oct 22, 2025
59f9eca
fix TestGatherReduceScatter
xhuohai Oct 22, 2025
382fad5
reduce using shardy.remote
xhuohai Oct 22, 2025
54c3df4
fix reduce
xhuohai Oct 23, 2025
593f827
fix test_ntt_reshard
xhuohai Oct 23, 2025
6b06373
Merge branch 'dev/3.0' into feature/split_granularity
xhuohai Oct 23, 2025
bb727e2
fix type infer of P->S
xhuohai Oct 24, 2025
0e5214a
fix P->S reduce
xhuohai Oct 24, 2025
cc5241e
non-divisible dist on dynamic dimension
xhuohai Oct 27, 2025
d3812a2
fix possible issue of reduce
xhuohai Oct 28, 2025
d1f214c
update partial
xhuohai Oct 29, 2025
82147d2
Apply code-format changes
xhuohai Oct 29, 2025
db53475
fix eval of boxing
xhuohai Oct 30, 2025
5b2047d
skip block local data opt of tiling ops
xhuohai Oct 30, 2025
f3a4697
no partial for summa
xhuohai Oct 30, 2025
85935a2
Merge branch 'dev/3.0' into feature/split_granularity
xhuohai Oct 30, 2025
b157c06
Merge branch 'dev/3.0' into feature/split_granularity
xhuohai Dec 3, 2025
1c3320e
Apply code-format changes
xhuohai Dec 3, 2025
2c0441a
try to fix k80 CI
xhuohai Dec 8, 2025
e2f1bad
Merge branch 'dev/3.0' into feature/split_granularity
xhuohai Dec 8, 2025
f2a2e84
simplify boxing's new_type
xhuohai Dec 9, 2025
9752932
fix SBP check of Custom MoE
xhuohai Dec 9, 2025
83ec88f
fix build
xhuohai Dec 9, 2025
3d55d00
fix SBP check of Custom MoE
xhuohai Dec 9, 2025
2281e69
more constraints for P->S boxing
xhuohai Dec 12, 2025
42d5031
fix reshape with block_local_data input
xhuohai Dec 12, 2025
81dacb4
update type infer of Boxing
xhuohai Dec 15, 2025
0639537
Merge branch 'dev/3.0' into feature/split_granularity
xhuohai Dec 16, 2025
cccd88e
Merge branch 'dev/3.0' into feature/split_granularity
xhuohai Dec 19, 2025
8622464
Merge branch 'dev/3.0' into feature/split_granularity
xhuohai Dec 24, 2025
7ed29cd
fix build
xhuohai Dec 24, 2025
69da13d
exclude AI md in gitignore
xhuohai Feb 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public void IndWrite(string? value)
/// <summary>
/// convert single prim function to c source.
/// </summary>
public abstract class CSourceConvertVisitor : ExprFunctor<CSymbol, Unit>
public class CSourceConvertVisitor : ExprFunctor<CSymbol, Unit>
{
protected readonly Dictionary<BaseExpr, CSymbol> _exprMemo = new(ReferenceEqualityComparer.Instance);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,11 +447,6 @@ protected override CSymbol VisitCall(Call expr)
case TIR.NTT.Gather gather:
{
WriteWithProfiler($"gather({VisitBuffer(args[0], local: false).Name}, {VisitBuffer(args[1], local: true).Name}, {VisitBuffer(args[2], local: true).Name}, {gather.Axis}_dim);\n");
if (args[0] is TIR.Buffer b && b.DistributedType?.AxisPolicies[gather.Axis] is SBPSplit s)
{
var reduceKind = "tar::reduce_kind::" + string.Join("_", Enumerable.Range(0, TargetOptions.HierarchyNames.Length).Select(i => (s.Axes.Contains(i) ? "r" : string.Empty) + TargetOptions.HierarchyNames[i]));
WriteIndWithProfiler($"tac::tensor_reduce_sync<reduce_op::{ReduceOp.Sum.ToC()}, {reduceKind}>({VisitBuffer(args[2], local: true).Name}, {VisitBuffer(args[2], local: true).Name});\n");
}
}

break;
Expand Down Expand Up @@ -536,8 +531,8 @@ protected override CSymbol VisitCall(Call expr)
{
// deprecated
var sbpPartial = (SBPPartial)grs.InType.AxisPolicies.Where(s => s is SBPPartial).Distinct().First();
var reduceKind = "tar::reduce_kind::" + string.Join("_", grs.InType.AxisPolicies.Select((s, i) => (s is SBPPartial ? "r" : string.Empty) + TargetOptions.HierarchyNames[i]));
WriteIndWithProfiler($"tac::tensor_reduce_sync<reduce_op::{sbpPartial.Op.ToC()}, {reduceKind}>({VisitBuffer(args[0], local: true).Name}, {VisitBuffer(args[1], local: true).Name});\n");
var reduceKind = "tar::reduce_kind::" + string.Join("_", Enumerable.Range(0, TargetOptions.HierarchyNames.Length).Select(i => (sbpPartial.Axes.Contains(i) ? "r" : string.Empty) + TargetOptions.HierarchyNames[i]));
WriteIndWithProfiler($"tac::tensor_reduce_sync<reduce_op::{sbpPartial.Op.ToC()}, {reduceKind}>({VisitBuffer(args[0], local: true).Name}, {VisitBuffer(args[1], local: false).Name});\n");
}
else
{
Expand Down
2 changes: 1 addition & 1 deletion modules/Nncase.Modules.NTT/CodeGen/CPU/KernelUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public static string SBPToC(this SBP value)
{
if (value is SBPSplit s)
{
return $"S<{string.Join(", ", s.Axes)}>()";
return $"S<{string.Join(", ", s.Axes)}>({new CSourceConvertVisitor().Visit(s.Granularity as BaseExpr ?? None.Default).Name})";
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,112 +177,149 @@ class tensor_reduce_sync_impl {
// collect all tensors pointer for access tensor from other nodes.
using TElem = typename TIn::element_type;
using TOutBase = std::decay_t<TOut>;
static_assert(ShardedTensor<TOutBase>, "dest must be sharded tensor");
constexpr size_t Rank = TIn::rank();
constexpr auto group_hierarchy = group_hierarchy_getter<Kind>::group_hierarchy;
auto cur_index = ntt::make_shape(@(cur_index));
auto cur_index_g = index_global2group(cur_index);
tar::src_ptr_tensor(cur_index) =
reinterpret_cast<void *>(src.elements().data());
tar::dest_ptr_tensor(cur_index) =
reinterpret_cast<void *>(dest.elements().data());
reinterpret_cast<void *>(dest.local().elements().data());
reduce_group_sync();

// according to the group size split the tensor.
// todo should using better split strategy.
constexpr auto group_size = ntt::fixed_dim_v<get_group_size()>;
const auto axis = [&] {
dim_t axis = -1;
loop<Rank>([&](auto i) {
if (axis == -1 && src.shape()[i] >= group_size) {
axis = i;
}
});
if (axis == -1) {
axis = 0;
}
return axis;
}();

auto remain = src.shape()[axis] % (group_size);
auto frac = src.shape()[axis] / (group_size);

auto node_number_g = ntt::linear_offset(cur_index_g, group_hierarchy);

// reduce-scatter, communicate (group_size - 1) times
for (auto i = 0; i < group_size - 1; i++)
{
auto new_shape = ntt::generate_shape<Rank>([&](auto j) {
if (j == axis) {
return ntt::where(node_number_g == group_size - 1, frac + remain, frac);
} else {
return (dim_t)src.shape()[j];
}
});
auto starts = ntt::generate_shape<Rank>([&](auto j) {
if (j == axis) {
return node_number_g * frac;
} else {
return (dim_t)0;
}
});
if (src.shape() != dest.local().shape()) {
using mesh_type = typename TOutBase::mesh_type;
const auto local_shard_index = mesh_type::local_index();
auto node_number_g = ntt::linear_offset(cur_index_g, group_hierarchy);
auto new_shape = dest.local().shape();
auto starts = dest.sharding().global_offset(dest.shape(), local_shard_index);
auto viewed_src1_tensor = src.view(starts, new_shape);
auto viewed_dest_tensor = dest.view(starts, new_shape);
auto viewed_dest_tensor = dest.local();

auto next_index_g = ntt::unravel_index((node_number_g + i + 1) % group_size, group_hierarchy);
// reduce-scatter, communicate (group_size - 1) times
for (auto i = 0; i < group_size - 1; i++)
{
auto next_index_g = ntt::unravel_index((node_number_g + i + 1) % group_size, group_hierarchy);

// keep the non-reduce axis invariant.
auto next_index = index_group2global(next_index_g, cur_index);
// keep the non-reduce axis invariant.
auto next_index = index_group2global(next_index_g, cur_index);

auto src2_tensor = ntt::make_tensor_view_from_address<TElem>(
(TElem *)tar::src_ptr_tensor(next_index), src.shape(),
src.strides());
auto viewed_src2_tensor = src2_tensor.view(starts, new_shape);
auto src2_tensor = ntt::make_tensor_view_from_address<TElem>(
(TElem *)tar::src_ptr_tensor(next_index), src.shape(),
src.strides());
auto viewed_src2_tensor = src2_tensor.view(starts, new_shape);

if (i == 0) {
reduce_impl(viewed_src1_tensor, viewed_src2_tensor,
viewed_dest_tensor);
} else {
reduce_impl(viewed_dest_tensor, viewed_src2_tensor,
viewed_dest_tensor);
if (i == 0) {
reduce_impl(viewed_src1_tensor, viewed_src2_tensor,
viewed_dest_tensor);
} else {
reduce_impl(viewed_dest_tensor, viewed_src2_tensor,
viewed_dest_tensor);
}
}
}

reduce_group_sync();

// all gather
for (size_t i = 0; i < group_size - 1; i++) {
auto offset = (node_number_g + i + 1) % (group_size);
auto src_index_g = ntt::unravel_index(offset % group_size, group_hierarchy);
auto src_index = index_group2global(src_index_g, cur_index);

auto src_tensor = ntt::make_tensor_view_from_address<TElem>(
(TElem *)tar::dest_ptr_tensor(src_index), dest.shape(),
dest.strides());
auto starts = ntt::generate_shape<Rank>([&](auto j) {
if (j == axis) {
return offset * frac;
} else {
return (dim_t)0;
ntt::tensor_copy_wait<void>();
reduce_group_sync(ctx, group_target_value);
} else {
const auto axis = [&] {
dim_t axis = -1;
loop<Rank>([&](auto i) {
if (axis == -1 && src.shape()[i] >= group_size) {
axis = i;
}
});
if (axis == -1) {
axis = 0;
}
});
auto new_shape = ntt::generate_shape<Rank>([&](auto j) {
if (j == axis) {
return ntt::where(offset == group_size - 1, frac + remain, frac);
return axis;
}();

auto remain = src.shape()[axis] % (group_size);
auto frac = src.shape()[axis] / (group_size);

auto node_number_g = ntt::linear_offset(cur_index_g, group_hierarchy);

// reduce-scatter, communicate (group_size - 1) times
for (auto i = 0; i < group_size - 1; i++)
{
auto new_shape = ntt::generate_shape<Rank>([&](auto j) {
if (j == axis) {
return ntt::where(node_number_g == group_size - 1, frac + remain, frac);
} else {
return (dim_t)src.shape()[j];
}
});
auto starts = ntt::generate_shape<Rank>([&](auto j) {
if (j == axis) {
return node_number_g * frac;
} else {
return (dim_t)0;
}
});
auto viewed_src1_tensor = src.view(starts, new_shape);
auto viewed_dest_tensor = dest.local().view(starts, new_shape);

auto next_index_g = ntt::unravel_index((node_number_g + i + 1) % group_size, group_hierarchy);

// keep the non-reduce axis invariant.
auto next_index = index_group2global(next_index_g, cur_index);

auto src2_tensor = ntt::make_tensor_view_from_address<TElem>(
(TElem *)tar::src_ptr_tensor(next_index), src.shape(),
src.strides());
auto viewed_src2_tensor = src2_tensor.view(starts, new_shape);

if (i == 0) {
reduce_impl(viewed_src1_tensor, viewed_src2_tensor,
viewed_dest_tensor);
} else {
return (dim_t)src.shape()[j];
reduce_impl(viewed_dest_tensor, viewed_src2_tensor,
viewed_dest_tensor);
}
});
auto viewed_src_tensor = src_tensor.view(starts, new_shape);
auto viewed_dest_tensor = dest.view(starts, new_shape);
ntt::tensor_copy_async(viewed_src_tensor, viewed_dest_tensor);
}
}

ntt::tensor_copy_wait<void>();
reduce_group_sync();
reduce_group_sync();

// all gather
for (size_t i = 0; i < group_size - 1; i++) {
auto offset = (node_number_g + i + 1) % (group_size);
auto src_index_g = ntt::unravel_index(offset % group_size, group_hierarchy);
auto src_index = index_group2global(src_index_g, cur_index);

auto src_tensor = ntt::make_tensor_view_from_address<TElem>(
(TElem *)tar::dest_ptr_tensor(src_index), dest.local().shape(),
dest.local().strides());
auto starts = ntt::generate_shape<Rank>([&](auto j) {
if (j == axis) {
return offset * frac;
} else {
return (dim_t)0;
}
});
auto new_shape = ntt::generate_shape<Rank>([&](auto j) {
if (j == axis) {
return ntt::where(offset == group_size - 1, frac + remain, frac);
} else {
return (dim_t)src.shape()[j];
}
});
auto viewed_src_tensor = src_tensor.view(starts, new_shape);
auto viewed_dest_tensor = dest.local().view(starts, new_shape);
ntt::tensor_copy_async(viewed_src_tensor, viewed_dest_tensor);
}

if (Op == ntt::reduce_op::mean) {
auto numerator = (element_or_scalar_t<TElem>)(size_t)group_size;
ntt::binary<ntt::ops::div>(dest, ntt::make_tensor_view_from_address(&numerator, ntt::fixed_shape_v<>), dest);
ntt::tensor_copy_wait<void>();
reduce_group_sync();

if (Op == ntt::reduce_op::mean) {
auto numerator = (element_or_scalar_t<TElem>)(size_t)group_size;
ntt::binary<ntt::ops::div>(dest.local(), ntt::make_tensor_view_from_address(&numerator, ntt::fixed_shape_v<>), dest.local());
}
}
}
};
Expand Down
29 changes: 25 additions & 4 deletions modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/LayerNorm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Diagnostics;
using System.Linq;
using System.Runtime.CompilerServices;
using DryIoc.ImTools;
Expand Down Expand Up @@ -47,7 +48,7 @@ public IRType Visit(ITypeInferenceContext context, LayerNorm target)
{
return (input, scale, bias) switch
{
(DistributedType a, DistributedType b, DistributedType c) => new DistributedType((TensorType)VisitTensorType(target, a.TensorType, b.TensorType, c.TensorType), target.OutSBPs, a.Placement),
(DistributedType a, DistributedType b, DistributedType c) => VisitDistributedType(target, a, b, c),
(TensorType a, TensorType b, TensorType c) => VisitTensorType(target, a, b, c),
_ => new InvalidType($"{input} {scale} {bias} not support"),
};
Expand All @@ -68,17 +69,17 @@ private bool CheckCustomSBP(IRType input, IRType scale, IRType bias, LayerNorm l
{
if (input is DistributedType a && scale is DistributedType b && bias is DistributedType c)
{
if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => a.AxisPolicies[i] != layerNorm.InSBPs[i]))
if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(a.AxisPolicies[i], layerNorm.InSBPs[i], false)))
{
return false;
}

if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => b.AxisPolicies[i] != layerNorm.ScaleSBPs[i]))
if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(b.AxisPolicies[i], layerNorm.ScaleSBPs[i], false)))
{
return false;
}

if (Enumerable.Range(0, c.TensorType.Shape.Rank).Any(i => c.AxisPolicies[i] != layerNorm.BiasSBPs[i]))
if (Enumerable.Range(0, c.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(c.AxisPolicies[i], layerNorm.BiasSBPs[i], false)))
{
return false;
}
Expand Down Expand Up @@ -114,4 +115,24 @@ private IRType VisitTensorType(LayerNorm target, TensorType input, TensorType sc
return new TensorType(target.OutputDataType, input.Shape);
}
}

private IRType VisitDistributedType(LayerNorm target, DistributedType input, DistributedType scale, DistributedType bias)
{
var tensorType = (TensorType)VisitTensorType(target, input.TensorType, scale.TensorType, bias.TensorType);

var ndsbps = new SBP[tensorType.Shape.Rank];
for (var i = 0; i < ndsbps.Length; i++)
{
if (i == target.VectorizedAxes[0] && input.AxisPolicies[i] is SBPSplit split)
{
ndsbps[i] = SBP.S(split.Axes, split.Granularity is null ? null : split.Granularity * ((VectorType)input.TensorType.DType).Lanes[0] / ((VectorType)tensorType.DType).Lanes[0]);
}
else
{
ndsbps[i] = input.AxisPolicies[i];
}
}

return new DistributedType(tensorType, ndsbps, input.Placement);
}
}
27 changes: 24 additions & 3 deletions modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/Matmul.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public IRType Visit(ITypeInferenceContext context, MatMul target)
{
return (lhs, rhs) switch
{
(DistributedType a, DistributedType b) => new DistributedType((TensorType)VisitTensorType(target, a.TensorType, b.TensorType, true, dimInfo), target.OutSBPs, a.Placement),
(DistributedType a, DistributedType b) => VisitDistributedType(target, a, b, true, dimInfo),
(TensorType a, TensorType b) => VisitTensorType(target, a, b, true, dimInfo),
_ => new InvalidType($"{lhs} {rhs} not support"),
};
Expand All @@ -78,12 +78,12 @@ private bool CheckCustomSBP(IRType lhs, IRType rhs, IRType extra, MatMul matmul)

if (lhs is DistributedType a && rhs is DistributedType b)
{
if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => a.AxisPolicies[i] != matmul.LhsSBPs[i]))
if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(a.AxisPolicies[i], matmul.LhsSBPs[i], false)))
{
return false;
}

if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => b.AxisPolicies[i] != matmul.RhsSBPs[i]))
if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(b.AxisPolicies[i], matmul.RhsSBPs[i], false)))
{
return false;
}
Expand Down Expand Up @@ -156,4 +156,25 @@ private IRType VisitTensorType(MatMul target, TensorType lhs, TensorType rhs, bo

return new TensorType(dtype, front.Concat(end).ToArray());
}

private IRType VisitDistributedType(MatMul target, DistributedType lhs, DistributedType rhs, bool vectorizeK, MatMulDimInfo dimInfo)
{
var tensorType = (TensorType)VisitTensorType(target, lhs.TensorType, rhs.TensorType, vectorizeK, dimInfo);

// FIXME: support rank>=2, and only support vectorize N of output.
var policyN = rhs.AxisPolicies[dimInfo!.Rn];
if (policyN is SBPSplit split)
{
policyN = SBP.S(split.Axes, split.Granularity is null ? null : split.Granularity / ((VectorType)tensorType.DType).Lanes[0]);
}

var policyM = lhs.AxisPolicies[dimInfo!.Lm];
var ndsbps = (target.TransposeA || target.TransposeB) ? new[] { policyN, policyM } : new[] { policyM, policyN };
if (DistributedUtility.AreSamePolicies(ndsbps, target.OutSBPs, false))
{
return new DistributedType(tensorType, ndsbps, lhs.Placement);
}

return new InvalidType("Please Check SBP Scheme.");
}
}
Loading
Loading