Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions kt-kernel/ext_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,22 +322,29 @@ class MOESFTBindings {
intptr_t grad_down_lora_a;
intptr_t grad_down_lora_b;
intptr_t grad_weights;
intptr_t grad_gate_proj;
intptr_t grad_up_proj;
intptr_t grad_down_proj;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&TP_MOE_SFT<T>::backward_binding, args_->moe, args_->grad_output, args_->grad_input,
args_->grad_gate_lora_a, args_->grad_gate_lora_b, args_->grad_up_lora_a,
args_->grad_up_lora_b, args_->grad_down_lora_a, args_->grad_down_lora_b,
args_->grad_weights);
args_->grad_weights, args_->grad_gate_proj, args_->grad_up_proj,
args_->grad_down_proj);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE_SFT<T>> moe, intptr_t grad_output,
intptr_t grad_input, intptr_t grad_gate_lora_a,
intptr_t grad_gate_lora_b, intptr_t grad_up_lora_a,
intptr_t grad_up_lora_b, intptr_t grad_down_lora_a,
intptr_t grad_down_lora_b, intptr_t grad_weights) {
intptr_t grad_down_lora_b, intptr_t grad_weights,
intptr_t grad_gate_proj, intptr_t grad_up_proj,
intptr_t grad_down_proj) {
Args* args = new Args{nullptr, moe.get(), grad_output, grad_input,
grad_gate_lora_a, grad_gate_lora_b, grad_up_lora_a, grad_up_lora_b,
grad_down_lora_a, grad_down_lora_b, grad_weights};
grad_down_lora_a, grad_down_lora_b, grad_weights,
grad_gate_proj, grad_up_proj, grad_down_proj};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
Expand Down Expand Up @@ -400,7 +407,13 @@ void bind_moe_sft_module(py::module_& moe_module, const char* name) {
self.prepare_and_save_bwd((void*)gate, (void*)up, (void*)down, path);
})
.def("submit_backward_repack", &MoeClass::submit_backward_repack)
.def("wait_backward_repack", &MoeClass::wait_backward_repack);
.def("wait_backward_repack", &MoeClass::wait_backward_repack)
// Update base weight BF16 pointers for reload_base_weights (full mode training)
// After calling this, call load_weights_task() to re-quantize BF16->AMX
.def("set_base_weight_pointers",
[](MoeClass& self, intptr_t gate, intptr_t up, intptr_t down) {
self.set_base_weight_pointers((void*)gate, (void*)up, (void*)down);
});
}
#endif // defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)

Expand Down Expand Up @@ -775,7 +788,11 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
.DEF_PTR_PROPERTY(MOESFTConfig, up_lora_a)
.DEF_PTR_PROPERTY(MOESFTConfig, up_lora_b)
.DEF_PTR_PROPERTY(MOESFTConfig, down_lora_a)
.DEF_PTR_PROPERTY(MOESFTConfig, down_lora_b);
.DEF_PTR_PROPERTY(MOESFTConfig, down_lora_b)
.def_readwrite("full_weight_grad", &MOESFTConfig::full_weight_grad)
.DEF_PTR_PROPERTY(MOESFTConfig, grad_gate_proj)
.DEF_PTR_PROPERTY(MOESFTConfig, grad_up_proj)
.DEF_PTR_PROPERTY(MOESFTConfig, grad_down_proj);

py::class_<MoE_Interface, std::shared_ptr<MoE_Interface>>(moe_module, "MoE_Interface");

