Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xllm/core/distributed_runtime/llm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ bool LLMEngine::init_model(MasterStatus master_status) {
CHECK(tokenizer_ != nullptr);

args_ = model_loader->model_args();
args_.enable_mla(options_.enable_mla());
quant_args_ = model_loader->quant_args();
tokenizer_args_ = model_loader->tokenizer_args();

Expand Down
1 change: 0 additions & 1 deletion xllm/core/framework/kv_cache/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
include(cc_library)
include(cc_test)


cc_library(
NAME
kv_cache
Expand Down
4 changes: 3 additions & 1 deletion xllm/core/framework/kv_cache/indexed_kv_cache_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ IndexedKVCacheImpl::IndexedKVCacheImpl(
: IndexedKVCacheImpl(
create_indexed_kv_cache_tensors(kv_cache_shape, create_options)) {
key_cache_shape_ = kv_cache_shape.key_cache_shape();
value_cache_shape_ = kv_cache_shape.value_cache_shape();
if (kv_cache_shape.has_value_cache_shape()) {
value_cache_shape_ = kv_cache_shape.value_cache_shape();
}
index_cache_shape_ = kv_cache_shape.index_cache_shape();
}

Expand Down
4 changes: 3 additions & 1 deletion xllm/core/framework/kv_cache/kv_cache_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ KVCacheImpl::KVCacheImpl(const KVCacheShape& kv_cache_shape,
const KVCacheCreateOptions& create_options)
: KVCacheImpl(create_kv_cache_tensors(kv_cache_shape, create_options)) {
key_cache_shape_ = kv_cache_shape.key_cache_shape();
value_cache_shape_ = kv_cache_shape.value_cache_shape();
if (kv_cache_shape.has_value_cache_shape()) {
value_cache_shape_ = kv_cache_shape.value_cache_shape();
}
}

torch::Tensor KVCacheImpl::get_k_cache() const { return key_cache_; }
Expand Down
19 changes: 12 additions & 7 deletions xllm/core/framework/kv_cache/kv_cache_shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ const std::vector<int64_t>& KVCacheShape::key_cache_shape() const {
}

const std::vector<int64_t>& KVCacheShape::value_cache_shape() const {
CHECK(value_cache_shape_.has_value())
<< "value_cache_shape is not initialized.";
if (!value_cache_shape_.has_value()) {
return empty_shape();
}
return *value_cache_shape_;
}

Expand Down Expand Up @@ -169,8 +170,10 @@ void KVCacheShape::to_proto(proto::KVCacheShape* proto_shape) const {
CHECK(proto_shape != nullptr) << "proto_shape must not be nullptr.";
proto_shape->Clear();
add_shape_to_proto(key_cache_shape(), proto_shape->mutable_key_cache_shape());
add_shape_to_proto(value_cache_shape(),
proto_shape->mutable_value_cache_shape());
if (has_value_cache_shape()) {
add_shape_to_proto(value_cache_shape(),
proto_shape->mutable_value_cache_shape());
}
if (has_index_cache_shape()) {
add_shape_to_proto(index_cache_shape(),
proto_shape->mutable_index_cache_shape());
Expand All @@ -187,8 +190,10 @@ KVCacheShape KVCacheShape::from_proto(const proto::KVCacheShape& proto_shape) {
KVCacheShape kv_cache_shape;
kv_cache_shape.key_cache_shape_ =
repeated_field_to_vector(proto_shape.key_cache_shape());
kv_cache_shape.value_cache_shape_ =
repeated_field_to_vector(proto_shape.value_cache_shape());
if (proto_shape.value_cache_shape_size() > 0) {
kv_cache_shape.value_cache_shape_ =
repeated_field_to_vector(proto_shape.value_cache_shape());
}
if (proto_shape.index_cache_shape_size() > 0) {
kv_cache_shape.index_cache_shape_ =
repeated_field_to_vector(proto_shape.index_cache_shape());
Expand Down Expand Up @@ -323,7 +328,7 @@ void KVCacheShape::apply_device_layout(const ModelArgs& model_args) {
CHECK_GE(key_cache_shape_->size(), 4) << "invalid mla key_cache_shape.";
(*key_cache_shape_)[3] =
model_args.kv_lora_rank() + model_args.qk_rope_head_dim();
value_cache_shape_ = std::vector<int64_t>{};
value_cache_shape_.reset();
}
#else
static_cast<void>(model_args);
Expand Down
46 changes: 46 additions & 0 deletions xllm/core/framework/kv_cache_transfer/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
include(cc_library)
include(cc_test)

cc_library(
NAME
push_route
HDRS
push_route.h
SRCS
push_route.cpp
)

cc_library(
NAME
pd_topology_guard
HDRS
pd_topology_guard.h
SRCS
pd_topology_guard.cpp
DEPS
:common
glog::glog
)

cc_library(
NAME
Expand All @@ -26,6 +46,7 @@ cc_library(
DEPS
:common
:kv_cache
:push_route
:xtensor
$<$<BOOL:${USE_NPU}>:graph>
glog::glog
Expand All @@ -38,6 +59,26 @@ cc_library(
$<$<BOOL:${USE_NPU}>:platform_npu>
)

cc_test(
NAME
push_route_test
SRCS
push_route_test.cpp
DEPS
:push_route
GTest::gtest_main
)

cc_test(
NAME
pd_topology_guard_test
SRCS
pd_topology_guard_test.cpp
DEPS
:pd_topology_guard
GTest::gtest_main
)

if(USE_NPU OR USE_MLU)
cc_test(
NAME
Expand All @@ -46,6 +87,11 @@ cc_test(
mooncake_transfer_engine_test.cpp
DEPS
:kv_cache_transfer
:xllm_server
GTest::gtest_main
)

# Resolve static link order between xtensor and xllm_server for this test target.
target_link_libraries(mooncake_transfer_engine_test PRIVATE
"$<LINK_GROUP:RESCAN,xtensor,xllm_server>")
endif()
155 changes: 144 additions & 11 deletions xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <glog/logging.h>

#include <numeric>
#include <unordered_set>

#if defined(USE_NPU)
#ifdef TORCH_HIGHER_THAN_PTA6
Expand All @@ -30,12 +31,100 @@ limitations under the License.
#endif

#include "common/global_flags.h"
#include "framework/kv_cache_transfer/push_route.h"
#include "framework/xtensor/global_xtensor.h"
#include "framework/xtensor/xtensor_allocator.h"
#include "platform/mlu/mlu_tensor_alloc.h"
#include "util/net.h"

namespace xllm {

namespace {

std::string get_merge_key(const uint64_t dst_cluster_id,
const std::string& dst_addr,
const int64_t k_cache_id,
const int64_t v_cache_id) {
return std::to_string(dst_cluster_id) + "_" + dst_addr + "_" +
std::to_string(k_cache_id) + "_" + std::to_string(v_cache_id);
}

void merge_xtensor_offsets(
std::vector<XTensorLayerOffsets>& merged_layer_offsets,
const std::vector<XTensorLayerOffsets>& layer_offsets) {
if (layer_offsets.empty()) {
return;
}
if (merged_layer_offsets.empty()) {
merged_layer_offsets = layer_offsets;
return;
}

for (size_t layer_id = 0; layer_id < layer_offsets.size() &&
layer_id < merged_layer_offsets.size();
++layer_id) {
std::vector<uint64_t>& k_target = merged_layer_offsets[layer_id].k_offsets;
const std::vector<uint64_t>& k_source = layer_offsets[layer_id].k_offsets;
k_target.reserve(k_target.size() + k_source.size());
k_target.insert(k_target.end(), k_source.begin(), k_source.end());

std::vector<uint64_t>& v_target = merged_layer_offsets[layer_id].v_offsets;
const std::vector<uint64_t>& v_source = layer_offsets[layer_id].v_offsets;
v_target.reserve(v_target.size() + v_source.size());
v_target.insert(v_target.end(), v_source.begin(), v_source.end());
}
}

void merge_kv_info(
std::unordered_map<std::string, KVCacheTransfer::KVCacheInfo>&
merged_kv_infos,
const TransferKVInfo& info,
const int32_t dst_rank) {
uint64_t dst_cluster_id = info.remote_instance_info.cluster_ids[dst_rank];
const std::string& dst_addr = info.remote_instance_info.addrs[dst_rank];
int64_t k_cache_id = info.remote_instance_info.k_cache_ids[dst_rank];
int64_t v_cache_id = info.remote_instance_info.v_cache_ids[dst_rank];
std::string key =
get_merge_key(dst_cluster_id, dst_addr, k_cache_id, v_cache_id);

auto it = merged_kv_infos.find(key);
if (it == merged_kv_infos.end()) {
KVCacheTransfer::KVCacheInfo kv_info;
kv_info.dst_cluster_id = dst_cluster_id;
kv_info.dst_addr = dst_addr;
kv_info.dst_k_cache_id = k_cache_id;
kv_info.dst_v_cache_id = v_cache_id;
kv_info.src_blocks.reserve(info.local_blocks_ids.size());
kv_info.src_blocks.insert(kv_info.src_blocks.end(),
info.local_blocks_ids.begin(),
info.local_blocks_ids.end());
kv_info.dst_blocks.reserve(info.remote_blocks_ids.size());
kv_info.dst_blocks.insert(kv_info.dst_blocks.end(),
info.remote_blocks_ids.begin(),
info.remote_blocks_ids.end());
merge_xtensor_offsets(kv_info.dst_xtensor_layer_offsets,
info.dst_xtensor_layer_offsets);
merged_kv_infos.emplace(key, std::move(kv_info));
return;
}

std::vector<uint64_t>& src_blocks = it->second.src_blocks;
src_blocks.reserve(src_blocks.size() + info.local_blocks_ids.size());
src_blocks.insert(src_blocks.end(),
info.local_blocks_ids.begin(),
info.local_blocks_ids.end());

std::vector<uint64_t>& dst_blocks = it->second.dst_blocks;
dst_blocks.reserve(dst_blocks.size() + info.remote_blocks_ids.size());
dst_blocks.insert(dst_blocks.end(),
info.remote_blocks_ids.begin(),
info.remote_blocks_ids.end());
merge_xtensor_offsets(it->second.dst_xtensor_layer_offsets,
info.dst_xtensor_layer_offsets);
}

} // namespace

// ============================================================================
// MooncakeKVCacheTransferBase
// ============================================================================
Expand Down Expand Up @@ -153,15 +242,17 @@ void MooncakeKVCacheTransferDefault::allocate_kv_cache_impl(
const std::vector<int64_t>& value_cache_shape =
kv_cache_shape.value_cache_shape();
#if defined(USE_MLU)
torch::TensorOptions options =
torch::TensorOptions().dtype(dtype).device(device_);
for (int64_t i = 0; i < num_layers; ++i) {
torch::Tensor key_cache = torch::zeros(key_cache_shape, options);
torch::Tensor key_cache =
mlu::alloc_zero_tensor(key_cache_shape, dtype, device_);
torch::Tensor value_cache;
torch::Tensor index_cache;
value_cache = torch::zeros(value_cache_shape, options);
if (kv_cache_shape.has_value_cache_shape()) {
value_cache = mlu::alloc_zero_tensor(value_cache_shape, dtype, device_);
}
if (kv_cache_shape.has_index_cache_shape()) {
index_cache = torch::zeros(kv_cache_shape.index_cache_shape(), options);
index_cache = mlu::alloc_zero_tensor(
kv_cache_shape.index_cache_shape(), dtype, device_);
}
if (index_cache.defined()) {
kv_caches.emplace_back(IndexedKVCacheTensors{
Expand Down Expand Up @@ -293,8 +384,7 @@ void MooncakeKVCacheTransferDefault::register_kv_cache_impl(
}

if (!mooncake_te_->register_memory(addrs, lens, buf_bytes)) {
LOG(ERROR) << "register_kv_cache_impl failed";
return;
LOG(FATAL) << "register_kv_cache_impl failed";
}

LOG(INFO) << "register_kv_cache_impl success, num_layers=" << num_layers_
Expand Down Expand Up @@ -322,6 +412,51 @@ bool MooncakeKVCacheTransferDefault::pull_kv_blocks(
return true;
}

void MooncakeKVCacheTransferDefault::merge_kv_blocks(
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
const std::vector<TransferKVInfo>& transfer_kv_infos,
const ParallelArgs& parallel_args) {
#if !defined(USE_MLU)
KVCacheTransfer::merge_kv_blocks(
merged_kv_infos, transfer_kv_infos, parallel_args);
#else
if (has_v_cache_) {
KVCacheTransfer::merge_kv_blocks(
merged_kv_infos, transfer_kv_infos, parallel_args);
return;
}

int32_t src_rank = parallel_args.rank();
int32_t src_dp_size = parallel_args.dp_size();
int32_t src_world_size = parallel_args.world_size();
int32_t src_tp_size = src_world_size / src_dp_size;
int32_t src_tp_rank = src_rank % src_tp_size;

for (const TransferKVInfo& info : transfer_kv_infos) {
int32_t dst_dp_rank = info.dp_rank;
int32_t dst_dp_size = info.remote_instance_info.dp_size;
int32_t dst_world_size =
static_cast<int32_t>(info.remote_instance_info.cluster_ids.size());
int32_t dst_tp_size = dst_world_size / dst_dp_size;

std::unordered_set<int32_t> linked_dp_ranks;
for (int32_t i = src_tp_rank; i < dst_world_size; i += src_tp_size) {
int32_t linked_dp_rank = i / dst_tp_size;
linked_dp_ranks.emplace(linked_dp_rank);
}
if (linked_dp_ranks.find(dst_dp_rank) == linked_dp_ranks.end()) {
continue;
}

std::vector<int32_t> dst_ranks =
get_dst_ranks(src_tp_rank, src_tp_size, dst_tp_size, dst_dp_rank);
for (int32_t dst_rank : dst_ranks) {
merge_kv_info(merged_kv_infos, info, dst_rank);
}
}
#endif
}

bool MooncakeKVCacheTransferDefault::push_kv_blocks(
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
std::shared_ptr<KVPushSynchronizerImpl>& layer_synchronizer,
Expand Down Expand Up @@ -423,8 +558,7 @@ void MooncakeKVCacheTransferXTensor::register_kv_cache_impl() {
// XTensor mode registers one shared GlobalXTensor memory region.
auto& global_xtensor = GlobalXTensor::get_instance();
if (!global_xtensor.is_initialized()) {
LOG(ERROR) << "GlobalXTensor not initialized in xtensor mode";
return;
LOG(FATAL) << "GlobalXTensor not initialized in xtensor mode";
}

if (global_xtensor.is_mooncake_registered()) {
Expand All @@ -437,8 +571,7 @@ void MooncakeKVCacheTransferXTensor::register_kv_cache_impl() {
std::vector<uint64_t> buf_bytes = {static_cast<uint64_t>(size_per_block_)};

if (!mooncake_te_->register_memory(addrs, lens, buf_bytes)) {
LOG(ERROR) << "register GlobalXTensor failed";
return;
LOG(FATAL) << "register GlobalXTensor failed";
}

global_xtensor.set_mooncake_registered(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ class MooncakeKVCacheTransferDefault final
const std::vector<uint64_t>& src_blocks,
const std::vector<uint64_t>& dst_blocks) override;

void merge_kv_blocks(
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
const std::vector<TransferKVInfo>& transfer_kv_infos,
const ParallelArgs& parallel_args) override;

bool push_kv_blocks(
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
std::shared_ptr<KVPushSynchronizerImpl>& layer_synchronizer,
Expand Down
Loading
Loading