diff --git a/xllm/core/kernels/cuda/CMakeLists.txt b/xllm/core/kernels/cuda/CMakeLists.txt index a19686ce7..4907177bb 100644 --- a/xllm/core/kernels/cuda/CMakeLists.txt +++ b/xllm/core/kernels/cuda/CMakeLists.txt @@ -33,6 +33,7 @@ set(CUDA_HEADER_FILES cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh fp8_quant_utils.cuh global_capture_instance.h + llm_decode_metadata_update.h piecewise_graphs.h topk_last_dim.cuh type_convert.cuh @@ -107,6 +108,7 @@ set(CUDA_SOURCE_FILES fp8_scaled_quantize.cpp fused_qknorm_rope.cu global_capture_instance.cpp + llm_decode_metadata_update.cu matmul.cpp norm.cu piecewise_graphs.cpp diff --git a/xllm/core/kernels/cuda/llm_decode_metadata_update.cu b/xllm/core/kernels/cuda/llm_decode_metadata_update.cu new file mode 100644 index 000000000..64fd81640 --- /dev/null +++ b/xllm/core/kernels/cuda/llm_decode_metadata_update.cu @@ -0,0 +1,88 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include + +#include "core/kernels/cuda/llm_decode_metadata_update.h" + +namespace xllm::kernel::cuda { +namespace { + +constexpr int32_t kThreadsPerBlock = 256; +constexpr int64_t kMaxBlocksPerLaunch = 4096; + +__global__ void llm_decode_metadata_update_kernel( + LlmDecodeMetadataUpdateParams params, + int64_t max_work_size) { + const int64_t thread_idx = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const int64_t step = static_cast(blockDim.x) * gridDim.x; + for (int64_t idx = thread_idx; idx < max_work_size; idx += step) { + if (idx < params.actual_num_tokens) { + params.dst_tokens[idx] = params.src_tokens[idx]; + params.dst_positions[idx] = params.src_positions[idx]; + params.dst_new_cache_slots[idx] = params.src_new_cache_slots[idx]; + } + if (idx >= params.actual_num_tokens && idx < params.padded_num_tokens) { + params.dst_tokens[idx] = 0; + params.dst_new_cache_slots[idx] = 0; + } + if (idx < params.actual_batch_size + 1) { + params.dst_kv_seq_lens[idx] = params.src_kv_seq_lens[idx]; + params.dst_paged_kv_indptr[idx] = params.src_paged_kv_indptr[idx]; + } + if (idx < params.actual_batch_size) { + params.dst_kv_seq_lens_delta[idx] = + params.src_kv_seq_lens[idx + 1] - params.src_kv_seq_lens[idx]; + params.dst_paged_kv_last_page_len[idx] = + params.src_paged_kv_last_page_len[idx]; + } + if (idx < params.actual_indices_size) { + params.dst_paged_kv_indices[idx] = params.src_paged_kv_indices[idx]; + } + } +} + +} // namespace + +void update_llm_decode_metadata(const LlmDecodeMetadataUpdateParams& params, + cudaStream_t stream) { + const int64_t max_work_size = std::max({params.actual_num_tokens, + params.padded_num_tokens, + params.actual_batch_size + 1, + params.actual_indices_size}); + if (max_work_size <= 0) { + return; + } + // Cap the grid size because the kernel already uses a strided loop. + // This keeps launch overhead bounded for large inputs without reducing + // coverage. + const int64_t num_blocks = std::min( + (max_work_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kMaxBlocksPerLaunch); + llm_decode_metadata_update_kernel<<(num_blocks), + kThreadsPerBlock, + /*shared_mem_bytes=*/0, + stream>>>(params, max_work_size); + const cudaError_t error = cudaGetLastError(); + CHECK_EQ(error, cudaSuccess) + << "llm_decode_metadata_update kernel launch failed: " + << cudaGetErrorString(error); +} + +} // namespace xllm::kernel::cuda diff --git a/xllm/core/kernels/cuda/llm_decode_metadata_update.h b/xllm/core/kernels/cuda/llm_decode_metadata_update.h new file mode 100644 index 000000000..9de9b1845 --- /dev/null +++ b/xllm/core/kernels/cuda/llm_decode_metadata_update.h @@ -0,0 +1,49 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include + +namespace xllm::kernel::cuda { + +struct LlmDecodeMetadataUpdateParams { + const int32_t* src_tokens; + const int32_t* src_positions; + const int32_t* src_new_cache_slots; + const int32_t* src_kv_seq_lens; + const int32_t* src_paged_kv_indptr; + const int32_t* src_paged_kv_indices; + const int32_t* src_paged_kv_last_page_len; + int32_t* dst_tokens; + int32_t* dst_positions; + int32_t* dst_new_cache_slots; + int32_t* dst_kv_seq_lens; + int32_t* dst_kv_seq_lens_delta; + int32_t* dst_paged_kv_indptr; + int32_t* dst_paged_kv_indices; + int32_t* dst_paged_kv_last_page_len; + int64_t actual_num_tokens; + int64_t padded_num_tokens; + int64_t actual_batch_size; + int64_t actual_indices_size; +}; + +void update_llm_decode_metadata(const LlmDecodeMetadataUpdateParams& params, + cudaStream_t stream); + +} // namespace xllm::kernel::cuda diff --git a/xllm/core/runtime/cuda_graph_executor_impl.cpp b/xllm/core/runtime/cuda_graph_executor_impl.cpp index 2cc1c92dc..2f6e8cdc5 100644 --- a/xllm/core/runtime/cuda_graph_executor_impl.cpp +++ b/xllm/core/runtime/cuda_graph_executor_impl.cpp @@ -96,6 +96,11 @@ size_t get_allocator_reserved_bytes(c10::DeviceIndex device_index) { return static_cast(device_stats.reserved_bytes[stat_index].current); } +bool is_cuda_contiguous_int32_tensor(const torch::Tensor& tensor) { + return tensor.defined() && tensor.is_cuda() && + tensor.scalar_type() == torch::kInt32 && tensor.is_contiguous(); +} + } // namespace // CudaGraphPersistentParam implementation @@ -177,6 +182,8 @@ CudaGraphPersistentParam::CudaGraphPersistentParam( // max_seqs_per_batch] persistent_decode_qo_indptr_ = torch::arange( 0, max_seqs_per_batch + 1, torch::dtype(torch::kInt).device(device)); + persistent_kv_seq_lens_delta_ = torch::zeros( + {max_seqs_per_batch}, torch::dtype(torch::kInt).device(device)); // will be updated by q_cu_seq_lens, q_cu_seq_lens is the cumulative sum of // q_seq_lens persistent_chunked_prefill_qo_indptr_ = torch::zeros( @@ -184,6 +191,61 @@ CudaGraphPersistentParam::CudaGraphPersistentParam( // aux_hidden_states will be lazily initialized when needed } +bool CudaGraphPersistentParam::can_use_llm_decode_fast_path( + const torch::Tensor& tokens, + const torch::Tensor& positions, + const ModelInputParams& params) const { + if (!params.batch_forward_type.is_decode() || is_rec_multi_round_mode() || + params.has_llmrec_params() || params.input_embedding.defined()) { + return false; + } + return is_cuda_contiguous_int32_tensor(tokens) && + is_cuda_contiguous_int32_tensor(positions) && + is_cuda_contiguous_int32_tensor(params.new_cache_slots) && + is_cuda_contiguous_int32_tensor(params.kv_seq_lens) && + is_cuda_contiguous_int32_tensor(params.paged_kv_indptr) && + is_cuda_contiguous_int32_tensor(params.paged_kv_indices) && + is_cuda_contiguous_int32_tensor(params.paged_kv_last_page_len); +} + +void CudaGraphPersistentParam::update_llm_decode_metadata_fast_path( + const torch::Tensor& tokens, + const torch::Tensor& positions, + const ModelInputParams& params, + uint32_t padded_num_tokens, + int64_t actual_batch_size, + int64_t actual_num_tokens) { + CHECK_GE(actual_batch_size, 0) << "actual_batch_size must be >= 0"; + CHECK_GE(actual_num_tokens, 0) << "actual_num_tokens must be >= 0"; + const int64_t actual_indices_size = params.paged_kv_indices.size(0); + xllm::kernel::cuda::LlmDecodeMetadataUpdateParams update_params{ + .src_tokens = tokens.data_ptr(), + .src_positions = positions.data_ptr(), + .src_new_cache_slots = params.new_cache_slots.data_ptr(), + .src_kv_seq_lens = params.kv_seq_lens.data_ptr(), + .src_paged_kv_indptr = params.paged_kv_indptr.data_ptr(), + .src_paged_kv_indices = params.paged_kv_indices.data_ptr(), + .src_paged_kv_last_page_len = + params.paged_kv_last_page_len.data_ptr(), + .dst_tokens = persistent_tokens_.data_ptr(), + .dst_positions = persistent_positions_.data_ptr(), + .dst_new_cache_slots = persistent_new_cache_slots_.data_ptr(), + .dst_kv_seq_lens = kv_seq_lens_.data_ptr(), + .dst_kv_seq_lens_delta = + persistent_kv_seq_lens_delta_.data_ptr(), + .dst_paged_kv_indptr = persistent_paged_kv_indptr_.data_ptr(), + .dst_paged_kv_indices = persistent_paged_kv_indices_.data_ptr(), + .dst_paged_kv_last_page_len = + persistent_paged_kv_last_page_len_.data_ptr(), + .actual_num_tokens = actual_num_tokens, + .padded_num_tokens = static_cast(padded_num_tokens), + .actual_batch_size = actual_batch_size, + .actual_indices_size = actual_indices_size, + }; + const cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device_.index()); + xllm::kernel::cuda::update_llm_decode_metadata(update_params, stream); +} + void CudaGraphPersistentParam::set_aux_hidden_states( const torch::Tensor& value) { if (!value.defined()) { @@ -231,6 +293,7 @@ size_t CudaGraphPersistentParam::get_persistent_tensor_bytes() const { total += bytes(persistent_paged_kv_indices_); total += bytes(persistent_paged_kv_last_page_len_); total += bytes(persistent_decode_qo_indptr_); + total += bytes(persistent_kv_seq_lens_delta_); total += bytes(persistent_chunked_prefill_qo_indptr_); return total; } @@ -280,33 +343,46 @@ std::optional CudaGraphPersistentParam::update( const uint32_t actual_num_tokens = tokens.size(0); const int64_t actual_batch_size = params.num_sequences; + const bool use_llm_decode_fast_path = + can_use_llm_decode_fast_path(tokens, positions, params); // Copy data from input parameters to persistent graph tensors - VLOG(kGraphExecutorLogVerboseLevel) - << "copy_ tokens: src shape=" << tokens.sizes() << ", dst slice shape=[" - << actual_num_tokens << "]"; - persistent_tokens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_num_tokens) - .copy_(tokens, /*non_blocking=*/true); - - // Zero out padding region for tokens to avoid stale data - // This is needed for both capture and replay when using padded tensors - if (padded_num_tokens > actual_num_tokens) { + if (use_llm_decode_fast_path) { VLOG(kGraphExecutorLogVerboseLevel) - << "fill_ tokens padding: [" << actual_num_tokens << ", " - << padded_num_tokens << "] with 0"; - persistent_tokens_ - .slice( - /*dim=*/0, /*start=*/actual_num_tokens, /*end=*/padded_num_tokens) - .fill_(0); - } + << "use fast path for LLM decode metadata update"; + update_llm_decode_metadata_fast_path(tokens, + positions, + params, + padded_num_tokens, + actual_batch_size, + actual_num_tokens); + } else { + VLOG(kGraphExecutorLogVerboseLevel) + << "copy_ tokens: src shape=" << tokens.sizes() << ", dst slice shape=[" + << actual_num_tokens << "]"; + persistent_tokens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_num_tokens) + .copy_(tokens, /*non_blocking=*/true); - VLOG(kGraphExecutorLogVerboseLevel) - << "copy_ positions: src shape=" << positions.sizes() - << ", dst slice shape=[" << actual_num_tokens << "]"; - persistent_positions_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_num_tokens) - .copy_(positions, /*non_blocking=*/true); + if (padded_num_tokens > actual_num_tokens) { + VLOG(kGraphExecutorLogVerboseLevel) + << "fill_ tokens padding: [" << actual_num_tokens << ", " + << padded_num_tokens << "] with 0"; + persistent_tokens_ + .slice(/*dim=*/0, + /*start=*/actual_num_tokens, + /*end=*/padded_num_tokens) + .fill_(0); + } - if (!is_rec_multi_round_mode()) { + VLOG(kGraphExecutorLogVerboseLevel) + << "copy_ positions: src shape=" << positions.sizes() + << ", dst slice shape=[" << actual_num_tokens << "]"; + persistent_positions_ + .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_num_tokens) + .copy_(positions, /*non_blocking=*/true); + } + + if (!is_rec_multi_round_mode() && !use_llm_decode_fast_path) { // q_seq_lens is q_cu_seq_lens in GPU Model. // kv_seq_lens is kv_cu_seq_lens in GPU Model. VLOG(kGraphExecutorLogVerboseLevel) @@ -327,19 +403,6 @@ std::optional CudaGraphPersistentParam::update( persistent_new_cache_slots_ .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_num_tokens) .copy_(params.new_cache_slots, /*non_blocking=*/true); - if (!params.linear_state_ids.empty()) { - if (params.linear_state_indices.defined()) { - persistent_linear_state_indices_ - .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size) - .copy_(params.linear_state_indices, /*non_blocking=*/true); - } else { - persistent_linear_state_indices_ - .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size) - .copy_( - torch::tensor(params.linear_state_ids, torch::kInt).to(device_), - /*non_blocking=*/true); - } - } if (padded_num_tokens > actual_num_tokens) { persistent_new_cache_slots_ .slice(/*dim=*/0, @@ -361,6 +424,20 @@ std::optional CudaGraphPersistentParam::update( persistent_new_cache_slots(slot_mapping_tokens); } + if (!is_rec_multi_round_mode() && !params.linear_state_ids.empty()) { + if (params.linear_state_indices.defined()) { + persistent_linear_state_indices_ + .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size) + .copy_(params.linear_state_indices, /*non_blocking=*/true); + } else { + persistent_linear_state_indices_ + .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size) + .copy_( + torch::tensor(params.linear_state_ids, torch::kInt).to(device_), + /*non_blocking=*/true); + } + } + // Copy block table data. In rec multi-round, block_tables may already be // expanded to batch_size * beam_width rows while num_sequences still tracks // the logical request count. Use the tensor's real row count here. @@ -572,60 +649,73 @@ std::optional CudaGraphPersistentParam::update( /*is_shared_stage_plan*/ false); return build_capture_params_if_needed(); } - CHECK(params.paged_kv_indptr.defined()) - << "paged_kv_indptr should not be null"; - VLOG(kGraphExecutorLogVerboseLevel) - << "copy_ paged_kv_indptr: src shape=" << params.paged_kv_indptr.sizes() - << ", dst slice shape=[" << (actual_batch_size + 1) << "]"; - if (VLOG_IS_ON(kGraphExecutorLogVerboseLevel)) { - torch::Tensor paged_kv_indptr_cpu = params.paged_kv_indptr.to(torch::kCPU); + if (use_llm_decode_fast_path) { + const uint32_t slot_mapping_tokens = + padded_num_tokens > 0 ? padded_num_tokens : actual_num_tokens; + attn_metadata->q_cu_seq_lens = + persistent_decode_qo_indptr(static_cast(actual_batch_size)); + attn_metadata->kv_cu_seq_lens = + kv_seq_lens(static_cast(actual_batch_size + 1)); + attn_metadata->kv_seq_lens = + persistent_kv_seq_lens_delta(static_cast(actual_batch_size)); + attn_metadata->slot_mapping = + persistent_new_cache_slots(slot_mapping_tokens); + attn_metadata->paged_kv_indptr = + persistent_paged_kv_indptr(static_cast(actual_batch_size)); + attn_metadata->paged_kv_indices = persistent_paged_kv_indices_; + attn_metadata->paged_kv_last_page_len = persistent_paged_kv_last_page_len( + static_cast(actual_batch_size)); + attn_metadata->qo_indptr = + persistent_decode_qo_indptr(static_cast(actual_batch_size)); + } else { + CHECK(params.paged_kv_indptr.defined()) + << "paged_kv_indptr should not be null"; VLOG(kGraphExecutorLogVerboseLevel) - << "copy_ paged_kv_indptr: src values=" << paged_kv_indptr_cpu; + << "copy_ paged_kv_indptr: src shape=" << params.paged_kv_indptr.sizes() + << ", dst slice shape=[" << (actual_batch_size + 1) << "]"; + if (VLOG_IS_ON(kGraphExecutorLogVerboseLevel)) { + torch::Tensor paged_kv_indptr_cpu = + params.paged_kv_indptr.to(torch::kCPU); + VLOG(kGraphExecutorLogVerboseLevel) + << "copy_ paged_kv_indptr: src values=" << paged_kv_indptr_cpu; + } + persistent_paged_kv_indptr_ + .slice(/*dim=*/0, + /*start=*/0, + /*end=*/actual_batch_size + 1) + .copy_(params.paged_kv_indptr, /*non_blocking=*/true); + CHECK(params.paged_kv_indices.defined()) + << "paged_kv_indices should not be null"; + const int64_t actual_indices_size = params.paged_kv_indices.size(0); + VLOG(kGraphExecutorLogVerboseLevel) + << "copy_ paged_kv_indices: src shape=" + << params.paged_kv_indices.sizes() << ", dst slice shape=[" + << actual_indices_size << "]"; + persistent_paged_kv_indices_ + .slice(/*dim=*/0, + /*start=*/0, + /*end=*/actual_indices_size) + .copy_(params.paged_kv_indices, /*non_blocking=*/true); + CHECK(params.paged_kv_last_page_len.defined()) + << "paged_kv_last_page_len should not be null"; + VLOG(kGraphExecutorLogVerboseLevel) + << "copy_ paged_kv_last_page_len: src shape=" + << params.paged_kv_last_page_len.sizes() << ", dst slice shape=[" + << actual_batch_size << "]"; + persistent_paged_kv_last_page_len_ + .slice(/*dim=*/0, + /*start=*/0, + /*end=*/actual_batch_size) + .copy_(params.paged_kv_last_page_len, /*non_blocking=*/true); + attn_metadata->kv_seq_lens = + torch::diff(kv_seq_lens(/*actual_batch_size=*/actual_batch_size + 1)); + attn_metadata->paged_kv_indptr = + persistent_paged_kv_indptr(actual_batch_size); + attn_metadata->paged_kv_indices = persistent_paged_kv_indices_; + attn_metadata->paged_kv_last_page_len = + persistent_paged_kv_last_page_len(actual_batch_size); + attn_metadata->qo_indptr = persistent_decode_qo_indptr(actual_batch_size); } - persistent_paged_kv_indptr_ - .slice(/*dim=*/0, - /*start=*/0, - /*end=*/actual_batch_size + 1) - .copy_(params.paged_kv_indptr, /*non_blocking=*/true); - CHECK(params.paged_kv_indices.defined()) - << "paged_kv_indices should not be null"; - const int64_t actual_indices_size = params.paged_kv_indices.size(0); - VLOG(kGraphExecutorLogVerboseLevel) - << "copy_ paged_kv_indices: src shape=" << params.paged_kv_indices.sizes() - << ", dst slice shape=[" << actual_indices_size << "]"; - persistent_paged_kv_indices_ - .slice(/*dim=*/0, - /*start=*/0, - /*end=*/actual_indices_size) - .copy_(params.paged_kv_indices, /*non_blocking=*/true); - CHECK(params.paged_kv_last_page_len.defined()) - << "paged_kv_last_page_len should not be null"; - VLOG(kGraphExecutorLogVerboseLevel) - << "copy_ paged_kv_last_page_len: src shape=" - << params.paged_kv_last_page_len.sizes() << ", dst slice shape=[" - << actual_batch_size << "]"; - persistent_paged_kv_last_page_len_ - .slice(/*dim=*/0, - /*start=*/0, - /*end=*/actual_batch_size) - .copy_(params.paged_kv_last_page_len, /*non_blocking=*/true); - // Convert cumulative lengths to individual sequence lengths using torch::diff - // This matches the behavior in attention_metadata_builder.cpp for decode mode - attn_metadata->kv_seq_lens = - torch::diff(kv_seq_lens(/*actual_batch_size=*/actual_batch_size + 1)); - // Set FlashInfer decode parameters (always update, not just for capture) - // This ensures attn_metadata points to updated persistent buffers for - // plan_info calculation - attn_metadata->paged_kv_indptr = - persistent_paged_kv_indptr(actual_batch_size); - // Match FlashInfer's CUDAGraph wrapper behavior: always pass the full - // pre-allocated indices buffer and use indptr to delimit valid range. - // This keeps kernel arguments stable across replays. - attn_metadata->paged_kv_indices = persistent_paged_kv_indices_; - attn_metadata->paged_kv_last_page_len = - persistent_paged_kv_last_page_len(actual_batch_size); - // qo_indptr is q_cu_seq_lens in GPU Model. - attn_metadata->qo_indptr = persistent_decode_qo_indptr(actual_batch_size); // Update plan_info if attn_metadata exists and enable_cuda_graph is true // This ensures plan_info is updated before CUDA graph capture/replay { diff --git a/xllm/core/runtime/cuda_graph_executor_impl.h b/xllm/core/runtime/cuda_graph_executor_impl.h index d6662b6fb..b4eb5e708 100644 --- a/xllm/core/runtime/cuda_graph_executor_impl.h +++ b/xllm/core/runtime/cuda_graph_executor_impl.h @@ -36,6 +36,7 @@ limitations under the License. #include "core/framework/kv_cache/kv_cache.h" #include "core/framework/model/causal_lm.h" #include "core/framework/model/model_input_params.h" +#include "core/kernels/cuda/llm_decode_metadata_update.h" #include "core/kernels/cuda/piecewise_graphs.h" #include "executor_impl.h" #include "executor_impl_factory.h" @@ -188,8 +189,25 @@ class CudaGraphPersistentParam { } return persistent_decode_qo_indptr_; } + torch::Tensor persistent_kv_seq_lens_delta(uint32_t actual_batch_size) const { + if (actual_batch_size > 0) { + return persistent_kv_seq_lens_delta_.slice( + /*dim=*/0, /*start=*/0, /*end=*/actual_batch_size); + } + return persistent_kv_seq_lens_delta_; + } private: + bool can_use_llm_decode_fast_path(const torch::Tensor& tokens, + const torch::Tensor& positions, + const ModelInputParams& params) const; + void update_llm_decode_metadata_fast_path(const torch::Tensor& tokens, + const torch::Tensor& positions, + const ModelInputParams& params, + uint32_t padded_num_tokens, + int64_t actual_batch_size, + int64_t actual_num_tokens); + const ModelArgs& args_; const torch::Device& device_; const runtime::Options& options_; @@ -211,6 +229,7 @@ class CudaGraphPersistentParam { torch::Tensor persistent_paged_kv_indices_; torch::Tensor persistent_paged_kv_last_page_len_; torch::Tensor persistent_decode_qo_indptr_; + torch::Tensor persistent_kv_seq_lens_delta_; // TODO maybe not used. or use q_cu_seq_lens instead. torch::Tensor persistent_chunked_prefill_qo_indptr_; diff --git a/xllm/core/runtime/cuda_graph_executor_test.cpp b/xllm/core/runtime/cuda_graph_executor_test.cpp index 09741ed4e..0c63dc57f 100644 --- a/xllm/core/runtime/cuda_graph_executor_test.cpp +++ b/xllm/core/runtime/cuda_graph_executor_test.cpp @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -292,6 +293,263 @@ std::vector MakeHybridKvCaches(const torch::Device& device, KVCache(full_kv.get_k_cache(), full_kv.get_v_cache())}; } +ModelArgs make_test_model_args() { + ModelArgs args; + args.model_type("fake_attn"); + args.dtype("bfloat16"); + args.hidden_size(256); + args.max_position_embeddings(32); + args.vocab_size(2048); + args.n_layers(1); + args.n_heads(2); + args.head_dim(128); + args.n_kv_heads(1); + return args; +} + +runtime::Options make_test_runtime_options(int64_t max_seqs_per_batch) { + runtime::Options options; + options.block_size(1); + options.max_seqs_per_batch(max_seqs_per_batch); + return options; +} + +ModelInputParams make_multi_sequence_decode_params( + const torch::Device& device) { + ModelInputParams p; + p.batch_forward_type = BatchForwardType::DECODE; + p.num_sequences = 2; + p.kv_max_seq_len = 9; + p.q_max_seq_len = 1; + p.enable_cuda_graph = false; + + torch::TensorOptions iopt = + torch::TensorOptions().dtype(torch::kInt32).device(device); + p.q_seq_lens = torch::tensor({0, 1, 2}, iopt); + p.kv_seq_lens = torch::tensor({0, 4, 9}, iopt); + p.q_cu_seq_lens = p.q_seq_lens; + p.new_cache_slots = torch::tensor({5, 7}, iopt); + p.block_tables = torch::tensor({{0, 1, 2, 3}, {4, 5, 6, 7}}, iopt); + p.paged_kv_indptr = torch::tensor({0, 1, 3}, iopt); + p.paged_kv_indices = torch::tensor({2, 4, 6}, iopt); + p.paged_kv_last_page_len = torch::tensor({1, 2}, iopt); + return p; +} + +TEST(CudaGraphExecutorTest, DecodeMetadataFastPathUpdatesPersistentBuffers) { + if (!torch::cuda::is_available()) { + GTEST_SKIP() << "CUDA is not available at runtime."; + } + + const torch::Device device = InitXllmCudaDeviceForTest(/*device_index=*/0); + xllm::layer::flashinfer::FlashinferWorkspace::get_instance().initialize( + device); + ModelArgs args = make_test_model_args(); + runtime::Options options = + make_test_runtime_options(/*max_seqs_per_batch=*/2); + runtime::cuda::CudaGraphPersistentParam persistent(args, device, options); + + torch::TensorOptions iopt = + torch::TensorOptions().dtype(torch::kInt32).device(device); + torch::Tensor tokens = torch::tensor({10, 11}, iopt); + torch::Tensor positions = torch::tensor({20, 21}, iopt); + ModelInputParams params = make_multi_sequence_decode_params(device); + std::vector kv = MakeKvCaches(device, + /*num_pages=*/16, + /*page_size=*/1, + /*num_kv_heads=*/1, + /*head_dim=*/128); + + std::optional updated = + persistent.update(tokens, + kv[0].get_k_cache(), + kv[0].get_v_cache(), + positions, + params, + /*padded_num_tokens=*/4, + /*return_capture_params=*/true); + + ASSERT_TRUE(updated.has_value()); + ASSERT_TRUE(updated->attn_metadata); + + EXPECT_TRUE( + torch::equal(persistent.persistent_tokens(/*actual_tokens=*/4).cpu(), + torch::tensor({10, 11, 0, 0}, torch::dtype(torch::kInt32)))); + EXPECT_TRUE( + torch::equal(persistent.persistent_positions(/*actual_tokens=*/2).cpu(), + torch::tensor({20, 21}, torch::dtype(torch::kInt32)))); + EXPECT_TRUE(torch::equal( + persistent.persistent_new_cache_slots(/*actual_tokens=*/4).cpu(), + torch::tensor({5, 7, 0, 0}, torch::dtype(torch::kInt32)))); + EXPECT_TRUE( + torch::equal(persistent.kv_seq_lens(/*actual_batch_size=*/3).cpu(), + torch::tensor({0, 4, 9}, torch::dtype(torch::kInt32)))); + EXPECT_TRUE(torch::equal( + persistent.persistent_kv_seq_lens_delta(/*actual_batch_size=*/2).cpu(), + torch::tensor({4, 5}, torch::dtype(torch::kInt32)))); + + EXPECT_TRUE( + torch::equal(updated->attn_metadata->q_cu_seq_lens.cpu(), + torch::tensor({0, 1, 2}, torch::dtype(torch::kInt32)))); + EXPECT_TRUE( + torch::equal(updated->attn_metadata->kv_cu_seq_lens.cpu(), + torch::tensor({0, 4, 9}, torch::dtype(torch::kInt32)))); + EXPECT_TRUE(torch::equal(updated->attn_metadata->kv_seq_lens.cpu(), + torch::tensor({4, 5}, torch::dtype(torch::kInt32)))); + EXPECT_TRUE( + torch::equal(updated->attn_metadata->slot_mapping.cpu(), + torch::tensor({5, 7, 0, 0}, torch::dtype(torch::kInt32)))); + EXPECT_TRUE( + torch::equal(updated->attn_metadata->paged_kv_indptr.cpu(), + torch::tensor({0, 1, 3}, torch::dtype(torch::kInt32)))); + EXPECT_TRUE( + torch::equal(updated->attn_metadata->paged_kv_indices + .slice(/*dim=*/0, + /*start=*/0, + /*end=*/3) + .cpu(), + torch::tensor({2, 4, 6}, torch::dtype(torch::kInt32)))); + EXPECT_TRUE(torch::equal(updated->attn_metadata->paged_kv_last_page_len.cpu(), + torch::tensor({1, 2}, torch::dtype(torch::kInt32)))); + ASSERT_TRUE(updated->attn_metadata->qo_indptr.has_value()); + EXPECT_TRUE( + torch::equal(updated->attn_metadata->qo_indptr.value().cpu(), + torch::tensor({0, 1, 2}, torch::dtype(torch::kInt32)))); + EXPECT_TRUE(torch::equal(updated->attn_metadata->block_table.cpu(), + params.block_tables.cpu())); +} + +TEST(CudaGraphExecutorTest, DecodeMetadataFastPathUpdatesLinearStateIndices) { + if (!torch::cuda::is_available()) { + GTEST_SKIP() << "CUDA is not available at runtime."; + } + + const torch::Device device = InitXllmCudaDeviceForTest(/*device_index=*/0); + xllm::layer::flashinfer::FlashinferWorkspace::get_instance().initialize( + device); + ModelArgs args = make_test_model_args(); + runtime::Options options = + make_test_runtime_options(/*max_seqs_per_batch=*/2); + runtime::cuda::CudaGraphPersistentParam persistent(args, device, options); + + torch::TensorOptions iopt = + torch::TensorOptions().dtype(torch::kInt32).device(device); + torch::Tensor tokens = torch::tensor({10, 11}, iopt); + torch::Tensor positions = torch::tensor({20, 21}, iopt); + ModelInputParams params = make_multi_sequence_decode_params(device); + params.linear_state_ids = {8, 6}; + params.linear_state_indices = torch::tensor({8, 6}, iopt); + std::vector kv = MakeKvCaches(device, + /*num_pages=*/16, + /*page_size=*/1, + /*num_kv_heads=*/1, + /*head_dim=*/128); + + std::optional updated = + persistent.update(tokens, + kv[0].get_k_cache(), + kv[0].get_v_cache(), + positions, + params, + /*padded_num_tokens=*/4, + /*return_capture_params=*/true); + + ASSERT_TRUE(updated.has_value()); + EXPECT_TRUE(torch::equal( + persistent.persistent_linear_state_indices(/*actual_batch_size=*/2).cpu(), + torch::tensor({8, 6}, torch::dtype(torch::kInt32)))); + EXPECT_TRUE(torch::equal(updated->linear_state_indices.cpu(), + torch::tensor({8, 6}, torch::dtype(torch::kInt32)))); +} + +TEST(CudaGraphExecutorTest, DecodeMetadataFastPathFallbackMatchesLegacyPath) { + if (!torch::cuda::is_available()) { + GTEST_SKIP() << "CUDA is not available at runtime."; + } + + const torch::Device device = InitXllmCudaDeviceForTest(/*device_index=*/0); + xllm::layer::flashinfer::FlashinferWorkspace::get_instance().initialize( + device); + ModelArgs args = make_test_model_args(); + runtime::Options options = + make_test_runtime_options(/*max_seqs_per_batch=*/2); + runtime::cuda::CudaGraphPersistentParam fast_path_persistent( + args, device, options); + runtime::cuda::CudaGraphPersistentParam fallback_persistent( + args, device, options); + + torch::TensorOptions iopt = + torch::TensorOptions().dtype(torch::kInt32).device(device); + torch::Tensor tokens = torch::tensor({10, 11}, iopt); + torch::Tensor positions = torch::tensor({20, 21}, iopt); + ModelInputParams fast_params = make_multi_sequence_decode_params(device); + ModelInputParams fallback_params = make_multi_sequence_decode_params(device); + torch::Tensor new_cache_slots_base = + torch::tensor({5, 99, 7, 88}, iopt).view({2, 2}); + fallback_params.new_cache_slots = new_cache_slots_base.select(1, 0); + ASSERT_FALSE(fallback_params.new_cache_slots.is_contiguous()); + + std::vector kv = MakeKvCaches(device, + /*num_pages=*/16, + /*page_size=*/1, + /*num_kv_heads=*/1, + /*head_dim=*/128); + + std::optional fast_updated = + fast_path_persistent.update(tokens, + kv[0].get_k_cache(), + kv[0].get_v_cache(), + positions, + fast_params, + /*padded_num_tokens=*/4, + /*return_capture_params=*/true); + std::optional fallback_updated = + fallback_persistent.update(tokens, + kv[0].get_k_cache(), + kv[0].get_v_cache(), + positions, + fallback_params, + /*padded_num_tokens=*/4, + /*return_capture_params=*/true); + + ASSERT_TRUE(fast_updated.has_value()); + ASSERT_TRUE(fallback_updated.has_value()); + ASSERT_TRUE(fast_updated->attn_metadata); + ASSERT_TRUE(fallback_updated->attn_metadata); + + EXPECT_TRUE(torch::equal( + fast_path_persistent.persistent_tokens(/*actual_tokens=*/4).cpu(), + fallback_persistent.persistent_tokens(/*actual_tokens=*/4).cpu())); + EXPECT_TRUE(torch::equal( + fast_path_persistent.persistent_positions(/*actual_tokens=*/2).cpu(), + fallback_persistent.persistent_positions(/*actual_tokens=*/2).cpu())); + EXPECT_TRUE(torch::equal( + fast_path_persistent.persistent_new_cache_slots(/*actual_tokens=*/4) + .cpu(), + fallback_persistent.persistent_new_cache_slots(/*actual_tokens=*/4) + .cpu())); + EXPECT_TRUE(torch::equal( + fast_path_persistent.kv_seq_lens(/*actual_batch_size=*/3).cpu(), + fallback_persistent.kv_seq_lens(/*actual_batch_size=*/3).cpu())); + EXPECT_TRUE(torch::equal(fast_updated->attn_metadata->kv_seq_lens.cpu(), + fallback_updated->attn_metadata->kv_seq_lens.cpu())); + EXPECT_TRUE( + torch::equal(fast_updated->attn_metadata->q_cu_seq_lens.cpu(), + fallback_updated->attn_metadata->q_cu_seq_lens.cpu())); + EXPECT_TRUE( + torch::equal(fast_updated->attn_metadata->slot_mapping.cpu(), + fallback_updated->attn_metadata->slot_mapping.cpu())); + EXPECT_TRUE( + torch::equal(fast_updated->attn_metadata->paged_kv_indptr.cpu(), + fallback_updated->attn_metadata->paged_kv_indptr.cpu())); + EXPECT_TRUE( + torch::equal(fast_updated->attn_metadata->paged_kv_indices.cpu(), + fallback_updated->attn_metadata->paged_kv_indices.cpu())); + EXPECT_TRUE(torch::equal( + fast_updated->attn_metadata->paged_kv_last_page_len.cpu(), + fallback_updated->attn_metadata->paged_kv_last_page_len.cpu())); +} + TEST(CudaGraphExecutorTest, BatchDecodeCaptureAndReplay) { if (!torch::cuda::is_available()) { GTEST_SKIP() << "CUDA is not available at runtime.";