Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ __pycache__/
*.txt

*.http

*.nsys-rep
26 changes: 26 additions & 0 deletions csrc/engine/compiler/general_compiler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "general_compiler.hpp"

namespace infinilm::engine {
GeneralCompiler::GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier) : GraphCompiler(model, barrier) {
static_batching_compiler_ = std::make_unique<StaticBatchingCompiler>(model_, barrier);
paged_compiler_ = std::make_unique<PagedCompiler>(model_, barrier);
}

void GeneralCompiler::compile() {
static_batching_compiler_->compile();
paged_compiler_->compile();
}

GeneralCompiler::Compiled GeneralCompiler::get_compiled(const InfinilmModel::Input &input) {
GeneralCompiler::Compiled result = {nullptr, nullptr};

// try each compiler, return the first valid result
result = static_batching_compiler_.get()->get_compiled(input);
if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) {
return result;
}
result = paged_compiler_.get()->get_compiled(input);
return result;
}

} // namespace infinilm::engine
19 changes: 19 additions & 0 deletions csrc/engine/compiler/general_compiler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include "paged_compiler.hpp"
#include "static_batching_compiler.hpp"

namespace infinilm::engine {
class GeneralCompiler : public GraphCompiler {
public:
GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);

void compile() override;

Compiled get_compiled(const InfinilmModel::Input &input) override;

private:
std::unique_ptr<StaticBatchingCompiler> static_batching_compiler_;
std::unique_ptr<PagedCompiler> paged_compiler_;
};
} // namespace infinilm::engine
25 changes: 25 additions & 0 deletions csrc/engine/compiler/graph_compiler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include "../../models/infinilm_model.hpp"
#include "../rank_barrier.hpp"

namespace infinilm::engine {

class GraphCompiler {
public:
using Compiled = std::tuple<
std::shared_ptr<infinicore::graph::Graph>,
std::shared_ptr<InfinilmModel::Output>>;

explicit GraphCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier) : model_(model), barrier_(barrier) {}
virtual ~GraphCompiler() = default;

virtual void compile() = 0;
virtual Compiled get_compiled(const InfinilmModel::Input &input) = 0;

protected:
std::shared_ptr<InfinilmModel> model_;
RankBarrier *barrier_;
};

} // namespace infinilm::engine
93 changes: 93 additions & 0 deletions csrc/engine/compiler/paged_compiler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#include "paged_compiler.hpp"

namespace infinilm::engine {
PagedCompiler::PagedCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
: GraphCompiler(model, barrier) {
for (size_t b = 1; b < 32; b++) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 32; b < 64; b += 8) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 64; b < 128; b += 16) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 128; b < 256; b += 32) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 256; b <= 512; b += 64) {
decode_batch_sizes_.push_back(b);
}
}

void PagedCompiler::compile() {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
size_t nblocks = dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())->num_blocks();
size_t max_batch_size = *std::max_element(decode_batch_sizes_.begin(), decode_batch_sizes_.end());
compiled_map_decode_.clear();
block_tables_holder_ = infinicore::Tensor::empty(
{nblocks}, infinicore::DataType::I64, infinicore::context::getDevice());
for (size_t b : decode_batch_sizes_) {
size_t block_per_req = nblocks / b;
InfinilmModel::Input input;
input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
std::vector<int64_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false);
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I64, infinicore::context::getDevice());
std::vector<int64_t> input_offsets_vec(b + 1, 0);
for (size_t i = 0; i <= b; i++) {
input_offsets_vec[i] = i;
}
infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int64_t), false);
input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1});
input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());

barrier_->wait();
infinicore::context::startGraphRecording();
auto output = model_->forward(input);
auto graph = infinicore::context::stopGraphRecording();
barrier_->wait();

auto shared_output = std::shared_ptr<InfinilmModel::Output>(
new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)});

compiled_map_decode_[b] = CompiledResult{std::move(input), std::make_tuple(graph, shared_output)};
}
}
}

PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input &input) {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
size_t batch_size = input.block_tables.value()->size(0);
size_t block_per_req = input.block_tables.value()->size(1);

// only support decode only batch
if (batch_size != input.input_ids.value()->size(1)) {
return {nullptr, nullptr};
} else {
auto result = compiled_map_decode_.find(batch_size);
if (result == compiled_map_decode_.end()) {
return {nullptr, nullptr};
}
auto &graph_input = result->second.input;

graph_input.input_ids.value()->copy_from(input.input_ids.value());
graph_input.position_ids.value()->copy_from(input.position_ids.value());
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
graph_input.input_offsets.value()->copy_from(input.input_offsets.value());
graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value());
graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value());

auto graph = std::get<0>(result->second.compiled);
auto shared_output = std::shared_ptr<InfinilmModel::Output>(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()});

return std::make_tuple(graph, shared_output);
}
} else {
return {nullptr, nullptr};
}
}

} // namespace infinilm::engine
31 changes: 31 additions & 0 deletions csrc/engine/compiler/paged_compiler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once

#include "graph_compiler.hpp"

#include <unordered_map>

namespace infinilm::engine {
class PagedCompiler : public GraphCompiler {
public:
PagedCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);

void compile() override;

Compiled get_compiled(const InfinilmModel::Input &input) override;

private:
std::vector<size_t> decode_batch_sizes_;

infinicore::Tensor block_tables_holder_;

struct CompiledResult {
InfinilmModel::Input input;
Compiled compiled;
};

std::unordered_map<
size_t, // num_requests
CompiledResult>
compiled_map_decode_;
};
} // namespace infinilm::engine
56 changes: 56 additions & 0 deletions csrc/engine/compiler/static_batching_compiler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#include "static_batching_compiler.hpp"

