Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xllm/core/kernels/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ set(CUDA_HEADER_FILES
cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh
fp8_quant_utils.cuh
global_capture_instance.h
llm_decode_metadata_update.h
piecewise_graphs.h
topk_last_dim.cuh
type_convert.cuh
Expand Down Expand Up @@ -107,6 +108,7 @@ set(CUDA_SOURCE_FILES
fp8_scaled_quantize.cpp
fused_qknorm_rope.cu
global_capture_instance.cpp
llm_decode_metadata_update.cu
matmul.cpp
norm.cu
piecewise_graphs.cpp
Expand Down
88 changes: 88 additions & 0 deletions xllm/core/kernels/cuda/llm_decode_metadata_update.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/* Copyright 2026 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cuda_runtime.h>
#include <glog/logging.h>

#include <algorithm>

#include "core/kernels/cuda/llm_decode_metadata_update.h"

namespace xllm::kernel::cuda {
namespace {

constexpr int32_t kThreadsPerBlock = 256;
constexpr int64_t kMaxBlocksPerLaunch = 4096;

__global__ void llm_decode_metadata_update_kernel(
LlmDecodeMetadataUpdateParams params,
int64_t max_work_size) {
const int64_t thread_idx =
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int64_t step = static_cast<int64_t>(blockDim.x) * gridDim.x;
for (int64_t idx = thread_idx; idx < max_work_size; idx += step) {
if (idx < params.actual_num_tokens) {
params.dst_tokens[idx] = params.src_tokens[idx];
params.dst_positions[idx] = params.src_positions[idx];
params.dst_new_cache_slots[idx] = params.src_new_cache_slots[idx];
}
if (idx >= params.actual_num_tokens && idx < params.padded_num_tokens) {
params.dst_tokens[idx] = 0;
params.dst_new_cache_slots[idx] = 0;
}
if (idx < params.actual_batch_size + 1) {
params.dst_kv_seq_lens[idx] = params.src_kv_seq_lens[idx];
params.dst_paged_kv_indptr[idx] = params.src_paged_kv_indptr[idx];
}
if (idx < params.actual_batch_size) {
params.dst_kv_seq_lens_delta[idx] =
params.src_kv_seq_lens[idx + 1] - params.src_kv_seq_lens[idx];
params.dst_paged_kv_last_page_len[idx] =
params.src_paged_kv_last_page_len[idx];
}
if (idx < params.actual_indices_size) {
params.dst_paged_kv_indices[idx] = params.src_paged_kv_indices[idx];
}
}
}

} // namespace

void update_llm_decode_metadata(const LlmDecodeMetadataUpdateParams& params,
cudaStream_t stream) {
const int64_t max_work_size = std::max({params.actual_num_tokens,
params.padded_num_tokens,
params.actual_batch_size + 1,
params.actual_indices_size});
if (max_work_size <= 0) {
return;
}
// Cap the grid size because the kernel already uses a strided loop.
// This keeps launch overhead bounded for large inputs without reducing
// coverage.
const int64_t num_blocks = std::min<int64_t>(
(max_work_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kMaxBlocksPerLaunch);
llm_decode_metadata_update_kernel<<<static_cast<uint32_t>(num_blocks),
kThreadsPerBlock,
/*shared_mem_bytes=*/0,
stream>>>(params, max_work_size);
Comment thread
RobbieLeung marked this conversation as resolved.
const cudaError_t error = cudaGetLastError();
CHECK_EQ(error, cudaSuccess)
<< "llm_decode_metadata_update kernel launch failed: "
<< cudaGetErrorString(error);
}

} // namespace xllm::kernel::cuda
49 changes: 49 additions & 0 deletions xllm/core/kernels/cuda/llm_decode_metadata_update.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/* Copyright 2026 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#include <cuda_runtime.h>

#include <cstdint>

namespace xllm::kernel::cuda {

struct LlmDecodeMetadataUpdateParams {
const int32_t* src_tokens;
const int32_t* src_positions;
const int32_t* src_new_cache_slots;
const int32_t* src_kv_seq_lens;
const int32_t* src_paged_kv_indptr;
const int32_t* src_paged_kv_indices;
const int32_t* src_paged_kv_last_page_len;
int32_t* dst_tokens;
int32_t* dst_positions;
int32_t* dst_new_cache_slots;
int32_t* dst_kv_seq_lens;
int32_t* dst_kv_seq_lens_delta;
int32_t* dst_paged_kv_indptr;
int32_t* dst_paged_kv_indices;
int32_t* dst_paged_kv_last_page_len;
int64_t actual_num_tokens;
int64_t padded_num_tokens;
int64_t actual_batch_size;
int64_t actual_indices_size;
};

void update_llm_decode_metadata(const LlmDecodeMetadataUpdateParams& params,
cudaStream_t stream);

} // namespace xllm::kernel::cuda
Loading
Loading