Expand Down
98 changes: 96 additions & 2 deletions kt-kernel/operators/amx/sft_moe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1404,7 +1404,8 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
void backward(const void* grad_output, void* grad_input, void* grad_gate_lora_a, void* grad_gate_lora_b,
void* grad_up_lora_a, void* grad_up_lora_b, void* grad_down_lora_a, void* grad_down_lora_b,
void* grad_weights, int full_intermediate_size = 0, float* fp32_grad_down_lora_b = nullptr,
float* fp32_grad_gate_lora_a = nullptr, float* fp32_grad_up_lora_a = nullptr) {
float* fp32_grad_gate_lora_a = nullptr, float* fp32_grad_up_lora_a = nullptr,
void* grad_gate_proj = nullptr, void* grad_up_proj = nullptr, void* grad_down_proj = nullptr) {
// If full_intermediate_size not provided, use local (non-TP mode)
if (full_intermediate_size == 0) full_intermediate_size = config_.intermediate_size;
SFT_POOL_LOG("bwd_enter", config_.layer_idx, tp_part_idx, 0, cache_stack_top_, forward_pool_bytes_,
Expand Down Expand Up @@ -1888,7 +1889,14 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
print_grad_stats_fp32("grad_weights", (const float*)grad_weights, qlen * k);
}

// ★ Cache pool is NOT freed here — kept for reuse across steps.
// =====================================================================
// Step 5: Base weight gradient accumulation (full weight grad mode)
// =====================================================================
if (sft_config_.full_weight_grad && grad_gate_proj && grad_up_proj && grad_down_proj) {
backward_base_weight_grad(cache, grad_output, grad_gate_proj, grad_up_proj, grad_down_proj);
}

// \u2605 Cache pool is NOT freed here \u2014 kept for reuse across steps.
// alloc_or_resize_cache_pool() is grow-only, so same-seqlen steps
// reuse the existing allocation without malloc/free overhead.
// Previously: free_seqlen_buffers() was called here, costing ~3.6ms per TP.
Expand All @@ -1897,6 +1905,92 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> {
cache.valid = false;
}

/**
* @brief Compute base weight gradients via outer-product accumulation.
*
* For each activated expert e, computes:
* grad_gate_proj[e] = grad_gate_out[e]^T @ input[e] -> [I, H]
* grad_up_proj[e] = grad_up_out[e]^T @ input[e] -> [I, H]
* grad_down_proj[e] = grad_output[e]^T @ intermediate[e] -> [H, I]
*
* Uses FP32 accumulator for precision, writes BF16 output.
*/
void backward_base_weight_grad(const ForwardCache& cache, const void* grad_output,
void* grad_gate_proj, void* grad_up_proj, void* grad_down_proj) {
const int H = config_.hidden_size;
const int I = config_.intermediate_size;
const int E = config_.expert_num;
int activated_expert = cache.activated_expert_cache;

auto* ggp = static_cast<ggml_bf16_t*>(grad_gate_proj); // [E, I, H]
auto* gup_ptr = static_cast<ggml_bf16_t*>(grad_up_proj); // [E, I, H]
auto* gdp = static_cast<ggml_bf16_t*>(grad_down_proj); // [E, H, I]
auto* grad_out_bf16 = static_cast<const ggml_bf16_t*>(grad_output);

for (int task_id = 0; task_id < activated_expert; task_id++) {
int expert_idx = cache.m_expert_id_map_cache[task_id];
int m = cache.m_local_num_cache[expert_idx];
if (m == 0) continue;

int pos_start = 0;
for (int prev_id = 0; prev_id < task_id; prev_id++) {
pos_start += cache.m_local_num_cache[cache.m_expert_id_map_cache[prev_id]];
}

const auto& local_pos = cache.m_local_pos_cache[expert_idx];

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Critical bug: cache.m_local_pos_cache is indexed by token ID, not expert index. Accessing it with expert_idx will lead to out-of-bounds access or incorrect data. You need a mapping from expert to the original token indices to correctly retrieve input_row and grad_out_row from the global buffers.


// Allocate FP32 accumulators from forward pool (safe during backward)
float* acc_gate = static_cast<float*>(forward_pool_); // [I, H]
float* acc_up = acc_gate + (size_t)I * H; // [I, H]
float* acc_down = acc_up + (size_t)I * H; // [H, I]
Comment on lines +1943 to +1945

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Critical issue: forward_pool_ is used for large FP32 accumulators (acc_gate, acc_up, acc_down), but its size is only determined by LoRA working buffers in alloc_forward_buffers. For full fine-tuning, these accumulators require (2 * I * H + H * I) * sizeof(float) bytes, which can be hundreds of megabytes (e.g., ~175MB for I=2048, H=7168). This will cause a buffer overflow if forward_pool_ is not explicitly resized to accommodate these.


std::memset(acc_gate, 0, (size_t)I * H * sizeof(float));
std::memset(acc_up, 0, (size_t)I * H * sizeof(float));
std::memset(acc_down, 0, (size_t)H * I * sizeof(float));

for (int t = 0; t < m; t++) {
int tok_pos = local_pos[t];
const ggml_bf16_t* input_row = cache.input_cache + (size_t)tok_pos * H;
const ggml_bf16_t* gate_grad_row = grad_gate_output_ + (size_t)(pos_start + t) * I;
const ggml_bf16_t* up_grad_row = grad_up_output_ + (size_t)(pos_start + t) * I;
const ggml_bf16_t* inter_row = cache.intermediate_cache + (size_t)(pos_start + t) * I;
const ggml_bf16_t* grad_out_row = grad_out_bf16 + (size_t)tok_pos * H;

// gate_proj grad: [I, H] += grad_gate_out[t]^T @ input[t]
for (int i = 0; i < I; i++) {
float gg = GGML_BF16_TO_FP32(gate_grad_row[i]);
float gu = GGML_BF16_TO_FP32(up_grad_row[i]);
for (int h = 0; h < H; h++) {
float inp = GGML_BF16_TO_FP32(input_row[h]);
acc_gate[i * H + h] += gg * inp;
acc_up[i * H + h] += gu * inp;
}
}

// down_proj grad: [H, I] += grad_output[t]^T @ intermediate[t]
for (int h = 0; h < H; h++) {
float go = GGML_BF16_TO_FP32(grad_out_row[h]);
for (int i = 0; i < I; i++) {
acc_down[h * I + i] += go * GGML_BF16_TO_FP32(inter_row[i]);
}
}
Comment on lines +1959 to +1976

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

High severity performance issue: The base weight gradient computation uses a triple nested loop without vectorization or parallelization. For typical MoE sizes, this results in billions of operations executed sequentially on a single thread. This should be implemented using a vectorized outer product (rank-1 update) or a GEMM ($Grad = Grad_Out^T \times Input$), and parallelized across experts using the worker pool.

}

// Convert FP32 accumulators to BF16 and store
for (int i = 0; i < I; i++) {
for (int h = 0; h < H; h++) {
ggp[(size_t)expert_idx * I * H + (size_t)i * H + h] = GGML_FP32_TO_BF16(acc_gate[i * H + h]);
gup_ptr[(size_t)expert_idx * I * H + (size_t)i * H + h] = GGML_FP32_TO_BF16(acc_up[i * H + h]);
}
}
for (int h = 0; h < H; h++) {
for (int i = 0; i < I; i++) {
gdp[(size_t)expert_idx * H * I + (size_t)h * I + i] = GGML_FP32_TO_BF16(acc_down[h * I + i]);
}
Comment on lines +1982 to +1989

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

High severity issue: The gradients are being overwritten in the output buffer (ggp[...] = ...) instead of being accumulated. This will break gradient accumulation if multiple micro-batches are used before an optimizer step. While PyTorch usually handles accumulation, here the C++ kernel writes directly into a persistent buffer in the wrapper, so it must perform in-place accumulation if the buffer is reused.

}
}
}

/**
* @brief Get qlen from the top of the forward cache stack.
*
Expand Down
9 changes: 9 additions & 0 deletions kt-kernel/operators/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,15 @@ struct MOESFTConfig : public GeneralMOEConfig {
void* down_lora_a = nullptr; // [expert_num, lora_rank, intermediate_size]
void* down_lora_b = nullptr; // [expert_num, hidden_size, lora_rank]

// Full weight gradient configuration
bool full_weight_grad = false;

// Base weight gradient buffer pointers (directly pointing to Python tensor memory, zero-copy)
// Only used when full_weight_grad == true
void* grad_gate_proj = nullptr; // [expert_num, intermediate_size, hidden_size]
void* grad_up_proj = nullptr; // [expert_num, intermediate_size, hidden_size]
void* grad_down_proj = nullptr; // [expert_num, hidden_size, intermediate_size]

MOESFTConfig() : GeneralMOEConfig() {}

MOESFTConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size)
Expand Down
29 changes: 24 additions & 5 deletions kt-kernel/operators/moe-sft-tp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ class TP_MOE_SFT : public TP_MOE<T> {
throw std::runtime_error("K2 pre-quantized mode does not support TP > 1 yet");
}
} else if (config.gate_proj != nullptr) {
printf("TP_MOE_SFT: From BF16 with partitioning\n");
// printf("TP_MOE_SFT: From BF16 with partitioning\n");

// Temporary storage for partitioned weights
std::vector<ggml_bf16_t*> temp_gate(tp_count);
Expand Down Expand Up @@ -548,7 +548,8 @@ class TP_MOE_SFT : public TP_MOE<T> {
*/
void backward(const void* grad_output, void* grad_input, void* grad_gate_lora_a, void* grad_gate_lora_b,
void* grad_up_lora_a, void* grad_up_lora_b, void* grad_down_lora_a, void* grad_down_lora_b,
void* grad_weights) {
void* grad_weights, void* grad_gate_proj = nullptr, void* grad_up_proj = nullptr,
void* grad_down_proj = nullptr) {
auto pool = config.pool;


Expand Down Expand Up @@ -694,7 +695,8 @@ class TP_MOE_SFT : public TP_MOE<T> {
tp_down_a_ptr[numa_id], /* copy-type: direct write */
nullptr, /* grad_down_lora_b — unused, FP32 path below */
part_grad_weights_[numa_id], full_intermediate_size, tp_fp32_down_b[numa_id],
tp_fp32_gate_a[numa_id], tp_fp32_up_a[numa_id]);
tp_fp32_gate_a[numa_id], tp_fp32_up_a[numa_id],
grad_gate_proj, grad_up_proj, grad_down_proj);
});