#include "../../cache/cache.hpp"

namespace infinilm::engine {
StaticBatchingCompiler::StaticBatchingCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
: GraphCompiler(model, barrier) {
}

void StaticBatchingCompiler::compile() {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::StaticKVCacheConfig *>(model_->get_cache_config())) {
size_t b = dynamic_cast<const cache::StaticKVCacheConfig *>(model_->get_cache_config())->max_batch_size();
InfinilmModel::Input input;
input.input_ids = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I64, infinicore::context::getDevice());
input.position_ids = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I64, infinicore::context::getDevice());
input.past_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
std::vector<int64_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false);

barrier_->wait();
infinicore::context::startGraphRecording();
auto output = model_->forward(input);
auto graph = infinicore::context::stopGraphRecording();
barrier_->wait();

auto shared_output = std::shared_ptr<InfinilmModel::Output>(new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)});

compiled_map_[std::make_tuple(b, 1)] = CompiledResult{std::move(input), std::make_tuple(graph, shared_output)};
}
}

StaticBatchingCompiler::Compiled StaticBatchingCompiler::get_compiled(
const InfinilmModel::Input &input) {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::StaticKVCacheConfig *>(model_->get_cache_config())) {
size_t batch_size = input.input_ids.value()->size(0);
size_t seqlen = input.input_ids.value()->size(1);
auto result = compiled_map_.find(std::make_tuple(batch_size, seqlen));
if (result == compiled_map_.end()) {
return std::make_tuple(nullptr, nullptr);
} else {
auto &graph_input = result->second.input;
graph_input.input_ids.value()->copy_from(input.input_ids.value());
graph_input.position_ids.value()->copy_from(input.position_ids.value());
graph_input.past_sequence_lengths.value()->copy_from(input.past_sequence_lengths.value());
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());

auto graph = std::get<0>(result->second.compiled);
auto shared_output = std::shared_ptr<InfinilmModel::Output>(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()});
return std::make_tuple(graph, shared_output);
}
} else {
return std::make_tuple(nullptr, nullptr);
}
}
} // namespace infinilm::engine
36 changes: 36 additions & 0 deletions csrc/engine/compiler/static_batching_compiler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#pragma once

#include "graph_compiler.hpp"

#include <unordered_map>

namespace infinilm::engine {
class StaticBatchingCompiler : public GraphCompiler {
public:
StaticBatchingCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);

void compile() override;

Compiled get_compiled(const InfinilmModel::Input &input) override;

private:
struct TupleHash {
size_t operator()(const std::tuple<size_t, size_t> &t) const noexcept {
auto h1 = std::hash<size_t>{}(std::get<0>(t));
auto h2 = std::hash<size_t>{}(std::get<1>(t));
return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2));
}
};

struct CompiledResult {
InfinilmModel::Input input;
Compiled compiled;
};

std::unordered_map<
std::tuple<size_t, size_t>, // (batch_size, seq_len)
CompiledResult,
TupleHash>
compiled_map_;
};
} // namespace infinilm::engine
27 changes: 23 additions & 4 deletions csrc/engine/infer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ InferEngine::InferEngine(
const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config) // Changed parameter
const cache::CacheConfig *cache_config,
bool enable_graph_compiling) // Changed parameter
: communication_group_(distributed_config, device_type),
model_config_(config) {

Expand All @@ -19,13 +20,19 @@ InferEngine::InferEngine(
}
// Create one RankWorker per rank
int world_size = communication_group_.get_world_size();
barrier_ = std::make_unique<RankBarrier>((size_t)world_size);
workers_.reserve(world_size);
for (int r = 0; r < world_size; ++r) {
workers_.emplace_back(std::make_unique<RankWorker>(
model_config_,
communication_group_.get_rank_info(r),
cache_config_ != nullptr ? cache_config_.get() : nullptr));
cache_config_ != nullptr ? cache_config_.get() : nullptr,
barrier_.get(),
enable_graph_compiling));
}

// Compile the model on all workers
this->compile();
}

//------------------------------------------------------
Expand Down Expand Up @@ -65,9 +72,9 @@ InferEngine::Input::to_model_input(infinicore::Device device) const {
};

return {
input_ids, // @todo: on device in the future
to_device(input_ids), // @todo: on device in the future
to_device(position_ids),
past_sequence_lengths, // @todo: on device in the future
to_device(past_sequence_lengths), // @todo: on device in the future
to_device(total_sequence_lengths),
to_device(input_offsets),
to_device(block_tables),
Expand All @@ -88,6 +95,16 @@ InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
return workers_[0]->get_output();
}

void InferEngine::compile() {
for (auto &worker : workers_) {
worker->compile();
}
// Wait for all workers
for (auto &worker : workers_) {
worker->wait();
}
}

//------------------------------------------------------
// Destructor
//------------------------------------------------------
Expand All @@ -112,6 +129,8 @@ void InferEngine::reset_cache(const cache::CacheConfig *new_config) {
for (auto &worker : workers_) {
worker->wait();
}

this->compile();
}

} // namespace infinilm::engine
Loading