diff --git a/.gitmodules b/.gitmodules index eab6041a..ade5ff58 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "third_party/spdlog"] path = third_party/spdlog url = https://github.com/gabime/spdlog.git +[submodule "third_party/json"] + path = third_party/json + url = https://github.com/nlohmann/json.git diff --git a/csrc/config/global_config.cpp b/csrc/config/global_config.cpp new file mode 100644 index 00000000..2f3ce308 --- /dev/null +++ b/csrc/config/global_config.cpp @@ -0,0 +1,88 @@ +#include "global_config.hpp" + +namespace infinilm::config::global_config { +GlobalConfig::GlobalConfig(const std::string &path) { + std::ifstream file(path); + if (file.is_open()) { + file >> config_json; + file.close(); + } else { + throw std::runtime_error("Could not open config file: " + path); + } + this->quant_config = quantization::QuantConfig(config_json["quantization_config"]); +} + +infinicore::nn::QuantScheme +GlobalConfig::get_quant_scheme() const { + if (quant_config.get_quant_scheme() != infinicore::nn::QuantScheme::NONE) { + return quant_config.get_quant_scheme(); + } else { + return infinicore::nn::QuantScheme::NONE; + } +} + +std::shared_ptr +GlobalConfig::get_rope_scaling() const { + if (!config_json.contains("rope_scaling") || config_json["rope_scaling"].is_null()) { + return nullptr; + } + + const auto &rope_scaling = config_json["rope_scaling"]; + if (!rope_scaling.is_object()) { + throw std::runtime_error("rope_scaling must be an object"); + } + + if (!rope_scaling.contains("type")) { + throw std::runtime_error("rope_scaling must contain 'type' field"); + } + + std::string type_str = rope_scaling["type"].get(); + if (type_str == "longrope") { + // Required fields for LongRopeConfig + if (!rope_scaling.contains("short_factor") || !rope_scaling.contains("long_factor") || !rope_scaling.contains("original_max_position_embeddings")) { + throw std::runtime_error( + "LongRopeConfig requires 'short_factor', 'long_factor', and 'original_max_position_embeddings'"); + } + + auto short_factor = rope_scaling["short_factor"].get>(); + auto long_factor = rope_scaling["long_factor"].get>(); + size_t original_max_position_embeddings = rope_scaling["original_max_position_embeddings"].get(); + + float factor = 1.0f; + if (rope_scaling.contains("factor")) { + factor = rope_scaling["factor"].get(); + } + + return std::make_shared( + std::move(short_factor), + std::move(long_factor), + original_max_position_embeddings, + factor); + } else if (type_str == "default" || type_str == "none") { + // Default scaling, no scaling applied + return nullptr; + } else { + throw std::runtime_error("Unsupported rope_scaling type: " + type_str); + } +} + +infinicore::DataType +GlobalConfig::get_dtype() const { + try { + std::string dtype_str = this->get("torch_dtype"); + if (dtype_str == "float32") { + return infinicore::DataType::F32; + } else if (dtype_str == "float16") { + return infinicore::DataType::F16; + } else if (dtype_str == "bfloat16") { + return infinicore::DataType::BF16; + } else if (dtype_str == "int8") { + return infinicore::DataType::I8; + } else { + throw std::runtime_error("Unsupported dtype string: " + dtype_str); + } + } catch (const std::exception &e) { + throw std::runtime_error("Error getting dtype from config: " + std::string(e.what())); + } +} +} // namespace infinilm::config::global_config diff --git a/csrc/config/global_config.hpp b/csrc/config/global_config.hpp new file mode 100644 index 00000000..e8be1ec2 --- /dev/null +++ b/csrc/config/global_config.hpp @@ -0,0 +1,63 @@ +#pragma once + +// #include "infinicore/nn/quantization.hpp" +#include "infinicore/nn/rope.hpp" +#include "infinicore/ops.hpp" +#include "quant_config.hpp" +#include +#include + +namespace infinilm::config::global_config { +struct GlobalConfig { + // Global config is implemented using nlohmann/json and is primarily used for advanced configuration + // beyond the standard model config. It is initialized via GlobalConfig(const std::string& path) + // and passed through the InferEngine during inference. +public: + GlobalConfig() = default; + GlobalConfig(const nlohmann::json &json) : config_json(json) {}; + GlobalConfig(const std::string &path); + + // Template Function to get a value by key with type safety + template + T get(const std::string &key) const { + if (!config_json.contains(key)) { + throw std::out_of_range("Key '" + key + "' not found in config."); + } + try { + return config_json.at(key).get(); + } catch (const nlohmann::json::type_error &e) { + throw std::runtime_error("Type conversion failed for key '" + key + "': " + std::string(e.what())); + } + } + + template + T get_or(const std::string &key, const T &default_value) const { + if (!config_json.contains(key) || config_json.at(key).is_null()) { + return default_value; + } + try { + return config_json.at(key).get(); + } catch (const nlohmann::json::type_error &) { + // If type conversion fails, return default value + return default_value; + } + } + size_t get_kv_dim() const { + return get("hidden_size") * get("num_key_value_heads") / get("num_attention_heads"); + } + size_t get_head_dim() const { + if (config_json.contains("head_dim")) { + return get("head_dim"); + } + return get("hidden_size") / get("num_attention_heads"); + } + + infinicore::DataType get_dtype() const; + infinicore::nn::QuantScheme get_quant_scheme() const; + std::shared_ptr get_rope_scaling() const; + +private: + nlohmann::json config_json; + quantization::QuantConfig quant_config; +}; +} // namespace infinilm::config::global_config diff --git a/csrc/config/quant_config.cpp b/csrc/config/quant_config.cpp new file mode 100644 index 00000000..8984661f --- /dev/null +++ b/csrc/config/quant_config.cpp @@ -0,0 +1,22 @@ +#include "quant_config.hpp" + +namespace infinilm::config::quantization { +QuantConfig::QuantConfig(const nlohmann::json &json) : quantization_config(json) { + this->quantization_method = get_quantization_method(); +} + +std::shared_ptr +QuantConfig::get_quantization_method() const { + if (quantization_config.is_null()) { + return nullptr; + } + + // Determine the quantization scheme from the JSON config + if (quantization_config["quant_method"] == "compressed-tensors") { + return std::make_shared(quantization_config); + } + // Add other schemes as needed + + return nullptr; // Default case if no matching scheme +} +} // namespace infinilm::config::quantization diff --git a/csrc/config/quant_config.hpp b/csrc/config/quant_config.hpp new file mode 100644 index 00000000..dec3750e --- /dev/null +++ b/csrc/config/quant_config.hpp @@ -0,0 +1,28 @@ +#pragma once +#include "../quantization/quantization.hpp" +#include "nlohmann/json.hpp" + +namespace infinilm::config::quantization { + +class QuantConfig { + // QuantConfig is used to store and parse the "quantization" field from config.json. + // This is currently a basic version and will be extended in the future. +public: + QuantConfig() = default; + QuantConfig(const nlohmann::json &json); + + infinicore::nn::QuantScheme get_quant_scheme() const { + if (quantization_method != nullptr) { + return quantization_method->get_quant_scheme(); + } else { + return infinicore::nn::QuantScheme::NONE; + } + } + +private: + nlohmann::json quantization_config; + std::shared_ptr get_quantization_method() const; + std::shared_ptr quantization_method; +}; + +} // namespace infinilm::config::quantization diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 482117c0..c737ffc1 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -1,5 +1,6 @@ #include "infer_engine.hpp" #include "spdlog/spdlog.h" +#include namespace infinilm::engine { @@ -7,24 +8,28 @@ namespace infinilm::engine { // Constructor //------------------------------------------------------ InferEngine::InferEngine( - const InfinilmModel::Config &config, const distributed::DistConfig &distributed_config, infinicore::Device::Type device_type, - const cache::CacheConfig *cache_config) // Changed parameter - : communication_group_(distributed_config, device_type), - model_config_(config) { + const cache::CacheConfig *cache_config, + const std::string &model_path) // Changed parameter + : communication_group_(distributed_config, device_type) { if (cache_config != nullptr) { cache_config_ = cache_config->unique_copy(); } + + // Load global config if model_path is provided, model_path must be valid, and config.json exists + this->global_config_ = std::make_shared(model_path + "/config.json"); + // Create one RankWorker per rank int world_size = communication_group_.get_world_size(); workers_.reserve(world_size); for (int r = 0; r < world_size; ++r) { workers_.emplace_back(std::make_unique( - model_config_, + // model_config_, communication_group_.get_rank_info(r), - cache_config_ != nullptr ? cache_config_.get() : nullptr)); + cache_config_ != nullptr ? cache_config_.get() : nullptr, + global_config_)); } } diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index 315e1c7c..f8a2d95c 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -1,5 +1,6 @@ #pragma once +#include "../config/global_config.hpp" #include "../models/infinilm_model.hpp" #include "../models/llama/llama_config.hpp" #include "distributed/distributed.hpp" @@ -19,10 +20,10 @@ class InferEngine { // Updated constructor: accept CacheConfig instead of CacheType InferEngine( - const InfinilmModel::Config &config, const distributed::DistConfig &distributed_config = distributed::DistConfig(), infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), - const cache::CacheConfig *cache_config = nullptr); + const cache::CacheConfig *cache_config = nullptr, + const std::string &modle_path = ""); // Load a parameter to all workers (each can extract its shard inside RankWorker) void load_param(const std::string &name, const infinicore::Tensor ¶m); @@ -45,8 +46,9 @@ class InferEngine { protected: std::vector> workers_; distributed::CommunicationGroup communication_group_; - const InfinilmModel::Config &model_config_; + // const InfinilmModel::Config &model_config_; std::unique_ptr cache_config_; + std::shared_ptr global_config_; }; } // namespace infinilm::engine diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 003fb265..be287bb9 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -10,14 +10,15 @@ namespace infinilm::engine { -RankWorker::RankWorker(const InfinilmModel::Config &model_config, - const distributed::RankInfo &rank_info, - const cache::CacheConfig *cache_config) - : model_config_(model_config), - rank_info_(rank_info), +RankWorker::RankWorker( + const distributed::RankInfo &rank_info, + const cache::CacheConfig *cache_config, + std::shared_ptr global_config) + : rank_info_(rank_info), job_cmd_(Command::INIT), has_job_(false), job_done_(false), + global_config_(global_config), should_exit_(false), init_done_(false) { if (cache_config != nullptr) { @@ -25,7 +26,6 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config, } // start the thread thread_ = std::thread(&RankWorker::thread_loop, this); - // Wait until the worker thread finishes initialization (model created) std::unique_lock lk(mutex_); cv_.wait(lk, [&] { return init_done_; }); @@ -175,7 +175,7 @@ void RankWorker::thread_loop() { infinicore::context::setDevice(rank_info_.device); // Create model using factory (may be expensive) - model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr); + model_ = InfinilmModelFactory::createModel(rank_info_, pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr, global_config_); if (!model_) { throw std::runtime_error("Failed to create model"); } diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index 98bb4b87..b939b3c1 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -1,6 +1,7 @@ #pragma once #include "../cache/cache.hpp" +#include "../config/global_config.hpp" #include "../models/model_factory.hpp" #include "distributed/distributed.hpp" @@ -54,9 +55,9 @@ class RankWorker { infinicore::Tensor output_ids; }; - RankWorker(const InfinilmModel::Config &model_config, - const distributed::RankInfo &rank_info, - const cache::CacheConfig *cache_config); + RankWorker(const distributed::RankInfo &rank_info, + const cache::CacheConfig *cache_config, + std::shared_ptr global_config); // Submit a parameter load job and wait until the load completes on the worker thread. void load_param(const std::string &name, @@ -87,10 +88,11 @@ class RankWorker { private: // Worker properties - const InfinilmModel::Config &model_config_; + // const InfinilmModel::Config &model_config_; distributed::RankInfo rank_info_; std::shared_ptr model_; std::shared_ptr cache_; + std::shared_ptr global_config_; // Command for the pending job (protected by mutex_) Command job_cmd_; diff --git a/csrc/layers/fused_linear.cpp b/csrc/layers/fused_linear.cpp index 9b2c813d..e108b275 100644 --- a/csrc/layers/fused_linear.cpp +++ b/csrc/layers/fused_linear.cpp @@ -13,12 +13,14 @@ QKVParallelLinear::QKVParallelLinear(size_t hidden_size, bool bias, const infinicore::DataType &dtype, const infinicore::Device &device, - engine::distributed::RankInfo rank_info) + engine::distributed::RankInfo rank_info, + std::optional quant_scheme) : QKVParallelLinear(hidden_size, head_dim, head_dim, head_dim, num_q_head, num_kv_head, num_kv_head, bias, bias, bias, - dtype, device, rank_info) {} + dtype, device, rank_info, + quant_scheme) {} QKVParallelLinear::QKVParallelLinear(size_t hidden_size, size_t q_dim, size_t k_dim, size_t v_dim, @@ -26,15 +28,17 @@ QKVParallelLinear::QKVParallelLinear(size_t hidden_size, bool q_bias, bool k_bias, bool v_bias, const infinicore::DataType &dtype, const infinicore::Device &device, - engine::distributed::RankInfo rank_info) + engine::distributed::RankInfo rank_info, + std::optional quant_scheme) : infinicore::nn::ColumnParallelLinear( - hidden_size, - num_q_head * q_dim + num_k_head * k_dim + num_v_head * v_dim, - (q_bias || k_bias || v_bias), - dtype, - device, - rank_info.tp_rank, - rank_info.tp_size), + hidden_size, + num_q_head * q_dim + num_k_head * k_dim + num_v_head * v_dim, + (q_bias || k_bias || v_bias), + dtype, + device, + rank_info.tp_rank, + rank_info.tp_size, + quant_scheme), q_dim_(q_dim), k_dim_(k_dim), v_dim_(v_dim), @@ -86,6 +90,23 @@ infinicore::nn::Parameter QKVParallelLinear::get_v_weight() const { 0, tp_rank_, tp_size_); } +infinicore::nn::Parameter QKVParallelLinear::get_q_weight_scale() const { + return infinicore::nn::Parameter( + weight_scale_->narrow({{0, 0, q_out_size_}}), 0, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter QKVParallelLinear::get_k_weight_scale() const { + return infinicore::nn::Parameter( + weight_scale_->narrow({{0, q_out_size_, k_out_size_}}), + 0, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter QKVParallelLinear::get_v_weight_scale() const { + return infinicore::nn::Parameter( + weight_scale_->narrow({{0, q_out_size_ + k_out_size_, v_out_size_}}), + 0, tp_rank_, tp_size_); +} + infinicore::nn::Parameter QKVParallelLinear::get_q_bias() const { if (!q_bias_) { return infinicore::nn::Parameter(); @@ -122,14 +143,16 @@ bool QKVParallelLinear::has_v_bias() const { return v_bias_; } // --------------------------------------------------------- GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool bias, const infinicore::DataType &dtype, const infinicore::Device &device, - engine::distributed::RankInfo rank_info) - : GateUpParallelLinear(hidden_size, intermediate_size, bias, bias, dtype, device, rank_info) { + engine::distributed::RankInfo rank_info, + std::optional quant_scheme) + : GateUpParallelLinear(hidden_size, intermediate_size, bias, bias, dtype, device, rank_info, quant_scheme) { } GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias, const infinicore::DataType &dtype, const infinicore::Device &device, - engine::distributed::RankInfo rank_info) - : infinicore::nn::ColumnParallelLinear(hidden_size, intermediate_size * 2, gate_bias || up_bias, dtype, device, rank_info.tp_rank, rank_info.tp_size), gate_bias_(gate_bias), up_bias_(up_bias) { + engine::distributed::RankInfo rank_info, + std::optional quant_scheme) + : infinicore::nn::ColumnParallelLinear(hidden_size, intermediate_size * 2, gate_bias || up_bias, dtype, device, rank_info.tp_rank, rank_info.tp_size, quant_scheme), gate_bias_(gate_bias), up_bias_(up_bias) { if (gate_bias_ != up_bias_) { throw std::runtime_error("Not supported yet: gate_bias and up_bias should be given at the same time"); } @@ -168,6 +191,14 @@ infinicore::nn::Parameter GateUpParallelLinear::get_up_bias() const { } } +infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight_scale() const { + return infinicore::nn::Parameter(weight_scale_->narrow({{0, 0, weight_scale_->size(0) / 2}}), 0, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_scale() const { + return infinicore::nn::Parameter(weight_scale_->narrow({{0, weight_scale_->size(0) / 2, weight_scale_->size(0) / 2}}), 0, tp_rank_, tp_size_); +} + bool GateUpParallelLinear::has_gate_bias() const { return gate_bias_; } diff --git a/csrc/layers/fused_linear.hpp b/csrc/layers/fused_linear.hpp index 1e32ce50..f3d95bae 100644 --- a/csrc/layers/fused_linear.hpp +++ b/csrc/layers/fused_linear.hpp @@ -1,5 +1,6 @@ #pragma once #include "infinicore/nn/linear.hpp" +#include "infinicore/nn/quantization.hpp" #include "../engine/distributed/communication_group.hpp" @@ -12,7 +13,8 @@ class QKVParallelLinear : public infinicore::nn::ColumnParallelLinear { bool q_bias, bool k_bias, bool v_bias, const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(), - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + std::optional quant_scheme = std::nullopt); // A more common case where all heads have the same dimension explicit QKVParallelLinear(size_t hidden_size, @@ -21,7 +23,8 @@ class QKVParallelLinear : public infinicore::nn::ColumnParallelLinear { bool bias = false, const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(), - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + std::optional quant_scheme = std::nullopt); std::tuple forward_split(infinicore::Tensor &input); @@ -30,6 +33,10 @@ class QKVParallelLinear : public infinicore::nn::ColumnParallelLinear { infinicore::nn::Parameter get_k_weight() const; infinicore::nn::Parameter get_v_weight() const; + infinicore::nn::Parameter get_q_weight_scale() const; + infinicore::nn::Parameter get_k_weight_scale() const; + infinicore::nn::Parameter get_v_weight_scale() const; + infinicore::nn::Parameter get_q_bias() const; infinicore::nn::Parameter get_k_bias() const; infinicore::nn::Parameter get_v_bias() const; @@ -57,20 +64,26 @@ class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear { public: GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool bias = false, const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(), - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + std::optional quant_scheme = std::nullopt); GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias, const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(), - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + std::optional quant_scheme = std::nullopt); std::tuple forward_split(infinicore::Tensor &input); infinicore::nn::Parameter get_gate_weight() const; + infinicore::nn::Parameter get_gate_weight_scale() const; + infinicore::nn::Parameter get_gate_bias() const; infinicore::nn::Parameter get_up_weight() const; + infinicore::nn::Parameter get_up_weight_scale() const; + infinicore::nn::Parameter get_up_bias() const; bool has_gate_bias() const; @@ -103,4 +116,39 @@ class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear { if (name##_->has_up_bias()) \ this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias()); +// ========================= QKV Quantization ================================== +#define INFINILM_QKV_LINEAR_W8A8_INIT(name, q_name, k_name, v_name, ...) \ + name##_ = std::make_shared(__VA_ARGS__); \ + /* 注册 Q 权重 */ \ + this->register_parameter(std::string(q_name) + ".weight", name##_->get_q_weight()); \ + this->register_parameter(std::string(q_name) + ".weight_scale", name##_->get_q_weight_scale()); \ + /* 注册 K 权重 */ \ + this->register_parameter(std::string(k_name) + ".weight", name##_->get_k_weight()); \ + this->register_parameter(std::string(k_name) + ".weight_scale", name##_->get_k_weight_scale()); \ + /* 注册 V 权重 */ \ + this->register_parameter(std::string(v_name) + ".weight", name##_->get_v_weight()); \ + this->register_parameter(std::string(v_name) + ".weight_scale", name##_->get_v_weight_scale()); \ + /* bias 保持原样 */ \ + if (name##_->has_q_bias()) \ + this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias()); \ + if (name##_->has_k_bias()) \ + this->register_parameter(std::string(k_name) + ".bias", name##_->get_k_bias()); \ + if (name##_->has_v_bias()) \ + this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias()); + +// ========================= Gate-Up Quantization ============================== +#define INFINILM_GATE_UP_LINEAR_W8A8_INIT(name, gate_name, up_name, ...) \ + name##_ = std::make_shared(__VA_ARGS__); \ + /* 注册 Gate 权重 */ \ + this->register_parameter(std::string(gate_name) + ".weight", name##_->get_gate_weight()); \ + this->register_parameter(std::string(gate_name) + ".weight_scale", name##_->get_gate_weight_scale()); \ + /* 注册 Up 权重 */ \ + this->register_parameter(std::string(up_name) + ".weight", name##_->get_up_weight()); \ + this->register_parameter(std::string(up_name) + ".weight_scale", name##_->get_up_weight_scale()); \ + /* bias 保持原样 */ \ + if (name##_->has_gate_bias()) \ + this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias()); \ + if (name##_->has_up_bias()) \ + this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias()); + } // namespace infinilm::layers diff --git a/csrc/models/infinilm_model.hpp b/csrc/models/infinilm_model.hpp index 4cad3b6c..fcad67fc 100644 --- a/csrc/models/infinilm_model.hpp +++ b/csrc/models/infinilm_model.hpp @@ -1,8 +1,8 @@ #pragma once -#include "infinicore/nn/module.hpp" - #include "../cache/cache.hpp" +#include "infinicore/nn/module.hpp" +#include "nlohmann/json.hpp" #include @@ -13,7 +13,6 @@ class InfinilmModel : public infinicore::nn::Module { public: struct Config { std::string model_type; - virtual ~Config() = default; }; diff --git a/csrc/models/llama/llama.hpp b/csrc/models/llama/llama.hpp index fe554c32..eebac92b 100644 --- a/csrc/models/llama/llama.hpp +++ b/csrc/models/llama/llama.hpp @@ -16,9 +16,10 @@ * - LlamaForCausalLM: Complete model with language modeling head */ -#include "llama_config.hpp" +#include "../../config/global_config.hpp" #include "llama_attention.hpp" -#include "llama_mlp.hpp" +#include "llama_config.hpp" #include "llama_decoder_layer.hpp" -#include "llama_model.hpp" #include "llama_for_causal_lm.hpp" +#include "llama_mlp.hpp" +#include "llama_model.hpp" diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp index c78040e2..6ca77034 100644 --- a/csrc/models/llama/llama_attention.cpp +++ b/csrc/models/llama/llama_attention.cpp @@ -17,27 +17,28 @@ namespace infinilm::models::llama { -LlamaAttention::LlamaAttention(const LlamaConfig &config, - const infinicore::Device &device, +LlamaAttention::LlamaAttention(const infinicore::Device &device, size_t layer_idx, - engine::distributed::RankInfo rank_info) + engine::distributed::RankInfo rank_info, + std::shared_ptr global_config) : layer_idx_(layer_idx), - hidden_size_(config.hidden_size), - num_attention_heads_(config.num_attention_heads), - num_key_value_heads_(config.num_key_value_heads), - head_dim_(config.head_dim), - kv_dim_(config.kv_dim()), - use_bias_(config.attention_bias), - use_output_bias_(config.attention_output_bias), - use_qk_norm_(config.qk_norm), - max_position_embeddings_(config.max_position_embeddings), rank_info_(rank_info) { - const auto &dtype{config.dtype}; + hidden_size_(global_config->get("hidden_size")), + num_attention_heads_(global_config->get("num_attention_heads")), + num_key_value_heads_(global_config->get("num_key_value_heads")), + head_dim_(global_config->get_head_dim()), + kv_dim_(global_config->get_kv_dim()), + use_bias_(global_config->get_or("attention_bias", true)), + use_output_bias_(global_config->get_or("attention_output_bias", false)), + max_position_embeddings_(global_config->get("max_position_embeddings")), + rank_info_(rank_info), + global_config_(global_config) { + const auto &dtype{global_config_->get_dtype()}; int tp_rank = rank_info.tp_rank; int tp_size = rank_info.tp_size; - int num_attention_heads = config.num_attention_heads; - int num_key_value_heads = config.num_key_value_heads; + int num_attention_heads = global_config_->get("num_attention_heads"); + int num_key_value_heads = global_config_->get("num_key_value_heads"); if ((num_key_value_heads >= tp_size) && (0 == (num_key_value_heads % tp_size))) { this->num_attention_heads_ = num_attention_heads / tp_size; @@ -47,17 +48,29 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, } scaling_ = 1.0f / std::sqrt(static_cast(head_dim_)); - // Initialize projection layers - INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, config.num_attention_heads, config.num_key_value_heads, use_bias_, - dtype, device, rank_info); - // Output projection uses attention_output_bias (can be different from qkv) - INFINICORE_NN_MODULE_INIT(o_proj, num_attention_heads * head_dim_, hidden_size_, use_output_bias_, - dtype, device, tp_rank, tp_size, rank_info.comm); - - // Initialize qk RMSNorm - if (use_qk_norm_) { - INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, config.rms_norm_eps, dtype, device); - INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, config.rms_norm_eps, dtype, device); + auto quant_scheme = this->global_config_->get_quant_scheme(); + switch (quant_scheme) { + case infinicore::nn::QuantScheme::COMPRESSED_TENSOR_W8A8I8: + INFINILM_QKV_LINEAR_W8A8_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, global_config_->get("num_attention_heads"), global_config_->get("num_key_value_heads"), use_bias_, + dtype, device, rank_info, quant_scheme); + + // INFINICORE_NN_MODULE_INIT(o_proj, hidden_size_, hidden_size_, use_output_bias_, + // dtype, device, tp_rank, tp_size, rank_info.comm, quant_scheme); + INFINICORE_NN_MODULE_INIT(o_proj, global_config_->get("num_attention_heads") * head_dim_, hidden_size_, use_output_bias_, + dtype, device, tp_rank, tp_size, rank_info.comm, quant_scheme); + break; + + default: + INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, global_config_->get("num_attention_heads"), global_config_->get("num_key_value_heads"), use_bias_, + dtype, device, rank_info); + + INFINICORE_NN_MODULE_INIT(o_proj, global_config_->get("num_attention_heads") * head_dim_, hidden_size_, use_output_bias_, + dtype, device, tp_rank, tp_size, rank_info.comm); + break; + } + if (global_config_->get("model_type") == "qwen3") { + INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, global_config_->get("rms_norm_eps"), dtype, device); + INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, global_config_->get("rms_norm_eps"), dtype, device); } } @@ -75,7 +88,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta // 1. Project Q, K, V auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); - if (use_qk_norm_) { + if (global_config_->get("model_type") == "qwen3") { q = q_norm_->forward(q->view({batch_size * seq_len, num_attention_heads_, head_dim_})); k = k_norm_->forward(k->view({batch_size * seq_len, num_key_value_heads_, head_dim_})); } @@ -184,7 +197,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_}); auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_}); - if (use_qk_norm_) { + if (global_config_->get("model_type") == "qwen3") { q_reshaped = q_norm_->forward(q_reshaped); k_reshaped = k_norm_->forward(k_reshaped); } diff --git a/csrc/models/llama/llama_attention.hpp b/csrc/models/llama/llama_attention.hpp index 9d464bcf..17f6f95e 100644 --- a/csrc/models/llama/llama_attention.hpp +++ b/csrc/models/llama/llama_attention.hpp @@ -1,6 +1,7 @@ #pragma once #include "../../cache/kv_cache.hpp" +#include "../../config/global_config.hpp" #include "../../engine/distributed/distributed.hpp" #include "../../layers/fused_linear.hpp" #include "llama_config.hpp" @@ -36,10 +37,10 @@ class LlamaAttention : public infinicore::nn::Module { * @param layer_idx Layer index for cache access * @param dtype Optional data type for model parameters (defaults to F32) */ - LlamaAttention(const LlamaConfig &config, - const infinicore::Device &device, + LlamaAttention(const infinicore::Device &device, size_t layer_idx, - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + std::shared_ptr global_config = nullptr); /** * @brief Forward pass: compute attention @@ -109,10 +110,10 @@ class LlamaAttention : public infinicore::nn::Module { size_t kv_dim_; bool use_bias_; // Bias for Q/K/V projections bool use_output_bias_; // Bias for output projection (o_proj) - bool use_qk_norm_; // Whether to use QK RMSNorm size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility) float scaling_; + std::shared_ptr global_config_; }; } // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_config.hpp b/csrc/models/llama/llama_config.hpp index 59108546..fe5ba7e9 100644 --- a/csrc/models/llama/llama_config.hpp +++ b/csrc/models/llama/llama_config.hpp @@ -6,6 +6,8 @@ #include #include "../infinilm_model.hpp" +#include "infinicore/nn/quantization.hpp" +#include "nlohmann/json.hpp" #include @@ -70,7 +72,8 @@ struct LlamaConfig : public InfinilmModel::Config { * @brief Compute key-value dimension for Grouped Query Attention (GQA) * @return The dimension for key/value projections */ - size_t kv_dim() const { + size_t + kv_dim() const { return hidden_size * num_key_value_heads / num_attention_heads; } diff --git a/csrc/models/llama/llama_decoder_layer.cpp b/csrc/models/llama/llama_decoder_layer.cpp index 35a1acab..332095c6 100644 --- a/csrc/models/llama/llama_decoder_layer.cpp +++ b/csrc/models/llama/llama_decoder_layer.cpp @@ -6,21 +6,21 @@ namespace infinilm::models::llama { -LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, - const infinicore::Device &device, +LlamaDecoderLayer::LlamaDecoderLayer(const infinicore::Device &device, size_t layer_idx, - engine::distributed::RankInfo rank_info) : layer_idx_(layer_idx), rank_info_(rank_info) { - const auto &dtype{config.dtype}; + engine::distributed::RankInfo rank_info, + std::shared_ptr global_config) : layer_idx_(layer_idx), rank_info_(rank_info), global_config_(global_config) { + const auto &dtype{global_config_->get_dtype()}; // Initialize layer normalization layers - INFINICORE_NN_MODULE_INIT(input_layernorm, config.hidden_size, config.rms_norm_eps, + INFINICORE_NN_MODULE_INIT(input_layernorm, global_config_->get("hidden_size"), global_config_->get("rms_norm_eps"), dtype, device); - INFINICORE_NN_MODULE_INIT(post_attention_layernorm, config.hidden_size, config.rms_norm_eps, + INFINICORE_NN_MODULE_INIT(post_attention_layernorm, global_config_->get("hidden_size"), global_config_->get("rms_norm_eps"), dtype, device); // Initialize attention and MLP modules - INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, rank_info_); - INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_); + INFINICORE_NN_MODULE_INIT(self_attn, device, layer_idx, rank_info_, global_config); + INFINICORE_NN_MODULE_INIT(mlp, device, rank_info_, global_config); } infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states, diff --git a/csrc/models/llama/llama_decoder_layer.hpp b/csrc/models/llama/llama_decoder_layer.hpp index 4ded50a7..e377d2ae 100644 --- a/csrc/models/llama/llama_decoder_layer.hpp +++ b/csrc/models/llama/llama_decoder_layer.hpp @@ -33,10 +33,10 @@ class LlamaDecoderLayer : public infinicore::nn::Module { * @param layer_idx Layer index for cache management and debugging * @param dtype Optional data type for model parameters (defaults to F32) */ - LlamaDecoderLayer(const LlamaConfig &config, - const infinicore::Device &device, + LlamaDecoderLayer(const infinicore::Device &device, size_t layer_idx, - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + std::shared_ptr global_config = nullptr); /** * @brief Forward pass: process one decoder layer @@ -75,6 +75,7 @@ class LlamaDecoderLayer : public infinicore::nn::Module { INFINICORE_NN_MODULE(LlamaAttention, self_attn); INFINICORE_NN_MODULE(LlamaMLP, mlp); engine::distributed::RankInfo rank_info_; + std::shared_ptr global_config_; private: size_t layer_idx_; // Layer index for cache management and debugging diff --git a/csrc/models/llama/llama_for_causal_lm.cpp b/csrc/models/llama/llama_for_causal_lm.cpp index 6ce1fd98..e63a9b22 100644 --- a/csrc/models/llama/llama_for_causal_lm.cpp +++ b/csrc/models/llama/llama_for_causal_lm.cpp @@ -6,22 +6,23 @@ namespace infinilm::models::llama { -LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, - const infinicore::Device &device, - engine::distributed::RankInfo rank_info) { +LlamaForCausalLM::LlamaForCausalLM(const infinicore::Device &device, + engine::distributed::RankInfo rank_info, + std::shared_ptr global_config) { // Initialize module's device_ member device_ = device; - const auto &dtype{config.dtype}; + const auto &dtype{global_config->get_dtype()}; // Initialize base model - INFINICORE_NN_MODULE_INIT(model, config, device, rank_info); + INFINICORE_NN_MODULE_INIT(model, device, rank_info, global_config); // Initialize language modeling head // Note: If tie_word_embeddings is true, we would share weights with embed_tokens // For now, we create a separate linear layer - INFINICORE_NN_MODULE_INIT(lm_head, config.hidden_size, config.vocab_size, false, + + INFINICORE_NN_MODULE_INIT(lm_head, global_config->get("hidden_size"), global_config->get("vocab_size"), false, dtype, device); } diff --git a/csrc/models/llama/llama_for_causal_lm.hpp b/csrc/models/llama/llama_for_causal_lm.hpp index dd6f90fa..b7eef806 100644 --- a/csrc/models/llama/llama_for_causal_lm.hpp +++ b/csrc/models/llama/llama_for_causal_lm.hpp @@ -28,9 +28,9 @@ class LlamaForCausalLM : public InfinilmModel { * @param config Model configuration * @param device Device to create tensors on */ - LlamaForCausalLM(const LlamaConfig &config, - const infinicore::Device &device, - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + LlamaForCausalLM(const infinicore::Device &device, + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + std::shared_ptr global_config = nullptr); /** * @brief Forward pass: compute language modeling logits @@ -43,7 +43,7 @@ class LlamaForCausalLM : public InfinilmModel { void reset_cache(const cache::CacheConfig *cache_config) override; // Module information - const LlamaConfig &config() const { return model_->config(); } + // const LlamaConfig &config() const { return model_->config(); } LlamaModel &model() { return *model_; } const LlamaModel &model() const { return *model_; } diff --git a/csrc/models/llama/llama_mlp.cpp b/csrc/models/llama/llama_mlp.cpp index fc7abd69..1f4ee436 100644 --- a/csrc/models/llama/llama_mlp.cpp +++ b/csrc/models/llama/llama_mlp.cpp @@ -1,25 +1,39 @@ #include "llama_mlp.hpp" #include "infinicore/nn/linear.hpp" #include "infinicore/ops.hpp" +#include namespace infinilm::models::llama { -LlamaMLP::LlamaMLP(const LlamaConfig &config, - const infinicore::Device &device, - engine::distributed::RankInfo rank_info) - : hidden_size_(config.hidden_size), - intermediate_size_(config.intermediate_size), - use_bias_(config.mlp_bias), rank_info_(rank_info) { - const auto &dtype{config.dtype}; +LlamaMLP::LlamaMLP(const infinicore::Device &device, + engine::distributed::RankInfo rank_info, + std::shared_ptr global_config) + : hidden_size_(global_config->get("hidden_size")), + intermediate_size_(global_config->get("intermediate_size")), + use_bias_(global_config->get_or("mlp_bias", false)), rank_info_(rank_info), global_config_(global_config) { + const auto &dtype{global_config_->get_dtype()}; int tp_rank = rank_info.tp_rank; int tp_size = rank_info.tp_size; // Initialize projection layers - INFINILM_GATE_UP_LINEAR_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, use_bias_, - dtype, device, rank_info_); - INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, use_bias_, - dtype, device, tp_rank, tp_size, rank_info.comm); + auto quant_scheme = this->global_config_->get_quant_scheme(); + // std::cout << "LlamaMLP quant_scheme: " << static_cast(quant_scheme) << std::endl; + switch (quant_scheme) { + case infinicore::nn::QuantScheme::COMPRESSED_TENSOR_W8A8I8: + INFINILM_GATE_UP_LINEAR_W8A8_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, use_bias_, + dtype, device, rank_info_, quant_scheme); + INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, use_bias_, + dtype, device, tp_rank, tp_size, rank_info.comm, quant_scheme); + break; + + default: + INFINILM_GATE_UP_LINEAR_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, use_bias_, + dtype, device, rank_info_); + INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, use_bias_, + dtype, device, tp_rank, tp_size, rank_info.comm); + break; + } } infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const { diff --git a/csrc/models/llama/llama_mlp.hpp b/csrc/models/llama/llama_mlp.hpp index 665dac70..38249cb3 100644 --- a/csrc/models/llama/llama_mlp.hpp +++ b/csrc/models/llama/llama_mlp.hpp @@ -3,6 +3,7 @@ #include "../../layers/fused_linear.hpp" #include "llama_config.hpp" +#include "../../config/global_config.hpp" #include "infinicore/device.hpp" #include "infinicore/nn/linear.hpp" #include "infinicore/nn/module.hpp" @@ -33,9 +34,9 @@ class LlamaMLP : public infinicore::nn::Module { * @param device Device to create tensors on * @param dtype Optional data type for model parameters (defaults to F32) */ - LlamaMLP(const LlamaConfig &config, - const infinicore::Device &device, - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + LlamaMLP(const infinicore::Device &device, + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + std::shared_ptr global_config = nullptr); /** * @brief Forward pass: compute MLP output @@ -57,6 +58,8 @@ class LlamaMLP : public infinicore::nn::Module { size_t hidden_size_; size_t intermediate_size_; bool use_bias_; + + std::shared_ptr global_config_; }; } // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_model.cpp b/csrc/models/llama/llama_model.cpp index 34c3c0b2..9baeb454 100644 --- a/csrc/models/llama/llama_model.cpp +++ b/csrc/models/llama/llama_model.cpp @@ -7,34 +7,33 @@ namespace infinilm::models::llama { -LlamaModel::LlamaModel(const LlamaConfig &config, - const infinicore::Device &device, - engine::distributed::RankInfo rank_info) - : config_(config), rank_info_(rank_info) { - const auto &dtype{config.dtype}; +LlamaModel::LlamaModel(const infinicore::Device &device, + engine::distributed::RankInfo rank_info, + std::shared_ptr global_config) + : rank_info_(rank_info), global_config_(global_config) { + const auto &dtype{global_config_->get_dtype()}; // Initialize token embeddings - INFINICORE_NN_MODULE_INIT(embed_tokens, config.vocab_size, config.hidden_size, + INFINICORE_NN_MODULE_INIT(embed_tokens, global_config_->get("vocab_size"), global_config_->get("hidden_size"), std::nullopt, dtype, device); - // Initialize decoder layers with layer indices // TODO: Update INFINICORE_NN_MODULE_VEC_INIT macro to support per-layer constructor arguments // (e.g., via a factory function or lambda that receives the layer index) // Currently, we can't use the macro because each layer needs a different layer_idx - layers_.reserve(config.num_hidden_layers); - for (size_t i = 0; i < config.num_hidden_layers; ++i) { + layers_.reserve(global_config_->get("num_hidden_layers")); + for (size_t i = 0; i < global_config_->get("num_hidden_layers"); ++i) { layers_.push_back(this->register_module( - "layers." + std::to_string(i), config, device, i, rank_info)); + "layers." + std::to_string(i), device, i, rank_info, global_config_)); } // Initialize final layer normalization - INFINICORE_NN_MODULE_INIT(norm, config.hidden_size, config.rms_norm_eps, + INFINICORE_NN_MODULE_INIT(norm, global_config_->get("hidden_size"), global_config_->get("rms_norm_eps"), dtype, device); // Initialize Rotary Position Embeddings (shared across all layers) // Use GPT-J-style inverse frequencies (default) and GPT_NEOX rotation pairing - INFINICORE_NN_MODULE_INIT(rotary_emb, config.head_dim, config.max_position_embeddings, - config.rope_theta, infinicore::nn::RoPE::Algo::GPT_NEOX, - dtype, device, config.rope_scaling); + INFINICORE_NN_MODULE_INIT(rotary_emb, global_config_->get_head_dim(), global_config_->get("max_position_embeddings"), + global_config_->get("rope_theta"), infinicore::nn::RoPE::Algo::GPT_NEOX, + dtype, device, global_config_->get_rope_scaling()); for (auto &layer : layers_) { if (layer) { @@ -69,24 +68,23 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { } if (auto kv_cache_config = dynamic_cast(cache_config)) { kv_cache_ = std::make_shared( - config_.head_dim, - config_.head_dim, - config_.num_key_value_heads, - config_.num_key_value_heads, - config_.num_hidden_layers, - config_.max_position_embeddings, - config_.dtype, + global_config_->get_head_dim(), + global_config_->get_head_dim(), + global_config_->get("num_key_value_heads"), + global_config_->get("num_key_value_heads"), + global_config_->get("num_hidden_layers"), + global_config_->get("max_position_embeddings"), + global_config_->get_dtype(), *kv_cache_config, rank_info_); - } else if (auto paged_kv_cache_config = dynamic_cast(cache_config)) { kv_cache_ = std::make_shared( - config_.head_dim, - config_.head_dim, - config_.num_key_value_heads, - config_.num_key_value_heads, - config_.num_hidden_layers, - config_.dtype, + global_config_->get_head_dim(), + global_config_->get_head_dim(), + global_config_->get("num_key_value_heads"), + global_config_->get("num_key_value_heads"), + global_config_->get("num_hidden_layers"), + global_config_->get_dtype(), *paged_kv_cache_config, rank_info_); } else { diff --git a/csrc/models/llama/llama_model.hpp b/csrc/models/llama/llama_model.hpp index 5a008b0f..422c1bd6 100644 --- a/csrc/models/llama/llama_model.hpp +++ b/csrc/models/llama/llama_model.hpp @@ -1,7 +1,6 @@ #pragma once #include "../../cache/kv_cache.hpp" -#include "llama_config.hpp" #include "llama_decoder_layer.hpp" #include "infinicore/nn/embedding.hpp" @@ -38,9 +37,9 @@ class LlamaModel : public infinicore::nn::Module { * @param device Device to create tensors on * @param dtype Optional data type for model parameters (defaults to F32) */ - LlamaModel(const LlamaConfig &config, - const infinicore::Device &device, - engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + LlamaModel(const infinicore::Device &device, + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), + std::shared_ptr global_config = nullptr); /** * @brief Forward pass: process input through the model @@ -64,8 +63,7 @@ class LlamaModel : public infinicore::nn::Module { void reset_cache(const cache::CacheConfig *cache_config); // Module information - const LlamaConfig &config() const { return config_; } - size_t num_layers() const { return config_.num_hidden_layers; } + size_t num_layers() const { return global_config_->get("num_hidden_layers"); } protected: // Token embeddings @@ -85,7 +83,7 @@ class LlamaModel : public infinicore::nn::Module { std::shared_ptr kv_cache_; private: - LlamaConfig config_; + std::shared_ptr global_config_; }; } // namespace infinilm::models::llama diff --git a/csrc/models/model_factory.cpp b/csrc/models/model_factory.cpp index 999bb364..b4fd634a 100644 --- a/csrc/models/model_factory.cpp +++ b/csrc/models/model_factory.cpp @@ -3,15 +3,16 @@ namespace infinilm { std::shared_ptr InfinilmModelFactory::createModel( - const InfinilmModel::Config &config, engine::distributed::RankInfo rank_info, - const cache::CacheConfig *cache) { + const cache::CacheConfig *cache, + std::shared_ptr global_config) { std::shared_ptr model; - if (const auto llama_config_ptr = dynamic_cast(&config)) { - const auto &llama_config = *llama_config_ptr; + //****************************NEED TO BE FIXED */ + if (true) { + // const auto &llama_config = *llama_config_ptr; model = std::make_shared( - llama_config, rank_info.device, rank_info); + rank_info.device, rank_info, global_config); } else { throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type"); } diff --git a/csrc/models/model_factory.hpp b/csrc/models/model_factory.hpp index a73f432c..c020f6a5 100644 --- a/csrc/models/model_factory.hpp +++ b/csrc/models/model_factory.hpp @@ -1,5 +1,6 @@ #pragma once +#include "../config/global_config.hpp" #include "infinilm_model.hpp" #include "../engine/distributed/distributed.hpp" @@ -8,8 +9,9 @@ namespace infinilm { class InfinilmModelFactory { public: static std::shared_ptr createModel( - const InfinilmModel::Config &config, + // const InfinilmModel::Config &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), - const cache::CacheConfig *cache = nullptr); + const cache::CacheConfig *cache = nullptr, + std::shared_ptr global_config = nullptr); }; } // namespace infinilm diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 5ac38d70..8d610c61 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -32,20 +32,23 @@ inline void bind_infer_engine(py::module &m) { py::class_> infer_engine(m, "InferEngine"); infer_engine .def(py::init([]( - const InfinilmModel::Config &cfg, + // const InfinilmModel::Config &cfg, const distributed::DistConfig &dist, infinicore::Device::Type dev, - std::shared_ptr cache_cfg) { + std::shared_ptr cache_cfg, + const std::string &modle_path) { return std::make_shared( - cfg, + // cfg, dist, dev, - cache_cfg ? cache_cfg.get() : nullptr); + cache_cfg ? cache_cfg.get() : nullptr, + modle_path); }), - py::arg("config"), + // py::arg("config"), py::arg("distributed_config") = distributed::DistConfig(), py::arg("device_type") = infinicore::context::getDevice().getType(), - py::arg("cache_config") = py::none()) + py::arg("cache_config") = py::none(), + py::arg("model_path") = "") .def("load_param", &InferEngine::load_param, py::arg("name"), py::arg("param"), "Load a parameter tensor into all workers (each worker picks its shard)") @@ -60,20 +63,12 @@ inline void bind_infer_engine(py::module &m) { } return state_dict_tp_all; }) - .def( - "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") - .def( - "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { - self.reset_cache(cfg ? cfg.get() : nullptr); - }, - py::arg("cache_config") = py::none()) + .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") + .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) { auto cfg = self.get_cache_config(); - return std::shared_ptr(std::move(cfg->unique_copy())); - }) - .def("__repr__", [](const InferEngine &self) { - return ""; - }); + return std::shared_ptr(std::move(cfg->unique_copy())); }) + .def("__repr__", [](const InferEngine &self) { return ""; }); py::class_(infer_engine, "Input") .def( diff --git a/csrc/quantization/base_quantization.hpp b/csrc/quantization/base_quantization.hpp new file mode 100644 index 00000000..0d1f52ce --- /dev/null +++ b/csrc/quantization/base_quantization.hpp @@ -0,0 +1,18 @@ +#pragma once +#include "../config/quant_config.hpp" +#include "infinicore/nn/quantization.hpp" +#include "nlohmann/json.hpp" + +namespace infinilm::quantization { +class BaseQuantization { + // Base class for quantization schemes. Intended to be extended to support various quantization methods. +public: + explicit BaseQuantization(const nlohmann::json &quant_config) : quant_config_(quant_config) {}; + virtual ~BaseQuantization() = default; + + virtual infinicore::nn::QuantScheme get_quant_scheme() const = 0; + +protected: + nlohmann::json quant_config_; +}; +} // namespace infinilm::quantization diff --git a/csrc/quantization/compressed_tensors.hpp b/csrc/quantization/compressed_tensors.hpp new file mode 100644 index 00000000..f502f398 --- /dev/null +++ b/csrc/quantization/compressed_tensors.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include "../config/quant_config.hpp" +#include "base_quantization.hpp" +namespace infinilm::quantization { + +class CompressedTensors : public BaseQuantization { + // This is a temporary class that currently only returns COMPRESSED_TENSOR_W8A8I8. + // Future enhancements should parse quant_config to extract detailed quantization + // information and support multiple quantization schemes. +public: + explicit CompressedTensors(const nlohmann::json &quant_config) + : BaseQuantization(quant_config) {}; + + infinicore::nn::QuantScheme + get_quant_scheme() const override { + return infinicore::nn::QuantScheme::COMPRESSED_TENSOR_W8A8I8; + }; +}; + +} // namespace infinilm::quantization diff --git a/csrc/quantization/quantization.hpp b/csrc/quantization/quantization.hpp new file mode 100644 index 00000000..48b7646e --- /dev/null +++ b/csrc/quantization/quantization.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "base_quantization.hpp" +#include "compressed_tensors.hpp" +#include "infinicore/nn/quantization.hpp" diff --git a/examples/jiuge.py b/examples/jiuge.py index c1ad567e..9ea1019a 100644 --- a/examples/jiuge.py +++ b/examples/jiuge.py @@ -109,7 +109,6 @@ def test( device=infini_device, distributed_config=DistConfig(tp), ) - # ---------------------------------------------------------------------------- # # Load Weights # ---------------------------------------------------------------------------- # @@ -119,7 +118,6 @@ def test( # create tokenizer # ---------------------------------------------------------------------------- # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - if "llama" == model.config.model_type: backend = getattr(tokenizer, "backend_tokenizer", None) target = getattr(backend, "_tokenizer", backend) diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 8d5ea985..94f02e30 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -35,10 +35,11 @@ def __init__( device = infinicore.device() super().__init__( - self.config, + # self.config, distributed_config._underlying, device._underlying.type, cache_config, + model_path, ) self.use_cache = False diff --git a/python/infinilm/modeling_utils.py b/python/infinilm/modeling_utils.py index 792aa503..d1b26dd9 100644 --- a/python/infinilm/modeling_utils.py +++ b/python/infinilm/modeling_utils.py @@ -75,7 +75,7 @@ def load_state_dict( ) for k in f.keys(): - state_dict[k] = f.get_tensor(k).to(device=device, dtype=dtype) + state_dict[k] = f.get_tensor(k).to(device=device) return state_dict @@ -155,7 +155,6 @@ def load_model_state_dict_by_file( model_param_infini = {} for key in model_param.keys(): model_param_infini[key] = infinicore.from_torch(model_param[key]) - model.load_state_dict(model_param_infini, strict=False) infinicore.sync_device() @@ -168,7 +167,6 @@ def load_model_state_dict_by_file( model_param_infini[key] = infinicore.from_torch( model_params[key].to(dtype=torch_dtype) ) - already_loaded_keys.append(key) model.load_state_dict(model_param_infini, strict=True) diff --git a/python/infinilm/models/llama/configuration_llama.py b/python/infinilm/models/llama/configuration_llama.py index 15776c84..8d07a657 100644 --- a/python/infinilm/models/llama/configuration_llama.py +++ b/python/infinilm/models/llama/configuration_llama.py @@ -21,7 +21,6 @@ from ...configuration_utils import PretrainedConfig - class LlamaConfig(PretrainedConfig, _infinilm.LlamaConfig): r""" This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA @@ -244,4 +243,4 @@ def __init__( eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, - ) + ) \ No newline at end of file diff --git a/third_party/json b/third_party/json new file mode 160000 index 00000000..5ed07097 --- /dev/null +++ b/third_party/json @@ -0,0 +1 @@ +Subproject commit 5ed07097faa6c50199c4a3b66e5ed37d4fbfccc2 diff --git a/xmake.lua b/xmake.lua index ad636197..aab1a0c7 100644 --- a/xmake.lua +++ b/xmake.lua @@ -6,6 +6,7 @@ set_toolchains("gcc") -- Add spdlog from third_party directory add_includedirs("third_party/spdlog/include") +add_includedirs("third_party/json/single_include/") target("infinicore_infer") set_kind("shared")