Skip to content

Commit 3e5a5a9

Browse files
authored
refactor: refactor the allocation of kvcache. (#1293)
1 parent ec0c913 commit 3e5a5a9

35 files changed

Lines changed: 1114 additions & 416 deletions

xllm/core/distributed_runtime/spawn_worker_server/CMakeLists.txt

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,7 @@ cc_binary(
99
spawn_worker_server.cpp
1010
spawn_worker_server_process.cpp
1111
DEPS
12-
:models
13-
:model
1412
:distributed_runtime
15-
absl::strings
16-
$<$<BOOL:${USE_NPU}>:xllm_atb_layers>
17-
$<$<BOOL:${USE_NPU}>:ascendcl>
18-
$<$<BOOL:${USE_NPU}>:nnopbase>
19-
$<$<BOOL:${USE_NPU}>:atb>
20-
$<$<BOOL:${USE_NPU}>:atb_customize>
21-
$<$<BOOL:${USE_NPU}>:c_sec>
22-
spdlog::spdlog
2313
)
2414

2515
add_dependencies(export_module spawn_worker)

xllm/core/framework/batch/batch_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,12 +312,12 @@ TEST(BatchTest, KVCacheEmptySupportsLinearOnlyAndFullOnlyLayouts) {
312312
auto conv_cache = torch::zeros({2, 4, 3}, options);
313313
auto ssm_cache = torch::zeros({2, 1, 4, 4}, options);
314314
KVCache linear_only_cache(
315-
torch::Tensor(), torch::Tensor(), conv_cache, ssm_cache);
315+
LinearAttentionKVCacheTensors{conv_cache, ssm_cache});
316316
EXPECT_FALSE(linear_only_cache.empty());
317317

318318
auto key_cache = torch::zeros({2, 4, 1, 8}, options);
319319
auto value_cache = torch::zeros({2, 4, 1, 8}, options);
320-
KVCache full_only_cache(key_cache, value_cache);
320+
KVCache full_only_cache(KVCacheTensors{key_cache, value_cache});
321321
EXPECT_FALSE(full_only_cache.empty());
322322

323323
KVCache empty_cache;

xllm/core/framework/eplb/eplb_policy_test.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,15 @@ limitations under the License.
1919
#include <gtest/gtest.h>
2020
#include <torch/torch.h>
2121

22+
#include "platform/device.h"
23+
2224
namespace xllm {
2325

2426
TEST(EplbPolicyTest, Build) {
27+
// use init device to trigger the loading of torch backend for different
28+
// devices
29+
// since the allocation of pinnned memory on cpu is still backend-dependent.
30+
torch::Device device(Device::type_torch(), 0);
2531
std::string rank_table_file;
2632
EplbPolicy eplb_policy(5, 4, 1);
2733
std::vector<torch::Tensor> tensors;

xllm/core/framework/kv_cache/CMakeLists.txt

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
include(cc_binary)
21
include(cc_library)
32
include(cc_test)
43

@@ -8,13 +7,24 @@ cc_library(
87
kv_cache
98
HDRS
109
embedding_cache.h
10+
indexed_kv_cache_impl.h
1111
kv_cache.h
1212
kv_cache_event.h
13+
kv_cache_impl.h
14+
kv_cache_utils.h
15+
linear_attention_kv_cache_impl.h
16+
quantized_kv_cache_impl.h
1317
SRCS
1418
embedding_cache.cpp
19+
indexed_kv_cache_impl.cpp
1520
kv_cache.cpp
21+
kv_cache_impl.cpp
22+
kv_cache_utils.cpp
23+
linear_attention_kv_cache_impl.cpp
24+
quantized_kv_cache_impl.cpp
1625
DEPS
1726
:common
27+
:xtensor
1828
glog::glog
1929
torch
2030
$<$<BOOL:${USE_NPU}>:torch_npu>
@@ -26,9 +36,7 @@ cc_test(
2636
SRCS
2737
embedding_cache_test.cpp
2838
DEPS
29-
:xllm_server
3039
:kv_cache
31-
$<$<BOOL:${USE_NPU}>:xllm_server>
3240
GTest::gtest_main
3341
)
3442
target_link_libraries(embedding_cache_test

xllm/core/framework/kv_cache/embedding_cache_test.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ limitations under the License.
1717

1818
#include <gtest/gtest.h>
1919

20+
#include "platform/device.h"
21+
2022
namespace xllm {
2123

2224
namespace {
@@ -28,6 +30,10 @@ bool tensor_equal(const torch::Tensor& lhs, const torch::Tensor& rhs) {
2830
} // namespace
2931

3032
TEST(EmbeddingCacheTest, WriteAndClear) {
33+
// use init device to trigger the loading of torch backend for different
34+
// devices
35+
// since the allocation of pinnned memory on cpu is still backend-dependent.
36+
torch::Device device(Device::type_torch(), 0);
3137
EmbeddingCache cache(/*total_nums=*/4);
3238

3339
std::vector<int32_t> ids = {3, 2};
@@ -57,6 +63,10 @@ TEST(EmbeddingCacheTest, WriteAndClear) {
5763
}
5864

5965
TEST(EmbeddingCacheTest, WriteSelectedOnlyProbs) {
66+
// use init device to trigger the loading of torch backend for different
67+
// devices
68+
// since the allocation of pinnned memory on cpu is still backend-dependent.
69+
torch::Device device(Device::type_torch(), 0);
6070
EmbeddingCache cache(/*total_nums=*/2);
6171
std::vector<int32_t> ids = {0, 1};
6272
auto cached_tokens = torch::tensor({11, 12}, torch::kInt);
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/* Copyright 2026 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "framework/kv_cache/indexed_kv_cache_impl.h"
17+
18+
#include "util/tensor_helper.h"
19+
20+
namespace xllm {
21+
22+
IndexedKVCacheImpl::IndexedKVCacheImpl(const IndexedKVCacheTensors& tensors)
23+
: KVCacheImpl(tensors.kv_cache_tensors),
24+
index_cache_(tensors.index_cache) {}
25+
26+
IndexedKVCacheImpl::IndexedKVCacheImpl(
27+
const std::vector<std::vector<int64_t>>& kv_cache_shape,
28+
const KVCacheCreateOptions& create_options)
29+
: IndexedKVCacheImpl(
30+
create_indexed_kv_cache_tensors(kv_cache_shape, create_options)) {}
31+
32+
torch::Tensor IndexedKVCacheImpl::get_index_cache() const {
33+
return index_cache_;
34+
}
35+
36+
bool IndexedKVCacheImpl::empty() const {
37+
return !key_cache_.defined() || !value_cache_.defined() ||
38+
!index_cache_.defined();
39+
}
40+
41+
std::vector<std::vector<int64_t>> IndexedKVCacheImpl::get_shapes() const {
42+
std::vector<std::vector<int64_t>> tensor_shapes(3);
43+
tensor_shapes[0] = get_tensor_shape(key_cache_);
44+
tensor_shapes[1] = get_tensor_shape(value_cache_);
45+
tensor_shapes[2] = get_tensor_shape(index_cache_);
46+
return tensor_shapes;
47+
}
48+
49+
} // namespace xllm
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/* Copyright 2026 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include "framework/kv_cache/kv_cache_impl.h"
19+
20+
namespace xllm {
21+
22+
class IndexedKVCacheImpl final : public KVCacheImpl {
23+
public:
24+
explicit IndexedKVCacheImpl(const IndexedKVCacheTensors& tensors);
25+
IndexedKVCacheImpl(const std::vector<std::vector<int64_t>>& kv_cache_shape,
26+
const KVCacheCreateOptions& create_options);
27+
28+
torch::Tensor get_index_cache() const override;
29+
30+
bool empty() const override;
31+
32+
std::vector<std::vector<int64_t>> get_shapes() const override;
33+
34+
void swap_blocks(torch::Tensor& src_tensor,
35+
torch::Tensor& dst_tensor) override {
36+
NOT_IMPLEMENTED();
37+
};
38+
39+
private:
40+
torch::Tensor index_cache_;
41+
};
42+
43+
} // namespace xllm

0 commit comments

Comments
 (0)