Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions xllm/core/framework/block/block_manager_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ bool BlockManagerPool::allocate(Sequence* sequence, size_t num_tokens) {
const size_t block_size = options_.block_size();
const size_t num_blocks_needed = (num_tokens + block_size - 1) / block_size;
if (num_blocks_needed <= num_blocks) {
process_beam_search(sequence, /*need_swap*/ true);
return true;
return process_beam_search(sequence, /*need_swap*/ true);
}
process_beam_search(sequence);

Expand Down Expand Up @@ -263,27 +262,31 @@ std::vector<Block> BlockManagerPool::allocate(size_t num_tokens,
return block_managers_[dp_rank]->allocate(num_blocks_needed);
}

void BlockManagerPool::process_beam_search(Sequence* sequence, bool need_swap) {
bool BlockManagerPool::process_beam_search(Sequence* sequence, bool need_swap) {
if (!sequence->check_beam_search()) {
return;
return true;
}

auto src_blocks = sequence->kv_state().src_blocks();
if (src_blocks.size() == 0) {
return;
return true;
}

// when sequence need to swap the last block and no new block appended,
// allocate a new block for this sequence
if (need_swap && sequence->kv_state().need_swap()) {
int32_t dp_rank = get_dp_rank(sequence);
auto new_blocks = block_managers_[dp_rank]->allocate(1);
if (new_blocks.size() == 0) {
return false;
}
swap_block_transfer_infos_[dp_rank].emplace_back(src_blocks.back().id(),
new_blocks[0].id());
sequence->kv_state().process_beam_search(new_blocks);
sequence->kv_state().process_beam_search(new_blocks[0]);
} else {
sequence->kv_state().process_beam_search({});
sequence->kv_state().process_beam_search(std::nullopt);
}
return true;
}

uint32_t BlockManagerPool::pre_allocate(Sequence* sequence) {
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/framework/block/block_manager_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class BlockManagerPool final : public KVCacheManager {
void allocate_host_shared(Sequence* sequence);
void save_offload_blocks(Sequence* sequence);

void process_beam_search(Sequence* sequence, bool need_swap = false);
bool process_beam_search(Sequence* sequence, bool need_swap = false);

private:
std::vector<std::unique_ptr<BlockManager>> block_managers_;
Expand Down
7 changes: 3 additions & 4 deletions xllm/core/framework/request/sequence_kv_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,13 @@ void KVCacheState::reset() {
transfer_kv_info_.reset();
}

void KVCacheState::process_beam_search(const std::vector<Block>& new_blocks) {
void KVCacheState::process_beam_search(std::optional<Block> new_block) {
blocks_.clear();
blocks_ = std::move(src_blocks_);

if (!new_blocks.empty()) {
CHECK_EQ(new_blocks.size(), 1);
if (new_block.has_value()) {
blocks_.pop_back();
blocks_.insert(blocks_.end(), new_blocks.begin(), new_blocks.end());
blocks_.emplace_back(new_block.value());
}
}

Expand Down
2 changes: 1 addition & 1 deletion xllm/core/framework/request/sequence_kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class KVCacheState {

void reset();

void process_beam_search(const std::vector<Block>& new_blocks);
void process_beam_search(std::optional<Block> new_block = std::nullopt);

private:
// number of tokens in kv cache
Expand Down
Loading