diff --git a/.gitignore b/.gitignore index 767db187..b728e6ea 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,5 @@ __pycache__/ *.txt *.http + +*.nsys-rep diff --git a/csrc/engine/compiler/general_compiler.cpp b/csrc/engine/compiler/general_compiler.cpp new file mode 100644 index 00000000..84ee670d --- /dev/null +++ b/csrc/engine/compiler/general_compiler.cpp @@ -0,0 +1,26 @@ +#include "general_compiler.hpp" + +namespace infinilm::engine { +GeneralCompiler::GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier) : GraphCompiler(model, barrier) { + static_batching_compiler_ = std::make_unique(model_, barrier); + paged_compiler_ = std::make_unique(model_, barrier); +} + +void GeneralCompiler::compile() { + static_batching_compiler_->compile(); + paged_compiler_->compile(); +} + +GeneralCompiler::Compiled GeneralCompiler::get_compiled(const InfinilmModel::Input &input) { + GeneralCompiler::Compiled result = {nullptr, nullptr}; + + // try each compiler, return the first valid result + result = static_batching_compiler_.get()->get_compiled(input); + if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) { + return result; + } + result = paged_compiler_.get()->get_compiled(input); + return result; +} + +} // namespace infinilm::engine diff --git a/csrc/engine/compiler/general_compiler.hpp b/csrc/engine/compiler/general_compiler.hpp new file mode 100644 index 00000000..e8b84b5d --- /dev/null +++ b/csrc/engine/compiler/general_compiler.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include "paged_compiler.hpp" +#include "static_batching_compiler.hpp" + +namespace infinilm::engine { +class GeneralCompiler : public GraphCompiler { +public: + GeneralCompiler(const std::shared_ptr &model, RankBarrier *barrier); + + void compile() override; + + Compiled get_compiled(const InfinilmModel::Input &input) override; + +private: + std::unique_ptr static_batching_compiler_; + std::unique_ptr paged_compiler_; +}; +} // namespace infinilm::engine diff --git a/csrc/engine/compiler/graph_compiler.hpp b/csrc/engine/compiler/graph_compiler.hpp new file mode 100644 index 00000000..5173994f --- /dev/null +++ b/csrc/engine/compiler/graph_compiler.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include "../../models/infinilm_model.hpp" +#include "../rank_barrier.hpp" + +namespace infinilm::engine { + +class GraphCompiler { +public: + using Compiled = std::tuple< + std::shared_ptr, + std::shared_ptr>; + + explicit GraphCompiler(const std::shared_ptr &model, RankBarrier *barrier) : model_(model), barrier_(barrier) {} + virtual ~GraphCompiler() = default; + + virtual void compile() = 0; + virtual Compiled get_compiled(const InfinilmModel::Input &input) = 0; + +protected: + std::shared_ptr model_; + RankBarrier *barrier_; +}; + +} // namespace infinilm::engine diff --git a/csrc/engine/compiler/paged_compiler.cpp b/csrc/engine/compiler/paged_compiler.cpp new file mode 100644 index 00000000..c32811ce --- /dev/null +++ b/csrc/engine/compiler/paged_compiler.cpp @@ -0,0 +1,93 @@ +#include "paged_compiler.hpp" + +namespace infinilm::engine { +PagedCompiler::PagedCompiler(const std::shared_ptr &model, RankBarrier *barrier) + : GraphCompiler(model, barrier) { + for (size_t b = 1; b < 32; b++) { + decode_batch_sizes_.push_back(b); + } + for (size_t b = 32; b < 64; b += 8) { + decode_batch_sizes_.push_back(b); + } + for (size_t b = 64; b < 128; b += 16) { + decode_batch_sizes_.push_back(b); + } + for (size_t b = 128; b < 256; b += 32) { + decode_batch_sizes_.push_back(b); + } + for (size_t b = 256; b <= 512; b += 64) { + decode_batch_sizes_.push_back(b); + } +} + +void PagedCompiler::compile() { + if (model_->get_cache_config() != nullptr && dynamic_cast(model_->get_cache_config())) { + size_t nblocks = dynamic_cast(model_->get_cache_config())->num_blocks(); + size_t max_batch_size = *std::max_element(decode_batch_sizes_.begin(), decode_batch_sizes_.end()); + compiled_map_decode_.clear(); + block_tables_holder_ = infinicore::Tensor::empty( + {nblocks}, infinicore::DataType::I64, infinicore::context::getDevice()); + for (size_t b : decode_batch_sizes_) { + size_t block_per_req = nblocks / b; + InfinilmModel::Input input; + input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice()); + input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); + input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); + std::vector total_sequence_lengths_vec(b, 1); + infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false); + input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I64, infinicore::context::getDevice()); + std::vector input_offsets_vec(b + 1, 0); + for (size_t i = 0; i <= b; i++) { + input_offsets_vec[i] = i; + } + infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int64_t), false); + input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1}); + input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); + + barrier_->wait(); + infinicore::context::startGraphRecording(); + auto output = model_->forward(input); + auto graph = infinicore::context::stopGraphRecording(); + barrier_->wait(); + + auto shared_output = std::shared_ptr( + new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)}); + + compiled_map_decode_[b] = CompiledResult{std::move(input), std::make_tuple(graph, shared_output)}; + } + } +} + +PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input &input) { + if (model_->get_cache_config() != nullptr && dynamic_cast(model_->get_cache_config())) { + size_t batch_size = input.block_tables.value()->size(0); + size_t block_per_req = input.block_tables.value()->size(1); + + // only support decode only batch + if (batch_size != input.input_ids.value()->size(1)) { + return {nullptr, nullptr}; + } else { + auto result = compiled_map_decode_.find(batch_size); + if (result == compiled_map_decode_.end()) { + return {nullptr, nullptr}; + } + auto &graph_input = result->second.input; + + graph_input.input_ids.value()->copy_from(input.input_ids.value()); + graph_input.position_ids.value()->copy_from(input.position_ids.value()); + graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value()); + graph_input.input_offsets.value()->copy_from(input.input_offsets.value()); + graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value()); + graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value()); + + auto graph = std::get<0>(result->second.compiled); + auto shared_output = std::shared_ptr(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()}); + + return std::make_tuple(graph, shared_output); + } + } else { + return {nullptr, nullptr}; + } +} + +} // namespace infinilm::engine diff --git a/csrc/engine/compiler/paged_compiler.hpp b/csrc/engine/compiler/paged_compiler.hpp new file mode 100644 index 00000000..a1125864 --- /dev/null +++ b/csrc/engine/compiler/paged_compiler.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include "graph_compiler.hpp" + +#include + +namespace infinilm::engine { +class PagedCompiler : public GraphCompiler { +public: + PagedCompiler(const std::shared_ptr &model, RankBarrier *barrier); + + void compile() override; + + Compiled get_compiled(const InfinilmModel::Input &input) override; + +private: + std::vector decode_batch_sizes_; + + infinicore::Tensor block_tables_holder_; + + struct CompiledResult { + InfinilmModel::Input input; + Compiled compiled; + }; + + std::unordered_map< + size_t, // num_requests + CompiledResult> + compiled_map_decode_; +}; +} // namespace infinilm::engine diff --git a/csrc/engine/compiler/static_batching_compiler.cpp b/csrc/engine/compiler/static_batching_compiler.cpp new file mode 100644 index 00000000..34873038 --- /dev/null +++ b/csrc/engine/compiler/static_batching_compiler.cpp @@ -0,0 +1,56 @@ +#include "static_batching_compiler.hpp" + +#include "../../cache/cache.hpp" + +namespace infinilm::engine { +StaticBatchingCompiler::StaticBatchingCompiler(const std::shared_ptr &model, RankBarrier *barrier) + : GraphCompiler(model, barrier) { +} + +void StaticBatchingCompiler::compile() { + if (model_->get_cache_config() != nullptr && dynamic_cast(model_->get_cache_config())) { + size_t b = dynamic_cast(model_->get_cache_config())->max_batch_size(); + InfinilmModel::Input input; + input.input_ids = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I64, infinicore::context::getDevice()); + input.position_ids = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I64, infinicore::context::getDevice()); + input.past_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); + input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); + std::vector total_sequence_lengths_vec(b, 1); + infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false); + + barrier_->wait(); + infinicore::context::startGraphRecording(); + auto output = model_->forward(input); + auto graph = infinicore::context::stopGraphRecording(); + barrier_->wait(); + + auto shared_output = std::shared_ptr(new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)}); + + compiled_map_[std::make_tuple(b, 1)] = CompiledResult{std::move(input), std::make_tuple(graph, shared_output)}; + } +} + +StaticBatchingCompiler::Compiled StaticBatchingCompiler::get_compiled( + const InfinilmModel::Input &input) { + if (model_->get_cache_config() != nullptr && dynamic_cast(model_->get_cache_config())) { + size_t batch_size = input.input_ids.value()->size(0); + size_t seqlen = input.input_ids.value()->size(1); + auto result = compiled_map_.find(std::make_tuple(batch_size, seqlen)); + if (result == compiled_map_.end()) { + return std::make_tuple(nullptr, nullptr); + } else { + auto &graph_input = result->second.input; + graph_input.input_ids.value()->copy_from(input.input_ids.value()); + graph_input.position_ids.value()->copy_from(input.position_ids.value()); + graph_input.past_sequence_lengths.value()->copy_from(input.past_sequence_lengths.value()); + graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value()); + + auto graph = std::get<0>(result->second.compiled); + auto shared_output = std::shared_ptr(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()}); + return std::make_tuple(graph, shared_output); + } + } else { + return std::make_tuple(nullptr, nullptr); + } +} +} // namespace infinilm::engine diff --git a/csrc/engine/compiler/static_batching_compiler.hpp b/csrc/engine/compiler/static_batching_compiler.hpp new file mode 100644 index 00000000..fe1180fc --- /dev/null +++ b/csrc/engine/compiler/static_batching_compiler.hpp @@ -0,0 +1,36 @@ +#pragma once + +#include "graph_compiler.hpp" + +#include + +namespace infinilm::engine { +class StaticBatchingCompiler : public GraphCompiler { +public: + StaticBatchingCompiler(const std::shared_ptr &model, RankBarrier *barrier); + + void compile() override; + + Compiled get_compiled(const InfinilmModel::Input &input) override; + +private: + struct TupleHash { + size_t operator()(const std::tuple &t) const noexcept { + auto h1 = std::hash{}(std::get<0>(t)); + auto h2 = std::hash{}(std::get<1>(t)); + return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2)); + } + }; + + struct CompiledResult { + InfinilmModel::Input input; + Compiled compiled; + }; + + std::unordered_map< + std::tuple, // (batch_size, seq_len) + CompiledResult, + TupleHash> + compiled_map_; +}; +} // namespace infinilm::engine diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 482117c0..f49a9108 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -10,7 +10,8 @@ InferEngine::InferEngine( const InfinilmModel::Config &config, const distributed::DistConfig &distributed_config, infinicore::Device::Type device_type, - const cache::CacheConfig *cache_config) // Changed parameter + const cache::CacheConfig *cache_config, + bool enable_graph_compiling) // Changed parameter : communication_group_(distributed_config, device_type), model_config_(config) { @@ -19,13 +20,19 @@ InferEngine::InferEngine( } // Create one RankWorker per rank int world_size = communication_group_.get_world_size(); + barrier_ = std::make_unique((size_t)world_size); workers_.reserve(world_size); for (int r = 0; r < world_size; ++r) { workers_.emplace_back(std::make_unique( model_config_, communication_group_.get_rank_info(r), - cache_config_ != nullptr ? cache_config_.get() : nullptr)); + cache_config_ != nullptr ? cache_config_.get() : nullptr, + barrier_.get(), + enable_graph_compiling)); } + + // Compile the model on all workers + this->compile(); } //------------------------------------------------------ @@ -65,9 +72,9 @@ InferEngine::Input::to_model_input(infinicore::Device device) const { }; return { - input_ids, // @todo: on device in the future + to_device(input_ids), // @todo: on device in the future to_device(position_ids), - past_sequence_lengths, // @todo: on device in the future + to_device(past_sequence_lengths), // @todo: on device in the future to_device(total_sequence_lengths), to_device(input_offsets), to_device(block_tables), @@ -88,6 +95,16 @@ InferEngine::Output InferEngine::forward(const InferEngine::Input &input) { return workers_[0]->get_output(); } +void InferEngine::compile() { + for (auto &worker : workers_) { + worker->compile(); + } + // Wait for all workers + for (auto &worker : workers_) { + worker->wait(); + } +} + //------------------------------------------------------ // Destructor //------------------------------------------------------ @@ -112,6 +129,8 @@ void InferEngine::reset_cache(const cache::CacheConfig *new_config) { for (auto &worker : workers_) { worker->wait(); } + + this->compile(); } } // namespace infinilm::engine diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index 315e1c7c..ce834c6a 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -4,6 +4,7 @@ #include "../models/llama/llama_config.hpp" #include "distributed/distributed.hpp" #include "infinicore/tensor.hpp" +#include "rank_barrier.hpp" #include "rank_worker.hpp" #include @@ -22,7 +23,8 @@ class InferEngine { const InfinilmModel::Config &config, const distributed::DistConfig &distributed_config = distributed::DistConfig(), infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), - const cache::CacheConfig *cache_config = nullptr); + const cache::CacheConfig *cache_config = nullptr, + bool enable_graph_compiling = false); // Load a parameter to all workers (each can extract its shard inside RankWorker) void load_param(const std::string &name, const infinicore::Tensor ¶m); @@ -33,6 +35,8 @@ class InferEngine { // Run a single forward pass on all workers and return the outputs from all ranks Output forward(const Input &input); + void compile(); + void reset_cache(const cache::CacheConfig *new_config); ~InferEngine(); @@ -44,6 +48,7 @@ class InferEngine { protected: std::vector> workers_; + std::unique_ptr barrier_; distributed::CommunicationGroup communication_group_; const InfinilmModel::Config &model_config_; std::unique_ptr cache_config_; diff --git a/csrc/engine/rank_barrier.cpp b/csrc/engine/rank_barrier.cpp new file mode 100644 index 00000000..5e852ac6 --- /dev/null +++ b/csrc/engine/rank_barrier.cpp @@ -0,0 +1,19 @@ +#include "rank_barrier.hpp" + +namespace infinilm::engine { +RankBarrier::RankBarrier(size_t num_ranks) : thread_count_(num_ranks), generation_(0), arrived_(0) {} + +void RankBarrier::wait() { + std::unique_lock lock(mutex_); + int gen = generation_; + + if (++arrived_ == thread_count_) { + // last thread + generation_++; + arrived_ = 0; + cv_.notify_all(); + } else { + cv_.wait(lock, [&] { return gen != generation_; }); + } +} +} // namespace infinilm::engine diff --git a/csrc/engine/rank_barrier.hpp b/csrc/engine/rank_barrier.hpp new file mode 100644 index 00000000..dd068e99 --- /dev/null +++ b/csrc/engine/rank_barrier.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace infinilm::engine { +class RankBarrier { +public: + explicit RankBarrier(size_t nranks); + + void wait(); + +private: + const size_t thread_count_; + size_t arrived_; + size_t generation_; + std::mutex mutex_; + std::condition_variable cv_; +}; +} // namespace infinilm::engine diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 003fb265..3b7c2e9f 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -12,14 +12,18 @@ namespace infinilm::engine { RankWorker::RankWorker(const InfinilmModel::Config &model_config, const distributed::RankInfo &rank_info, - const cache::CacheConfig *cache_config) + const cache::CacheConfig *cache_config, + RankBarrier *barrier, + bool enable_graph_compiling) : model_config_(model_config), rank_info_(rank_info), + enable_graph_compiling_(enable_graph_compiling), job_cmd_(Command::INIT), has_job_(false), job_done_(false), should_exit_(false), - init_done_(false) { + init_done_(false), + barrier_(barrier) { if (cache_config != nullptr) { pending_cache_config_ = cache_config->unique_copy(); } @@ -112,6 +116,21 @@ void RankWorker::run(const Input &args) { cv_.notify_all(); } +//------------------------------------------------------ +// compile -- asynchronous +//------------------------------------------------------ +void RankWorker::compile() { + std::lock_guard lock(mutex_); + if (should_exit_) { + throw std::runtime_error("RankWorker is closing; cannot run"); + } + + job_cmd_ = Command::COMPILE; + has_job_ = true; + job_done_ = false; + cv_.notify_all(); +} + //------------------------------------------------------ // wait -- asynchronous //------------------------------------------------------ @@ -179,6 +198,10 @@ void RankWorker::thread_loop() { if (!model_) { throw std::runtime_error("Failed to create model"); } + if (enable_graph_compiling_) { + compiler_ = std::make_unique(model_, barrier_); + } + init_done_ = true; } cv_.notify_all(); @@ -244,9 +267,21 @@ void RankWorker::thread_loop() { { std::lock_guard lk(mutex_); - auto model_args = local_args.to_model_input(rank_info_.device); - // Forward calculation - auto logits{model_->forward(model_args).logits}; + infinicore::Tensor logits; + // Try to get compiled graph + if (compiler_ != nullptr) { + auto [graph, output] = compiler_->get_compiled(local_args.to_model_input(infinicore::Device::cpu())); + if (graph != nullptr && output != nullptr) { + graph->run(); + logits = output->logits; + } + } + // Fall back to eager mode + if (!logits) { + auto model_args = local_args.to_model_input(rank_info_.device); + logits = model_->forward(model_args).logits; + } + // Random sampling (rank 0 only) if (rank_info_.tp_rank == 0) { auto temperature{local_args.temperature}; @@ -295,7 +330,6 @@ void RankWorker::thread_loop() { } else if (local_cmd == Command::RESET_CACHE) { try { model_->reset_cache(local_cache_config != nullptr ? local_cache_config.get() : nullptr); - { std::lock_guard lk(mutex_); job_done_ = true; @@ -310,6 +344,26 @@ void RankWorker::thread_loop() { spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what()); break; } + } else if (local_cmd == Command::COMPILE) { + try { + if (compiler_ != nullptr) { + compiler_->compile(); + } + { + std::lock_guard lk(mutex_); + job_done_ = true; + } + cv_.notify_all(); + + } catch (const std::exception &e) { + std::lock_guard lk(mutex_); + should_exit_ = true; + job_done_ = true; + cv_.notify_all(); + spdlog::error("[{}] exception during compile: {}\n", info(), e.what()); + break; + } + } else { // Shouldn't reach here (no-op) } diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index 98bb4b87..51304a6a 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -2,7 +2,9 @@ #include "../cache/cache.hpp" #include "../models/model_factory.hpp" +#include "compiler/general_compiler.hpp" #include "distributed/distributed.hpp" +#include "rank_barrier.hpp" #include #include @@ -19,6 +21,7 @@ class RankWorker { LOAD, RUN, RESET_CACHE, + COMPILE, STOP }; @@ -56,7 +59,9 @@ class RankWorker { RankWorker(const InfinilmModel::Config &model_config, const distributed::RankInfo &rank_info, - const cache::CacheConfig *cache_config); + const cache::CacheConfig *cache_config, + RankBarrier *barrier, + bool enable_graph_compiling); // Submit a parameter load job and wait until the load completes on the worker thread. void load_param(const std::string &name, @@ -71,6 +76,9 @@ class RankWorker { // Reset the internal cache with a new configuration void reset_cache(const cache::CacheConfig *new_config); + // Compile the model graph if enabled. + void compile(); + // Wait until run job completes. The result can be retrieved with get_output(). void wait(); @@ -92,6 +100,10 @@ class RankWorker { std::shared_ptr model_; std::shared_ptr cache_; + // Graph Compiling + bool enable_graph_compiling_; + std::unique_ptr compiler_; + // Command for the pending job (protected by mutex_) Command job_cmd_; @@ -114,6 +126,7 @@ class RankWorker { std::thread thread_; std::mutex mutex_; std::condition_variable cv_; + RankBarrier *barrier_; }; } // namespace infinilm::engine diff --git a/csrc/models/infinilm_model.hpp b/csrc/models/infinilm_model.hpp index 4cad3b6c..3537bc75 100644 --- a/csrc/models/infinilm_model.hpp +++ b/csrc/models/infinilm_model.hpp @@ -43,5 +43,6 @@ class InfinilmModel : public infinicore::nn::Module { virtual Output forward(const Input &input) const = 0; virtual void reset_cache(const cache::CacheConfig *cache_config) = 0; + virtual const cache::CacheConfig *get_cache_config() const = 0; }; } // namespace infinilm diff --git a/csrc/models/llama/llama_for_causal_lm.cpp b/csrc/models/llama/llama_for_causal_lm.cpp index 6ce1fd98..c7f8728e 100644 --- a/csrc/models/llama/llama_for_causal_lm.cpp +++ b/csrc/models/llama/llama_for_causal_lm.cpp @@ -45,7 +45,12 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const { } void LlamaForCausalLM::reset_cache(const cache::CacheConfig *cache_config) { - model_->reset_cache(cache_config); + cache_config_ = cache_config->unique_copy(); + model_->reset_cache(cache_config_.get()); +} + +const cache::CacheConfig *LlamaForCausalLM::get_cache_config() const { + return cache_config_.get(); } } // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_for_causal_lm.hpp b/csrc/models/llama/llama_for_causal_lm.hpp index dd6f90fa..4b7275cd 100644 --- a/csrc/models/llama/llama_for_causal_lm.hpp +++ b/csrc/models/llama/llama_for_causal_lm.hpp @@ -42,6 +42,8 @@ class LlamaForCausalLM : public InfinilmModel { void reset_cache(const cache::CacheConfig *cache_config) override; + const cache::CacheConfig *get_cache_config() const override; + // Module information const LlamaConfig &config() const { return model_->config(); } LlamaModel &model() { return *model_; } @@ -53,6 +55,8 @@ class LlamaForCausalLM : public InfinilmModel { // Language modeling head INFINICORE_NN_MODULE(infinicore::nn::Linear, lm_head); + + std::unique_ptr cache_config_; }; } // namespace infinilm::models::llama diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 5ac38d70..f5dae4a7 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -35,17 +35,20 @@ inline void bind_infer_engine(py::module &m) { const InfinilmModel::Config &cfg, const distributed::DistConfig &dist, infinicore::Device::Type dev, - std::shared_ptr cache_cfg) { + std::shared_ptr cache_cfg, + bool enable_graph_compiling) { return std::make_shared( cfg, dist, dev, - cache_cfg ? cache_cfg.get() : nullptr); + cache_cfg ? cache_cfg.get() : nullptr, + enable_graph_compiling); }), py::arg("config"), py::arg("distributed_config") = distributed::DistConfig(), py::arg("device_type") = infinicore::context::getDevice().getType(), - py::arg("cache_config") = py::none()) + py::arg("cache_config") = py::none(), + py::arg("enable_graph_compiling") = false) .def("load_param", &InferEngine::load_param, py::arg("name"), py::arg("param"), "Load a parameter tensor into all workers (each worker picks its shard)") diff --git a/examples/bench.py b/examples/bench.py index f5d9ddbb..46cdd08a 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -3,7 +3,7 @@ from infinilm.modeling_utils import load_model_state_dict_by_file from infinilm.distributed import DistConfig from infinilm.infer_engine import GenerationConfig, InferEngine -from infinilm.cache import StaticKVCacheConfig +from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig import argparse import sys import time @@ -179,6 +179,16 @@ def get_args(): action="store_true", help="skip loading model weights", ) + parser.add_argument( + "--enable-paged-attn", + action="store_true", + help="use paged cache", + ) + parser.add_argument( + "--enable-graph", + action="store_true", + help="enable graph compiling", + ) return parser.parse_args() @@ -202,6 +212,8 @@ def __init__( infini_device=infinicore.device("cpu", 0), tp=1, skip_load=False, + cache_config=None, + enable_graph=False, ) -> None: model_path = os.path.expanduser(model_path) # ---------------------------------------------------------------------------- # @@ -211,6 +223,8 @@ def __init__( model_path, device=infini_device, distributed_config=DistConfig(tp), + cache_config=cache_config, + enable_graph_compiling=enable_graph, ) # ---------------------------------------------------------------------------- # @@ -306,6 +320,8 @@ def run( batch_size = args.batch_size input_len = args.input_len output_len = args.output_len + enable_paged_attn = args.enable_paged_attn + enable_graph = args.enable_graph if isinstance(batch_size, int): batch_size = [batch_size] @@ -320,13 +336,25 @@ def run( # -------------------------------------------------------- # # 测试 # -------------------------------------------------------- # - # print("=================== start test ====================", type(batch_size)) + if enable_paged_attn: + paged_kv_block_size = 16 + max_num_blocks = max( + [ + ((c_["input_len"] + c_["output_len"] + 15) // 16) * c_["batch_size"] + for _, c_ in cases_dict.items() + ] + ) + cache_config = PagedKVCacheConfig(max_num_blocks, paged_kv_block_size) + else: + cache_config = None test = TestModel( model_path, infini_device=infini_device, tp=tp, skip_load=skip_load, + cache_config=cache_config, + enable_graph=enable_graph, ) for idx, case in tqdm(cases_dict.items(), desc="Processing cases"): @@ -336,13 +364,14 @@ def run( input_len = case["input_len"] output_len = case["output_len"] - # reset cache for each case - initial_capacity = input_len + output_len - test.model.reset_cache( - StaticKVCacheConfig( - max_batch_size=batch_size, max_cache_len=initial_capacity + if not enable_paged_attn: + # reset cache if static kvcache is used + initial_capacity = input_len + output_len + test.model.reset_cache( + StaticKVCacheConfig( + max_batch_size=batch_size, max_cache_len=initial_capacity + ) ) - ) # run test one case test.run( diff --git a/examples/jiuge.py b/examples/jiuge.py index c1ad567e..c66ab83c 100644 --- a/examples/jiuge.py +++ b/examples/jiuge.py @@ -88,6 +88,11 @@ def get_args(): action="store_true", help="use paged cache", ) + parser.add_argument( + "--enable-graph", + action="store_true", + help="enable graph compiling", + ) return parser.parse_args() @@ -99,6 +104,7 @@ def test( infini_device=infinicore.device("cpu", 0), tp=1, enable_paged_attn=False, + enable_graph=False, ): model_path = os.path.expanduser(model_path) # ---------------------------------------------------------------------------- # @@ -108,6 +114,7 @@ def test( model_path, device=infini_device, distributed_config=DistConfig(tp), + enable_graph_compiling=enable_graph, ) # ---------------------------------------------------------------------------- # @@ -164,7 +171,7 @@ def test( batch_size = 1 if prompts is str else len(prompts) max_total_tokens = max_new_tokens + len(input_ids_list[0]) cache_config = PagedKVCacheConfig( - num_blocks=(max_total_tokens // 16 + 1) * batch_size, block_size=16 + num_blocks=((max_total_tokens + 15) // 16) * batch_size, block_size=16 ) else: batch_size = 1 if prompts is str else len(prompts) @@ -231,6 +238,7 @@ def test( backend = args.backend tp = args.tp enable_paged_attn = args.enable_paged_attn + enable_graph = args.enable_graph if backend != "cpp": raise ValueError(f"Unsupported backend: {backend}.") @@ -243,4 +251,5 @@ def test( infini_device=infini_device, tp=tp, enable_paged_attn=enable_paged_attn, + enable_graph=enable_graph, ) diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 1a3e9255..510255b1 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -28,6 +28,7 @@ def __init__( device=None, distributed_config=DistConfig(1), cache_config=None, + enable_graph_compiling=False, ): self.config = AutoConfig.from_pretrained(model_path) @@ -39,6 +40,7 @@ def __init__( distributed_config._underlying, device._underlying.type, cache_config, + enable_graph_compiling, ) self.use_cache = False