-
Notifications
You must be signed in to change notification settings - Fork 1.3k
[WIP]Full fine tune support #2020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1404,7 +1404,8 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> { | |
| void backward(const void* grad_output, void* grad_input, void* grad_gate_lora_a, void* grad_gate_lora_b, | ||
| void* grad_up_lora_a, void* grad_up_lora_b, void* grad_down_lora_a, void* grad_down_lora_b, | ||
| void* grad_weights, int full_intermediate_size = 0, float* fp32_grad_down_lora_b = nullptr, | ||
| float* fp32_grad_gate_lora_a = nullptr, float* fp32_grad_up_lora_a = nullptr) { | ||
| float* fp32_grad_gate_lora_a = nullptr, float* fp32_grad_up_lora_a = nullptr, | ||
| void* grad_gate_proj = nullptr, void* grad_up_proj = nullptr, void* grad_down_proj = nullptr) { | ||
| // If full_intermediate_size not provided, use local (non-TP mode) | ||
| if (full_intermediate_size == 0) full_intermediate_size = config_.intermediate_size; | ||
| SFT_POOL_LOG("bwd_enter", config_.layer_idx, tp_part_idx, 0, cache_stack_top_, forward_pool_bytes_, | ||
|
|
@@ -1888,7 +1889,14 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> { | |
| print_grad_stats_fp32("grad_weights", (const float*)grad_weights, qlen * k); | ||
| } | ||
|
|
||
| // ★ Cache pool is NOT freed here — kept for reuse across steps. | ||
| // ===================================================================== | ||
| // Step 5: Base weight gradient accumulation (full weight grad mode) | ||
| // ===================================================================== | ||
| if (sft_config_.full_weight_grad && grad_gate_proj && grad_up_proj && grad_down_proj) { | ||
| backward_base_weight_grad(cache, grad_output, grad_gate_proj, grad_up_proj, grad_down_proj); | ||
| } | ||
|
|
||
| // \u2605 Cache pool is NOT freed here \u2014 kept for reuse across steps. | ||
| // alloc_or_resize_cache_pool() is grow-only, so same-seqlen steps | ||
| // reuse the existing allocation without malloc/free overhead. | ||
| // Previously: free_seqlen_buffers() was called here, costing ~3.6ms per TP. | ||
|
|
@@ -1897,6 +1905,92 @@ class AMX_SFT_MOE_TP : public BaseMOE<T> { | |
| cache.valid = false; | ||
| } | ||
|
|
||
| /** | ||
| * @brief Compute base weight gradients via outer-product accumulation. | ||
| * | ||
| * For each activated expert e, computes: | ||
| * grad_gate_proj[e] = grad_gate_out[e]^T @ input[e] -> [I, H] | ||
| * grad_up_proj[e] = grad_up_out[e]^T @ input[e] -> [I, H] | ||
| * grad_down_proj[e] = grad_output[e]^T @ intermediate[e] -> [H, I] | ||
| * | ||
| * Uses FP32 accumulator for precision, writes BF16 output. | ||
| */ | ||
| void backward_base_weight_grad(const ForwardCache& cache, const void* grad_output, | ||
| void* grad_gate_proj, void* grad_up_proj, void* grad_down_proj) { | ||
| const int H = config_.hidden_size; | ||
| const int I = config_.intermediate_size; | ||
| const int E = config_.expert_num; | ||
| int activated_expert = cache.activated_expert_cache; | ||
|
|
||
| auto* ggp = static_cast<ggml_bf16_t*>(grad_gate_proj); // [E, I, H] | ||
| auto* gup_ptr = static_cast<ggml_bf16_t*>(grad_up_proj); // [E, I, H] | ||
| auto* gdp = static_cast<ggml_bf16_t*>(grad_down_proj); // [E, H, I] | ||
| auto* grad_out_bf16 = static_cast<const ggml_bf16_t*>(grad_output); | ||
|
|
||
| for (int task_id = 0; task_id < activated_expert; task_id++) { | ||
| int expert_idx = cache.m_expert_id_map_cache[task_id]; | ||
| int m = cache.m_local_num_cache[expert_idx]; | ||
| if (m == 0) continue; | ||
|
|
||
| int pos_start = 0; | ||
| for (int prev_id = 0; prev_id < task_id; prev_id++) { | ||
| pos_start += cache.m_local_num_cache[cache.m_expert_id_map_cache[prev_id]]; | ||
| } | ||
|
|
||
| const auto& local_pos = cache.m_local_pos_cache[expert_idx]; | ||
|
|
||
| // Allocate FP32 accumulators from forward pool (safe during backward) | ||
| float* acc_gate = static_cast<float*>(forward_pool_); // [I, H] | ||
| float* acc_up = acc_gate + (size_t)I * H; // [I, H] | ||
| float* acc_down = acc_up + (size_t)I * H; // [H, I] | ||
|
Comment on lines
+1943
to
+1945
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical issue: |
||
|
|
||
| std::memset(acc_gate, 0, (size_t)I * H * sizeof(float)); | ||
| std::memset(acc_up, 0, (size_t)I * H * sizeof(float)); | ||
| std::memset(acc_down, 0, (size_t)H * I * sizeof(float)); | ||
|
|
||
| for (int t = 0; t < m; t++) { | ||
| int tok_pos = local_pos[t]; | ||
| const ggml_bf16_t* input_row = cache.input_cache + (size_t)tok_pos * H; | ||
| const ggml_bf16_t* gate_grad_row = grad_gate_output_ + (size_t)(pos_start + t) * I; | ||
| const ggml_bf16_t* up_grad_row = grad_up_output_ + (size_t)(pos_start + t) * I; | ||
| const ggml_bf16_t* inter_row = cache.intermediate_cache + (size_t)(pos_start + t) * I; | ||
| const ggml_bf16_t* grad_out_row = grad_out_bf16 + (size_t)tok_pos * H; | ||
|
|
||
| // gate_proj grad: [I, H] += grad_gate_out[t]^T @ input[t] | ||
| for (int i = 0; i < I; i++) { | ||
| float gg = GGML_BF16_TO_FP32(gate_grad_row[i]); | ||
| float gu = GGML_BF16_TO_FP32(up_grad_row[i]); | ||
| for (int h = 0; h < H; h++) { | ||
| float inp = GGML_BF16_TO_FP32(input_row[h]); | ||
| acc_gate[i * H + h] += gg * inp; | ||
| acc_up[i * H + h] += gu * inp; | ||
| } | ||
| } | ||
|
|
||
| // down_proj grad: [H, I] += grad_output[t]^T @ intermediate[t] | ||
| for (int h = 0; h < H; h++) { | ||
| float go = GGML_BF16_TO_FP32(grad_out_row[h]); | ||
| for (int i = 0; i < I; i++) { | ||
| acc_down[h * I + i] += go * GGML_BF16_TO_FP32(inter_row[i]); | ||
| } | ||
| } | ||
|
Comment on lines
+1959
to
+1976
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. High severity performance issue: The base weight gradient computation uses a triple nested loop without vectorization or parallelization. For typical MoE sizes, this results in billions of operations executed sequentially on a single thread. This should be implemented using a vectorized outer product (rank-1 update) or a GEMM ( |
||
| } | ||
|
|
||
| // Convert FP32 accumulators to BF16 and store | ||
| for (int i = 0; i < I; i++) { | ||
| for (int h = 0; h < H; h++) { | ||
| ggp[(size_t)expert_idx * I * H + (size_t)i * H + h] = GGML_FP32_TO_BF16(acc_gate[i * H + h]); | ||
| gup_ptr[(size_t)expert_idx * I * H + (size_t)i * H + h] = GGML_FP32_TO_BF16(acc_up[i * H + h]); | ||
| } | ||
| } | ||
| for (int h = 0; h < H; h++) { | ||
| for (int i = 0; i < I; i++) { | ||
| gdp[(size_t)expert_idx * H * I + (size_t)h * I + i] = GGML_FP32_TO_BF16(acc_down[h * I + i]); | ||
| } | ||
|
Comment on lines
+1982
to
+1989
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. High severity issue: The gradients are being overwritten in the output buffer ( |
||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * @brief Get qlen from the top of the forward cache stack. | ||
| * | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug:
cache.m_local_pos_cacheis indexed by token ID, not expert index. Accessing it withexpert_idxwill lead to out-of-bounds access or incorrect data. You need a mapping from expert to the original token indices to correctly retrieveinput_rowandgrad_out_rowfrom the global buffers.