diff --git a/kt-kernel/ext_bindings.cpp b/kt-kernel/ext_bindings.cpp index 66f6fd318..5f4e4ad1d 100644 --- a/kt-kernel/ext_bindings.cpp +++ b/kt-kernel/ext_bindings.cpp @@ -322,22 +322,29 @@ class MOESFTBindings { intptr_t grad_down_lora_a; intptr_t grad_down_lora_b; intptr_t grad_weights; + intptr_t grad_gate_proj; + intptr_t grad_up_proj; + intptr_t grad_down_proj; }; static void inner(void* args) { Args* args_ = (Args*)args; args_->cpuinfer->enqueue(&TP_MOE_SFT::backward_binding, args_->moe, args_->grad_output, args_->grad_input, args_->grad_gate_lora_a, args_->grad_gate_lora_b, args_->grad_up_lora_a, args_->grad_up_lora_b, args_->grad_down_lora_a, args_->grad_down_lora_b, - args_->grad_weights); + args_->grad_weights, args_->grad_gate_proj, args_->grad_up_proj, + args_->grad_down_proj); } static std::pair cpuinfer_interface(std::shared_ptr> moe, intptr_t grad_output, intptr_t grad_input, intptr_t grad_gate_lora_a, intptr_t grad_gate_lora_b, intptr_t grad_up_lora_a, intptr_t grad_up_lora_b, intptr_t grad_down_lora_a, - intptr_t grad_down_lora_b, intptr_t grad_weights) { + intptr_t grad_down_lora_b, intptr_t grad_weights, + intptr_t grad_gate_proj, intptr_t grad_up_proj, + intptr_t grad_down_proj) { Args* args = new Args{nullptr, moe.get(), grad_output, grad_input, grad_gate_lora_a, grad_gate_lora_b, grad_up_lora_a, grad_up_lora_b, - grad_down_lora_a, grad_down_lora_b, grad_weights}; + grad_down_lora_a, grad_down_lora_b, grad_weights, + grad_gate_proj, grad_up_proj, grad_down_proj}; return std::make_pair((intptr_t)&inner, (intptr_t)args); } }; @@ -400,7 +407,13 @@ void bind_moe_sft_module(py::module_& moe_module, const char* name) { self.prepare_and_save_bwd((void*)gate, (void*)up, (void*)down, path); }) .def("submit_backward_repack", &MoeClass::submit_backward_repack) - .def("wait_backward_repack", &MoeClass::wait_backward_repack); + .def("wait_backward_repack", &MoeClass::wait_backward_repack) + // Update base weight BF16 pointers for reload_base_weights (full mode training) + // After calling this, call load_weights_task() to re-quantize BF16->AMX + .def("set_base_weight_pointers", + [](MoeClass& self, intptr_t gate, intptr_t up, intptr_t down) { + self.set_base_weight_pointers((void*)gate, (void*)up, (void*)down); + }); } #endif // defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL) @@ -775,7 +788,11 @@ PYBIND11_MODULE(kt_kernel_ext, m) { .DEF_PTR_PROPERTY(MOESFTConfig, up_lora_a) .DEF_PTR_PROPERTY(MOESFTConfig, up_lora_b) .DEF_PTR_PROPERTY(MOESFTConfig, down_lora_a) - .DEF_PTR_PROPERTY(MOESFTConfig, down_lora_b); + .DEF_PTR_PROPERTY(MOESFTConfig, down_lora_b) + .def_readwrite("full_weight_grad", &MOESFTConfig::full_weight_grad) + .DEF_PTR_PROPERTY(MOESFTConfig, grad_gate_proj) + .DEF_PTR_PROPERTY(MOESFTConfig, grad_up_proj) + .DEF_PTR_PROPERTY(MOESFTConfig, grad_down_proj); py::class_>(moe_module, "MoE_Interface"); diff --git a/kt-kernel/operators/amx/sft_moe.hpp b/kt-kernel/operators/amx/sft_moe.hpp index 295c263c0..01e609a20 100644 --- a/kt-kernel/operators/amx/sft_moe.hpp +++ b/kt-kernel/operators/amx/sft_moe.hpp @@ -1404,7 +1404,8 @@ class AMX_SFT_MOE_TP : public BaseMOE { void backward(const void* grad_output, void* grad_input, void* grad_gate_lora_a, void* grad_gate_lora_b, void* grad_up_lora_a, void* grad_up_lora_b, void* grad_down_lora_a, void* grad_down_lora_b, void* grad_weights, int full_intermediate_size = 0, float* fp32_grad_down_lora_b = nullptr, - float* fp32_grad_gate_lora_a = nullptr, float* fp32_grad_up_lora_a = nullptr) { + float* fp32_grad_gate_lora_a = nullptr, float* fp32_grad_up_lora_a = nullptr, + void* grad_gate_proj = nullptr, void* grad_up_proj = nullptr, void* grad_down_proj = nullptr) { // If full_intermediate_size not provided, use local (non-TP mode) if (full_intermediate_size == 0) full_intermediate_size = config_.intermediate_size; SFT_POOL_LOG("bwd_enter", config_.layer_idx, tp_part_idx, 0, cache_stack_top_, forward_pool_bytes_, @@ -1888,7 +1889,14 @@ class AMX_SFT_MOE_TP : public BaseMOE { print_grad_stats_fp32("grad_weights", (const float*)grad_weights, qlen * k); } - // ★ Cache pool is NOT freed here — kept for reuse across steps. + // ===================================================================== + // Step 5: Base weight gradient accumulation (full weight grad mode) + // ===================================================================== + if (sft_config_.full_weight_grad && grad_gate_proj && grad_up_proj && grad_down_proj) { + backward_base_weight_grad(cache, grad_output, grad_gate_proj, grad_up_proj, grad_down_proj); + } + + // \u2605 Cache pool is NOT freed here \u2014 kept for reuse across steps. // alloc_or_resize_cache_pool() is grow-only, so same-seqlen steps // reuse the existing allocation without malloc/free overhead. // Previously: free_seqlen_buffers() was called here, costing ~3.6ms per TP. @@ -1897,6 +1905,92 @@ class AMX_SFT_MOE_TP : public BaseMOE { cache.valid = false; } + /** + * @brief Compute base weight gradients via outer-product accumulation. + * + * For each activated expert e, computes: + * grad_gate_proj[e] = grad_gate_out[e]^T @ input[e] -> [I, H] + * grad_up_proj[e] = grad_up_out[e]^T @ input[e] -> [I, H] + * grad_down_proj[e] = grad_output[e]^T @ intermediate[e] -> [H, I] + * + * Uses FP32 accumulator for precision, writes BF16 output. + */ + void backward_base_weight_grad(const ForwardCache& cache, const void* grad_output, + void* grad_gate_proj, void* grad_up_proj, void* grad_down_proj) { + const int H = config_.hidden_size; + const int I = config_.intermediate_size; + const int E = config_.expert_num; + int activated_expert = cache.activated_expert_cache; + + auto* ggp = static_cast(grad_gate_proj); // [E, I, H] + auto* gup_ptr = static_cast(grad_up_proj); // [E, I, H] + auto* gdp = static_cast(grad_down_proj); // [E, H, I] + auto* grad_out_bf16 = static_cast(grad_output); + + for (int task_id = 0; task_id < activated_expert; task_id++) { + int expert_idx = cache.m_expert_id_map_cache[task_id]; + int m = cache.m_local_num_cache[expert_idx]; + if (m == 0) continue; + + int pos_start = 0; + for (int prev_id = 0; prev_id < task_id; prev_id++) { + pos_start += cache.m_local_num_cache[cache.m_expert_id_map_cache[prev_id]]; + } + + const auto& local_pos = cache.m_local_pos_cache[expert_idx]; + + // Allocate FP32 accumulators from forward pool (safe during backward) + float* acc_gate = static_cast(forward_pool_); // [I, H] + float* acc_up = acc_gate + (size_t)I * H; // [I, H] + float* acc_down = acc_up + (size_t)I * H; // [H, I] + + std::memset(acc_gate, 0, (size_t)I * H * sizeof(float)); + std::memset(acc_up, 0, (size_t)I * H * sizeof(float)); + std::memset(acc_down, 0, (size_t)H * I * sizeof(float)); + + for (int t = 0; t < m; t++) { + int tok_pos = local_pos[t]; + const ggml_bf16_t* input_row = cache.input_cache + (size_t)tok_pos * H; + const ggml_bf16_t* gate_grad_row = grad_gate_output_ + (size_t)(pos_start + t) * I; + const ggml_bf16_t* up_grad_row = grad_up_output_ + (size_t)(pos_start + t) * I; + const ggml_bf16_t* inter_row = cache.intermediate_cache + (size_t)(pos_start + t) * I; + const ggml_bf16_t* grad_out_row = grad_out_bf16 + (size_t)tok_pos * H; + + // gate_proj grad: [I, H] += grad_gate_out[t]^T @ input[t] + for (int i = 0; i < I; i++) { + float gg = GGML_BF16_TO_FP32(gate_grad_row[i]); + float gu = GGML_BF16_TO_FP32(up_grad_row[i]); + for (int h = 0; h < H; h++) { + float inp = GGML_BF16_TO_FP32(input_row[h]); + acc_gate[i * H + h] += gg * inp; + acc_up[i * H + h] += gu * inp; + } + } + + // down_proj grad: [H, I] += grad_output[t]^T @ intermediate[t] + for (int h = 0; h < H; h++) { + float go = GGML_BF16_TO_FP32(grad_out_row[h]); + for (int i = 0; i < I; i++) { + acc_down[h * I + i] += go * GGML_BF16_TO_FP32(inter_row[i]); + } + } + } + + // Convert FP32 accumulators to BF16 and store + for (int i = 0; i < I; i++) { + for (int h = 0; h < H; h++) { + ggp[(size_t)expert_idx * I * H + (size_t)i * H + h] = GGML_FP32_TO_BF16(acc_gate[i * H + h]); + gup_ptr[(size_t)expert_idx * I * H + (size_t)i * H + h] = GGML_FP32_TO_BF16(acc_up[i * H + h]); + } + } + for (int h = 0; h < H; h++) { + for (int i = 0; i < I; i++) { + gdp[(size_t)expert_idx * H * I + (size_t)h * I + i] = GGML_FP32_TO_BF16(acc_down[h * I + i]); + } + } + } + } + /** * @brief Get qlen from the top of the forward cache stack. * diff --git a/kt-kernel/operators/common.hpp b/kt-kernel/operators/common.hpp index da7bbd4bd..1e42b4929 100644 --- a/kt-kernel/operators/common.hpp +++ b/kt-kernel/operators/common.hpp @@ -346,6 +346,15 @@ struct MOESFTConfig : public GeneralMOEConfig { void* down_lora_a = nullptr; // [expert_num, lora_rank, intermediate_size] void* down_lora_b = nullptr; // [expert_num, hidden_size, lora_rank] + // Full weight gradient configuration + bool full_weight_grad = false; + + // Base weight gradient buffer pointers (directly pointing to Python tensor memory, zero-copy) + // Only used when full_weight_grad == true + void* grad_gate_proj = nullptr; // [expert_num, intermediate_size, hidden_size] + void* grad_up_proj = nullptr; // [expert_num, intermediate_size, hidden_size] + void* grad_down_proj = nullptr; // [expert_num, hidden_size, intermediate_size] + MOESFTConfig() : GeneralMOEConfig() {} MOESFTConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size) diff --git a/kt-kernel/operators/moe-sft-tp.hpp b/kt-kernel/operators/moe-sft-tp.hpp index afd845997..8801166a2 100644 --- a/kt-kernel/operators/moe-sft-tp.hpp +++ b/kt-kernel/operators/moe-sft-tp.hpp @@ -343,7 +343,7 @@ class TP_MOE_SFT : public TP_MOE { throw std::runtime_error("K2 pre-quantized mode does not support TP > 1 yet"); } } else if (config.gate_proj != nullptr) { - printf("TP_MOE_SFT: From BF16 with partitioning\n"); + // printf("TP_MOE_SFT: From BF16 with partitioning\n"); // Temporary storage for partitioned weights std::vector temp_gate(tp_count); @@ -548,7 +548,8 @@ class TP_MOE_SFT : public TP_MOE { */ void backward(const void* grad_output, void* grad_input, void* grad_gate_lora_a, void* grad_gate_lora_b, void* grad_up_lora_a, void* grad_up_lora_b, void* grad_down_lora_a, void* grad_down_lora_b, - void* grad_weights) { + void* grad_weights, void* grad_gate_proj = nullptr, void* grad_up_proj = nullptr, + void* grad_down_proj = nullptr) { auto pool = config.pool; @@ -694,7 +695,8 @@ class TP_MOE_SFT : public TP_MOE { tp_down_a_ptr[numa_id], /* copy-type: direct write */ nullptr, /* grad_down_lora_b — unused, FP32 path below */ part_grad_weights_[numa_id], full_intermediate_size, tp_fp32_down_b[numa_id], - tp_fp32_gate_a[numa_id], tp_fp32_up_a[numa_id]); + tp_fp32_gate_a[numa_id], tp_fp32_up_a[numa_id], + grad_gate_proj, grad_up_proj, grad_down_proj); }); // // Collect per-thread timing from all NUMA subpools @@ -891,10 +893,11 @@ class TP_MOE_SFT : public TP_MOE { */ void backward_binding(intptr_t grad_output, intptr_t grad_input, intptr_t grad_gate_lora_a, intptr_t grad_gate_lora_b, intptr_t grad_up_lora_a, intptr_t grad_up_lora_b, intptr_t grad_down_lora_a, - intptr_t grad_down_lora_b, intptr_t grad_weights) { + intptr_t grad_down_lora_b, intptr_t grad_weights, intptr_t grad_gate_proj, + intptr_t grad_up_proj, intptr_t grad_down_proj) { backward((const void*)grad_output, (void*)grad_input, (void*)grad_gate_lora_a, (void*)grad_gate_lora_b, (void*)grad_up_lora_a, (void*)grad_up_lora_b, (void*)grad_down_lora_a, (void*)grad_down_lora_b, - (void*)grad_weights); + (void*)grad_weights, (void*)grad_gate_proj, (void*)grad_up_proj, (void*)grad_down_proj); } /** @@ -1106,6 +1109,22 @@ class TP_MOE_SFT : public TP_MOE { update_lora_weights((void*)gate_lora_a, (void*)gate_lora_b, (void*)up_lora_a, (void*)up_lora_b, (void*)down_lora_a, (void*)down_lora_b); } + + /** + * @brief Update base weight BF16 pointers for reload_base_weights (full mode training). + * + * After calling this, call load_weights_task() to re-quantize BF16->AMX + * and update the C++ kernel's internal quantized buffers. + * This avoids creating a new C++ MOE object (~0.6s/layer for quantization + * vs ~1.9s/layer for full object recreation). + */ + void set_base_weight_pointers(void* gate, void* up, void* down) { + config.gate_proj = gate; + config.up_proj = up; + config.down_proj = down; + // Mark that weights need re-loading (partitioning + quantization) + weights_loaded = false; + } }; #endif // CPUINFER_OPERATOR_MOE_SFT_TP_HPP diff --git a/kt-kernel/python/experts.py b/kt-kernel/python/experts.py index 2e99bf339..d54c43eeb 100644 --- a/kt-kernel/python/experts.py +++ b/kt-kernel/python/experts.py @@ -143,6 +143,8 @@ def __new__( # Quantization config (for K-Group SFT methods) group_size: int = 128, zero_point: bool = True, + # Full weight gradient mode (for full fine-tuning without LoRA) + full_weight_grad: bool = False, # V4-Flash 2604B SwiGLU clamp limit. 0.0 = disabled (default for # every dtype except DSV4-2604B routed experts, which set this to # 10.0 to match trtllm gemm1_clamp_limit / deep_gemm @@ -247,6 +249,7 @@ def __new__( max_cache_depth=max_cache_depth, group_size=group_size, zero_point=zero_point, + full_weight_grad=full_weight_grad, ) # Forward static methods to the base class @@ -292,6 +295,7 @@ def clear_sft_buffer_cache(): to reset the buffer state or free memory during SFT. """ from .sft.base import KExpertsSFTBuffer + KExpertsSFTBuffer.clear_cache() @@ -393,6 +397,7 @@ def _create_sft_wrapper( max_cache_depth: int, group_size: int, zero_point: bool, + full_weight_grad: bool = False, ): """ Create an SFT wrapper based on the method. @@ -423,4 +428,5 @@ def _create_sft_wrapper( method=method, group_size=group_size, zero_point=zero_point, + full_weight_grad=full_weight_grad, ) diff --git a/kt-kernel/python/sft/__init__.py b/kt-kernel/python/sft/__init__.py index 7cab43bd2..88b266e08 100644 --- a/kt-kernel/python/sft/__init__.py +++ b/kt-kernel/python/sft/__init__.py @@ -14,8 +14,15 @@ from .base import BaseSFTMoEWrapper, KExpertsSFTBuffer from .amx import AMXSFTMoEWrapper from .arch import ( - MOEArchConfig, get_moe_arch_config, get_moe_module, move_non_experts_to_gpu, get_expert_device, - KTAMXError, KTAMXNotAvailableError, KTAMXModelNotSupportedError, KTAMXConfigError, + MOEArchConfig, + get_moe_arch_config, + get_moe_module, + move_non_experts_to_gpu, + get_expert_device, + KTAMXError, + KTAMXNotAvailableError, + KTAMXModelNotSupportedError, + KTAMXConfigError, ) from .autograd import KTMoEFunction from .layer import KTMoELayerWrapper @@ -28,6 +35,7 @@ from .lora import ( kt_adapt_peft_lora, get_kt_lora_params, + get_kt_trainable_params, update_kt_lora_pointers, sync_kt_lora_gradients, save_lora_experts_to_adapter, @@ -67,6 +75,7 @@ "INT8ExpertWeights", "kt_adapt_peft_lora", "get_kt_lora_params", + "get_kt_trainable_params", "update_kt_lora_pointers", "sync_kt_lora_gradients", "save_lora_experts_to_adapter", diff --git a/kt-kernel/python/sft/amx.py b/kt-kernel/python/sft/amx.py index 3f3270f01..595f1e081 100644 --- a/kt-kernel/python/sft/amx.py +++ b/kt-kernel/python/sft/amx.py @@ -9,6 +9,7 @@ from __future__ import annotations import ctypes +import logging import os import glob as _glob import torch @@ -16,6 +17,8 @@ from kt_kernel_ext.moe import MOESFTConfig +logger = logging.getLogger(__name__) + from ..utils.loader import BF16SafeTensorLoader, SafeTensorLoader try: @@ -79,6 +82,7 @@ def __init__( method: str = "AMXBF16_SFT", group_size: int = 128, zero_point: bool = True, + full_weight_grad: bool = False, ): if not _HAS_AMX_SFT_SUPPORT: raise RuntimeError( @@ -100,6 +104,7 @@ def __init__( lora_rank=lora_rank, lora_alpha=lora_alpha, max_cache_depth=max_cache_depth, + full_weight_grad=full_weight_grad, ) self.method = method @@ -138,19 +143,42 @@ def _make_backward_task(self, buffer: KExpertsSFTBuffer): return self.moe.backward_task( buffer.grad_output_cpu.data_ptr(), buffer.grad_input_cpu.data_ptr(), - 0, 0, 0, 0, 0, 0, + 0, + 0, + 0, + 0, + 0, + 0, buffer.grad_weights.data_ptr(), + 0, + 0, + 0, # grad_gate_proj, grad_up_proj, grad_down_proj ) + + # Base weight grad pointers (nullptr if not in full mode) + grad_gate_proj_ptr = ( + self.grad_gate_proj_buf.data_ptr() if self._full_weight_grad and self.grad_gate_proj_buf is not None else 0 + ) + grad_up_proj_ptr = ( + self.grad_up_proj_buf.data_ptr() if self._full_weight_grad and self.grad_up_proj_buf is not None else 0 + ) + grad_down_proj_ptr = ( + self.grad_down_proj_buf.data_ptr() if self._full_weight_grad and self.grad_down_proj_buf is not None else 0 + ) + return self.moe.backward_task( buffer.grad_output_cpu.data_ptr(), buffer.grad_input_cpu.data_ptr(), - self.grad_gate_lora_a.data_ptr(), - self.grad_gate_lora_b.data_ptr(), - self.grad_up_lora_a.data_ptr(), - self.grad_up_lora_b.data_ptr(), - self.grad_down_lora_a.data_ptr(), - self.grad_down_lora_b.data_ptr(), + self.grad_gate_lora_a.data_ptr() if self.lora_rank > 0 else 0, + self.grad_gate_lora_b.data_ptr() if self.lora_rank > 0 else 0, + self.grad_up_lora_a.data_ptr() if self.lora_rank > 0 else 0, + self.grad_up_lora_b.data_ptr() if self.lora_rank > 0 else 0, + self.grad_down_lora_a.data_ptr() if self.lora_rank > 0 else 0, + self.grad_down_lora_b.data_ptr() if self.lora_rank > 0 else 0, buffer.grad_weights.data_ptr(), + grad_gate_proj_ptr, + grad_up_proj_ptr, + grad_down_proj_ptr, ) # ========== Weight loading ========== @@ -174,6 +202,7 @@ def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None: config.layer_idx = self.layer_idx config.share_backward_bb = getattr(self, "share_backward_bb", False) config.share_cache_pool = getattr(self, "share_cache_pool", False) + config.full_weight_grad = self._full_weight_grad if getattr(self, "_use_kt_direct_load", False): config.load = True @@ -215,6 +244,13 @@ def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None: config.quant_config.group_size = self.group_size config.quant_config.zero_point = self.zero_point + # Release old C++ MOE object before creating a new one to avoid memory leak + old_moe = getattr(self, "moe", None) + if old_moe is not None: + del old_moe + import gc + gc.collect() + self.moe = self._moe_class(config) self.cpu_infer.submit(self.moe.load_weights_task()) @@ -224,9 +260,11 @@ def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None: self.cpu_infer.sync() # Release Python-side weight tensors (C++ copied them) - self.gate_proj = None - self.up_proj = None - self.down_proj = None + # In full_weight_grad mode, keep them for nn.Parameter initialization + if not self._full_weight_grad: + self.gate_proj = None + self.up_proj = None + self.down_proj = None if getattr(self, "_bf16_gate_proj", None) is not None: self._bf16_gate_proj = None @@ -235,18 +273,34 @@ def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None: if getattr(self, "_use_projs_path", False): for attr in [ - "_gate_weights_per_numa", "_up_weights_per_numa", "_down_weights_per_numa", - "_gate_scales_per_numa", "_up_scales_per_numa", "_down_scales_per_numa", - "_gate_projs_ptrs", "_up_projs_ptrs", "_down_projs_ptrs", - "_gate_scale_ptrs", "_up_scale_ptrs", "_down_scale_ptrs", + "_gate_weights_per_numa", + "_up_weights_per_numa", + "_down_weights_per_numa", + "_gate_scales_per_numa", + "_up_scales_per_numa", + "_down_scales_per_numa", + "_gate_projs_ptrs", + "_up_projs_ptrs", + "_down_projs_ptrs", + "_gate_scale_ptrs", + "_up_scale_ptrs", + "_down_scale_ptrs", ]: setattr(self, attr, None) if getattr(self, "_has_bwd_projs", False): for attr in [ - "_gate_bwd_weights_per_numa", "_up_bwd_weights_per_numa", "_down_bwd_weights_per_numa", - "_gate_bwd_scales_per_numa", "_up_bwd_scales_per_numa", "_down_bwd_scales_per_numa", - "_gate_bwd_projs_ptrs", "_up_bwd_projs_ptrs", "_down_bwd_projs_ptrs", - "_gate_bwd_scale_ptrs", "_up_bwd_scale_ptrs", "_down_bwd_scale_ptrs", + "_gate_bwd_weights_per_numa", + "_up_bwd_weights_per_numa", + "_down_bwd_weights_per_numa", + "_gate_bwd_scales_per_numa", + "_up_bwd_scales_per_numa", + "_down_bwd_scales_per_numa", + "_gate_bwd_projs_ptrs", + "_up_bwd_projs_ptrs", + "_down_bwd_projs_ptrs", + "_gate_bwd_scale_ptrs", + "_up_bwd_scale_ptrs", + "_down_bwd_scale_ptrs", ]: setattr(self, attr, None) @@ -297,6 +351,7 @@ def _load_base_weights_from_file(self) -> None: self.up_proj = torch.stack(up_weights, dim=0).contiguous() self.down_proj = torch.stack(down_weights, dim=0).contiguous() else: + def _make_ptrs(arrays_per_numa): return [ [ @@ -349,12 +404,18 @@ def _make_ptrs(arrays_per_numa): def init_lora_weights( self, - gate_lora_a: torch.Tensor, gate_lora_b: torch.Tensor, - up_lora_a: torch.Tensor, up_lora_b: torch.Tensor, - down_lora_a: torch.Tensor, down_lora_b: torch.Tensor, - grad_gate_lora_a: torch.Tensor, grad_gate_lora_b: torch.Tensor, - grad_up_lora_a: torch.Tensor, grad_up_lora_b: torch.Tensor, - grad_down_lora_a: torch.Tensor, grad_down_lora_b: torch.Tensor, + gate_lora_a: torch.Tensor, + gate_lora_b: torch.Tensor, + up_lora_a: torch.Tensor, + up_lora_b: torch.Tensor, + down_lora_a: torch.Tensor, + down_lora_b: torch.Tensor, + grad_gate_lora_a: torch.Tensor, + grad_gate_lora_b: torch.Tensor, + grad_up_lora_a: torch.Tensor, + grad_up_lora_b: torch.Tensor, + grad_down_lora_a: torch.Tensor, + grad_down_lora_b: torch.Tensor, ) -> None: expected_shapes = { "gate_lora_a": (self.num_experts, self.lora_rank, self.hidden_size), @@ -365,9 +426,12 @@ def init_lora_weights( "down_lora_b": (self.num_experts, self.hidden_size, self.lora_rank), } provided = { - "gate_lora_a": gate_lora_a, "gate_lora_b": gate_lora_b, - "up_lora_a": up_lora_a, "up_lora_b": up_lora_b, - "down_lora_a": down_lora_a, "down_lora_b": down_lora_b, + "gate_lora_a": gate_lora_a, + "gate_lora_b": gate_lora_b, + "up_lora_a": up_lora_a, + "up_lora_b": up_lora_b, + "down_lora_a": down_lora_a, + "down_lora_b": down_lora_b, } for name, tensor in provided.items(): expected = expected_shapes[name] @@ -399,6 +463,8 @@ def update_lora_weights(self) -> None: if self._is_skip_lora: return if not self._lora_initialized: + if self.lora_rank <= 0: + return # Full mode without LoRA — no LoRA weights to update raise RuntimeError("LoRA weights not initialized. Call init_lora_weights() first.") self.cpu_infer.submit( @@ -413,6 +479,52 @@ def update_lora_weights(self) -> None: ) self.cpu_infer.sync() + def update_base_weights(self) -> None: + """Sync updated base weight parameters back to C++ kernel after optimizer step.""" + if not self._weights_loaded: + raise RuntimeError("Weights not loaded. Call load_weights() first.") + if not self._full_weight_grad: + return # No base weights to update in LoRA mode + if self.gate_proj_buf is None: + raise RuntimeError("Base weight buffers not initialized. Call init_full_weight_grad_buffers() first.") + + logger.info(f"Layer {self.layer_idx}: update_base_weights() - syncing updated weights to C++ kernel") + + # Preferred path: update config pointers on existing C++ object and re-quantize. + # This avoids full C++ MOE object recreation (~0.6s/layer vs ~1.9s/layer). + if hasattr(self.moe, "set_base_weight_pointers"): + self.moe.set_base_weight_pointers( + self.gate_proj_buf.data.data_ptr(), + self.up_proj_buf.data.data_ptr(), + self.down_proj_buf.data.data_ptr(), + ) + self.cpu_infer.submit(self.moe.load_weights_task()) + self.cpu_infer.sync() + logger.info(f"Layer {self.layer_idx}: update_base_weights() - re-quantized existing kernel") + return + + # Fallback: full reload path (creates new C++ MOE object) + # This is slower but works without C++ set_base_weight_pointers support. + logger.warning( + f"Layer {self.layer_idx}: set_base_weight_pointers not available, " + f"falling back to full C++ MOE object recreation" + ) + old_moe = getattr(self, "moe", None) + if old_moe is not None: + del old_moe + + self.gate_proj = self.gate_proj_buf.data + self.up_proj = self.up_proj_buf.data + self.down_proj = self.down_proj_buf.data + physical_to_logical_map = torch.arange(self.num_experts, dtype=torch.int64, device="cpu") + self._weights_loaded = False # Allow re-load + self.load_weights_from_tensors( + gate_proj=self.gate_proj, + up_proj=self.up_proj, + down_proj=self.down_proj, + physical_to_logical_map_cpu=physical_to_logical_map, + ) + def save_backward_weights_from_tensors( self, gate_proj: torch.Tensor, diff --git a/kt-kernel/python/sft/autograd.py b/kt-kernel/python/sft/autograd.py index 0264e9de6..98c92ed97 100644 --- a/kt-kernel/python/sft/autograd.py +++ b/kt-kernel/python/sft/autograd.py @@ -38,12 +38,17 @@ def forward( training: bool, train_lora: bool, all_qlens: list[int] | tuple[int, ...] | None, + gate_proj_param: torch.Tensor | None = None, + up_proj_param: torch.Tensor | None = None, + down_proj_param: torch.Tensor | None = None, ) -> torch.Tensor: if _KT_SFT_DEBUG: logging.debug( "KTMoEFunction.forward: layer=%d training=%s train_lora=%s", - layer_idx, training, train_lora, + layer_idx, + training, + train_lora, ) original_device = hidden_states.device @@ -52,6 +57,7 @@ def forward( qlen = batch_size * seq_len import torch.distributed as dist + dist_on = dist.is_initialized() and dist.get_world_size() > 1 rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist_on else 1 @@ -65,13 +71,9 @@ def forward( else: all_qlens_list = [int(q) for q in all_qlens] if len(all_qlens_list) != world_size: - raise RuntimeError( - f"all_qlens length mismatch: got {len(all_qlens_list)}, expected {world_size}" - ) + raise RuntimeError(f"all_qlens length mismatch: got {len(all_qlens_list)}, expected {world_size}") if int(all_qlens_list[rank]) != qlen: - raise RuntimeError( - f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens_list[rank]}" - ) + raise RuntimeError(f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens_list[rank]}") total_qlen = sum(all_qlens_list) # Rank 0: sync CPU result and split by real lengths @@ -100,9 +102,7 @@ def forward( output = cpu_output.view(batch_size, seq_len, hidden_size).to(dtype=original_dtype) else: # Broadcast-only rank (no wrapper) - output = torch.empty( - batch_size, seq_len, hidden_size, device=original_device, dtype=original_dtype - ) + output = torch.empty(batch_size, seq_len, hidden_size, device=original_device, dtype=original_dtype) ctx.wrapper = wrapper ctx.hidden_size = hidden_size @@ -120,6 +120,11 @@ def forward( ctx.num_experts_per_tok = num_experts_per_tok ctx.layer_idx = layer_idx + # Store base weight param references for gradient flow in full mode + ctx.full_weight_grad = ( + wrapper is not None and getattr(wrapper, "_full_weight_grad", False) and gate_proj_param is not None + ) + # Save a sentinel tensor so non-reentrant checkpoint's saved_tensors # hooks can intercept it. When backward accesses ctx.saved_tensors, # the checkpoint unpack hook triggers a full recompute of the decoder @@ -135,7 +140,7 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor): # Wait for any in-flight async repack before recompute forward uses the pool - if getattr(ctx.wrapper, 'share_backward_bb', False): + if getattr(ctx.wrapper, "share_backward_bb", False): ctx.wrapper.wait_backward_repack() # Access saved_tensors FIRST — under non-reentrant checkpoint this @@ -152,12 +157,15 @@ def backward(ctx, grad_output: torch.Tensor): num_experts_per_tok = ctx.num_experts_per_tok import torch.distributed as dist + rank = dist.get_rank() if dist.is_initialized() else 0 if _KT_SFT_DEBUG: logging.debug( "KTMoEFunction.backward: layer=%d dist_on=%s qlen=%d", - getattr(ctx, "layer_idx", -1), dist_on, qlen, + getattr(ctx, "layer_idx", -1), + dist_on, + qlen, ) if dist_on: @@ -243,12 +251,39 @@ def backward(ctx, grad_output: torch.Tensor): grad_weights = grad_weights.to(dtype=torch.bfloat16) else: # No wrapper, no dist — shouldn't happen in normal flow - grad_input = torch.zeros(batch_size, seq_len, hidden_size, device=ctx.original_device, dtype=ctx.original_dtype) + grad_input = torch.zeros( + batch_size, seq_len, hidden_size, device=ctx.original_device, dtype=ctx.original_dtype + ) grad_weights = torch.zeros(ctx.weights_shape, device=ctx.weights_device, dtype=ctx.weights_dtype) # Trigger async repack for next MoE layer in backward order - next_bwd = getattr(ctx.wrapper, '_next_backward_wrapper', None) - if next_bwd is not None and getattr(next_bwd, 'share_backward_bb', False): + next_bwd = getattr(ctx.wrapper, "_next_backward_wrapper", None) + if next_bwd is not None and getattr(next_bwd, "share_backward_bb", False): next_bwd.submit_backward_repack() - return grad_input, None, grad_weights, None, None, None, None, None, None, None, None + # Base weight gradients: return C++-written grad buffers in full mode, None otherwise + if ctx.full_weight_grad and ctx.wrapper is not None: + grad_gate_proj = ctx.wrapper.grad_gate_proj_buf + grad_up_proj = ctx.wrapper.grad_up_proj_buf + grad_down_proj = ctx.wrapper.grad_down_proj_buf + else: + grad_gate_proj = None + grad_up_proj = None + grad_down_proj = None + + return ( + grad_input, + None, + grad_weights, + None, + None, + None, + None, + None, + None, + None, + None, + grad_gate_proj, + grad_up_proj, + grad_down_proj, + ) diff --git a/kt-kernel/python/sft/base.py b/kt-kernel/python/sft/base.py index 25b0e2cb3..d56e248db 100644 --- a/kt-kernel/python/sft/base.py +++ b/kt-kernel/python/sft/base.py @@ -125,6 +125,7 @@ def __init__( lora_rank: int = 16, lora_alpha: float = 32.0, max_cache_depth: int = 1, + full_weight_grad: bool = False, ): self.cpu_infer = self._get_cpu_infer(cpuinfer_threads, threadpool_count) @@ -134,7 +135,7 @@ def __init__( moe_intermediate_size=moe_intermediate_size, num_experts_per_tok=num_experts_per_tok, ) - self._validate_sft_config(lora_rank, lora_alpha, max_cache_depth) + self._validate_sft_config(lora_rank, lora_alpha, max_cache_depth, full_weight_grad=full_weight_grad) self.layer_idx = layer_idx self.num_experts = num_experts @@ -148,9 +149,11 @@ def __init__( self.lora_rank = lora_rank self.lora_alpha = lora_alpha - self.lora_scaling = lora_alpha / lora_rank + self.lora_scaling = lora_alpha / lora_rank if lora_rank > 0 else 0.0 self.max_cache_depth = max_cache_depth + self._full_weight_grad = full_weight_grad + self.gate_lora_a: Optional[torch.Tensor] = None self.gate_lora_b: Optional[torch.Tensor] = None self.up_lora_a: Optional[torch.Tensor] = None @@ -158,22 +161,75 @@ def __init__( self.down_lora_a: Optional[torch.Tensor] = None self.down_lora_b: Optional[torch.Tensor] = None + # Base weight parameters for full fine-tuning + self.gate_proj_buf: Optional[torch.Tensor] = None + self.up_proj_buf: Optional[torch.Tensor] = None + self.down_proj_buf: Optional[torch.Tensor] = None + self.grad_gate_proj_buf: Optional[torch.Tensor] = None + self.grad_up_proj_buf: Optional[torch.Tensor] = None + self.grad_down_proj_buf: Optional[torch.Tensor] = None + self._weights_loaded: bool = False self._lora_initialized: bool = False self._cache_depth: int = 0 self._is_skip_lora: bool = False + self._base_weights_dirty: bool = False self.moe = None @staticmethod - def _validate_sft_config(lora_rank: int, lora_alpha: float, max_cache_depth: int) -> None: - if lora_rank <= 0: - raise ValueError(f"lora_rank must be positive, got {lora_rank}") - if lora_alpha <= 0: + def _validate_sft_config( + lora_rank: int, lora_alpha: float, max_cache_depth: int, full_weight_grad: bool = False + ) -> None: + if not full_weight_grad and lora_rank <= 0: + raise ValueError( + f"lora_rank must be positive in LoRA mode, got {lora_rank}. " + "Set kt_train_mode='full' for full fine-tuning." + ) + if lora_rank > 0 and lora_alpha <= 0: raise ValueError(f"lora_alpha must be positive, got {lora_alpha}") if max_cache_depth <= 0: raise ValueError(f"max_cache_depth must be positive, got {max_cache_depth}") + # ========== Full weight grad methods ========== + + def init_full_weight_grad_buffers( + self, gate_proj: torch.Tensor, up_proj: torch.Tensor, down_proj: torch.Tensor + ) -> None: + """Initialize base weight nn.Parameter buffers and gradient buffers for full fine-tuning. + + Args: + gate_proj: [num_experts, intermediate_size, hidden_size] BF16 CPU tensor + up_proj: [num_experts, intermediate_size, hidden_size] BF16 CPU tensor + down_proj: [num_experts, hidden_size, intermediate_size] BF16 CPU tensor + """ + import torch.nn as nn + + dtype = torch.bfloat16 + E = self.num_experts + I = self.moe_intermediate_size + H = self.hidden_size + + # Create nn.Parameter buffers (optimizer-visible) + self.gate_proj_buf = nn.Parameter(gate_proj.to(dtype=dtype, device="cpu").contiguous(), requires_grad=True) + self.up_proj_buf = nn.Parameter(up_proj.to(dtype=dtype, device="cpu").contiguous(), requires_grad=True) + self.down_proj_buf = nn.Parameter(down_proj.to(dtype=dtype, device="cpu").contiguous(), requires_grad=True) + + # Create gradient buffers (C++ writes directly to these) + self.grad_gate_proj_buf = torch.zeros(E, I, H, dtype=dtype, device="cpu") + self.grad_up_proj_buf = torch.zeros(E, I, H, dtype=dtype, device="cpu") + self.grad_down_proj_buf = torch.zeros(E, H, I, dtype=dtype, device="cpu") + + # Note: .grad is NOT pre-assigned here. PyTorch autograd will set it + # when KTMoEFunction.backward() returns the gradient buffers. + # The C++ kernel writes directly to grad_gate_proj_buf etc., + # and backward returns them so PyTorch can propagate correctly. + + @abstractmethod + def update_base_weights(self) -> None: + """Sync updated base weight parameters back to C++ kernel after optimizer step.""" + ... + # ========== Abstract methods for subclasses ========== @abstractmethod @@ -187,24 +243,27 @@ def _make_backward_task(self, buffer: KExpertsSFTBuffer): ... @abstractmethod - def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None: - ... + def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None: ... @abstractmethod def init_lora_weights( self, - gate_lora_a: torch.Tensor, gate_lora_b: torch.Tensor, - up_lora_a: torch.Tensor, up_lora_b: torch.Tensor, - down_lora_a: torch.Tensor, down_lora_b: torch.Tensor, - grad_gate_lora_a: torch.Tensor, grad_gate_lora_b: torch.Tensor, - grad_up_lora_a: torch.Tensor, grad_up_lora_b: torch.Tensor, - grad_down_lora_a: torch.Tensor, grad_down_lora_b: torch.Tensor, - ) -> None: - ... + gate_lora_a: torch.Tensor, + gate_lora_b: torch.Tensor, + up_lora_a: torch.Tensor, + up_lora_b: torch.Tensor, + down_lora_a: torch.Tensor, + down_lora_b: torch.Tensor, + grad_gate_lora_a: torch.Tensor, + grad_gate_lora_b: torch.Tensor, + grad_up_lora_a: torch.Tensor, + grad_up_lora_b: torch.Tensor, + grad_down_lora_a: torch.Tensor, + grad_down_lora_b: torch.Tensor, + ) -> None: ... @abstractmethod - def update_lora_weights(self) -> None: - ... + def update_lora_weights(self) -> None: ... # ========== Buffer helpers ========== @@ -222,7 +281,7 @@ def _get_buffer(self, qlen: int) -> KExpertsSFTBuffer: def _validate_forward_inputs(self, hidden_states: torch.Tensor, expert_ids: torch.Tensor, weights: torch.Tensor): if not self._weights_loaded: raise RuntimeError("Weights not loaded. Call load_weights() or load_weights_from_tensors() first.") - if not self._lora_initialized and not self._is_skip_lora: + if not self._lora_initialized and not self._is_skip_lora and not self._full_weight_grad: raise RuntimeError("LoRA weights not initialized. Call init_lora_weights() first.") qlen = hidden_states.shape[0] if qlen > self.chunked_prefill_size: @@ -235,12 +294,16 @@ def _validate_forward_inputs(self, hidden_states: torch.Tensor, expert_ids: torc f"expert_ids shape {tuple(expert_ids.shape)} must be ({qlen}, {self.num_experts_per_tok})." ) if weights.shape[0] != qlen or weights.shape[1] != self.num_experts_per_tok: - raise ValueError( - f"weights shape {tuple(weights.shape)} must be ({qlen}, {self.num_experts_per_tok})." - ) + raise ValueError(f"weights shape {tuple(weights.shape)} must be ({qlen}, {self.num_experts_per_tok}).") - def _copy_inputs_to_buffer(self, buffer: KExpertsSFTBuffer, hidden_states: torch.Tensor, - expert_ids: torch.Tensor, weights: torch.Tensor, qlen: int) -> torch.device: + def _copy_inputs_to_buffer( + self, + buffer: KExpertsSFTBuffer, + hidden_states: torch.Tensor, + expert_ids: torch.Tensor, + weights: torch.Tensor, + qlen: int, + ) -> torch.device: """Copy inputs to CPU buffer, return input device.""" input_device = hidden_states.device buffer.input_cpu[:qlen].copy_(hidden_states.to(torch.bfloat16), non_blocking=True) @@ -392,11 +455,11 @@ def sync_backward(self) -> Tuple[torch.Tensor, torch.Tensor]: def submit_backward_repack(self): if not self._weights_loaded or self.moe is None: return - if hasattr(self.moe, 'submit_backward_repack'): + if hasattr(self.moe, "submit_backward_repack"): self.moe.submit_backward_repack() def wait_backward_repack(self): if not self._weights_loaded or self.moe is None: return - if hasattr(self.moe, 'wait_backward_repack'): + if hasattr(self.moe, "wait_backward_repack"): self.moe.wait_backward_repack() diff --git a/kt-kernel/python/sft/config.py b/kt-kernel/python/sft/config.py index 35af4d3d6..0172ad5dd 100644 --- a/kt-kernel/python/sft/config.py +++ b/kt-kernel/python/sft/config.py @@ -75,6 +75,10 @@ class KTConfig: kt_lora_rank: int | None = None kt_lora_alpha: float | None = None + # Training mode + kt_train_mode: str | None = None # "lora" | "full" | "hybrid" + kt_full_weight_grad: bool | None = None # auto-set True when train_mode in (full, hybrid) + # LoRA Experts (GPU-side extra experts) kt_use_lora_experts: bool | None = None kt_lora_expert_num: int | None = None @@ -132,6 +136,10 @@ def __post_init__(self): self.kt_lora_alpha = _env_float("ACCELERATE_KT_LORA_ALPHA", None) if self.kt_lora_alpha is None and self.kt_lora_rank is not None: self.kt_lora_alpha = float(self.kt_lora_rank * 2) + if self.kt_train_mode is None: + self.kt_train_mode = os.environ.get("ACCELERATE_KT_TRAIN_MODE", "lora") + if self.kt_full_weight_grad is None: + self.kt_full_weight_grad = self.kt_train_mode in ("full", "hybrid") if self.kt_model_max_length is None: self.kt_model_max_length = _env_int("ACCELERATE_KT_MODEL_MAX_LENGTH", None) if self.kt_skip_expert_loading is None: diff --git a/kt-kernel/python/sft/layer.py b/kt-kernel/python/sft/layer.py index e4cb2b657..c28e4a843 100644 --- a/kt-kernel/python/sft/layer.py +++ b/kt-kernel/python/sft/layer.py @@ -13,6 +13,7 @@ import logging import os +from contextlib import nullcontext from typing import Any import torch @@ -62,7 +63,29 @@ def __init__( # 1. gate/router FIRST - keep original attribute name for PEFT compatibility router_attr = moe_config.router_attr # "gate" for Qwen3/DeepSeek - setattr(self, router_attr, getattr(original_moe, router_attr, None)) + original_router = getattr(original_moe, router_attr, None) + self._original_router = None # Set when router is not nn.Linear (e.g. TopKRouter) + + if original_router is not None and isinstance(original_router, nn.Linear): + # transformers <=4.x / some models: gate is nn.Linear - register directly. + setattr(self, router_attr, original_router) + elif original_router is not None and hasattr(original_router, "weight") and isinstance( + getattr(original_router, "weight"), nn.Parameter + ): + # transformers v5+: gate is a TopKRouter with nn.Parameter weight. + # Wrap it in nn.Linear so PEFT can discover and inject LoRA. + # The nn.Linear shares the same weight tensor - LoRA applied to it + # is equivalent to LoRA on the original gate. + router_weight = original_router.weight + router_linear = nn.Linear( + router_weight.shape[1], router_weight.shape[0], bias=False, + ) + router_linear.weight = router_weight # share the same parameter + setattr(self, router_attr, router_linear) + # Keep the original router for forward (top-k selection logic) + self._original_router = original_router + else: + setattr(self, router_attr, original_router) self._router_attr = router_attr # 2. experts SECOND (this is what PEFT targets for LoRA) @@ -84,10 +107,13 @@ def __init__( self._peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]] | None = None self._lora_pointers_dirty = False + # Full weight grad mode (set during wrapping or kt_adapt_peft_lora) + self._full_weight_grad = getattr(wrapper, "_full_weight_grad", False) if wrapper is not None else False + def _apply(self, fn, recurse=True): # Protect experts from device transfer (PEFT LoRA should stay on CPU for KT) saved_experts = None - experts_attr = getattr(self, '_experts_attr', None) + experts_attr = getattr(self, "_experts_attr", None) if experts_attr is not None and getattr(self, experts_attr, None) is not None: saved_experts = getattr(self, experts_attr) @@ -103,6 +129,7 @@ def _apply(self, fn, recurse=True): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: import torch.distributed as dist + dist_on = dist.is_initialized() and dist.get_world_size() > 1 rank = dist.get_rank() if dist.is_initialized() else 0 @@ -112,11 +139,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: topk_ids, topk_weights = self._compute_routing(hidden_states) train_lora = self._peft_lora_modules is not None and len(self._peft_lora_modules) > 0 + full_weight_grad = self._full_weight_grad save_for_backward = ( self.training and torch.is_grad_enabled() - and (hidden_states.requires_grad or topk_weights.requires_grad or train_lora) + and (hidden_states.requires_grad or topk_weights.requires_grad or train_lora or full_weight_grad) ) use_autograd_path = save_for_backward save_for_backward_submit = use_autograd_path @@ -127,6 +155,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: self.update_lora_pointers() self._lora_pointers_dirty = False + # In full_weight_grad mode, sync base weights after optimizer step + if full_weight_grad and getattr(self.wrapper, "_base_weights_dirty", False): + self.wrapper.update_base_weights() + self.wrapper._base_weights_dirty = False + gpu_output, all_qlens = self._submit_and_compute_gpu( hidden_states, topk_ids, @@ -141,11 +174,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if train_lora and self._peft_lora_modules: for expert_loras in self._peft_lora_modules.values(): for lora_A, lora_B in expert_loras.values(): - if hasattr(lora_A, 'weight') and lora_A.weight.requires_grad: + if hasattr(lora_A, "weight") and lora_A.weight.requires_grad: lora_ref = lora_A.weight break if lora_ref.numel() > 0: break + elif full_weight_grad and self.wrapper is not None: + # In full mode, use base weight param as autograd sentinel + if self.wrapper.gate_proj_buf is not None: + lora_ref = self.wrapper.gate_proj_buf moe_output = KTMoEFunction.apply( hidden_states, @@ -159,6 +196,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: save_for_backward, train_lora, all_qlens, + # Base weight params for full mode gradient flow + self.wrapper.gate_proj_buf if full_weight_grad and self.wrapper is not None else None, + self.wrapper.up_proj_buf if full_weight_grad and self.wrapper is not None else None, + self.wrapper.down_proj_buf if full_weight_grad and self.wrapper is not None else None, ) else: moe_output = self._sync_forward_output_no_autograd( @@ -194,13 +235,9 @@ def _sync_forward_output_no_autograd( else: all_qlens_list = [int(q) for q in all_qlens] if len(all_qlens_list) != world_size: - raise RuntimeError( - f"all_qlens length mismatch: got {len(all_qlens_list)}, expected {world_size}" - ) + raise RuntimeError(f"all_qlens length mismatch: got {len(all_qlens_list)}, expected {world_size}") if int(all_qlens_list[rank]) != qlen: - raise RuntimeError( - f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens_list[rank]}" - ) + raise RuntimeError(f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens_list[rank]}") total_qlen = sum(all_qlens_list) if rank == 0: @@ -234,12 +271,11 @@ def _sync_forward_output_no_autograd( return torch.empty(batch_size, seq_len, self.hidden_size, device=original_device, dtype=original_dtype) def _compute_routing(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - # Run routing under no_grad to avoid creating autograd nodes whose - # SavedVariables become orphan holders inside gradient checkpoint. - # The gate is frozen during LoRA fine-tuning and the main gradient - # flows through KTMoEFunction.backward()'s grad_input, so the - # routing gradient contribution to hidden_states can be safely dropped. - with torch.no_grad(): + # In full_weight_grad mode, Router gradients should flow (no torch.no_grad). + # In LoRA mode, Router is frozen — wrap in no_grad to avoid orphan autograd nodes. + no_grad_ctx = torch.no_grad() if not self._full_weight_grad else nullcontext() + + with no_grad_ctx: router = getattr(self, self._router_attr) if self.router_type == "deepseek_gate": # DeepSeek V3's MoEGate has `assert not self.training` in its noaux_tc @@ -259,6 +295,23 @@ def _compute_routing(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, t topk_weights = topk_weights.to(torch.bfloat16) return topk_ids, topk_weights + # When _original_router is set, self.gate is an nn.Linear wrapper + # around the TopKRouter's weight. Use it (with PEFT LoRA if + # applied) for the linear projection, then replicate top-k logic. + if self._original_router is not None: + orig_router = self._original_router + router_logits = router(hidden_states.view(-1, self.hidden_size)) + router_probs = F.softmax(router_logits, dtype=torch.float, dim=-1) + top_k = getattr(orig_router, "top_k", self.moe_config.num_experts_per_tok) + norm_topk_prob = getattr(orig_router, "norm_topk_prob", True) + topk_weights, topk_ids = torch.topk(router_probs, top_k, dim=-1) + if norm_topk_prob: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.to(router_logits.dtype) + if topk_weights.is_floating_point(): + topk_weights = topk_weights.to(torch.bfloat16) + return topk_ids, topk_weights + router_output = router(hidden_states.view(-1, self.hidden_size)) # transformers v5 TopKRouter returns (router_logits, router_scores, router_indices) # directly — scores/indices are already topk-normalized. @@ -299,9 +352,7 @@ def _submit_and_compute_gpu( if dist_on: all_qlens = _all_gather_qlens(qlen, original_device, world_size) if int(all_qlens[rank]) != qlen: - raise RuntimeError( - f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens[rank]}" - ) + raise RuntimeError(f"Rank {rank} qlen mismatch: local={qlen}, all_qlens[{rank}]={all_qlens[rank]}") total_qlen = sum(all_qlens) hs_flat = hidden_states.view(qlen, self.hidden_size).contiguous() diff --git a/kt-kernel/python/sft/lora.py b/kt-kernel/python/sft/lora.py index 5a594ec8e..0d928b1ce 100644 --- a/kt-kernel/python/sft/lora.py +++ b/kt-kernel/python/sft/lora.py @@ -114,9 +114,9 @@ def get_kt_lora_params(model: nn.Module) -> list[nn.Parameter]: if peft_lora_modules is not None: for expert_loras in peft_lora_modules.values(): for lora_A, lora_B in expert_loras.values(): - if hasattr(lora_A, 'weight') and lora_A.weight.requires_grad: + if hasattr(lora_A, "weight") and lora_A.weight.requires_grad: params.append(lora_A.weight) - if hasattr(lora_B, 'weight') and lora_B.weight.requires_grad: + if hasattr(lora_B, "weight") and lora_B.weight.requires_grad: params.append(lora_B.weight) # Fused expert LoRA parameters (KT-managed, not PEFT) fused_params = getattr(wrapper, "_fused_expert_lora_params", None) @@ -129,6 +129,48 @@ def get_kt_lora_params(model: nn.Module) -> list[nn.Parameter]: return params +def get_kt_trainable_params(model: nn.Module) -> list[nn.Parameter]: + """Get all trainable parameters from KT model based on training mode. + + In full mode: returns base weight nn.Parameter buffers from wrappers. + In LoRA mode: returns LoRA parameters (same as get_kt_lora_params). + """ + wrappers = _find_kt_wrappers(model) + if not wrappers: + return [] + + # Check if any wrapper is in full_weight_grad mode + has_full_weight_grad = any(getattr(w, "_full_weight_grad", False) for w in wrappers) + + if has_full_weight_grad: + # Full mode: return base weight parameters + params: list[nn.Parameter] = [] + for wrapper in wrappers: + if getattr(wrapper, "_full_weight_grad", False) and wrapper.wrapper is not None: + if wrapper.wrapper.gate_proj_buf is not None: + params.append(wrapper.wrapper.gate_proj_buf) + if wrapper.wrapper.up_proj_buf is not None: + params.append(wrapper.wrapper.up_proj_buf) + if wrapper.wrapper.down_proj_buf is not None: + params.append(wrapper.wrapper.down_proj_buf) + # Also include LoRA params if in hybrid mode + peft_lora_modules = getattr(wrapper, "_peft_lora_modules", None) + if peft_lora_modules is not None: + for expert_loras in peft_lora_modules.values(): + for lora_A, lora_B in expert_loras.values(): + if hasattr(lora_A, "weight") and lora_A.weight.requires_grad: + params.append(lora_A.weight) + if hasattr(lora_B, "weight") and lora_B.weight.requires_grad: + params.append(lora_B.weight) + fused_params = getattr(wrapper, "_fused_expert_lora_params", None) + if fused_params is not None: + params.extend(fused_params) + return params + else: + # LoRA mode: return LoRA parameters + return get_kt_lora_params(model) + + # ============================================================================= # PEFT LoRA Adaptation # ============================================================================= @@ -175,8 +217,24 @@ def kt_adapt_peft_lora(model: nn.Module) -> None: # wrap as nn.Parameter for optimizer, and pre-assign .grad for C++ backward. if getattr(wrapper, "_fused_experts", False): lora_rank = getattr(wrapper, "_lora_rank", 1) + + # In full mode (lora_rank=0), skip LoRA buffer creation entirely. + # C++ kernel will not compute LoRA contributions when lora_rank=0. + if lora_rank == 0: + wrapper._fused_expert_lora_params = [] + wrapper._peft_lora_modules = None + logger.info( + f"[kt_adapt_peft_lora] Layer {layer_idx}: fused expert, " + f"full mode (lora_rank=0, no LoRA buffers)" + ) + adapted_count += 1 + continue + lora_buffers, lora_grad_buffers, lora_params = _create_fused_expert_lora_buffers( - wrapper, moe_config, lora_rank, torch.bfloat16, + wrapper, + moe_config, + lora_rank, + torch.bfloat16, ) if is_rank_0 and wrapper.wrapper is not None: @@ -197,6 +255,18 @@ def kt_adapt_peft_lora(model: nn.Module) -> None: if len(experts) == 0: continue + # In full mode (lora_rank=0), PEFT does not inject LoRA on experts. + # Skip LoRA detection and initialization entirely. + if getattr(wrapper, "_lora_rank", 1) == 0: + wrapper._peft_lora_modules = None + wrapper._fused_expert_lora_params = [] + logger.info( + f"[kt_adapt_peft_lora] Layer {layer_idx}: non-fused expert, " + f"full mode (lora_rank=0, no LoRA)" + ) + adapted_count += 1 + continue + # Collect references to PEFT LoRA modules for each expert # Structure: {expert_idx: {proj_name: (lora_A_module, lora_B_module)}} peft_lora_modules = {} @@ -228,7 +298,16 @@ def kt_adapt_peft_lora(model: nn.Module) -> None: # Store PEFT LoRA references on wrapper wrapper._peft_lora_modules = peft_lora_modules + # In full_weight_grad mode, PEFT LoRA is not injected by LlamaFactory, + # so no PEFT LoRA found is expected — skip the error. if not peft_lora_modules: + if getattr(wrapper, "_full_weight_grad", False): + logger.info( + f"[kt_adapt_peft_lora] Layer {layer_idx}: No PEFT LoRA found " + f"(full_weight_grad mode — expected, skipping)" + ) + adapted_count += 1 + continue raise RuntimeError( f"[kt_adapt_peft_lora] Layer {layer_idx}: No PEFT LoRA found on any expert. " f"Check that PEFT lora_target includes expert modules." @@ -510,9 +589,16 @@ def _replace_peft_weights_with_views( "[_replace_peft_weights_with_views] first param: " "id %s->%s (same=%s) data_ptr %s->%s buf_ptr=%s (match=%s) " "has_grad=%s requires_grad=%s shape=%s", - _old_id_a, _new_id_a, _old_id_a == _new_id_a, - _old_ptr_a, _new_ptr_a, _buf_ptr_a, _new_ptr_a == _buf_ptr_a, - _has_grad, lora_A.weight.requires_grad, tuple(lora_A.weight.shape), + _old_id_a, + _new_id_a, + _old_id_a == _new_id_a, + _old_ptr_a, + _new_ptr_a, + _buf_ptr_a, + _new_ptr_a == _buf_ptr_a, + _has_grad, + lora_A.weight.requires_grad, + tuple(lora_A.weight.shape), ) _first_logged = True _replaced += 1 @@ -526,12 +612,15 @@ def _replace_peft_weights_with_views( def update_kt_lora_pointers(model: nn.Module): - """Mark KT wrapper LoRA pointers as dirty after optimizer.step().""" + """Mark KT wrapper LoRA pointers and base weight pointers as dirty after optimizer.step().""" wrappers = _find_kt_wrappers(model) if wrappers: for wrapper in wrappers: wrapper._lora_pointers_dirty = True + # In full mode, base weights also need re-sync after optimizer step + if getattr(wrapper, "_full_weight_grad", False) and wrapper.wrapper is not None: + wrapper.wrapper._base_weights_dirty = True # ============================================================================= @@ -541,12 +630,10 @@ def update_kt_lora_pointers(model: nn.Module): def sync_kt_lora_gradients(model: nn.Module) -> None: """ - Synchronize KT-managed LoRA gradients across ranks. + Synchronize KT-managed gradients across ranks. - KT computes expert LoRA gradients only on rank 0 (gather/scatter path). This function broadcasts the - per-layer contiguous grad buffers from rank 0 to all ranks so that: - - gradient clipping sees identical grads on every rank - - optimizer.step() applies identical updates + In LoRA mode: synchronizes LoRA gradients only. + In full mode: synchronizes both base weight and LoRA gradients. """ import torch.distributed as dist @@ -557,17 +644,34 @@ def sync_kt_lora_gradients(model: nn.Module) -> None: if world_size <= 1: return + # Sync base weight gradients in full mode + wrappers = _find_kt_wrappers(model) + if wrappers: + for wrapper in wrappers: + if not getattr(wrapper, "_full_weight_grad", False): + continue + if wrapper.wrapper is None: + continue + for grad_buf in ( + wrapper.wrapper.grad_gate_proj_buf, + wrapper.wrapper.grad_up_proj_buf, + wrapper.wrapper.grad_down_proj_buf, + ): + if grad_buf is not None: + grad_gpu = grad_buf.cuda() + dist.all_reduce(grad_gpu, op=dist.ReduceOp.SUM) + grad_gpu.div_(world_size) + grad_buf.copy_(grad_gpu.cpu()) + + # Sync LoRA gradients params = get_kt_lora_params(model) if not params: return for param in params: if param.grad is not None: - # Move grad to the same device as the parameter for all-reduce - # Then move back to CPU original_device = param.grad.device if original_device.type == "cpu": - # All-reduce on CPU might be slow; consider using a GPU buffer grad_gpu = param.grad.cuda() dist.all_reduce(grad_gpu, op=dist.ReduceOp.SUM) grad_gpu.div_(world_size) diff --git a/kt-kernel/python/sft/weights.py b/kt-kernel/python/sft/weights.py index c15e22638..e535e1c48 100644 --- a/kt-kernel/python/sft/weights.py +++ b/kt-kernel/python/sft/weights.py @@ -108,9 +108,16 @@ def get_weight_tensor(mod): return gate_proj, up_proj, down_proj -def _clear_original_expert_weights(moe_module: nn.Module, moe_config: MOEArchConfig) -> None: +def _clear_original_expert_weights( + moe_module: nn.Module, moe_config: MOEArchConfig, full_weight_grad: bool = False +) -> None: """ Clear original expert weights to free memory after KT weights are loaded. + + In full_weight_grad mode, gate_proj_buf/up_proj_buf/down_proj_buf serve as + the authoritative copies for the optimizer. The original expert weights in + the model tree are redundant and cause double-counting in count_parameters(). + Clear them just like in LoRA mode. """ from .arch import detect_fused_experts @@ -127,10 +134,14 @@ def _clear_original_expert_weights(moe_module: nn.Module, moe_config: MOEArchCon original_dtype = param.dtype tiny_storage = torch.UntypedStorage(1, device="cpu") fake_tensor = torch.tensor([], dtype=original_dtype, device="cpu").set_( - tiny_storage, storage_offset=0, size=param.shape, + tiny_storage, + storage_offset=0, + size=param.shape, stride=[0] * len(param.shape), ) - experts._parameters[name] = nn.Parameter(fake_tensor, requires_grad=False) + placeholder = nn.Parameter(fake_tensor, requires_grad=False) + placeholder._kt_zero_storage = True # Mark for _setup_full_tuning / count_parameters to skip + experts._parameters[name] = placeholder return def _iter_weight_params(): @@ -141,7 +152,9 @@ def _iter_weight_params(): continue parametrizations = getattr(proj, "parametrizations", None) - parametrized_weight = getattr(parametrizations, "weight", None) if parametrizations is not None else None + parametrized_weight = ( + getattr(parametrizations, "weight", None) if parametrizations is not None else None + ) if parametrized_weight is not None: original = getattr(parametrized_weight, "original", None) if isinstance(original, torch.nn.Parameter): @@ -178,10 +191,13 @@ def _iter_weight_params(): # only used for shape/dtype discovery by PEFT. tiny_storage = torch.UntypedStorage(1, device="cpu") fake_tensor = torch.tensor([], dtype=original_dtype, device="cpu").set_( - tiny_storage, storage_offset=0, size=weight_param.shape, + tiny_storage, + storage_offset=0, + size=weight_param.shape, stride=[0] * len(weight_param.shape), ) new_param = nn.Parameter(fake_tensor, requires_grad=False) + new_param._kt_zero_storage = True # Mark for _setup_full_tuning / count_parameters to skip replaced_count += 1 # Avoid `KeyError: attribute 'weight' already exists` for parametrized modules @@ -201,9 +217,7 @@ def _iter_weight_params(): try: setattr(container, param_name, new_param) except Exception as exc: - logger.warning( - f"Failed to clear expert weight {type(proj).__name__}.{param_name}: {exc}" - ) + logger.warning(f"Failed to clear expert weight {type(proj).__name__}.{param_name}: {exc}") logger.info(f"Replaced {replaced_count} expert weight params") @@ -256,7 +270,9 @@ def _load_kt_weight_index(kt_weight_path: str) -> dict[str, str]: return index -def _dequant_fp8_experts(weights: list[torch.Tensor], scales: list[torch.Tensor | None], block_size: tuple[int, int]) -> torch.Tensor: +def _dequant_fp8_experts( + weights: list[torch.Tensor], scales: list[torch.Tensor | None], block_size: tuple[int, int] +) -> torch.Tensor: """Dequantize a list of FP8 expert weights and stack them (batched, vectorized). Args: @@ -468,9 +484,7 @@ def load_experts_from_kt_weight_path( f"Expected keys like 'blk.{layer_idx}.ffn_gate_exps.0.numa.0.weight'" ) - logger.info( - f"Loading INT8 weights for layer {layer_idx}: {num_experts} experts, {numa_count} NUMA partitions" - ) + logger.info(f"Loading INT8 weights for layer {layer_idx}: {num_experts} experts, {numa_count} NUMA partitions") gate_weights_list = [] gate_scales_list = [] diff --git a/kt-kernel/python/sft/wrapper.py b/kt-kernel/python/sft/wrapper.py index a53ea88af..bb5284353 100644 --- a/kt-kernel/python/sft/wrapper.py +++ b/kt-kernel/python/sft/wrapper.py @@ -95,9 +95,7 @@ def build_kt_device_map(config, kt_plugin, device: str = "cuda:0") -> dict[str, else: device_map[expert_key] = "cpu" - logger.info( - f"Built KT device_map: {num_gpu_experts} GPU experts, {num_experts - num_gpu_experts} CPU experts" - ) + logger.info(f"Built KT device_map: {num_gpu_experts} GPU experts, {num_experts - num_gpu_experts} CPU experts") return device_map @@ -163,8 +161,24 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT cfg = _get_kt_config(kt_plugin) # Read lora_rank/lora_alpha for C++ wrapper initialization (buffer allocation only) - lora_rank = getattr(cfg, "kt_lora_rank", 1) or 1 - lora_alpha = getattr(cfg, "kt_lora_alpha", 1.0) or 1.0 + # Use explicit None checks: lora_rank=0 is a valid value (full mode, no LoRA), + # but `or` pattern would treat 0 as falsy and replace it with 1. + _raw_rank = getattr(cfg, "kt_lora_rank", None) + lora_rank = _raw_rank if _raw_rank is not None else 1 + _raw_alpha = getattr(cfg, "kt_lora_alpha", None) + lora_alpha = _raw_alpha if _raw_alpha is not None else 1.0 + + # Read full_weight_grad mode + _raw_fwg = getattr(cfg, "kt_full_weight_grad", None) + full_weight_grad = _raw_fwg if _raw_fwg is not None else False + + # In full mode, lora_rank should be 0 (no LoRA, only base weight grad) + # If user explicitly set lora_rank > 0 in full mode (hybrid), keep it. + # Otherwise, auto-set lora_rank=0. + if full_weight_grad and lora_rank > 0: + _has_explicit_lora_rank = getattr(cfg, "kt_lora_rank", None) is not None + if not _has_explicit_lora_rank: + lora_rank = 0 # Read LoRA Experts configuration _raw_le = getattr(cfg, "kt_use_lora_experts", None) @@ -177,6 +191,8 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT f"LoRA Experts config: use_lora_experts={use_lora_experts}, " f"num={lora_expert_num}, intermediate_size={lora_expert_intermediate_size}" ) + if full_weight_grad: + logger.info(f"Full weight gradient mode enabled (lora_rank={lora_rank})") wrappers: list[KTMoELayerWrapper] = [] moe_layer_count = 0 @@ -225,7 +241,9 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT cfg.kt_sharded_metadata = sharded_metadata logger.info(f"Resolved {len(checkpoint_files)} checkpoint files from kt_expert_checkpoint_path") else: - logger.warning(f"Failed to resolve checkpoint files from kt_expert_checkpoint_path={kt_expert_checkpoint_path!r}") + logger.warning( + f"Failed to resolve checkpoint files from kt_expert_checkpoint_path={kt_expert_checkpoint_path!r}" + ) use_checkpoint_files = bool(checkpoint_files) and not use_kt_weight_path @@ -260,6 +278,7 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT ) import torch.distributed as _dist + _rank = _dist.get_rank() if _dist.is_initialized() else 0 model_container, layers = _get_model_container_and_layers(model, purpose="wrapping") @@ -329,6 +348,7 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT lora_rank=lora_rank, lora_alpha=lora_alpha, max_cache_depth=getattr(cfg, "kt_max_cache_depth", 2), + full_weight_grad=full_weight_grad, ) # Set share_backward_bb and share_cache_pool BEFORE load_weights (config is built during load) @@ -352,9 +372,18 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT physical_to_logical_map_cpu=physical_to_logical_map, ) - wrapper.gate_proj = None - wrapper.up_proj = None - wrapper.down_proj = None + # In full_weight_grad mode, keep weight references for nn.Parameter initialization + # and initialize the base weight buffers + if full_weight_grad: + wrapper.init_full_weight_grad_buffers( + gate_proj=wrapper.gate_proj if wrapper.gate_proj is not None else gate_proj, + up_proj=wrapper.up_proj if wrapper.up_proj is not None else up_proj, + down_proj=wrapper.down_proj if wrapper.down_proj is not None else down_proj, + ) + else: + wrapper.gate_proj = None + wrapper.up_proj = None + wrapper.down_proj = None # Create LoRA Experts if enabled lora_experts = None @@ -381,16 +410,18 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT setattr(layer, moe_config.moe_layer_attr, layer_wrapper) # Base weights have been copied into the C++ kernel's internal BufferB format. - # Do not hold a Python-side reference --- it wastes ~1 GB/layer. + # In full_weight_grad mode, the authoritative copies are gate_proj_buf etc. + # Always release local references to save ~1 GB/layer. del gate_proj, up_proj, down_proj wrappers.append(layer_wrapper) moe_layer_count += 1 - # Replace original expert weights with meta placeholders. + # Replace original expert weights with zero-storage placeholders. # Experts remain in the model tree (via wrapper.experts) so PEFT can discover them. # Rank 0 already copied weights to C++ kernel via load_weights_from_tensors. - _clear_original_expert_weights(moe_module, moe_config) + # gate_proj_buf serves as the authoritative copy in full_weight_grad mode. + _clear_original_expert_weights(moe_module, moe_config, full_weight_grad=full_weight_grad) logger.info(f"Wrapped {moe_layer_count} MoE layers with KTMoEWrapper") @@ -420,6 +451,17 @@ class should import it from the appropriate dataclasses module. from .config import KTConfig from accelerate.utils.dataclasses import KTransformersPlugin + # Map LlamaFactory finetuning_type to kt_train_mode + finetuning_type = getattr(finetuning_args, "finetuning_type", None) if finetuning_args else None + kt_train_mode_map = { + "full": "full", + "freeze": "hybrid", + "lora": "lora", + "galore": "full", + "badam": "full", + } + kt_train_mode = kt_train_mode_map.get(finetuning_type, None) if finetuning_type else None + kt_config = KTConfig( kt_backend=getattr(model_args, "kt_backend", None), kt_num_threads=getattr(model_args, "kt_num_threads", None), @@ -435,6 +477,7 @@ class should import it from the appropriate dataclasses module. kt_lora_rank=getattr(finetuning_args, "lora_rank", None) if finetuning_args else None, kt_lora_alpha=getattr(finetuning_args, "lora_alpha", None) if finetuning_args else None, kt_model_max_length=getattr(model_args, "model_max_length", None), + kt_train_mode=kt_train_mode, ) return KTransformersPlugin(enabled=True, kt_config=kt_config) @@ -509,7 +552,13 @@ def load_kt_model( **kwargs, ) -> nn.Module: """Load model with KTMoEWrapper backend.""" - from .arch import get_moe_arch_config, move_non_experts_to_gpu, get_expert_device, KTAMXNotAvailableError, KTAMXConfigError + from .arch import ( + get_moe_arch_config, + move_non_experts_to_gpu, + get_expert_device, + KTAMXNotAvailableError, + KTAMXConfigError, + ) if kt_plugin is None: if model_args is None: @@ -536,8 +585,11 @@ def load_kt_model( from transformers.integrations.kt import set_kt_config, unset_kt_config loading_kwargs = get_kt_loading_kwargs( - config, kt_plugin, torch_dtype=torch_dtype, - trust_remote_code=trust_remote_code, token=token, + config, + kt_plugin, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + token=token, ) if model_args is not None: for key in ("cache_dir", "revision"): @@ -551,8 +603,10 @@ def load_kt_model( if getattr(cfg, "kt_skip_expert_loading", None) is None: checkpoint_files, sharded_metadata = _resolve_checkpoint_files( model_name_or_path=model_name_or_path, - cache_dir=cache_dir, revision=revision, - token=token, trust_remote_code=trust_remote_code, + cache_dir=cache_dir, + revision=revision, + token=token, + trust_remote_code=trust_remote_code, ) if checkpoint_files and all(f.endswith(".safetensors") for f in checkpoint_files): if getattr(cfg, "kt_weight_path", None) is None: