@@ -43,16 +43,22 @@ bool KVCacheStore::init(const StoreConfig& config,
4343 }
4444 client_ptr_ = client_opt.value ();
4545
46- auto k_tensor_one_block = host_kv_caches_->at (0 ).get_k_cache ();
47- auto v_tensor_one_block = host_kv_caches_->at (0 ).get_v_cache ();
48-
49- k_cache_size_per_block_ =
50- k_tensor_one_block.numel () * k_tensor_one_block.element_size ();
51- v_cache_size_per_block_ =
52- v_tensor_one_block.numel () * v_tensor_one_block.element_size ();
46+ auto k_cache = host_kv_caches_->at (0 ).get_k_cache ();
47+ k_cache_size_per_block_ = k_cache.numel () * k_cache.element_size ();
48+ LOG (INFO) << " key cache size per block: " << k_cache_size_per_block_;
49+
50+ auto v_cache = host_kv_caches_->at (0 ).get_v_cache ();
51+ if (v_cache.defined () && v_cache.numel () != 0 ) {
52+ v_cache_size_per_block_ = v_cache.numel () * v_cache.element_size ();
53+ LOG (INFO) << " value cache size per block: " << v_cache_size_per_block_;
54+ }
5355
54- LOG (INFO) << " k_cache_size_per_block: " << k_cache_size_per_block_;
55- LOG (INFO) << " v_cache_size_per_block: " << v_cache_size_per_block_;
56+ auto index_cache = host_kv_caches_->at (0 ).get_index_cache ();
57+ if (index_cache.defined () && index_cache.numel () != 0 ) {
58+ index_cache_size_per_block_ =
59+ index_cache.numel () * index_cache.element_size ();
60+ LOG (INFO) << " index cache size per block: " << index_cache_size_per_block_;
61+ }
5662
5763 if (config_.protocol == " rdma" ) {
5864 if (config_.total_size > 0 && config_.tensor_data != nullptr ) {
@@ -103,14 +109,28 @@ uint32_t KVCacheStore::batch_put(
103109
104110 str_keys.emplace_back (str_key);
105111
112+ std::vector<mooncake::Slice> slice;
113+ slice.reserve (3 );
114+
106115 void * k_cache =
107116 host_kv_caches_->at (block_info.dst_block_id ).get_k_cache ().data_ptr ();
108- void * v_cache =
109- host_kv_caches_->at (block_info.dst_block_id ).get_k_cache ().data_ptr ();
117+ slice.emplace_back (mooncake::Slice{k_cache, k_cache_size_per_block_});
118+
119+ if (v_cache_size_per_block_ != 0 ) {
120+ void * v_cache =
121+ host_kv_caches_->at (block_info.dst_block_id ).get_v_cache ().data_ptr ();
122+ slice.emplace_back (mooncake::Slice{v_cache, v_cache_size_per_block_});
123+ }
110124
111- slices.emplace_back (std::vector<mooncake::Slice>{
112- mooncake::Slice{k_cache, k_cache_size_per_block_},
113- mooncake::Slice{v_cache, v_cache_size_per_block_}});
125+ if (index_cache_size_per_block_ != 0 ) {
126+ void * index_cache = host_kv_caches_->at (block_info.dst_block_id )
127+ .get_index_cache ()
128+ .data_ptr ();
129+ slice.emplace_back (
130+ mooncake::Slice{index_cache, index_cache_size_per_block_});
131+ }
132+
133+ slices.emplace_back (std::move (slice));
114134 }
115135
116136 if (str_keys.size () == 0 ) {
@@ -150,16 +170,28 @@ uint32_t KVCacheStore::batch_get(
150170
151171 str_keys.emplace_back (str_key);
152172
173+ std::vector<mooncake::Slice> slice;
174+ slice.reserve (3 );
175+
153176 void * k_cache =
154177 host_kv_caches_->at (block_info.dst_block_id ).get_k_cache ().data_ptr ();
155- void * v_cache =
156- host_kv_caches_->at (block_info.dst_block_id ).get_k_cache ().data_ptr ();
178+ slice.emplace_back (mooncake::Slice{k_cache, k_cache_size_per_block_});
157179
158- slices.insert (
159- std::make_pair (str_key,
160- std::vector<mooncake::Slice>{
161- mooncake::Slice{k_cache, k_cache_size_per_block_},
162- mooncake::Slice{v_cache, v_cache_size_per_block_}}));
180+ if (v_cache_size_per_block_ != 0 ) {
181+ void * v_cache =
182+ host_kv_caches_->at (block_info.dst_block_id ).get_v_cache ().data_ptr ();
183+ slice.emplace_back (mooncake::Slice{v_cache, v_cache_size_per_block_});
184+ }
185+
186+ if (index_cache_size_per_block_ != 0 ) {
187+ void * index_cache = host_kv_caches_->at (block_info.dst_block_id )
188+ .get_index_cache ()
189+ .data_ptr ();
190+ slice.emplace_back (
191+ mooncake::Slice{index_cache, index_cache_size_per_block_});
192+ }
193+
194+ slices.insert (std::make_pair (str_key, std::move (slice)));
163195 }
164196
165197 if (str_keys.size () == 0 ) {
@@ -177,24 +209,6 @@ uint32_t KVCacheStore::batch_get(
177209 return success_cnt;
178210}
179211
180- uint32_t KVCacheStore::batch_remove (
181- Slice<BlockTransferInfo>& block_transfer_info) {
182- CHECK (is_initialized_) << " KVCacheStore is not initialized." ;
183- uint32_t success_cnt = 0 ;
184- for (auto block_info : block_transfer_info) {
185- std::string str_key (reinterpret_cast <const char *>(block_info.hash_key ),
186- MURMUR_HASH3_VALUE_LEN);
187- str_key.append (std::to_string (config_.tp_rank ));
188-
189- auto result = client_ptr_->Remove (str_key);
190-
191- if (result.has_value ()) {
192- success_cnt++;
193- }
194- }
195- return success_cnt;
196- }
197-
198212uint32_t KVCacheStore::batch_exist (std::vector<std::string>&& keys) {
199213 if (!is_initialized_) {
200214 return 0 ;
0 commit comments