Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xllm/core/distributed_runtime/llm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ limitations under the License.
#include "common/options.h"
#include "framework/block/hierarchy_block_manager_pool.h"
#include "framework/kv_cache/kv_cache_shape.h"
#include "framework/kv_cache/kv_cache_utils.h"
#include "framework/model/model_args.h"
#include "framework/model_loader.h"
#include "framework/xtensor/page_allocator.h"
Expand Down Expand Up @@ -456,7 +457,7 @@ KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {

if (options_.enable_mla()) {
#if defined(USE_NPU)
if (args_.model_type() == "deepseek_v3" && FLAGS_enable_prefix_cache) {
if (use_npu_nz_kv_cache_layout(args_.model_type())) {
slot_size =
cache_dtype_size *
((args_.kv_lora_rank() + NZ_ALIGNMENT - 1) / NZ_ALIGNMENT +
Expand Down
5 changes: 3 additions & 2 deletions xllm/core/framework/kv_cache/kv_cache_shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include <utility>

#include "common/global_flags.h"
#include "framework/kv_cache/kv_cache_utils.h"
#include "worker.pb.h"

namespace xllm {
Expand Down Expand Up @@ -209,7 +210,7 @@ void KVCacheShape::init_key_cache_shape(const KVCacheCapacity& kv_cache_cap,
int64_t world_size) {
if (model_args.enable_mla()) {
#if defined(USE_NPU)
if (model_args.model_type() == "deepseek_v3" && FLAGS_enable_prefix_cache) {
if (use_npu_nz_kv_cache_layout(model_args.model_type())) {
key_cache_shape_ = std::vector<int64_t>{
kv_cache_cap.n_blocks(),
ceil_div(model_args.kv_lora_rank(), kNzAlignment),
Expand Down Expand Up @@ -240,7 +241,7 @@ void KVCacheShape::init_value_cache_shape(const KVCacheCapacity& kv_cache_cap,
int64_t world_size) {
if (model_args.enable_mla()) {
#if defined(USE_NPU)
if (model_args.model_type() == "deepseek_v3" && FLAGS_enable_prefix_cache) {
if (use_npu_nz_kv_cache_layout(model_args.model_type())) {
value_cache_shape_ = std::vector<int64_t>{
kv_cache_cap.n_blocks(),
ceil_div(model_args.qk_rope_head_dim(), kNzAlignment),
Expand Down
10 changes: 7 additions & 3 deletions xllm/core/framework/kv_cache/kv_cache_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ bool is_linear_attention_layer(int64_t layer_idx,
return (layer_idx + 1) % full_attention_interval != 0;
}

bool use_npu_nz_kv_cache_layout(const std::string& model_type) {
return (model_type == "deepseek_v3" || model_type == "deepseek_v3_mtp") &&
FLAGS_enable_prefix_cache;
}

KVCacheTensors create_kv_cache_tensors(
const KVCacheShape& kv_cache_shape,
const KVCacheCreateOptions& create_options) {
Expand Down Expand Up @@ -152,9 +157,8 @@ LinearAttentionKVCacheTensors create_linear_attention_kv_cache_tensors(

#if defined(USE_NPU)
aclFormat get_npu_kv_cache_format(const std::string& model_type) {
return model_type == "deepseek_v3" && FLAGS_enable_prefix_cache
? ACL_FORMAT_FRACTAL_NZ
: ACL_FORMAT_ND;
return use_npu_nz_kv_cache_layout(model_type) ? ACL_FORMAT_FRACTAL_NZ
: ACL_FORMAT_ND;
}
#endif

Expand Down
3 changes: 3 additions & 0 deletions xllm/core/framework/kv_cache/kv_cache_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ struct LinearAttentionKVCacheTensors {
bool is_linear_attention_layer(int64_t layer_idx,
int64_t full_attention_interval);

// Whether NPU KV cache should use FRACTAL_NZ layout for a model type.
bool use_npu_nz_kv_cache_layout(const std::string& model_type);

KVCacheTensors create_kv_cache_tensors(
const KVCacheShape& kv_cache_shape,
const KVCacheCreateOptions& create_options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <glog/logging.h>

#include "framework/kv_cache/kv_cache_utils.h"
#include "util/net.h"

namespace xllm {
Expand Down Expand Up @@ -155,10 +156,7 @@ void LlmDataDistTransfer::allocate_kv_cache(
}

// convert memory addrs to torch tensors
aclFormat npu_format_type =
model_type_ == "deepseek_v3" && FLAGS_enable_prefix_cache
? ACL_FORMAT_FRACTAL_NZ
: ACL_FORMAT_ND;
aclFormat npu_format_type = get_npu_kv_cache_format(model_type_);

auto k_torch_tensors = convert_to_torch_tensor(
key_cache_shape, dtype, k_cache_.tensor_addrs, npu_format_type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#endif

#include "common/global_flags.h"
#include "framework/kv_cache/kv_cache_utils.h"
#include "framework/xtensor/global_xtensor.h"
#include "framework/xtensor/xtensor_allocator.h"
#include "util/net.h"
Expand Down Expand Up @@ -211,10 +212,7 @@ void MooncakeKVCacheTransferDefault::allocate_kv_cache_impl(
}

// convert memory addrs to torch tensors
aclFormat npu_format_type =
model_type_ == "deepseek_v3" && FLAGS_enable_prefix_cache
? ACL_FORMAT_FRACTAL_NZ
: ACL_FORMAT_ND;
aclFormat npu_format_type = get_npu_kv_cache_format(model_type_);
auto k_torch_tensors = convert_to_torch_tensor(
key_cache_shape, dtype, k_tensor_addrs, npu_format_type);
auto v_torch_tensors = convert_to_torch_tensor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ limitations under the License.
#include <glog/logging.h>
#include <torch_npu/csrc/core/npu/NPUFormat.h>

#include "framework/kv_cache/kv_cache_utils.h"

namespace xllm {
namespace {
#define CHECK_LDD_RET(ret) \
Expand Down Expand Up @@ -133,10 +135,7 @@ void SpecKVCacheTransfer::allocate_kv_cache_internal(
}

// convert memory addrs to torch tensors
aclFormat npu_format_type =
model_type_ == "deepseek_v3" && FLAGS_enable_prefix_cache
? ACL_FORMAT_FRACTAL_NZ
: ACL_FORMAT_ND;
aclFormat npu_format_type = get_npu_kv_cache_format(model_type_);
auto k_torch_tensors = convert_to_torch_tensor(
key_cache_shape, dtype, k_cache.tensor_addrs, npu_format_type);
auto v_torch_tensors = convert_to_torch_tensor(
Expand Down
Loading