// // Collect per-thread timing from all NUMA subpools
Expand Down Expand Up @@ -891,10 +893,11 @@ class TP_MOE_SFT : public TP_MOE<T> {
*/
void backward_binding(intptr_t grad_output, intptr_t grad_input, intptr_t grad_gate_lora_a, intptr_t grad_gate_lora_b,
intptr_t grad_up_lora_a, intptr_t grad_up_lora_b, intptr_t grad_down_lora_a,
intptr_t grad_down_lora_b, intptr_t grad_weights) {
intptr_t grad_down_lora_b, intptr_t grad_weights, intptr_t grad_gate_proj,
intptr_t grad_up_proj, intptr_t grad_down_proj) {
backward((const void*)grad_output, (void*)grad_input, (void*)grad_gate_lora_a, (void*)grad_gate_lora_b,
(void*)grad_up_lora_a, (void*)grad_up_lora_b, (void*)grad_down_lora_a, (void*)grad_down_lora_b,
(void*)grad_weights);
(void*)grad_weights, (void*)grad_gate_proj, (void*)grad_up_proj, (void*)grad_down_proj);
}

/**
Expand Down Expand Up @@ -1106,6 +1109,22 @@ class TP_MOE_SFT : public TP_MOE<T> {
update_lora_weights((void*)gate_lora_a, (void*)gate_lora_b, (void*)up_lora_a, (void*)up_lora_b, (void*)down_lora_a,
(void*)down_lora_b);
}

/**
* @brief Update base weight BF16 pointers for reload_base_weights (full mode training).
*
* After calling this, call load_weights_task() to re-quantize BF16->AMX
* and update the C++ kernel's internal quantized buffers.
* This avoids creating a new C++ MOE object (~0.6s/layer for quantization
* vs ~1.9s/layer for full object recreation).
*/
void set_base_weight_pointers(void* gate, void* up, void* down) {
config.gate_proj = gate;
config.up_proj = up;
config.down_proj = down;
// Mark that weights need re-loading (partitioning + quantization)
weights_loaded = false;
}
};

