Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions ntt/include/nncase/ntt/arch/riscv64/ukernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ template <reduce_op Op, class T> struct u_reduce_policy<Op, T, true> {
static constexpr size_t unroll = 8;
};

#if 0
// cast
template <> struct u_cast_policy<true> {
static constexpr size_t unroll = 8;
Expand Down Expand Up @@ -899,6 +900,183 @@ DEFINE_U_CAST_2_1(half, 16, float_e4m3_t, 8, _Float16, int8_t, f16, float8e4m3)
DEFINE_U_CAST_1_2(half, 16, float, 32, _Float16, float, float16, f32)
#if defined(NNCASE_XPU_MODULE) && defined(SYS_MODE)
DEFINE_U_CAST_1_2(float_e4m3_t, 8, half, 16, int8_t, _Float16, float8e4m3, f16)
#endif
#else
// cast
template <> struct u_cast_policy<true> {
static constexpr size_t unroll = 8;
};

#define DEFINE_U_CAST_2_1(IN_ELEM, IN_BW, OUT_ELEM, OUT_BW, IN_BUILTIN_ELEM, \
OUT_BUILTIN_ELEM, IN_INTRINSIC_ELEM, \
OUT_INTRINSIC_ELEM) \
template <template <class> class TPostOps, class Stride> \
struct u_cast<true, vector<IN_ELEM, 2, NTT_VLEN / IN_BW>, \
vector<OUT_ELEM, NTT_VLEN / OUT_BW>, TPostOps, Stride> { \
public: \
using T2Elem = OUT_ELEM; \
using T1 = vector<IN_ELEM, 2, NTT_VLEN / IN_BW>; \
using T2 = vector<OUT_ELEM, NTT_VLEN / OUT_BW>; \
constexpr static size_t in_offset_scale = 2; \
\
constexpr void operator()(const T1 *input, Stride input_stride, \
T2 *output, \
[[maybe_unused]] Stride output_stride, \
size_t count) noexcept { \
using policy_t = u_cast_policy<true>; \
constexpr auto unroll = policy_t::unroll; \
while (count / unroll) { \
auto v0 = ntt::cast_elem<T2Elem>(*(input + 0 * input_stride)); \
auto v2 = ntt::cast_elem<T2Elem>(*(input + 1 * input_stride)); \
auto v4 = ntt::cast_elem<T2Elem>(*(input + 2 * input_stride)); \
auto v6 = ntt::cast_elem<T2Elem>(*(input + 3 * input_stride)); \
auto v8 = ntt::cast_elem<T2Elem>(*(input + 4 * input_stride)); \
auto v10 = ntt::cast_elem<T2Elem>(*(input + 5 * input_stride)); \
auto v12 = ntt::cast_elem<T2Elem>(*(input + 6 * input_stride)); \
auto v14 = ntt::cast_elem<T2Elem>(*(input + 7 * input_stride)); \
\
v0 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v0); \
v2 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v2); \
v4 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v4); \
v6 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v6); \
v8 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v8); \
v10 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v10); \
v12 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v12); \
v14 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v14); \
\
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v0), \
"r"(output + 0 * output_stride) \
: "memory"); \
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v2), \
"r"(output + 1 * output_stride) \
: "memory"); \
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v4), \
"r"(output + 2 * output_stride) \
: "memory"); \
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v6), \
"r"(output + 3 * output_stride) \
: "memory"); \
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v8), \
"r"(output + 4 * output_stride) \
: "memory"); \
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v10), \
"r"(output + 5 * output_stride) \
: "memory"); \
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v12), \
"r"(output + 6 * output_stride) \
: "memory"); \
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v14), \
"r"(output + 7 * output_stride) \
: "memory"); \
output += unroll; \
input += unroll; \
count -= unroll; \
} \
\
for (size_t i = 0; i < count; i++) { \
auto v0 = ntt::cast_elem<T2Elem>(*input); \
v0 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v0); \
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v0), \
"r"(output) \
: "memory"); \
input += input_stride; \
output += output_stride; \
} \
} \
};

DEFINE_U_CAST_2_1(float, 32, half, 16, float, _Float16, f32, float16)
#if defined(NNCASE_XPU_MODULE) && defined(SYS_MODE)
DEFINE_U_CAST_2_1(half, 16, float_e4m3_t, 8, _Float16, int8_t, f16, float8e4m3)
#endif


#define DEFINE_U_CAST_1_2(IN_ELEM, IN_BW, OUT_ELEM, OUT_BW, IN_BUILTIN_ELEM, \
OUT_BUILTIN_ELEM, IN_INTRINSIC_ELEM, \
OUT_INTRINSIC_ELEM) \
template <template <class> class TPostOps, class Stride> \
struct u_cast<true, vector<IN_ELEM, NTT_VLEN / IN_BW>, \
vector<OUT_ELEM, 2 * NTT_VLEN / OUT_BW>, TPostOps, Stride> { \
constexpr void \
operator()(const vector<IN_ELEM, NTT_VLEN / IN_BW> *input, \
[[maybe_unused]] Stride input_stride, \
vector<OUT_ELEM, 2 * NTT_VLEN / OUT_BW> *output, \
Stride output_stride, size_t count) noexcept { \
using policy_t = u_cast_policy<true>; \
constexpr auto unroll = policy_t::unroll; \
constexpr auto half_unroll = unroll / 2; \
\
using T2Elem = OUT_ELEM; \
using T1 = vector<IN_ELEM, NTT_VLEN / IN_BW>; \
using T2 = vector<OUT_ELEM, 2 * NTT_VLEN / OUT_BW>; \
[[maybe_unused]] constexpr static size_t out_offset_scale = 2; \
\
while (count / unroll) { \
constexpr auto vl_in = NTT_VLEN / IN_BW; \
constexpr auto vl_out = 2 * NTT_VLEN / OUT_BW; \
auto v0 = ntt::cast_elem<T2Elem>(*(input + 0 * input_stride)); \
auto v2 = ntt::cast_elem<T2Elem>(*(input + 1 * input_stride)); \
auto v4 = ntt::cast_elem<T2Elem>(*(input + 2 * input_stride)); \
auto v6 = ntt::cast_elem<T2Elem>(*(input + 3 * input_stride)); \
auto v8 = ntt::cast_elem<T2Elem>(*(input + 4 * input_stride)); \
auto v10 = ntt::cast_elem<T2Elem>(*(input + 5 * input_stride)); \
auto v12 = ntt::cast_elem<T2Elem>(*(input + 6 * input_stride)); \
auto v14 = ntt::cast_elem<T2Elem>(*(input + 7 * input_stride)); \
\
v0 = TPostOps<vector<OUT_ELEM, 2, NTT_VLEN / OUT_BW>>()(v0); \
v2 = TPostOps<vector<OUT_ELEM, 2, NTT_VLEN / OUT_BW>>()(v2); \
v4 = TPostOps<vector<OUT_ELEM, 2, NTT_VLEN / OUT_BW>>()(v4); \
v6 = TPostOps<vector<OUT_ELEM, 2, NTT_VLEN / OUT_BW>>()(v6); \
v8 = TPostOps<vector<OUT_ELEM, 2, NTT_VLEN / OUT_BW>>()(v8); \
v10 = TPostOps<vector<OUT_ELEM, 2, NTT_VLEN / OUT_BW>>()(v10); \
v12 = TPostOps<vector<OUT_ELEM, 2, NTT_VLEN / OUT_BW>>()(v12); \
v14 = TPostOps<vector<OUT_ELEM, 2, NTT_VLEN / OUT_BW>>()(v14); \
\
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v0), \
"r"(output + 0 * output_stride) \
: "memory"); \
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v2), \
"r"(output + 1 * output_stride) \
: "memory"); \
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v4), \
"r"(output + 2 * output_stride) \
: "memory"); \
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v6), \
"r"(output + 3 * output_stride) \
: "memory"); \
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v8), \
"r"(output + 4 * output_stride) \
: "memory"); \
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v10), \
"r"(output + 5 * output_stride) \
: "memory"); \
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v12), \
"r"(output + 6 * output_stride) \
: "memory"); \
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v14), \
"r"(output + 7 * output_stride) \
: "memory"); \
input += unroll; \
output += unroll; \
count -= unroll; \
} \
for (size_t i = 0; i < count; i++) { \
auto v0 = ntt::cast_elem<T2Elem>(*input); \
v0 = TPostOps<vector<OUT_ELEM, 2 * NTT_VLEN / OUT_BW>>()(v0); \
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v0), \
"r"(output) \
: "memory"); \
input += input_stride; \
output += output_stride; \
} \
} \
};

DEFINE_U_CAST_1_2(half, 16, float, 32, _Float16, float, float16, float32)
#if defined(NNCASE_XPU_MODULE) && defined(SYS_MODE)
DEFINE_U_CAST_1_2(float_e4m3_t, 8, half, 16, int8_t, _Float16, float8e4m3, f16)
#endif

#endif

template <Scalar TProbs, Scalar TIndices, size_t Rank, size_t Axis, bool Norm>
Expand Down
Loading
Loading