diff --git a/xllm/core/distributed_runtime/llm_engine.cpp b/xllm/core/distributed_runtime/llm_engine.cpp index e841923c7..346d1d514 100644 --- a/xllm/core/distributed_runtime/llm_engine.cpp +++ b/xllm/core/distributed_runtime/llm_engine.cpp @@ -34,6 +34,7 @@ limitations under the License. #include "common/options.h" #include "framework/block/hierarchy_block_manager_pool.h" #include "framework/kv_cache/kv_cache_shape.h" +#include "framework/kv_cache/kv_cache_utils.h" #include "framework/model/model_args.h" #include "framework/model_loader.h" #include "framework/xtensor/page_allocator.h" @@ -456,7 +457,7 @@ KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() { if (options_.enable_mla()) { #if defined(USE_NPU) - if (args_.model_type() == "deepseek_v3" && FLAGS_enable_prefix_cache) { + if (use_npu_nz_kv_cache_layout(args_.model_type())) { slot_size = cache_dtype_size * ((args_.kv_lora_rank() + NZ_ALIGNMENT - 1) / NZ_ALIGNMENT + diff --git a/xllm/core/framework/kv_cache/kv_cache_shape.cpp b/xllm/core/framework/kv_cache/kv_cache_shape.cpp index dc3fecbd5..522297cd3 100644 --- a/xllm/core/framework/kv_cache/kv_cache_shape.cpp +++ b/xllm/core/framework/kv_cache/kv_cache_shape.cpp @@ -22,6 +22,7 @@ limitations under the License. #include #include "common/global_flags.h" +#include "framework/kv_cache/kv_cache_utils.h" #include "worker.pb.h" namespace xllm { @@ -209,7 +210,7 @@ void KVCacheShape::init_key_cache_shape(const KVCacheCapacity& kv_cache_cap, int64_t world_size) { if (model_args.enable_mla()) { #if defined(USE_NPU) - if (model_args.model_type() == "deepseek_v3" && FLAGS_enable_prefix_cache) { + if (use_npu_nz_kv_cache_layout(model_args.model_type())) { key_cache_shape_ = std::vector{ kv_cache_cap.n_blocks(), ceil_div(model_args.kv_lora_rank(), kNzAlignment), @@ -240,7 +241,7 @@ void KVCacheShape::init_value_cache_shape(const KVCacheCapacity& kv_cache_cap, int64_t world_size) { if (model_args.enable_mla()) { #if defined(USE_NPU) - if (model_args.model_type() == "deepseek_v3" && FLAGS_enable_prefix_cache) { + if (use_npu_nz_kv_cache_layout(model_args.model_type())) { value_cache_shape_ = std::vector{ kv_cache_cap.n_blocks(), ceil_div(model_args.qk_rope_head_dim(), kNzAlignment), diff --git a/xllm/core/framework/kv_cache/kv_cache_utils.cpp b/xllm/core/framework/kv_cache/kv_cache_utils.cpp index 9b5c8173a..285772fe9 100644 --- a/xllm/core/framework/kv_cache/kv_cache_utils.cpp +++ b/xllm/core/framework/kv_cache/kv_cache_utils.cpp @@ -27,6 +27,11 @@ bool is_linear_attention_layer(int64_t layer_idx, return (layer_idx + 1) % full_attention_interval != 0; } +bool use_npu_nz_kv_cache_layout(const std::string& model_type) { + return (model_type == "deepseek_v3" || model_type == "deepseek_v3_mtp") && + FLAGS_enable_prefix_cache; +} + KVCacheTensors create_kv_cache_tensors( const KVCacheShape& kv_cache_shape, const KVCacheCreateOptions& create_options) { @@ -152,9 +157,8 @@ LinearAttentionKVCacheTensors create_linear_attention_kv_cache_tensors( #if defined(USE_NPU) aclFormat get_npu_kv_cache_format(const std::string& model_type) { - return model_type == "deepseek_v3" && FLAGS_enable_prefix_cache - ? ACL_FORMAT_FRACTAL_NZ - : ACL_FORMAT_ND; + return use_npu_nz_kv_cache_layout(model_type) ? ACL_FORMAT_FRACTAL_NZ + : ACL_FORMAT_ND; } #endif diff --git a/xllm/core/framework/kv_cache/kv_cache_utils.h b/xllm/core/framework/kv_cache/kv_cache_utils.h index 8ec003bfa..4196078fe 100644 --- a/xllm/core/framework/kv_cache/kv_cache_utils.h +++ b/xllm/core/framework/kv_cache/kv_cache_utils.h @@ -101,6 +101,9 @@ struct LinearAttentionKVCacheTensors { bool is_linear_attention_layer(int64_t layer_idx, int64_t full_attention_interval); +// Whether NPU KV cache should use FRACTAL_NZ layout for a model type. +bool use_npu_nz_kv_cache_layout(const std::string& model_type); + KVCacheTensors create_kv_cache_tensors( const KVCacheShape& kv_cache_shape, const KVCacheCreateOptions& create_options); diff --git a/xllm/core/framework/kv_cache_transfer/llm_data_dist_transfer.cpp b/xllm/core/framework/kv_cache_transfer/llm_data_dist_transfer.cpp index c57a9b330..7293d67a7 100644 --- a/xllm/core/framework/kv_cache_transfer/llm_data_dist_transfer.cpp +++ b/xllm/core/framework/kv_cache_transfer/llm_data_dist_transfer.cpp @@ -17,6 +17,7 @@ limitations under the License. #include +#include "framework/kv_cache/kv_cache_utils.h" #include "util/net.h" namespace xllm { @@ -155,10 +156,7 @@ void LlmDataDistTransfer::allocate_kv_cache( } // convert memory addrs to torch tensors - aclFormat npu_format_type = - model_type_ == "deepseek_v3" && FLAGS_enable_prefix_cache - ? ACL_FORMAT_FRACTAL_NZ - : ACL_FORMAT_ND; + aclFormat npu_format_type = get_npu_kv_cache_format(model_type_); auto k_torch_tensors = convert_to_torch_tensor( key_cache_shape, dtype, k_cache_.tensor_addrs, npu_format_type); diff --git a/xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.cpp b/xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.cpp index a71cfc802..0aade75e2 100644 --- a/xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.cpp +++ b/xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.cpp @@ -30,6 +30,7 @@ limitations under the License. #endif #include "common/global_flags.h" +#include "framework/kv_cache/kv_cache_utils.h" #include "framework/xtensor/global_xtensor.h" #include "framework/xtensor/xtensor_allocator.h" #include "util/net.h" @@ -211,10 +212,7 @@ void MooncakeKVCacheTransferDefault::allocate_kv_cache_impl( } // convert memory addrs to torch tensors - aclFormat npu_format_type = - model_type_ == "deepseek_v3" && FLAGS_enable_prefix_cache - ? ACL_FORMAT_FRACTAL_NZ - : ACL_FORMAT_ND; + aclFormat npu_format_type = get_npu_kv_cache_format(model_type_); auto k_torch_tensors = convert_to_torch_tensor( key_cache_shape, dtype, k_tensor_addrs, npu_format_type); auto v_torch_tensors = convert_to_torch_tensor( diff --git a/xllm/core/framework/kv_cache_transfer/spec_kv_cache_transfer.cpp b/xllm/core/framework/kv_cache_transfer/spec_kv_cache_transfer.cpp index d199ce88b..b1bc56530 100644 --- a/xllm/core/framework/kv_cache_transfer/spec_kv_cache_transfer.cpp +++ b/xllm/core/framework/kv_cache_transfer/spec_kv_cache_transfer.cpp @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "framework/kv_cache/kv_cache_utils.h" + namespace xllm { namespace { #define CHECK_LDD_RET(ret) \ @@ -133,10 +135,7 @@ void SpecKVCacheTransfer::allocate_kv_cache_internal( } // convert memory addrs to torch tensors - aclFormat npu_format_type = - model_type_ == "deepseek_v3" && FLAGS_enable_prefix_cache - ? ACL_FORMAT_FRACTAL_NZ - : ACL_FORMAT_ND; + aclFormat npu_format_type = get_npu_kv_cache_format(model_type_); auto k_torch_tensors = convert_to_torch_tensor( key_cache_shape, dtype, k_cache.tensor_addrs, npu_format_type); auto v_torch_tensors = convert_to_torch_tensor(