#endif // CPUINFER_OPERATOR_MOE_SFT_TP_HPP
6 changes: 6 additions & 0 deletions kt-kernel/python/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def __new__(
# Quantization config (for K-Group SFT methods)
group_size: int = 128,
zero_point: bool = True,
# Full weight gradient mode (for full fine-tuning without LoRA)
full_weight_grad: bool = False,
# V4-Flash 2604B SwiGLU clamp limit. 0.0 = disabled (default for
# every dtype except DSV4-2604B routed experts, which set this to
# 10.0 to match trtllm gemm1_clamp_limit / deep_gemm
Expand Down Expand Up @@ -247,6 +249,7 @@ def __new__(
max_cache_depth=max_cache_depth,
group_size=group_size,
zero_point=zero_point,
full_weight_grad=full_weight_grad,
)

# Forward static methods to the base class
Expand Down Expand Up @@ -292,6 +295,7 @@ def clear_sft_buffer_cache():
to reset the buffer state or free memory during SFT.
"""
from .sft.base import KExpertsSFTBuffer

KExpertsSFTBuffer.clear_cache()


Expand Down Expand Up @@ -393,6 +397,7 @@ def _create_sft_wrapper(
max_cache_depth: int,
group_size: int,
zero_point: bool,
full_weight_grad: bool = False,
):
"""
Create an SFT wrapper based on the method.
Expand Down Expand Up @@ -423,4 +428,5 @@ def _create_sft_wrapper(
method=method,
group_size=group_size,
zero_point=zero_point,
full_weight_grad=full_weight_grad,
)
13 changes: 11 additions & 2 deletions kt-kernel/python/sft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,15 @@
from .base import BaseSFTMoEWrapper, KExpertsSFTBuffer
from .amx import AMXSFTMoEWrapper
from .arch import (
MOEArchConfig, get_moe_arch_config, get_moe_module, move_non_experts_to_gpu, get_expert_device,
KTAMXError, KTAMXNotAvailableError, KTAMXModelNotSupportedError, KTAMXConfigError,
MOEArchConfig,
get_moe_arch_config,
get_moe_module,
move_non_experts_to_gpu,
get_expert_device,
KTAMXError,
KTAMXNotAvailableError,
KTAMXModelNotSupportedError,
KTAMXConfigError,
)
from .autograd import KTMoEFunction
from .layer import KTMoELayerWrapper
Expand All @@ -28,6 +35,7 @@
from .lora import (
kt_adapt_peft_lora,
get_kt_lora_params,
get_kt_trainable_params,
update_kt_lora_pointers,
sync_kt_lora_gradients,
save_lora_experts_to_adapter,
Expand Down Expand Up @@ -67,6 +75,7 @@
"INT8ExpertWeights",
"kt_adapt_peft_lora",
"get_kt_lora_params",
"get_kt_trainable_params",
"update_kt_lora_pointers",
"sync_kt_lora_gradients",
"save_lora_experts_to_adapter",
Expand Down
Loading
Loading