Skip to content

Commit 6e86fe3

Browse files
committed
perf: add llm decode metadata update fast path.
- add a decode-only fused metadata update kernel for ordinary LLM CUDA graph execution - reuse persistent kv seq len delta buffers and keep block_tables on the legacy copy path - add decode fast-path coverage and fallback equivalence tests
1 parent 32412da commit 6e86fe3

6 files changed

Lines changed: 525 additions & 75 deletions

File tree

xllm/core/kernels/cuda/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ set(CUDA_HEADER_FILES
3333
cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
3434
fp8_quant_utils.cuh
3535
global_capture_instance.h
36+
llm_decode_metadata_update.h
3637
piecewise_graphs.h
3738
topk_last_dim.cuh
3839
type_convert.cuh
@@ -107,6 +108,7 @@ set(CUDA_SOURCE_FILES
107108
fp8_scaled_quantize.cpp
108109
fused_qknorm_rope.cu
109110
global_capture_instance.cpp
111+
llm_decode_metadata_update.cu
110112
matmul.cpp
111113
norm.cu
112114
piecewise_graphs.cpp
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/* Copyright 2026 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <c10/cuda/CUDAException.h>
17+
#include <cuda_runtime.h>
18+
#include <glog/logging.h>
19+
20+
#include <algorithm>
21+
22+
#include "core/kernels/cuda/llm_decode_metadata_update.h"
23+
24+
namespace xllm::kernel::cuda {
25+
namespace {
26+
27+
constexpr int32_t kThreadsPerBlock = 256;
28+
29+
__global__ void llm_decode_metadata_update_kernel(
30+
LlmDecodeMetadataUpdateParams params,
31+
int64_t max_work_size) {
32+
const int64_t thread_idx =
33+
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
34+
const int64_t step = static_cast<int64_t>(blockDim.x) * gridDim.x;
35+
for (int64_t idx = thread_idx; idx < max_work_size; idx += step) {
36+
if (idx < params.actual_num_tokens) {
37+
params.dst_tokens[idx] = params.src_tokens[idx];
38+
params.dst_positions[idx] = params.src_positions[idx];
39+
params.dst_new_cache_slots[idx] = params.src_new_cache_slots[idx];
40+
}
41+
if (idx >= params.actual_num_tokens && idx < params.padded_num_tokens) {
42+
params.dst_tokens[idx] = 0;
43+
params.dst_new_cache_slots[idx] = 0;
44+
}
45+
if (idx < params.actual_batch_size + 1) {
46+
params.dst_kv_cu_seq_lens[idx] = params.src_kv_cu_seq_lens[idx];
47+
params.dst_paged_kv_indptr[idx] = params.src_paged_kv_indptr[idx];
48+
}
49+
if (idx < params.actual_batch_size) {
50+
params.dst_kv_seq_lens_delta[idx] =
51+
params.src_kv_cu_seq_lens[idx + 1] - params.src_kv_cu_seq_lens[idx];
52+
params.dst_paged_kv_last_page_len[idx] =
53+
params.src_paged_kv_last_page_len[idx];
54+
}
55+
if (idx < params.actual_indices_size) {
56+
params.dst_paged_kv_indices[idx] = params.src_paged_kv_indices[idx];
57+
}
58+
}
59+
}
60+
61+
} // namespace
62+
63+
void UpdateLlmDecodeMetadata(const LlmDecodeMetadataUpdateParams& params,
64+
cudaStream_t stream) {
65+
const int64_t max_work_size = std::max({params.actual_num_tokens,
66+
params.padded_num_tokens,
67+
params.actual_batch_size + 1,
68+
params.actual_indices_size});
69+
if (max_work_size <= 0) {
70+
return;
71+
}
72+
const int64_t num_blocks = std::min<int64_t>(
73+
(max_work_size + kThreadsPerBlock - 1) / kThreadsPerBlock, 4096);
74+
llm_decode_metadata_update_kernel<<<static_cast<uint32_t>(num_blocks),
75+
kThreadsPerBlock,
76+
0,
77+
stream>>>(params, max_work_size);
78+
const cudaError_t error = cudaGetLastError();
79+
CHECK_EQ(error, cudaSuccess)
80+
<< "llm_decode_metadata_update kernel launch failed: "
81+
<< cudaGetErrorString(error);
82+
}
83+
84+
} // namespace xllm::kernel::cuda
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright 2026 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include <cuda_runtime.h>
19+
20+
#include <cstdint>
21+
22+
namespace xllm::kernel::cuda {
23+
24+
struct LlmDecodeMetadataUpdateParams {
25+
const int32_t* src_tokens;
26+
const int32_t* src_positions;
27+
const int32_t* src_new_cache_slots;
28+
const int32_t* src_kv_cu_seq_lens;
29+
const int32_t* src_paged_kv_indptr;
30+
const int32_t* src_paged_kv_indices;
31+
const int32_t* src_paged_kv_last_page_len;
32+
int32_t* dst_tokens;
33+
int32_t* dst_positions;
34+
int32_t* dst_new_cache_slots;
35+
int32_t* dst_kv_cu_seq_lens;
36+
int32_t* dst_kv_seq_lens_delta;
37+
int32_t* dst_paged_kv_indptr;
38+
int32_t* dst_paged_kv_indices;
39+
int32_t* dst_paged_kv_last_page_len;
40+
int64_t actual_num_tokens;
41+
int64_t padded_num_tokens;
42+
int64_t actual_batch_size;
43+
int64_t actual_indices_size;
44+
};
45+
46+
void UpdateLlmDecodeMetadata(const LlmDecodeMetadataUpdateParams& params,
47+
cudaStream_t stream);
48+
49+
} // namespace xllm::kernel::cuda

0 commit comments

Comments
 (0)