diff --git a/xllm/core/distributed_runtime/llm_engine.cpp b/xllm/core/distributed_runtime/llm_engine.cpp index e841923c7..0e467f59d 100644 --- a/xllm/core/distributed_runtime/llm_engine.cpp +++ b/xllm/core/distributed_runtime/llm_engine.cpp @@ -187,6 +187,7 @@ bool LLMEngine::init_model(MasterStatus master_status) { CHECK(tokenizer_ != nullptr); args_ = model_loader->model_args(); + args_.enable_mla(options_.enable_mla()); quant_args_ = model_loader->quant_args(); tokenizer_args_ = model_loader->tokenizer_args(); diff --git a/xllm/core/framework/kv_cache/CMakeLists.txt b/xllm/core/framework/kv_cache/CMakeLists.txt index 2bd1c1a9c..97f6d5722 100644 --- a/xllm/core/framework/kv_cache/CMakeLists.txt +++ b/xllm/core/framework/kv_cache/CMakeLists.txt @@ -1,7 +1,6 @@ include(cc_library) include(cc_test) - cc_library( NAME kv_cache diff --git a/xllm/core/framework/kv_cache/indexed_kv_cache_impl.cpp b/xllm/core/framework/kv_cache/indexed_kv_cache_impl.cpp index 104d4589a..d735210ec 100644 --- a/xllm/core/framework/kv_cache/indexed_kv_cache_impl.cpp +++ b/xllm/core/framework/kv_cache/indexed_kv_cache_impl.cpp @@ -40,7 +40,9 @@ IndexedKVCacheImpl::IndexedKVCacheImpl( : IndexedKVCacheImpl( create_indexed_kv_cache_tensors(kv_cache_shape, create_options)) { key_cache_shape_ = kv_cache_shape.key_cache_shape(); - value_cache_shape_ = kv_cache_shape.value_cache_shape(); + if (kv_cache_shape.has_value_cache_shape()) { + value_cache_shape_ = kv_cache_shape.value_cache_shape(); + } index_cache_shape_ = kv_cache_shape.index_cache_shape(); } diff --git a/xllm/core/framework/kv_cache/kv_cache_impl.cpp b/xllm/core/framework/kv_cache/kv_cache_impl.cpp index eef895379..a595f9687 100644 --- a/xllm/core/framework/kv_cache/kv_cache_impl.cpp +++ b/xllm/core/framework/kv_cache/kv_cache_impl.cpp @@ -31,7 +31,9 @@ KVCacheImpl::KVCacheImpl(const KVCacheShape& kv_cache_shape, const KVCacheCreateOptions& create_options) : KVCacheImpl(create_kv_cache_tensors(kv_cache_shape, create_options)) { key_cache_shape_ = kv_cache_shape.key_cache_shape(); - value_cache_shape_ = kv_cache_shape.value_cache_shape(); + if (kv_cache_shape.has_value_cache_shape()) { + value_cache_shape_ = kv_cache_shape.value_cache_shape(); + } } torch::Tensor KVCacheImpl::get_k_cache() const { return key_cache_; } diff --git a/xllm/core/framework/kv_cache/kv_cache_shape.cpp b/xllm/core/framework/kv_cache/kv_cache_shape.cpp index dc3fecbd5..2bfa7a3c8 100644 --- a/xllm/core/framework/kv_cache/kv_cache_shape.cpp +++ b/xllm/core/framework/kv_cache/kv_cache_shape.cpp @@ -96,8 +96,9 @@ const std::vector& KVCacheShape::key_cache_shape() const { } const std::vector& KVCacheShape::value_cache_shape() const { - CHECK(value_cache_shape_.has_value()) - << "value_cache_shape is not initialized."; + if (!value_cache_shape_.has_value()) { + return empty_shape(); + } return *value_cache_shape_; } @@ -169,8 +170,10 @@ void KVCacheShape::to_proto(proto::KVCacheShape* proto_shape) const { CHECK(proto_shape != nullptr) << "proto_shape must not be nullptr."; proto_shape->Clear(); add_shape_to_proto(key_cache_shape(), proto_shape->mutable_key_cache_shape()); - add_shape_to_proto(value_cache_shape(), - proto_shape->mutable_value_cache_shape()); + if (has_value_cache_shape()) { + add_shape_to_proto(value_cache_shape(), + proto_shape->mutable_value_cache_shape()); + } if (has_index_cache_shape()) { add_shape_to_proto(index_cache_shape(), proto_shape->mutable_index_cache_shape()); @@ -187,8 +190,10 @@ KVCacheShape KVCacheShape::from_proto(const proto::KVCacheShape& proto_shape) { KVCacheShape kv_cache_shape; kv_cache_shape.key_cache_shape_ = repeated_field_to_vector(proto_shape.key_cache_shape()); - kv_cache_shape.value_cache_shape_ = - repeated_field_to_vector(proto_shape.value_cache_shape()); + if (proto_shape.value_cache_shape_size() > 0) { + kv_cache_shape.value_cache_shape_ = + repeated_field_to_vector(proto_shape.value_cache_shape()); + } if (proto_shape.index_cache_shape_size() > 0) { kv_cache_shape.index_cache_shape_ = repeated_field_to_vector(proto_shape.index_cache_shape()); @@ -323,7 +328,7 @@ void KVCacheShape::apply_device_layout(const ModelArgs& model_args) { CHECK_GE(key_cache_shape_->size(), 4) << "invalid mla key_cache_shape."; (*key_cache_shape_)[3] = model_args.kv_lora_rank() + model_args.qk_rope_head_dim(); - value_cache_shape_ = std::vector{}; + value_cache_shape_.reset(); } #else static_cast(model_args); diff --git a/xllm/core/framework/kv_cache_transfer/CMakeLists.txt b/xllm/core/framework/kv_cache_transfer/CMakeLists.txt index 3075efaaf..69de7f4df 100644 --- a/xllm/core/framework/kv_cache_transfer/CMakeLists.txt +++ b/xllm/core/framework/kv_cache_transfer/CMakeLists.txt @@ -1,6 +1,26 @@ include(cc_library) include(cc_test) +cc_library( + NAME + push_route + HDRS + push_route.h + SRCS + push_route.cpp +) + +cc_library( + NAME + pd_topology_guard + HDRS + pd_topology_guard.h + SRCS + pd_topology_guard.cpp + DEPS + :common + glog::glog +) cc_library( NAME @@ -26,6 +46,7 @@ cc_library( DEPS :common :kv_cache + :push_route :xtensor $<$:graph> glog::glog @@ -38,6 +59,26 @@ cc_library( $<$:platform_npu> ) +cc_test( + NAME + push_route_test + SRCS + push_route_test.cpp + DEPS + :push_route + GTest::gtest_main +) + +cc_test( + NAME + pd_topology_guard_test + SRCS + pd_topology_guard_test.cpp + DEPS + :pd_topology_guard + GTest::gtest_main +) + if(USE_NPU OR USE_MLU) cc_test( NAME @@ -46,6 +87,11 @@ cc_test( mooncake_transfer_engine_test.cpp DEPS :kv_cache_transfer + :xllm_server GTest::gtest_main ) + +# Resolve static link order between xtensor and xllm_server for this test target. +target_link_libraries(mooncake_transfer_engine_test PRIVATE + "$") endif() diff --git a/xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.cpp b/xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.cpp index a71cfc802..e8b48c2c2 100644 --- a/xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.cpp +++ b/xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.cpp @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #if defined(USE_NPU) #ifdef TORCH_HIGHER_THAN_PTA6 @@ -30,12 +31,100 @@ limitations under the License. #endif #include "common/global_flags.h" +#include "framework/kv_cache_transfer/push_route.h" #include "framework/xtensor/global_xtensor.h" #include "framework/xtensor/xtensor_allocator.h" +#include "platform/mlu/mlu_tensor_alloc.h" #include "util/net.h" namespace xllm { +namespace { + +std::string get_merge_key(const uint64_t dst_cluster_id, + const std::string& dst_addr, + const int64_t k_cache_id, + const int64_t v_cache_id) { + return std::to_string(dst_cluster_id) + "_" + dst_addr + "_" + + std::to_string(k_cache_id) + "_" + std::to_string(v_cache_id); +} + +void merge_xtensor_offsets( + std::vector& merged_layer_offsets, + const std::vector& layer_offsets) { + if (layer_offsets.empty()) { + return; + } + if (merged_layer_offsets.empty()) { + merged_layer_offsets = layer_offsets; + return; + } + + for (size_t layer_id = 0; layer_id < layer_offsets.size() && + layer_id < merged_layer_offsets.size(); + ++layer_id) { + std::vector& k_target = merged_layer_offsets[layer_id].k_offsets; + const std::vector& k_source = layer_offsets[layer_id].k_offsets; + k_target.reserve(k_target.size() + k_source.size()); + k_target.insert(k_target.end(), k_source.begin(), k_source.end()); + + std::vector& v_target = merged_layer_offsets[layer_id].v_offsets; + const std::vector& v_source = layer_offsets[layer_id].v_offsets; + v_target.reserve(v_target.size() + v_source.size()); + v_target.insert(v_target.end(), v_source.begin(), v_source.end()); + } +} + +void merge_kv_info( + std::unordered_map& + merged_kv_infos, + const TransferKVInfo& info, + const int32_t dst_rank) { + uint64_t dst_cluster_id = info.remote_instance_info.cluster_ids[dst_rank]; + const std::string& dst_addr = info.remote_instance_info.addrs[dst_rank]; + int64_t k_cache_id = info.remote_instance_info.k_cache_ids[dst_rank]; + int64_t v_cache_id = info.remote_instance_info.v_cache_ids[dst_rank]; + std::string key = + get_merge_key(dst_cluster_id, dst_addr, k_cache_id, v_cache_id); + + auto it = merged_kv_infos.find(key); + if (it == merged_kv_infos.end()) { + KVCacheTransfer::KVCacheInfo kv_info; + kv_info.dst_cluster_id = dst_cluster_id; + kv_info.dst_addr = dst_addr; + kv_info.dst_k_cache_id = k_cache_id; + kv_info.dst_v_cache_id = v_cache_id; + kv_info.src_blocks.reserve(info.local_blocks_ids.size()); + kv_info.src_blocks.insert(kv_info.src_blocks.end(), + info.local_blocks_ids.begin(), + info.local_blocks_ids.end()); + kv_info.dst_blocks.reserve(info.remote_blocks_ids.size()); + kv_info.dst_blocks.insert(kv_info.dst_blocks.end(), + info.remote_blocks_ids.begin(), + info.remote_blocks_ids.end()); + merge_xtensor_offsets(kv_info.dst_xtensor_layer_offsets, + info.dst_xtensor_layer_offsets); + merged_kv_infos.emplace(key, std::move(kv_info)); + return; + } + + std::vector& src_blocks = it->second.src_blocks; + src_blocks.reserve(src_blocks.size() + info.local_blocks_ids.size()); + src_blocks.insert(src_blocks.end(), + info.local_blocks_ids.begin(), + info.local_blocks_ids.end()); + + std::vector& dst_blocks = it->second.dst_blocks; + dst_blocks.reserve(dst_blocks.size() + info.remote_blocks_ids.size()); + dst_blocks.insert(dst_blocks.end(), + info.remote_blocks_ids.begin(), + info.remote_blocks_ids.end()); + merge_xtensor_offsets(it->second.dst_xtensor_layer_offsets, + info.dst_xtensor_layer_offsets); +} + +} // namespace + // ============================================================================ // MooncakeKVCacheTransferBase // ============================================================================ @@ -153,15 +242,17 @@ void MooncakeKVCacheTransferDefault::allocate_kv_cache_impl( const std::vector& value_cache_shape = kv_cache_shape.value_cache_shape(); #if defined(USE_MLU) - torch::TensorOptions options = - torch::TensorOptions().dtype(dtype).device(device_); for (int64_t i = 0; i < num_layers; ++i) { - torch::Tensor key_cache = torch::zeros(key_cache_shape, options); + torch::Tensor key_cache = + mlu::alloc_zero_tensor(key_cache_shape, dtype, device_); torch::Tensor value_cache; torch::Tensor index_cache; - value_cache = torch::zeros(value_cache_shape, options); + if (kv_cache_shape.has_value_cache_shape()) { + value_cache = mlu::alloc_zero_tensor(value_cache_shape, dtype, device_); + } if (kv_cache_shape.has_index_cache_shape()) { - index_cache = torch::zeros(kv_cache_shape.index_cache_shape(), options); + index_cache = mlu::alloc_zero_tensor( + kv_cache_shape.index_cache_shape(), dtype, device_); } if (index_cache.defined()) { kv_caches.emplace_back(IndexedKVCacheTensors{ @@ -293,8 +384,7 @@ void MooncakeKVCacheTransferDefault::register_kv_cache_impl( } if (!mooncake_te_->register_memory(addrs, lens, buf_bytes)) { - LOG(ERROR) << "register_kv_cache_impl failed"; - return; + LOG(FATAL) << "register_kv_cache_impl failed"; } LOG(INFO) << "register_kv_cache_impl success, num_layers=" << num_layers_ @@ -322,6 +412,51 @@ bool MooncakeKVCacheTransferDefault::pull_kv_blocks( return true; } +void MooncakeKVCacheTransferDefault::merge_kv_blocks( + std::unordered_map& merged_kv_infos, + const std::vector& transfer_kv_infos, + const ParallelArgs& parallel_args) { +#if !defined(USE_MLU) + KVCacheTransfer::merge_kv_blocks( + merged_kv_infos, transfer_kv_infos, parallel_args); +#else + if (has_v_cache_) { + KVCacheTransfer::merge_kv_blocks( + merged_kv_infos, transfer_kv_infos, parallel_args); + return; + } + + int32_t src_rank = parallel_args.rank(); + int32_t src_dp_size = parallel_args.dp_size(); + int32_t src_world_size = parallel_args.world_size(); + int32_t src_tp_size = src_world_size / src_dp_size; + int32_t src_tp_rank = src_rank % src_tp_size; + + for (const TransferKVInfo& info : transfer_kv_infos) { + int32_t dst_dp_rank = info.dp_rank; + int32_t dst_dp_size = info.remote_instance_info.dp_size; + int32_t dst_world_size = + static_cast(info.remote_instance_info.cluster_ids.size()); + int32_t dst_tp_size = dst_world_size / dst_dp_size; + + std::unordered_set linked_dp_ranks; + for (int32_t i = src_tp_rank; i < dst_world_size; i += src_tp_size) { + int32_t linked_dp_rank = i / dst_tp_size; + linked_dp_ranks.emplace(linked_dp_rank); + } + if (linked_dp_ranks.find(dst_dp_rank) == linked_dp_ranks.end()) { + continue; + } + + std::vector dst_ranks = + get_dst_ranks(src_tp_rank, src_tp_size, dst_tp_size, dst_dp_rank); + for (int32_t dst_rank : dst_ranks) { + merge_kv_info(merged_kv_infos, info, dst_rank); + } + } +#endif +} + bool MooncakeKVCacheTransferDefault::push_kv_blocks( std::unordered_map& merged_kv_infos, std::shared_ptr& layer_synchronizer, @@ -423,8 +558,7 @@ void MooncakeKVCacheTransferXTensor::register_kv_cache_impl() { // XTensor mode registers one shared GlobalXTensor memory region. auto& global_xtensor = GlobalXTensor::get_instance(); if (!global_xtensor.is_initialized()) { - LOG(ERROR) << "GlobalXTensor not initialized in xtensor mode"; - return; + LOG(FATAL) << "GlobalXTensor not initialized in xtensor mode"; } if (global_xtensor.is_mooncake_registered()) { @@ -437,8 +571,7 @@ void MooncakeKVCacheTransferXTensor::register_kv_cache_impl() { std::vector buf_bytes = {static_cast(size_per_block_)}; if (!mooncake_te_->register_memory(addrs, lens, buf_bytes)) { - LOG(ERROR) << "register GlobalXTensor failed"; - return; + LOG(FATAL) << "register GlobalXTensor failed"; } global_xtensor.set_mooncake_registered(true); diff --git a/xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.h b/xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.h index 711dab0bd..8a528079b 100644 --- a/xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.h +++ b/xllm/core/framework/kv_cache_transfer/mooncake_kv_cache_transfer.h @@ -84,6 +84,11 @@ class MooncakeKVCacheTransferDefault final const std::vector& src_blocks, const std::vector& dst_blocks) override; + void merge_kv_blocks( + std::unordered_map& merged_kv_infos, + const std::vector& transfer_kv_infos, + const ParallelArgs& parallel_args) override; + bool push_kv_blocks( std::unordered_map& merged_kv_infos, std::shared_ptr& layer_synchronizer, diff --git a/xllm/core/framework/kv_cache_transfer/mooncake_transfer_engine_test.cpp b/xllm/core/framework/kv_cache_transfer/mooncake_transfer_engine_test.cpp index 5b55e4851..814468f8c 100644 --- a/xllm/core/framework/kv_cache_transfer/mooncake_transfer_engine_test.cpp +++ b/xllm/core/framework/kv_cache_transfer/mooncake_transfer_engine_test.cpp @@ -18,8 +18,64 @@ limitations under the License. #include #include +#include +#include + +#include "framework/kv_cache_transfer/kv_cache_transfer.h" + +#define private public +#include "framework/kv_cache_transfer/mooncake_kv_cache_transfer.h" +#undef private + namespace xllm { +namespace { + +TransferKVInfo make_info(int32_t dst_dp_size, + int32_t dst_tp_size, + int32_t dst_dp_rank) { + TransferKVInfo info; + info.request_id = "req"; + info.local_blocks_ids = {11, 12}; + info.remote_blocks_ids = {21, 22}; + info.dp_rank = dst_dp_rank; + info.remote_instance_info.dp_size = dst_dp_size; + + int32_t dst_world_size = dst_dp_size * dst_tp_size; + for (int32_t i = 0; i < dst_world_size; ++i) { + info.remote_instance_info.cluster_ids.emplace_back( + static_cast(100 + i)); + info.remote_instance_info.addrs.emplace_back("addr_" + std::to_string(i)); + info.remote_instance_info.k_cache_ids.emplace_back(200 + i); + info.remote_instance_info.v_cache_ids.emplace_back(300 + i); + } + + return info; +} + +ParallelArgs make_args(int32_t rank, int32_t world_size, int32_t dp_size) { + return ParallelArgs(rank, world_size, dp_size, nullptr); +} + +void expect_same_merge( + const std::unordered_map& lhs, + const std::unordered_map& rhs) { + ASSERT_EQ(lhs.size(), rhs.size()); + for (const auto& [key, lhs_info] : lhs) { + auto it = rhs.find(key); + ASSERT_NE(it, rhs.end()); + const KVCacheTransfer::KVCacheInfo& rhs_info = it->second; + EXPECT_EQ(lhs_info.dst_cluster_id, rhs_info.dst_cluster_id); + EXPECT_EQ(lhs_info.dst_addr, rhs_info.dst_addr); + EXPECT_EQ(lhs_info.dst_k_cache_id, rhs_info.dst_k_cache_id); + EXPECT_EQ(lhs_info.dst_v_cache_id, rhs_info.dst_v_cache_id); + EXPECT_EQ(lhs_info.src_blocks, rhs_info.src_blocks); + EXPECT_EQ(lhs_info.dst_blocks, rhs_info.dst_blocks); + } +} + +} // namespace + TEST(MooncakeTransferEngineServiceTest, OpenSessionRejectsMissingAddr) { MooncakeTransferEngineService service; proto::SessionInfo request; @@ -54,4 +110,82 @@ TEST(MooncakeTransferEngineServiceTest, CloseSessionWithoutHandleReturnsTrue) { EXPECT_TRUE(response.ok()); } +#if defined(USE_MLU) +TEST(MooncakeKVCacheTransferDefaultTest, OwnerRankMergesSingleDst) { + MooncakeKVCacheTransferDefault transfer( + 0, 0, torch::Device(torch::kCPU), "test"); + transfer.has_v_cache_ = false; + + const TransferKVInfo info = make_info(1, 3, 0); + const ParallelArgs parallel_args = make_args(2, 8, 1); + std::unordered_map merged_kv_infos; + + transfer.merge_kv_blocks(merged_kv_infos, {info}, parallel_args); + + ASSERT_EQ(merged_kv_infos.size(), 1U); + const KVCacheTransfer::KVCacheInfo& kv_info = merged_kv_infos.begin()->second; + EXPECT_EQ(kv_info.dst_cluster_id, 102U); + EXPECT_EQ(kv_info.dst_addr, "addr_2"); + EXPECT_EQ(kv_info.dst_k_cache_id, 202); + EXPECT_EQ(kv_info.dst_v_cache_id, 302); + EXPECT_EQ(kv_info.src_blocks, info.local_blocks_ids); + EXPECT_EQ(kv_info.dst_blocks, info.remote_blocks_ids); +} + +TEST(MooncakeKVCacheTransferDefaultTest, WrappedOwnerRankKeepsMerge) { + MooncakeKVCacheTransferDefault transfer( + 0, 0, torch::Device(torch::kCPU), "test"); + transfer.has_v_cache_ = false; + + const TransferKVInfo info = make_info(2, 3, 1); + const ParallelArgs parallel_args = make_args(5, 8, 1); + std::unordered_map merged_kv_infos; + + transfer.merge_kv_blocks(merged_kv_infos, {info}, parallel_args); + + ASSERT_EQ(merged_kv_infos.size(), 1U); + const KVCacheTransfer::KVCacheInfo& kv_info = merged_kv_infos.begin()->second; + EXPECT_EQ(kv_info.dst_cluster_id, 105U); + EXPECT_EQ(kv_info.dst_addr, "addr_5"); + EXPECT_EQ(kv_info.dst_k_cache_id, 205); + EXPECT_EQ(kv_info.dst_v_cache_id, 305); + EXPECT_EQ(kv_info.src_blocks, info.local_blocks_ids); + EXPECT_EQ(kv_info.dst_blocks, info.remote_blocks_ids); +} + +TEST(MooncakeKVCacheTransferDefaultTest, HasVCacheUsesBaseMerge) { + MooncakeKVCacheTransferDefault transfer( + 0, 0, torch::Device(torch::kCPU), "test"); + transfer.has_v_cache_ = true; + + const TransferKVInfo info = make_info(2, 3, 1); + const ParallelArgs parallel_args = make_args(5, 8, 1); + std::unordered_map merged_kv_infos; + std::unordered_map base_kv_infos; + + transfer.merge_kv_blocks(merged_kv_infos, {info}, parallel_args); + transfer.KVCacheTransfer::merge_kv_blocks( + base_kv_infos, {info}, parallel_args); + + expect_same_merge(merged_kv_infos, base_kv_infos); +} + +TEST(MooncakeKVCacheTransferDefaultTest, SmallSrcTpUsesBaseMerge) { + MooncakeKVCacheTransferDefault transfer( + 0, 0, torch::Device(torch::kCPU), "test"); + transfer.has_v_cache_ = false; + + const TransferKVInfo info = make_info(1, 4, 0); + const ParallelArgs parallel_args = make_args(1, 2, 1); + std::unordered_map merged_kv_infos; + std::unordered_map base_kv_infos; + + transfer.merge_kv_blocks(merged_kv_infos, {info}, parallel_args); + transfer.KVCacheTransfer::merge_kv_blocks( + base_kv_infos, {info}, parallel_args); + + expect_same_merge(merged_kv_infos, base_kv_infos); +} +#endif + } // namespace xllm diff --git a/xllm/core/framework/kv_cache_transfer/pd_topology_guard.cpp b/xllm/core/framework/kv_cache_transfer/pd_topology_guard.cpp new file mode 100644 index 000000000..e261d39f3 --- /dev/null +++ b/xllm/core/framework/kv_cache_transfer/pd_topology_guard.cpp @@ -0,0 +1,122 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "framework/kv_cache_transfer/pd_topology_guard.h" + +#include + +#include +#include + +namespace xllm { + +namespace { + +bool fail_topo(const std::string& msg, std::string* reason) { + if (reason != nullptr) { + *reason = msg; + } + return false; +} + +PdTopoResult check_hetero_pd_req(bool is_mlu_build, + const std::string& kv_mode, + bool enable_mla) { + if (!is_mlu_build) { + return PdTopoResult{PdTopoStatus::DENY_HETERO, + "hetero pd requires is_mlu_build=true"}; + } + if (kv_mode != "PUSH") { + return PdTopoResult{PdTopoStatus::DENY_HETERO, + "hetero pd requires kv_mode=PUSH"}; + } + // Non-MLA KV cache still shards KV heads by TP. Hetero TP needs separate + // head-dimension split/merge support, so this path is limited to MLA. + if (!enable_mla) { + return PdTopoResult{PdTopoStatus::DENY_HETERO, + "hetero pd requires enable_mla=true"}; + } + + return PdTopoResult{PdTopoStatus::ALLOW_HETERO, ""}; +} + +} // namespace + +bool try_get_pd_topo(const InstanceInfo& info, + PdTopo* topo, + std::string* reason) { + if (topo == nullptr) { + return fail_topo("topo must not be null", reason); + } + if (info.dp_size <= 0) { + return fail_topo("dp_size must be greater than 0", reason); + } + + const size_t cluster_num = info.cluster_ids.size(); + if (cluster_num == static_cast(0)) { + return fail_topo("cluster_ids must not be empty", reason); + } + + const size_t dp_size = static_cast(info.dp_size); + if (cluster_num % dp_size != 0) { + return fail_topo("cluster_ids.size() must be divisible by dp_size", reason); + } + if (cluster_num > static_cast(std::numeric_limits::max())) { + return fail_topo("cluster_ids.size() exceeds int32_t range", reason); + } + + topo->dp_size = info.dp_size; + topo->tp_size = static_cast(cluster_num / dp_size); + if (reason != nullptr) { + reason->clear(); + } + return true; +} + +PdTopo get_pd_topo(const InstanceInfo& info) { + PdTopo topo; + std::string reason; + CHECK(try_get_pd_topo(info, &topo, &reason)) << reason; + return topo; +} + +PdTopoResult check_pd_topo(const InstanceInfo& local, + const InstanceInfo& remote, + bool is_mlu_build, + const std::string& kv_mode, + bool enable_mla) { + PdTopo local_topo; + std::string reason; + if (!try_get_pd_topo(local, &local_topo, &reason)) { + return PdTopoResult{PdTopoStatus::INVALID_LOCAL, + "invalid local pd topo: " + reason}; + } + + PdTopo remote_topo; + if (!try_get_pd_topo(remote, &remote_topo, &reason)) { + return PdTopoResult{PdTopoStatus::INVALID_REMOTE, + "invalid remote pd topo: " + reason}; + } + + const bool same_dp = local_topo.dp_size == remote_topo.dp_size; + const bool same_tp = local_topo.tp_size == remote_topo.tp_size; + if (same_dp && same_tp) { + return PdTopoResult{PdTopoStatus::ALLOW_HOMO, ""}; + } + + return check_hetero_pd_req(is_mlu_build, kv_mode, enable_mla); +} + +} // namespace xllm diff --git a/xllm/core/framework/kv_cache_transfer/pd_topology_guard.h b/xllm/core/framework/kv_cache_transfer/pd_topology_guard.h new file mode 100644 index 000000000..f21055cc4 --- /dev/null +++ b/xllm/core/framework/kv_cache_transfer/pd_topology_guard.h @@ -0,0 +1,55 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include "common/types.h" + +namespace xllm { + +struct PdTopo { + int32_t dp_size = 0; + int32_t tp_size = 0; +}; + +enum class PdTopoStatus : int8_t { + ALLOW_HOMO = 0, + ALLOW_HETERO = 1, + DENY_HETERO = 2, + INVALID_LOCAL = 3, + INVALID_REMOTE = 4, +}; + +struct PdTopoResult { + PdTopoStatus status = PdTopoStatus::DENY_HETERO; + std::string reason = ""; +}; + +bool try_get_pd_topo(const InstanceInfo& info, + PdTopo* topo, + std::string* reason); + +PdTopo get_pd_topo(const InstanceInfo& info); + +PdTopoResult check_pd_topo(const InstanceInfo& local, + const InstanceInfo& remote, + bool is_mlu_build, + const std::string& kv_mode, + bool enable_mla); + +} // namespace xllm diff --git a/xllm/core/framework/kv_cache_transfer/pd_topology_guard_test.cpp b/xllm/core/framework/kv_cache_transfer/pd_topology_guard_test.cpp new file mode 100644 index 000000000..22adbe5f6 --- /dev/null +++ b/xllm/core/framework/kv_cache_transfer/pd_topology_guard_test.cpp @@ -0,0 +1,170 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "framework/kv_cache_transfer/pd_topology_guard.h" + +#include + +#include +#include +#include + +namespace xllm { + +namespace { + +void set_death_style() { GTEST_FLAG_SET(death_test_style, "threadsafe"); } + +InstanceInfo make_info(int32_t dp_size, + const std::vector& cluster_ids) { + InstanceInfo info; + info.dp_size = dp_size; + info.cluster_ids = cluster_ids; + return info; +} + +TEST(PdTopologyGuardTest, HomoTopoBypass) { + const InstanceInfo local_info = make_info(2, {0, 1, 2, 3}); + const InstanceInfo remote_info = make_info(2, {0, 1, 2, 3}); + + const PdTopo topo = get_pd_topo(local_info); + EXPECT_EQ(topo.dp_size, 2); + EXPECT_EQ(topo.tp_size, 2); + + const PdTopoResult result = + check_pd_topo(local_info, remote_info, false, "PULL", false); + EXPECT_EQ(result.status, PdTopoStatus::ALLOW_HOMO); + EXPECT_TRUE(result.reason.empty()); +} + +TEST(PdTopologyGuardTest, TryGetPdTopoReturnTopo) { + const InstanceInfo info = make_info(2, {0, 1, 2, 3}); + + PdTopo topo; + std::string reason; + EXPECT_TRUE(try_get_pd_topo(info, &topo, &reason)); + EXPECT_EQ(topo.dp_size, 2); + EXPECT_EQ(topo.tp_size, 2); + EXPECT_TRUE(reason.empty()); +} + +TEST(PdTopologyGuardTest, HeteroTopoNeedMla) { + const InstanceInfo local_info = make_info(2, {0, 1, 2, 3}); + const InstanceInfo remote_info = make_info(1, {0, 1, 2, 3}); + + const PdTopoResult result = + check_pd_topo(local_info, remote_info, true, "PUSH", false); + EXPECT_EQ(result.status, PdTopoStatus::DENY_HETERO); + EXPECT_EQ(result.reason, "hetero pd requires enable_mla=true"); +} + +TEST(PdTopologyGuardTest, HeteroTopoNeedMluBuild) { + const InstanceInfo local_info = make_info(2, {0, 1, 2, 3}); + const InstanceInfo remote_info = make_info(1, {0, 1, 2, 3}); + + const PdTopoResult result = + check_pd_topo(local_info, remote_info, false, "PUSH", true); + EXPECT_EQ(result.status, PdTopoStatus::DENY_HETERO); + EXPECT_EQ(result.reason, "hetero pd requires is_mlu_build=true"); +} + +TEST(PdTopologyGuardTest, HeteroTopoNeedPushKv) { + const InstanceInfo local_info = make_info(2, {0, 1, 2, 3}); + const InstanceInfo remote_info = make_info(1, {0, 1, 2, 3}); + + const PdTopoResult result = + check_pd_topo(local_info, remote_info, true, "PULL", true); + EXPECT_EQ(result.status, PdTopoStatus::DENY_HETERO); + EXPECT_EQ(result.reason, "hetero pd requires kv_mode=PUSH"); +} + +TEST(PdTopologyGuardTest, HeteroTopoAllowOnMluPushMla) { + const InstanceInfo local_info = make_info(2, {0, 1, 2, 3}); + const InstanceInfo remote_info = make_info(1, {0, 1, 2, 3}); + + const PdTopoResult result = + check_pd_topo(local_info, remote_info, true, "PUSH", true); + EXPECT_EQ(result.status, PdTopoStatus::ALLOW_HETERO); + EXPECT_TRUE(result.reason.empty()); +} + +TEST(PdTopologyGuardTest, CheckPdTopoRejectInvalidLocalTopo) { + const InstanceInfo local_info = make_info(0, {0, 1, 2, 3}); + const InstanceInfo remote_info = make_info(1, {0, 1, 2, 3}); + + const PdTopoResult result = + check_pd_topo(local_info, remote_info, true, "PUSH", true); + EXPECT_EQ(result.status, PdTopoStatus::INVALID_LOCAL); + EXPECT_EQ(result.reason, + "invalid local pd topo: dp_size must be greater than 0"); +} + +TEST(PdTopologyGuardTest, CheckPdTopoRejectInvalidRemoteTopo) { + const InstanceInfo local_info = make_info(1, {0, 1, 2, 3}); + const InstanceInfo remote_info = make_info(2, {0, 1, 2}); + + const PdTopoResult result = + check_pd_topo(local_info, remote_info, true, "PUSH", true); + EXPECT_EQ(result.status, PdTopoStatus::INVALID_REMOTE); + EXPECT_EQ(result.reason, + "invalid remote pd topo: cluster_ids.size() must be divisible by " + "dp_size"); +} + +TEST(PdTopologyGuardTest, TryGetPdTopoRejectBadClusterSplit) { + const InstanceInfo info = make_info(2, {0, 1, 2}); + + PdTopo topo; + std::string reason; + EXPECT_FALSE(try_get_pd_topo(info, &topo, &reason)); + EXPECT_EQ(reason, "cluster_ids.size() must be divisible by dp_size"); +} + +TEST(PdTopologyGuardTest, TryGetPdTopoRejectEmptyClusterIds) { + const InstanceInfo info = make_info(2, {}); + + PdTopo topo; + std::string reason; + EXPECT_FALSE(try_get_pd_topo(info, &topo, &reason)); + EXPECT_EQ(reason, "cluster_ids must not be empty"); +} + +TEST(PdTopologyGuardTest, TryGetPdTopoRejectZeroDpSize) { + const InstanceInfo info = make_info(0, {0, 1, 2, 3}); + + PdTopo topo; + std::string reason; + EXPECT_FALSE(try_get_pd_topo(info, &topo, &reason)); + EXPECT_EQ(reason, "dp_size must be greater than 0"); +} + +TEST(PdTopologyGuardTest, GetPdTopoRejectBadClusterSplit) { + set_death_style(); + const InstanceInfo info = make_info(2, {0, 1, 2}); + + EXPECT_DEATH(get_pd_topo(info), + "cluster_ids.size\\(\\) must be divisible by dp_size"); +} + +TEST(PdTopologyGuardTest, GetPdTopoRejectEmptyClusterIds) { + set_death_style(); + const InstanceInfo info = make_info(2, {}); + + EXPECT_DEATH(get_pd_topo(info), "cluster_ids must not be empty"); +} + +} // namespace + +} // namespace xllm diff --git a/xllm/core/framework/kv_cache_transfer/push_route.cpp b/xllm/core/framework/kv_cache_transfer/push_route.cpp new file mode 100644 index 000000000..2a9b347e1 --- /dev/null +++ b/xllm/core/framework/kv_cache_transfer/push_route.cpp @@ -0,0 +1,54 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "framework/kv_cache_transfer/push_route.h" + +#include + +namespace xllm { + +bool use_push_owner(const int32_t src_tp_size, const int32_t dst_tp_size) { + return src_tp_size > 0 && dst_tp_size > 0 && src_tp_size > dst_tp_size; +} + +std::vector get_dst_ranks(const int32_t src_tp_rank, + const int32_t src_tp_size, + const int32_t dst_tp_size, + const int32_t dst_dp_rank) { + std::vector dst_ranks; + if (src_tp_size <= 0 || dst_tp_size <= 0 || src_tp_rank < 0 || + src_tp_rank >= src_tp_size || dst_dp_rank < 0) { + return dst_ranks; + } + + if (use_push_owner(src_tp_size, dst_tp_size)) { + dst_ranks.reserve(1); + int32_t dst_rank = dst_dp_rank * dst_tp_size + src_tp_rank % dst_tp_size; + dst_ranks.emplace_back(dst_rank); + return dst_ranks; + } + + int32_t start_rank = src_tp_rank % dst_tp_size + dst_tp_size * dst_dp_rank; + int32_t end_rank = dst_tp_size * (dst_dp_rank + 1); + const size_t dst_rank_num = + static_cast((end_rank - 1 - start_rank) / src_tp_size + 1); + dst_ranks.reserve(dst_rank_num); + for (int32_t i = start_rank; i < end_rank; i += src_tp_size) { + dst_ranks.emplace_back(i); + } + return dst_ranks; +} + +} // namespace xllm diff --git a/xllm/core/framework/kv_cache_transfer/push_route.h b/xllm/core/framework/kv_cache_transfer/push_route.h new file mode 100644 index 000000000..cb31b2c19 --- /dev/null +++ b/xllm/core/framework/kv_cache_transfer/push_route.h @@ -0,0 +1,30 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +namespace xllm { + +bool use_push_owner(int32_t src_tp_size, int32_t dst_tp_size); + +std::vector get_dst_ranks(int32_t src_tp_rank, + int32_t src_tp_size, + int32_t dst_tp_size, + int32_t dst_dp_rank); + +} // namespace xllm diff --git a/xllm/core/framework/kv_cache_transfer/push_route_test.cpp b/xllm/core/framework/kv_cache_transfer/push_route_test.cpp new file mode 100644 index 000000000..65511a5f3 --- /dev/null +++ b/xllm/core/framework/kv_cache_transfer/push_route_test.cpp @@ -0,0 +1,74 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "framework/kv_cache_transfer/push_route.h" + +#include + +#include + +namespace xllm { + +TEST(PushRouteTest, SrcTpLessThanDstTpKeepsMulticast) { + EXPECT_FALSE(use_push_owner(2, 8)); + + const std::vector dst_ranks = get_dst_ranks(1, 2, 8, 3); + const std::vector expect_ranks = {25, 27, 29, 31}; + EXPECT_EQ(dst_ranks, expect_ranks); +} + +TEST(PushRouteTest, InvalidTpSizeNotUseOwnerAndReturnEmpty) { + EXPECT_FALSE(use_push_owner(0, 4)); + EXPECT_FALSE(use_push_owner(4, 0)); + EXPECT_FALSE(use_push_owner(-1, 2)); + EXPECT_FALSE(use_push_owner(2, -1)); + + const std::vector dst_ranks = get_dst_ranks(0, 0, 4, 0); + EXPECT_TRUE(dst_ranks.empty()); +} + +TEST(PushRouteTest, SrcTpEqualsDstTpNotUseOwner) { + EXPECT_FALSE(use_push_owner(4, 4)); + + const std::vector dst_ranks = get_dst_ranks(2, 4, 4, 1); + const std::vector expect_ranks = {6}; + EXPECT_EQ(dst_ranks, expect_ranks); +} + +TEST(PushRouteTest, SrcTpGreaterThanDstTpUsesOwnerRouting) { + EXPECT_TRUE(use_push_owner(8, 3)); + + const std::vector owner_ranks = get_dst_ranks(2, 8, 3, 2); + const std::vector expect_owner_ranks = {8}; + EXPECT_EQ(owner_ranks, expect_owner_ranks); + + const std::vector wrapped_owner_ranks = get_dst_ranks(5, 8, 3, 2); + const std::vector expect_wrapped_owner_ranks = {8}; + EXPECT_EQ(wrapped_owner_ranks, expect_wrapped_owner_ranks); +} + +TEST(PushRouteTest, HeteroTpTwoToOneKeepsOddDpRoute) { + const std::vector odd_dp_ranks = get_dst_ranks(1, 2, 1, 3); + const std::vector expect_odd_dp_ranks = {3}; + EXPECT_EQ(odd_dp_ranks, expect_odd_dp_ranks); +} + +TEST(PushRouteTest, DstDpRankOffsetApplied) { + const std::vector dst_ranks = get_dst_ranks(1, 6, 4, 3); + const std::vector expect_ranks = {13}; + EXPECT_EQ(dst_ranks, expect_ranks); +} + +} // namespace xllm diff --git a/xllm/core/platform/CMakeLists.txt b/xllm/core/platform/CMakeLists.txt index 7b7465516..ca3665c05 100644 --- a/xllm/core/platform/CMakeLists.txt +++ b/xllm/core/platform/CMakeLists.txt @@ -15,6 +15,7 @@ cc_library( shared_vmm_allocator.h vmm_torch_allocator.h $<$:mlu/mlu_layer_synchronizer.h> + $<$:mlu/mlu_tensor_alloc.h> $<$:cuda/cuda_utils.h> $<$:numa_utils.h> SRCS @@ -23,6 +24,7 @@ cc_library( vmm_api.cpp shared_vmm_allocator.cpp $<$:mlu/mlu_layer_synchronizer.cpp> + $<$:mlu/mlu_tensor_alloc.cpp> $<$:numa_utils.cpp> DEPS torch diff --git a/xllm/core/platform/mlu/mlu_tensor_alloc.cpp b/xllm/core/platform/mlu/mlu_tensor_alloc.cpp new file mode 100644 index 000000000..bfc0952a4 --- /dev/null +++ b/xllm/core/platform/mlu/mlu_tensor_alloc.cpp @@ -0,0 +1,91 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "platform/mlu/mlu_tensor_alloc.h" + +#include +#include + +#include + +namespace xllm::mlu { + +namespace { + +size_t get_nbytes(const std::vector& dims, + const torch::ScalarType dtype) { + size_t count = 1; + for (int64_t dim : dims) { + CHECK_GE(dim, 0) << "tensor dim must be non-negative"; + const size_t dim_size = static_cast(dim); + if (dim_size > static_cast(0)) { + CHECK_LE(count, std::numeric_limits::max() / dim_size) + << "tensor element count overflow"; + } + count *= dim_size; + } + const size_t elem_size = static_cast(torch::elementSize(dtype)); + CHECK_GT(elem_size, static_cast(0)) << "tensor dtype size is zero"; + CHECK_LE(count, std::numeric_limits::max() / elem_size) + << "tensor byte size overflow"; + return count * elem_size; +} + +void free_tensor(void* ptr, int32_t device_id) { + if (ptr == nullptr) { + return; + } + + cnrtRet_t ret = cnrtSetDevice(device_id); + CHECK(ret == cnrtSuccess) + << "cnrtSetDevice failed, ret=" << static_cast(ret) + << ", device_id=" << device_id; + ret = cnrtFree(ptr); + CHECK(ret == cnrtSuccess) + << "cnrtFree failed, ret=" << static_cast(ret) + << ", ptr=" << ptr; +} + +} // namespace + +torch::Tensor alloc_zero_tensor(const std::vector& dims, + torch::ScalarType dtype, + const torch::Device& device) { + CHECK(device.has_index()) << "MLU device index is required"; + int32_t device_id = static_cast(device.index()); + + cnrtRet_t ret = cnrtSetDevice(device_id); + CHECK(ret == cnrtSuccess) + << "cnrtSetDevice failed, ret=" << static_cast(ret) + << ", device_id=" << device_id; + + size_t nbytes = get_nbytes(dims, dtype); + void* ptr = nullptr; + ret = cnrtMalloc(&ptr, nbytes); + CHECK(ret == cnrtSuccess) + << "cnrtMalloc failed, ret=" << static_cast(ret) + << ", nbytes=" << nbytes; + ret = cnrtMemset(ptr, 0, nbytes); + CHECK(ret == cnrtSuccess) + << "cnrtMemset failed, ret=" << static_cast(ret) + << ", nbytes=" << nbytes; + + auto deleter = [device_id](void* data) { free_tensor(data, device_id); }; + auto options = + torch::TensorOptions().dtype(dtype).device(device).requires_grad(false); + return torch::from_blob(ptr, dims, deleter, options); +} + +} // namespace xllm::mlu diff --git a/xllm/core/platform/mlu/mlu_tensor_alloc.h b/xllm/core/platform/mlu/mlu_tensor_alloc.h new file mode 100644 index 000000000..2b4f41d34 --- /dev/null +++ b/xllm/core/platform/mlu/mlu_tensor_alloc.h @@ -0,0 +1,29 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include +#include + +namespace xllm::mlu { + +torch::Tensor alloc_zero_tensor(const std::vector& dims, + torch::ScalarType dtype, + const torch::Device& device); + +} // namespace xllm::mlu diff --git a/xllm/core/scheduler/CMakeLists.txt b/xllm/core/scheduler/CMakeLists.txt index 23e58a15c..47aee1c9c 100644 --- a/xllm/core/scheduler/CMakeLists.txt +++ b/xllm/core/scheduler/CMakeLists.txt @@ -40,6 +40,7 @@ cc_library( :request :runtime :profile + :pd_topology_guard glog::glog Folly::folly absl::time diff --git a/xllm/core/scheduler/disagg_pd_scheduler.cpp b/xllm/core/scheduler/disagg_pd_scheduler.cpp index c35a466d8..641dd556e 100644 --- a/xllm/core/scheduler/disagg_pd_scheduler.cpp +++ b/xllm/core/scheduler/disagg_pd_scheduler.cpp @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include @@ -28,6 +29,7 @@ limitations under the License. #include "disagg_pd_scheduler.h" #include "distributed_runtime/engine.h" #include "framework/batch/batch_factory.h" +#include "framework/kv_cache_transfer/pd_topology_guard.h" #include "framework/request/request.h" #include "framework/request/request_state.h" #include "framework/request/sequence.h" @@ -326,6 +328,50 @@ void DisaggPDScheduler::dispatch_requests() { std::vector> requests; requests.emplace_back(request); std::string selected_instance = request->state().decode_address; + + const InstanceInfo remote_info = + xservice_client_->get_instance_info(selected_instance); + if (remote_info.name.empty()) { + response_processor_->process_failed_request( + request, + {StatusCode::UNKNOWN, "failed to fetch remote decode instance info"}); + continue; + } + remote_instances_info_[selected_instance] = remote_info; + + bool is_mlu_build = false; +#if defined(USE_MLU) + is_mlu_build = true; +#endif + const bool enable_mla = engine_->model_args().enable_mla(); + const PdTopoResult topo_result = + check_pd_topo(instance_info_, + remote_info, + is_mlu_build, + options_.kv_cache_transfer_mode(), + enable_mla); + const bool allow_pd_topo = topo_result.status == PdTopoStatus::ALLOW_HOMO || + topo_result.status == PdTopoStatus::ALLOW_HETERO; + if (!allow_pd_topo) { + if (topo_result.status == PdTopoStatus::INVALID_REMOTE) { + remote_instances_info_.erase(selected_instance); + } + response_processor_->process_failed_request( + request, + {StatusCode::INVALID_ARGUMENT, + "decode instance " + selected_instance + + " is incompatible: " + topo_result.reason}); + continue; + } + if (topo_result.status == PdTopoStatus::ALLOW_HETERO && VLOG_IS_ON(1)) { + const PdTopo local_topo = get_pd_topo(instance_info_); + const PdTopo remote_topo = get_pd_topo(remote_info); + VLOG(1) << "Allow hetero pd topo guard: local dp/tp=" + << local_topo.dp_size << "/" << local_topo.tp_size + << ", remote dp/tp=" << remote_topo.dp_size << "/" + << remote_topo.tp_size; + } + proto::DisaggPDService_Stub* stub = create_rpc_channel(selected_instance); if (stub == nullptr) { response_processor_->process_failed_request( diff --git a/xllm/models/llm/deepseek_v32.h b/xllm/models/llm/deepseek_v32.h index 07f25dc98..f2afcb534 100644 --- a/xllm/models/llm/deepseek_v32.h +++ b/xllm/models/llm/deepseek_v32.h @@ -108,6 +108,11 @@ class DeepseekV32ModelImpl : public DeepseekV2ModelImpl { attn_metadata, kv_caches[i], modified_input_params); + if (!modified_input_params.record_layer(static_cast(i), + hidden_states.device())) { + active_sequence_parallel_context_ = nullptr; + return ModelOutput(); + } } hidden_states = layer::v32_sp::gather_and_restore_global(hidden_states, sp_ctx.value());