diff --git a/.clang-format b/.clang-format deleted file mode 100644 index a77ae97c..00000000 --- a/.clang-format +++ /dev/null @@ -1,30 +0,0 @@ ---- -BasedOnStyle: LLVM -IndentWidth: 4 # 缩进宽度,LLVM 默认值为 2,改为 4 -AccessModifierOffset: -4 # public/protected/private 访问控制符相对成员的偏移,与 IndentWidth 配合,LLVM 默认值为 -2 -AlignOperands: AlignAfterOperator # 双目运算符的行间对齐,LLVM 默认值为 Align,改为带符号一起换行 -BreakBeforeBinaryOperators: All # 在双目运算符之前换行,LLVM 默认值为 None,改为换行时总是把双目运算符放在行首,包括赋值(=) -ColumnLimit: 0 # 列宽限制,LLVM 默认值为 80,改为不限制 -AllowShortBlocksOnASingleLine: Always # 是否允许短块(单个语句的块)不换行,LLVM 默认值为 Never,改为允许 -AllowShortLoopsOnASingleLine: true # 是否允许短循环不换行,LLVM 默认值为 false,改为允许 -InsertBraces: true # 是否在 if/for/while/switch 等语句后插入大括号,LLVM 默认值为 false,改为允许 -BreakBeforeBraces: Custom # 大括号换行配置,LLVM 默认值为 LLVM,改为自定义以使 BraceWrapping 生效 -BraceWrapping: - AfterCaseLabel: false - AfterClass: false - AfterControlStatement: Never - AfterEnum: false - AfterFunction: false - AfterNamespace: false - AfterObjCDeclaration: false - AfterStruct: false - AfterUnion: false - AfterExternBlock: false - BeforeCatch: false - BeforeElse: false - BeforeLambdaBody: false - BeforeWhile: false - IndentBraces: false - SplitEmptyFunction: true - SplitEmptyRecord: true - SplitEmptyNamespace: true diff --git a/.gitignore b/.gitignore index 0c9ef52c..80bc06ef 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,9 @@ cache/ *.txt *.http + +#weight_analyse +weight_analyse/ + +#llada_debug +llada_debug/ diff --git a/LLaDA_Implementation_Guide.md b/LLaDA_Implementation_Guide.md new file mode 100644 index 00000000..c8c6a714 --- /dev/null +++ b/LLaDA_Implementation_Guide.md @@ -0,0 +1,204 @@ +# LLaDA Python Implementation + +本文档展示了如何在Python中使用LLaDA(Masked Diffusion Language Model)进行文本生成。 + +## 概述 + +LLaDA是一种基于掩码扩散的语言模型,它通过逐步从全掩码序列还原为有意义文本来生成文本。这个实现包含了与原始LLaDA论文相同的核心算法: + +- 基于掩码的扩散生成过程 +- Gumbel噪声采样 +- 低置信度重掩码策略 +- 分块半自回归生成 +- 无监督分类器指导(CFG) + +## 主要功能 + +### 1. 基本generate函数 + +`generate()` 函数提供了LLaDA的核心生成逻辑,但使用占位符模型前向传播。 + +```python +model.generate( + prompts="Your prompt here", + max_steps=128, # 采样步数 + gen_length=128, # 生成长度 + block_length=128, # 块长度 + temperature_=0.0, # 采样温度 + cfg_scale=0.0, # CFG尺度 + remasking='low_confidence', # 重掩码策略 + verbose=True # 详细输出 +) +``` + +### 2. C++模型集成函数 + +`generate_with_cpp_model()` 函数展示了如何与C++模型接口集成: + +```python +model.generate_with_cpp_model( + prompts="Your prompt here", + max_steps=128, + gen_length=128, + block_length=32, # 使用分块生成 + temperature_=0.5, + cfg_scale=1.0, # 启用CFG + verbose=True +) +``` + +## 核心算法 + +### 1. 掩码扩散过程 + +LLaDA从全掩码序列开始,逐步还原token: + +```python +# 初始化为全掩码 +x = torch.full((batch_size, prompt_length + gen_length), mask_id, dtype=torch.long) +x[:, :prompt_length] = input_ids # 保留prompt部分 +``` + +### 2. Gumbel噪声采样 + +使用Gumbel-Max采样进行分类分布采样: + +```python +def add_gumbel_noise(logits, temperature): + if temperature == 0: + return logits + logits = logits.to(torch.float64) + noise = torch.rand_like(logits, dtype=torch.float64) + gumbel_noise = (- torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise +``` + +### 3. 低置信度重掩码 + +在每一步中,选择置信度最低的位置进行重新掩码: + +```python +if remasking == 'low_confidence': + p = F.softmax(logits, dim=-1) + x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) + +# 选择置信度最低的位置进行转移 +_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) +``` + +### 4. 分块生成 + +支持半自回归分块生成,提高生成长文本的效率: + +```python +# 将生成长度分成多个块 +num_blocks = gen_length // block_length +for num_block in range(num_blocks): + # 处理当前块 + block_mask_index = (x[:, block_start:block_end] == mask_id) +``` + +## 参数说明 + +- **prompts**: 输入提示(字符串或字符串列表) +- **max_steps**: 采样步数(建议≤gen_length) +- **gen_length**: 生成长度 +- **block_length**: 块长度(≤gen_length,如果小于gen_length则使用半自回归) +- **temperature_**: 采样温度(0=确定性,>0=随机性) +- **cfg_scale**: 无监督分类器指导尺度(0=禁用,>0=启用) +- **remasking**: 重掩码策略('low_confidence'或'random') +- **mask_id**: 掩码token ID(LLaDA中为126336) +- **logits_eos_inf**: 是否将EOS token logits设为-inf +- **confidence_eos_eot_inf**: 是否将EOS和EOT token置信度设为-inf + +## C++集成指南 + +当前实现已经集成了C++模型接口,通过以下组件: + +### 1. InferTask和BatchedTask +- `InferTask`: 封装单个推理请求的参数和状态 +- `LLaDABatchedTask`: 批量处理多个InferTask,转换为C++接口需要的格式 + +### 2. C++接口方法 + +**batch_infer_one_round(tasks)**: +- 输入: InferTask对象列表 +- 输出: 生成的token ID列表 +- 用于采样推理 + +**forward_logits_batch(input_ids_tensor, attention_mask_tensor)**: +- 输入: PyTorch张量格式的token IDs +- 输出: logits张量 [batch_size, seq_len, vocab_size] +- 用于获取完整logits进行LLaDA采样 + +### 3. 自动错误处理 +- C++模型调用失败时自动降级到占位符logits +- 详细的调试输出帮助排查问题 + +### 4. 已实现的C++接口 +```python +# 在scripts/libinfinicore_infer/llada.py中已定义: +- inferBatchLLaDA: 批量采样推理 +- forwardBatchLLaDA: 批量logits计算 +``` + +## 使用示例 + +```python +# 导入LLaDA模型 +from scripts.llada import LLaDAForCauslLM + +# 加载模型 +model = LLaDAForCauslLM( + model_dir_path="/path/to/llada/model", + device=DeviceType.DEVICE_TYPE_CPU, + ndev=1 +) + +# 基本生成 +result = model.generate( + prompts="The future of AI is", + max_steps=64, + gen_length=64, + temperature_=0.0, + verbose=True +) + +# 使用C++模型集成 +result = model.generate_with_cpp_model( + prompts="Explain quantum computing:", + max_steps=128, + gen_length=128, + block_length=32, + temperature_=0.7, + cfg_scale=1.0, + remasking='low_confidence', + verbose=True +) +``` + +## 注意事项 + +1. **模型接口**: 需要根据您的具体C++模型接口调整`cpp_model_forward`函数 +2. **内存管理**: 确保PyTorch张量和C++内存之间的正确同步 +3. **设备兼容性**: 确保PyTorch张量与C++模型在相同的设备上(CPU/GPU) +4. **性能优化**: 对于生产环境,考虑批量处理和内存优化 +5. **token ID**: 确保使用正确的掩码token ID(126336 for LLaDA) + +## 测试 + +运行测试函数: + +```bash +python scripts/llada.py +``` + +这将执行多个测试用例,验证基本生成、C++集成和高级参数功能。 + +## 扩展 + +可以基于此实现添加: +- 更多的重掩码策略 +- 不同的采样方法 +- 批量处理优化 +- 更多的高级控制参数 \ No newline at end of file diff --git a/README.md b/README.md index 3260b723..7774b313 100644 --- a/README.md +++ b/README.md @@ -1,53 +1 @@ -# InfiniLM - -本项目是基于 [`InfiniCore`](https://github.com/InfiniTensor/InfiniCore) 的推理引擎。 - -## 使用方式 - -- 编译并安装 `InfiniCore` 。注意根据提示设置好 `INFINI_ROOT` 环境变量(默认为 `$HOME/.infini`)。 - -- 编译并安装 `InfiniLM` - -```bash -xmake && xmake install -``` - -- 运行模型推理测试 - -```bash -python scripts/jiuge.py [--cpu | --nvidia | --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] path/to/model_dir [n_device] -``` - -- 部署模型推理服务 - -```bash -python scripts/launch_server.py --model-path MODEL_PATH [-h] [--dev {cpu,nvidia,cambricon,ascend,metax,moore,iluvatar,kunlun,hygon}] [--ndev NDEV] [--max-batch MAX_BATCH] [--max-tokens MAX_TOKENS] -``` - -- 测试模型推理服务性能 - -```bash -python scripts/test_perf.py -``` - -- 使用推理服务测试模型困惑度(Perplexity) - -```bash -python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MAX_BATCH] [--max-tokens MAX_TOKENS] -``` - -## 使用方式(新版) - -- 编译并安装 `InfiniCore`, 详情见 InfiniCore的 [`README`](https://github.com/InfiniTensor/InfiniCore) : - - - 注意根据提示设置好 `INFINI_ROOT` 环境变量(默认为 `$HOME/.infini`) - - 根据硬件平台,选择 xmake 构建配置 - - 编译安装InfiniCore - - 安装 C++ 库 - - 安装 Python 包 - -- 单次推理测试 - - llama示例 -```bash -python examples/llama.py [--cpu | --nvidia] --model_path= -``` \ No newline at end of file +个人开发用已迁移到private仓库 diff --git a/claude_code_env.sh b/claude_code_env.sh new file mode 100644 index 00000000..de40f604 --- /dev/null +++ b/claude_code_env.sh @@ -0,0 +1,205 @@ +#!/bin/bash + +set -euo pipefail + +# ======================== +# 常量定义 +# ======================== +SCRIPT_NAME=$(basename "$0") +NODE_MIN_VERSION=18 +NODE_INSTALL_VERSION=22 +NVM_VERSION="v0.40.3" +CLAUDE_PACKAGE="@anthropic-ai/claude-code" +CONFIG_DIR="$HOME/.claude" +CONFIG_FILE="$CONFIG_DIR/settings.json" +API_BASE_URL="https://open.bigmodel.cn/api/anthropic" +API_KEY_URL="https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys" +API_TIMEOUT_MS=3000000 + +# ======================== +# 工具函数 +# ======================== + +log_info() { + echo "🔹 $*" +} + +log_success() { + echo "✅ $*" +} + +log_error() { + echo "❌ $*" >&2 +} + +ensure_dir_exists() { + local dir="$1" + if [ ! -d "$dir" ]; then + mkdir -p "$dir" || { + log_error "Failed to create directory: $dir" + exit 1 + } + fi +} + +# ======================== +# Node.js 安装函数 +# ======================== + +install_nodejs() { + local platform=$(uname -s) + + case "$platform" in + Linux|Darwin) + log_info "Installing Node.js on $platform..." + + # 安装 nvm + log_info "Installing nvm ($NVM_VERSION)..." + curl -s https://raw.githubusercontent.com/nvm-sh/nvm/"$NVM_VERSION"/install.sh | bash + + # 加载 nvm + log_info "Loading nvm environment..." + \. "$HOME/.nvm/nvm.sh" + + # 安装 Node.js + log_info "Installing Node.js $NODE_INSTALL_VERSION..." + nvm install "$NODE_INSTALL_VERSION" + + # 验证安装 + node -v &>/dev/null || { + log_error "Node.js installation failed" + exit 1 + } + log_success "Node.js installed: $(node -v)" + log_success "npm version: $(npm -v)" + ;; + *) + log_error "Unsupported platform: $platform" + exit 1 + ;; + esac +} + +# ======================== +# Node.js 检查函数 +# ======================== + +check_nodejs() { + if command -v node &>/dev/null; then + current_version=$(node -v | sed 's/v//') + major_version=$(echo "$current_version" | cut -d. -f1) + + if [ "$major_version" -ge "$NODE_MIN_VERSION" ]; then + log_success "Node.js is already installed: v$current_version" + return 0 + else + log_info "Node.js v$current_version is installed but version < $NODE_MIN_VERSION. Upgrading..." + install_nodejs + fi + else + log_info "Node.js not found. Installing..." + install_nodejs + fi +} + +# ======================== +# Claude Code 安装 +# ======================== + +install_claude_code() { + if command -v claude &>/dev/null; then + log_success "Claude Code is already installed: $(claude --version)" + else + log_info "Installing Claude Code..." + npm install -g "$CLAUDE_PACKAGE" || { + log_error "Failed to install claude-code" + exit 1 + } + log_success "Claude Code installed successfully" + fi +} + +configure_claude_json(){ + node --eval ' + const os = require("os"); + const fs = require("fs"); + const path = require("path"); + + const homeDir = os.homedir(); + const filePath = path.join(homeDir, ".claude.json"); + if (fs.existsSync(filePath)) { + const content = JSON.parse(fs.readFileSync(filePath, "utf-8")); + fs.writeFileSync(filePath, JSON.stringify({ ...content, hasCompletedOnboarding: true }, null, 2), "utf-8"); + } else { + fs.writeFileSync(filePath, JSON.stringify({ hasCompletedOnboarding: true }, null, 2), "utf-8"); + }' +} + +# ======================== +# API Key 配置 +# ======================== + +configure_claude() { + log_info "Configuring Claude Code..." + echo " You can get your API key from: $API_KEY_URL" + read -s -p "🔑 Please enter your ZHIPU API key: " api_key + echo + + if [ -z "$api_key" ]; then + log_error "API key cannot be empty. Please run the script again." + exit 1 + fi + + ensure_dir_exists "$CONFIG_DIR" + + # 写入配置文件 + node --eval ' + const os = require("os"); + const fs = require("fs"); + const path = require("path"); + + const homeDir = os.homedir(); + const filePath = path.join(homeDir, ".claude", "settings.json"); + const apiKey = "'"$api_key"'"; + + const content = fs.existsSync(filePath) + ? JSON.parse(fs.readFileSync(filePath, "utf-8")) + : {}; + + fs.writeFileSync(filePath, JSON.stringify({ + ...content, + env: { + ANTHROPIC_AUTH_TOKEN: apiKey, + ANTHROPIC_BASE_URL: "'"$API_BASE_URL"'", + API_TIMEOUT_MS: "'"$API_TIMEOUT_MS"'", + CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC: 1 + } + }, null, 2), "utf-8"); + ' || { + log_error "Failed to write settings.json" + exit 1 + } + + log_success "Claude Code configured successfully" +} + +# ======================== +# 主流程 +# ======================== + +main() { + echo "🚀 Starting $SCRIPT_NAME" + + check_nodejs + install_claude_code + configure_claude_json + configure_claude + + echo "" + log_success "🎉 Installation completed successfully!" + echo "" + echo "🚀 You can now start using Claude Code with:" + echo " claude" +} + +main "$@" diff --git a/configuration_lladamoe.py b/configuration_lladamoe.py new file mode 100644 index 00000000..32d6ff68 --- /dev/null +++ b/configuration_lladamoe.py @@ -0,0 +1,97 @@ +""" +LLaDA MoE configuration +""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation + + +class LLaDAConfig(PretrainedConfig): + model_type = "llada" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=-1, + hidden_size=-1, + dense_intermediate_size=-1, + expert_intermediate_size=-1, + shared_expert_intermediate_size=-1, + num_hidden_layers=-1, + num_attention_heads=-1, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=False, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=-1, + partial_rotary_factor=-1, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + clip_qkv=None, + num_experts_per_tok=-1, + num_experts=-1, + output_router_logits=False, + router_aux_loss_coef=0.01, + norm_topk_prob=None, + qk_layernorm=None, + moe_layer_freq=[], + moe_router_enable_expert_bias=None, + moe_router_score_function=None, + routed_scaling_factor=1, + router_num_group=-2, + router_topk_group=-2, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.expert_intermediate_size = expert_intermediate_size + self.dense_intermediate_size = dense_intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.clip_qkv = clip_qkv + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.norm_topk_prob = norm_topk_prob + self.qk_layernorm = qk_layernorm + self.moe_layer_freq = moe_layer_freq + self.moe_router_enable_expert_bias = moe_router_enable_expert_bias + self.moe_router_score_function = moe_router_score_function + self.partial_rotary_factor = partial_rotary_factor + self.routed_scaling_factor = routed_scaling_factor + self.router_num_group = router_num_group + self.router_topk_group = router_topk_group + + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/csrc/engine/distributed/communication_group.cpp b/csrc/engine/distributed/communication_group.cpp new file mode 100644 index 00000000..a110bd83 --- /dev/null +++ b/csrc/engine/distributed/communication_group.cpp @@ -0,0 +1,46 @@ +#include "communication_group.hpp" +#include "../../utils.hpp" + +namespace infinilm::engine::distributed { + +CommunicationGroup::CommunicationGroup(const DistConfig &dist_config, infinicore::Device::Type device_type) + : dist_config_(dist_config), device_type_(device_type), + communicators_(std::vector(dist_config.tp_device_ids.size(), nullptr)) { + if (infinicore::context::getDevice().getType() != device_type_) { + infinicore::context::setDevice(infinicore::Device(device_type_, 0)); + } + if (dist_config_.tp_device_ids.size() > 1) { + RUN_INFINI(infinicclCommInitAll( + (infiniDevice_t)infinicore::context::getDevice().getType(), + communicators_.data(), + dist_config.tp_device_ids.size(), + dist_config.tp_device_ids.data())); + } +} + +const DistConfig &CommunicationGroup::get_dist_config() const { + return dist_config_; +} + +RankInfo CommunicationGroup::get_rank_info(int rank) const { + RankInfo info; + info.tp_size = dist_config_.tp_device_ids.size(); + info.tp_rank = rank; + info.device = infinicore::Device(device_type_, dist_config_.tp_device_ids[rank]); + info.comm = communicators_[rank]; + return info; +} + +int CommunicationGroup::get_world_size() const { + return dist_config_.tp_device_ids.size(); +} + +CommunicationGroup::~CommunicationGroup() { + if (communicators_.size() > 1) { + for (auto &comm : communicators_) { + RUN_INFINI(infinicclCommDestroy(comm)); + } + } +} + +} // namespace infinilm::engine::distributed diff --git a/csrc/engine/distributed/communication_group.hpp b/csrc/engine/distributed/communication_group.hpp new file mode 100644 index 00000000..e4f3c81a --- /dev/null +++ b/csrc/engine/distributed/communication_group.hpp @@ -0,0 +1,53 @@ +#pragma once + +#include "dist_config.hpp" + +#include +#include + +#include +#include + +namespace infinilm::engine::distributed { + +// Communicator each rank will hold +struct RankInfo { + // Device Type and ID assigned to this rank + infinicore::Device device; + // Tensor parallelism size + int tp_size; + // Tensor parallelism rank number of this rank + int tp_rank; + // Communicator handle + infinicclComm_t comm; + + RankInfo(infinicore::Device _device = infinicore::context::getDevice()) + : tp_size(1), tp_rank(0), device(_device), comm(nullptr){}; + + std::string to_string() const { + std::stringstream ss; + ss << "RankInfo: device=" << device.toString() << ", tp_size=" << tp_size << ", tp_rank=" << tp_rank; + return ss.str(); + } +}; + +// The communication group managed by model infer engine +class CommunicationGroup { +public: + explicit CommunicationGroup(const DistConfig &dist_config, infinicore::Device::Type device_type); + + const DistConfig &get_dist_config() const; + + RankInfo get_rank_info(int rank) const; + + int get_world_size() const; + + ~CommunicationGroup(); + +protected: + DistConfig dist_config_; + infinicore::Device::Type device_type_; + std::vector communicators_; +}; + +} // namespace infinilm::engine::distributed diff --git a/csrc/engine/distributed/dist_config.cpp b/csrc/engine/distributed/dist_config.cpp new file mode 100644 index 00000000..4a6c678c --- /dev/null +++ b/csrc/engine/distributed/dist_config.cpp @@ -0,0 +1,29 @@ +#include "dist_config.hpp" + +namespace infinilm::engine::distributed { +DistConfig::DistConfig() + : tp_device_ids{0} {} + +DistConfig::DistConfig(int tp_size) + : tp_device_ids(tp_size, 0) { + for (int i = 0; i < tp_size; ++i) { + tp_device_ids[i] = i; + } +} + +DistConfig::DistConfig(const std::vector &tp_device_ids_) + : tp_device_ids(tp_device_ids_) {} + +DistConfig::operator std::string() const { + std::string repr = "DistConfig(tp_device_ids=["; + for (size_t i = 0; i < tp_device_ids.size(); ++i) { + repr += std::to_string(tp_device_ids[i]); + if (i != tp_device_ids.size() - 1) { + repr += ", "; + } + } + repr += "])"; + return repr; +} + +} // namespace infinilm::engine::distributed diff --git a/csrc/engine/distributed/dist_config.hpp b/csrc/engine/distributed/dist_config.hpp new file mode 100644 index 00000000..2b46fb8e --- /dev/null +++ b/csrc/engine/distributed/dist_config.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +namespace infinilm::engine::distributed { + +struct DistConfig { + // Device IDs for each rank in tensor parallelism + std::vector tp_device_ids; + + DistConfig(); + explicit DistConfig(int tp_size); + explicit DistConfig(const std::vector &tp_device_ids_); + + explicit operator std::string() const; +}; + +} // namespace infinilm::engine::distributed diff --git a/csrc/engine/distributed/distributed.hpp b/csrc/engine/distributed/distributed.hpp new file mode 100644 index 00000000..a8661d63 --- /dev/null +++ b/csrc/engine/distributed/distributed.hpp @@ -0,0 +1,4 @@ +#pragma once + +#include "communication_group.hpp" +#include "dist_config.hpp" diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp new file mode 100644 index 00000000..6e4696b2 --- /dev/null +++ b/csrc/engine/infer_engine.cpp @@ -0,0 +1,132 @@ +#include "infer_engine.hpp" +#include "../models/llama/llama_config.hpp" +#include "spdlog/spdlog.h" + +namespace infinilm::engine { + +//------------------------------------------------------ +// Constructor +//------------------------------------------------------ +InferEngine::InferEngine( + const std::any &config, + const distributed::DistConfig &distributed_config, + infinicore::Device::Type device_type, + const cache::CacheConfig &cache_config) // Changed parameter + : communication_group_(distributed_config, device_type), + model_config_(config), + cache_config_(cache_config) { + + spdlog::info("Launch InferEngine with {}", std::string(distributed_config)); + spdlog::info("Cache configuration: type={}, layers={}, max_kv_cache_length={}", + static_cast(cache_config_.type), + cache_config_.num_layers, + cache_config_.max_kv_cache_length); + + // Try to extract model configuration to override default cache parameters if needed + try { + if (config.type() == typeid(models::llama::LlamaConfig)) { + const auto &llama_config = std::any_cast(config); + + cache_config_.num_layers = llama_config.num_hidden_layers; + cache_config_.max_kv_cache_length = llama_config.max_position_embeddings; + + spdlog::info("Updated cache config from model: layers={}, max_kv_cache_length={}", + cache_config_.num_layers, cache_config_.max_kv_cache_length); + } + } catch (...) { + spdlog::warn("Could not extract model config, using provided CacheConfig"); + } + + // Create one RankWorker per rank + int world_size = communication_group_.get_world_size(); + workers_.reserve(world_size); + for (int r = 0; r < world_size; ++r) { + workers_.emplace_back(std::make_unique( + model_config_, + communication_group_.get_rank_info(r), + cache_config_)); + } +} + +//------------------------------------------------------ +// load_param +//------------------------------------------------------ +void InferEngine::load_param(const std::string &name, const infinicore::Tensor ¶m) { + // Load the parameter on all workers + for (auto &worker : workers_) { + worker->load_param(name, param); + } +} + +//------------------------------------------------------ +// state_dict +//------------------------------------------------------ +std::vector> InferEngine::state_dict() { + std::vector> results; + if (0 == workers_.size()) { + throw std::runtime_error(" Model object not found. "); + } + + for (auto &worker : workers_) { + results.push_back(worker->state_dict()); + } + return results; +} + +//------------------------------------------------------ +// generate +//------------------------------------------------------ +infinicore::Tensor InferEngine::generate(const infinicore::Tensor &input_ids, + const infinicore::Tensor &position_ids) { + // Trigger each worker to run inference + for (auto &worker : workers_) { + worker->run(std::vector({input_ids, position_ids})); + } + // Wait for all workers + for (auto &worker : workers_) { + worker->wait(); + } + + return workers_[0]->get_output(); +} + +//------------------------------------------------------ +// Destructor +//------------------------------------------------------ +InferEngine::~InferEngine() { + // Close all workers + for (auto &worker : workers_) { + worker->close(); + } +} + +const distributed::DistConfig &InferEngine::get_dist_config() const { + return communication_group_.get_dist_config(); +} + +//------------------------------------------------------ +// reset_cache +//------------------------------------------------------ +void InferEngine::reset_cache(size_t pos) { + for (auto &worker : workers_) { + worker->reset_cache(pos); + } + for (auto &worker : workers_) { + worker->wait(); + } +} + +//------------------------------------------------------ +// reset_cache (overloaded with CacheConfig) +//------------------------------------------------------ +void InferEngine::reset_cache(const cache::CacheConfig &new_config, size_t pos) { + cache_config_ = new_config; + for (auto &worker : workers_) { + worker->reset_cache(new_config, pos); + } + for (auto &worker : workers_) { + worker->wait(); + } +} + +} // namespace infinilm::engine diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp new file mode 100644 index 00000000..df5c7599 --- /dev/null +++ b/csrc/engine/infer_engine.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include "distributed/distributed.hpp" +#include "infinicore/tensor.hpp" +#include "rank_worker.hpp" + +#include +#include + +namespace infinilm::engine { + +class InferEngine { +public: + // Updated constructor: accept CacheConfig instead of CacheType + InferEngine( + const std::any &config, + const distributed::DistConfig &distributed_config = distributed::DistConfig(), + infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), + const cache::CacheConfig &cache_config = cache::CacheConfig()); + + // Load a parameter to all workers (each can extract its shard inside RankWorker) + void load_param(const std::string &name, const infinicore::Tensor ¶m); + + // return the parameters (i.e. weights and biases). + std::vector> state_dict(); + + // Run a single forward pass on all workers and return the outputs from all ranks + infinicore::Tensor generate(const infinicore::Tensor &input_ids, + const infinicore::Tensor &position_ids); + + // Reset the internal cache pos in all workers (clears state between generations) + void reset_cache(size_t pos = 0); + + // Overload: reset cache with new KV configuration + void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0); + + ~InferEngine(); + + const distributed::DistConfig &get_dist_config() const; + + // Get current KV configuration + const cache::CacheConfig &get_cache_config() const { return cache_config_; } + +protected: + std::vector> workers_; + distributed::CommunicationGroup communication_group_; + std::any model_config_; + cache::CacheConfig cache_config_; +}; + +} // namespace infinilm::engine diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp new file mode 100644 index 00000000..8740f7a3 --- /dev/null +++ b/csrc/engine/rank_worker.cpp @@ -0,0 +1,330 @@ +#include "rank_worker.hpp" + +#include "../models/model_factory.hpp" + +#include +#include +#include + +namespace infinilm::engine { + +RankWorker::RankWorker(const std::any &model_config, + const distributed::RankInfo &rank_info, + const cache::CacheConfig &cache_config) + : model_config_(model_config), + rank_info_(rank_info), + job_cmd_(Command::INIT), + has_job_(false), + job_done_(false), + should_exit_(false), + init_done_(false), + pending_cache_config_(cache_config) { + // start the thread + thread_ = std::thread(&RankWorker::thread_loop, this); + + // Wait until the worker thread finishes initialization (model created) + std::unique_lock lk(mutex_); + cv_.wait(lk, [&] { return init_done_; }); +} + +std::string RankWorker::info() const { + std::stringstream ss; + + ss << "RankWorker{"; + + // Rank related + ss << rank_info_.to_string() << " "; + + // Flags + ss << "| init_done: " << (init_done_ ? "true" : "false") << " "; + ss << "| should_exit: " << (should_exit_ ? "true" : "false") << " "; + ss << "| has_job: " << (has_job_ ? "true" : "false") << " "; + ss << "| job_done: " << (job_done_ ? "true" : "false") << " "; + + ss << "}"; + + return ss.str(); +} + +//------------------------------------------------------ +// load_param -- synchronous (blocks until worker finishes loading) +//------------------------------------------------------ +void RankWorker::load_param(const std::string &name, + const infinicore::Tensor ¶m) { + { + std::lock_guard lock(mutex_); + // If the worker is stopping, don't submit new jobs. + if (should_exit_) { + throw std::runtime_error("RankWorker is closing; cannot load_param"); + } + + pending_param_name_ = name; + pending_param_ = param; + + job_cmd_ = Command::LOAD; + has_job_ = true; + job_done_ = false; + } + cv_.notify_all(); + + // Wait for job completion + std::unique_lock lk(mutex_); + cv_.wait(lk, [&] { return job_done_ || should_exit_; }); + + if (should_exit_) { + throw std::runtime_error("RankWorker stopped while loading parameter"); + } +} + +//------------------------------------------------------ +// state_dict -- +//------------------------------------------------------ +std::unordered_map RankWorker::state_dict() { + return this->model_->state_dict(); +} + +//------------------------------------------------------ +// run -- asynchronous +//------------------------------------------------------ +void RankWorker::run(const std::vector &args) { + std::lock_guard lock(mutex_); + + if (should_exit_) { + throw std::runtime_error("RankWorker is closing; cannot run"); + } + + pending_args_ = args; + job_cmd_ = Command::RUN; + has_job_ = true; + job_done_ = false; + + cv_.notify_all(); +} + +//------------------------------------------------------ +// wait -- asynchronous +//------------------------------------------------------ +void RankWorker::wait() { + std::unique_lock lk(mutex_); + cv_.wait(lk, [&] { return job_done_ || should_exit_; }); + + if (should_exit_) { + throw std::runtime_error("RankWorker stopped during run"); + } +} + +//------------------------------------------------------ +// reset_cache -- synchronous by default, async optional (unstable) +//------------------------------------------------------ +void RankWorker::reset_cache(size_t pos) { + std::lock_guard lock(mutex_); + if (should_exit_) { + throw std::runtime_error("RankWorker is closing; cannot reset_cache"); + } + + pending_reset_pos_ = pos; + job_cmd_ = Command::RESET_CACHE; + has_job_ = true; + job_done_ = false; + cv_.notify_all(); +} + +void RankWorker::reset_cache(const cache::CacheConfig &new_config, size_t pos) { + std::lock_guard lock(mutex_); + if (should_exit_) { + throw std::runtime_error("RankWorker is closing; cannot reset_cache"); + } + + // Store both the position and the new config + pending_reset_pos_ = pos; + pending_cache_config_ = new_config; + job_cmd_ = Command::RESET_CACHE_WITH_CONFIG; + has_job_ = true; + job_done_ = false; + cv_.notify_all(); +} + +//------------------------------------------------------ +// close -- request shutdown and join thread +//------------------------------------------------------ +void RankWorker::close() { + { + std::lock_guard lock(mutex_); + should_exit_ = true; + has_job_ = false; // don't keep old jobs pending + job_cmd_ = Command::STOP; + } + cv_.notify_all(); + + if (thread_.joinable()) { + thread_.join(); + } +} + +//------------------------------------------------------ +// get_output (thread safe) +//------------------------------------------------------ +infinicore::Tensor RankWorker::get_output() { + std::lock_guard lock(mutex_); + return output_; +} + +//------------------------------------------------------ +// thread_loop +//------------------------------------------------------ +void RankWorker::thread_loop() { + try { + // Initialize device & model outside of holding the main mutex to avoid blocking callers. + infinicore::context::setDevice(rank_info_.device); + + cache_ptr_ = std::make_shared(pending_cache_config_); + + // Create model using factory (may be expensive) + model_ = InfinilmModelFactory::createModel(model_config_, rank_info_, cache_ptr_); + + // Signal that initialization is done + { + std::lock_guard lk(mutex_); + init_done_ = true; + } + cv_.notify_all(); + + // Main loop: wait for jobs or exit + while (true) { + Command local_cmd = Command::INIT; + std::string local_param_name; + infinicore::Tensor local_param; + std::vector local_args; + size_t local_reset_pos = 0; + cache::CacheConfig local_reset_config; + + // Wait for a job or exit + { + std::unique_lock lk(mutex_); + cv_.wait(lk, [&] { return has_job_ || should_exit_; }); + + if (should_exit_) { + break; + } + + // capture job data and clear has_job_ + local_cmd = job_cmd_; + if (local_cmd == Command::LOAD) { + local_param_name = pending_param_name_; + local_param = pending_param_; + } else if (local_cmd == Command::RUN) { + local_args = pending_args_; + } else if (local_cmd == Command::RESET_CACHE) { + local_reset_pos = pending_reset_pos_; + } else if (local_cmd == Command::RESET_CACHE_WITH_CONFIG) { + local_reset_pos = pending_reset_pos_; + local_reset_config = pending_cache_config_; + } + + // mark job as being processed + has_job_ = false; + job_done_ = false; + } // unlock mutex while executing the job + + // Execute job outside the lock + if (local_cmd == Command::LOAD) { + try { + model_->load_parameter(local_param_name, local_param); + } catch (const std::exception &e) { + // convert exceptions to a safe behavior: set should_exit_ and notify caller + std::lock_guard lk(mutex_); + should_exit_ = true; + job_done_ = true; + cv_.notify_all(); + // rethrow so the thread can be joined and caller sees an error if desired (optional) + spdlog::error("[{}] exception during load_parameter_: {}\n", info(), e.what()); + break; + } + + // signal completion + { + std::lock_guard lk(mutex_); + job_done_ = true; + } + cv_.notify_all(); + + } else if (local_cmd == Command::RUN) { + try { + auto out = model_->forward(local_args); + + { + std::lock_guard lk(mutex_); + output_ = std::move(out); + job_done_ = true; + } + cv_.notify_all(); + + } catch (const std::exception &e) { + std::lock_guard lk(mutex_); + should_exit_ = true; + job_done_ = true; + cv_.notify_all(); + spdlog::error("[{}] exception during forward: {}\n", info(), e.what()); + break; + } + } else if (local_cmd == Command::RESET_CACHE) { + try { + // Option 1: Use model's reset_cache if it handles cache + model_->reset_cache(local_reset_pos); + + // Option 2: Reset cache directly if we have access + // if (cache_ptr_ != nullptr) { + // auto* dynamic_cache = static_cast(cache_ptr_); + // dynamic_cache->reset(local_reset_pos); + // } + + { + std::lock_guard lk(mutex_); + job_done_ = true; + } + cv_.notify_all(); + + } catch (const std::exception &e) { + std::lock_guard lk(mutex_); + should_exit_ = true; + job_done_ = true; + cv_.notify_all(); + spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what()); + break; + } + } else if (local_cmd == Command::RESET_CACHE_WITH_CONFIG) { + try { + // Use model's reset_cache with new configuration + model_->reset_cache(local_reset_config, local_reset_pos); + + { + std::lock_guard lk(mutex_); + job_done_ = true; + } + cv_.notify_all(); + + } catch (const std::exception &e) { + std::lock_guard lk(mutex_); + should_exit_ = true; + job_done_ = true; + cv_.notify_all(); + spdlog::error("[{}] exception during reset_cache with config: {}\n", info(), e.what()); + break; + } + } else { + // Shouldn't reach here (no-op) + } + } // while + } catch (const std::exception &e) { + // Top-level exception: ensure any waiters are woken and the thread exits cleanly. + { + std::lock_guard lk(mutex_); + should_exit_ = true; + job_done_ = true; + } + cv_.notify_all(); + spdlog::error("[{}] fatal exception in thread_loop: {} \n", info(), e.what()); + } +} + +} // namespace infinilm::engine diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp new file mode 100644 index 00000000..63e4cef9 --- /dev/null +++ b/csrc/engine/rank_worker.hpp @@ -0,0 +1,93 @@ +#pragma once + +#include "../cache/cache.hpp" +#include "../models/model_factory.hpp" +#include "distributed/distributed.hpp" + +#include +#include +#include +#include +#include +#include + +namespace infinilm::engine { + +class RankWorker { + enum class Command { + INIT, + LOAD, + RUN, + RESET_CACHE, + RESET_CACHE_WITH_CONFIG, + STOP + }; + +public: + RankWorker(const std::any &model_config, + const distributed::RankInfo &rank_info, + const cache::CacheConfig &cache_config); + + // Submit a parameter load job and wait until the load completes on the worker thread. + void load_param(const std::string &name, + const infinicore::Tensor ¶m); + + // return the parameters (i.e. weights and biases). + std::unordered_map state_dict(); + + // Submit a run (forward) job. + void run(const std::vector &args); + + // Reset the internal cache in the model (clears state between generations) + void reset_cache(size_t pos = 0); + + // Reset the internal cache with a new configuration + void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0); + + // Wait until run job completes. The result can be retrieved with get_output(). + void wait(); + + // Request worker shutdown and join the thread. + void close(); + + // Thread-safe accessor for last output produced by RUN. + infinicore::Tensor get_output(); + + std::string info() const; + +private: + void thread_loop(); + +private: + // Worker properties + std::any model_config_; + distributed::RankInfo rank_info_; + std::shared_ptr model_; + std::shared_ptr cache_ptr_; + + // Command for the pending job (protected by mutex_) + Command job_cmd_; + + // State flags (protected by mutex_) + bool has_job_ = false; // a job is pending + bool job_done_ = false; // last job completed + bool should_exit_ = false; // request to stop + bool init_done_ = false; // initialization finished + + // Task payloads (protected by mutex) + std::string pending_param_name_; + infinicore::Tensor pending_param_; + std::vector pending_args_; + size_t pending_reset_pos_ = 0; + cache::CacheConfig pending_cache_config_; + + // Output (protected by mutex) + infinicore::Tensor output_; + + // Thread sync + std::thread thread_; + std::mutex mutex_; + std::condition_variable cv_; +}; + +} // namespace infinilm::engine diff --git a/csrc/layers/fused_linear.cpp b/csrc/layers/fused_linear.cpp new file mode 100644 index 00000000..9b2c813d --- /dev/null +++ b/csrc/layers/fused_linear.cpp @@ -0,0 +1,178 @@ +#include "fused_linear.hpp" + +#include + +namespace infinilm::layers { +// --------------------------------------------------------- +// QKV Parallel Linear +// --------------------------------------------------------- +QKVParallelLinear::QKVParallelLinear(size_t hidden_size, + size_t head_dim, + size_t num_q_head, + size_t num_kv_head, + bool bias, + const infinicore::DataType &dtype, + const infinicore::Device &device, + engine::distributed::RankInfo rank_info) + : QKVParallelLinear(hidden_size, + head_dim, head_dim, head_dim, + num_q_head, num_kv_head, num_kv_head, + bias, bias, bias, + dtype, device, rank_info) {} + +QKVParallelLinear::QKVParallelLinear(size_t hidden_size, + size_t q_dim, size_t k_dim, size_t v_dim, + size_t num_q_head, size_t num_k_head, size_t num_v_head, + bool q_bias, bool k_bias, bool v_bias, + const infinicore::DataType &dtype, + const infinicore::Device &device, + engine::distributed::RankInfo rank_info) + : infinicore::nn::ColumnParallelLinear( + hidden_size, + num_q_head * q_dim + num_k_head * k_dim + num_v_head * v_dim, + (q_bias || k_bias || v_bias), + dtype, + device, + rank_info.tp_rank, + rank_info.tp_size), + q_dim_(q_dim), + k_dim_(k_dim), + v_dim_(v_dim), + num_q_head_(num_q_head), + num_k_head_(num_k_head), + num_v_head_(num_v_head), + q_bias_(q_bias), + k_bias_(k_bias), + v_bias_(v_bias) { + if (num_q_head % tp_size_ != 0 || num_k_head % tp_size_ != 0 || num_v_head % tp_size_ != 0) { + throw std::runtime_error("QKVParallelLinear: num_[q|k|v]_head must be divisible by tp_size"); + } + + if ((q_bias_ != k_bias_) || (k_bias_ != v_bias_)) { + throw std::runtime_error("q_bias, k_bias, v_bias must all match"); + } + + q_out_size_ = num_q_head_ * q_dim_ / tp_size_; + k_out_size_ = num_k_head_ * k_dim_ / tp_size_; + v_out_size_ = num_v_head_ * v_dim_ / tp_size_; +} + +std::tuple +QKVParallelLinear::forward_split(infinicore::Tensor &input) { + auto output = this->forward(input); + + auto q_out = output->narrow({{2, 0, q_out_size_}}); + auto k_out = output->narrow({{2, q_out_size_, k_out_size_}}); + auto v_out = output->narrow({{2, q_out_size_ + k_out_size_, v_out_size_}}); + + return std::make_tuple(q_out, k_out, v_out); +} + +infinicore::nn::Parameter QKVParallelLinear::get_q_weight() const { + return infinicore::nn::Parameter( + weight_->narrow({{0, 0, q_out_size_}}), + 0, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter QKVParallelLinear::get_k_weight() const { + return infinicore::nn::Parameter( + weight_->narrow({{0, q_out_size_, k_out_size_}}), + 0, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter QKVParallelLinear::get_v_weight() const { + return infinicore::nn::Parameter( + weight_->narrow({{0, q_out_size_ + k_out_size_, v_out_size_}}), + 0, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter QKVParallelLinear::get_q_bias() const { + if (!q_bias_) { + return infinicore::nn::Parameter(); + } + return infinicore::nn::Parameter( + bias_->narrow({{0, 0, q_out_size_}}), + 0, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter QKVParallelLinear::get_k_bias() const { + if (!k_bias_) { + return infinicore::nn::Parameter(); + } + return infinicore::nn::Parameter( + bias_->narrow({{0, q_out_size_, k_out_size_}}), + 0, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter QKVParallelLinear::get_v_bias() const { + if (!v_bias_) { + return infinicore::nn::Parameter(); + } + return infinicore::nn::Parameter( + bias_->narrow({{0, q_out_size_ + k_out_size_, v_out_size_}}), + 0, tp_rank_, tp_size_); +} + +bool QKVParallelLinear::has_q_bias() const { return q_bias_; } +bool QKVParallelLinear::has_k_bias() const { return k_bias_; } +bool QKVParallelLinear::has_v_bias() const { return v_bias_; } + +// --------------------------------------------------------- +// Gate-Up Parallel Linear +// --------------------------------------------------------- +GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool bias, + const infinicore::DataType &dtype, const infinicore::Device &device, + engine::distributed::RankInfo rank_info) + : GateUpParallelLinear(hidden_size, intermediate_size, bias, bias, dtype, device, rank_info) { +} + +GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias, + const infinicore::DataType &dtype, const infinicore::Device &device, + engine::distributed::RankInfo rank_info) + : infinicore::nn::ColumnParallelLinear(hidden_size, intermediate_size * 2, gate_bias || up_bias, dtype, device, rank_info.tp_rank, rank_info.tp_size), gate_bias_(gate_bias), up_bias_(up_bias) { + if (gate_bias_ != up_bias_) { + throw std::runtime_error("Not supported yet: gate_bias and up_bias should be given at the same time"); + } +} + +std::tuple GateUpParallelLinear::forward_split(infinicore::Tensor &input) { + auto output = this->forward(input); + auto cols = output->shape()[2]; + auto gate_output = output->narrow({{2, 0, cols / 2}}); + auto up_output = output->narrow({{2, cols / 2, cols / 2}}); + return std::make_tuple(gate_output, up_output); +} + +infinicore::nn::Parameter GateUpParallelLinear::get_gate_weight() const { + return infinicore::nn::Parameter(weight_->narrow({{0, 0, weight_->size(0) / 2}}), 0, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter GateUpParallelLinear::get_gate_bias() const { + if (!gate_bias_) { + return infinicore::nn::Parameter(); + } else { + return infinicore::nn::Parameter(bias_->narrow({{0, 0, bias_->size(0) / 2}}), 0, tp_rank_, tp_size_); + } +} + +infinicore::nn::Parameter GateUpParallelLinear::get_up_weight() const { + return infinicore::nn::Parameter(weight_->narrow({{0, weight_->size(0) / 2, weight_->size(0) / 2}}), 0, tp_rank_, tp_size_); +} + +infinicore::nn::Parameter GateUpParallelLinear::get_up_bias() const { + if (!up_bias_) { + return infinicore::nn::Parameter(); + } else { + return infinicore::nn::Parameter(bias_->narrow({{0, bias_->size(0) / 2, bias_->size(0) / 2}}), + 0, tp_rank_, tp_size_); + } +} + +bool GateUpParallelLinear::has_gate_bias() const { + return gate_bias_; +} + +bool GateUpParallelLinear::has_up_bias() const { + return up_bias_; +} +} // namespace infinilm::layers diff --git a/csrc/layers/fused_linear.hpp b/csrc/layers/fused_linear.hpp new file mode 100644 index 00000000..1e32ce50 --- /dev/null +++ b/csrc/layers/fused_linear.hpp @@ -0,0 +1,106 @@ +#pragma once +#include "infinicore/nn/linear.hpp" + +#include "../engine/distributed/communication_group.hpp" + +namespace infinilm::layers { +class QKVParallelLinear : public infinicore::nn::ColumnParallelLinear { +public: + explicit QKVParallelLinear(size_t hidden_size, + size_t q_dim, size_t k_dim, size_t v_dim, + size_t num_q_head, size_t num_k_head, size_t num_v_head, + bool q_bias, bool k_bias, bool v_bias, + const infinicore::DataType &dtype = infinicore::DataType::F32, + const infinicore::Device &device = infinicore::Device(), + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + + // A more common case where all heads have the same dimension + explicit QKVParallelLinear(size_t hidden_size, + size_t head_dim, + size_t num_q_head, size_t num_kv_head, + bool bias = false, + const infinicore::DataType &dtype = infinicore::DataType::F32, + const infinicore::Device &device = infinicore::Device(), + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + + std::tuple + forward_split(infinicore::Tensor &input); + + infinicore::nn::Parameter get_q_weight() const; + infinicore::nn::Parameter get_k_weight() const; + infinicore::nn::Parameter get_v_weight() const; + + infinicore::nn::Parameter get_q_bias() const; + infinicore::nn::Parameter get_k_bias() const; + infinicore::nn::Parameter get_v_bias() const; + + bool has_q_bias() const; + bool has_k_bias() const; + bool has_v_bias() const; + +private: + size_t q_dim_; + size_t k_dim_; + size_t v_dim_; + size_t num_q_head_; + size_t num_k_head_; + size_t num_v_head_; + bool q_bias_; + bool k_bias_; + bool v_bias_; + size_t q_out_size_; // num_q_head * q_dim / tp_size + size_t k_out_size_; // num_k_head * k_dim / tp_size + size_t v_out_size_; // num_v_head * v_dim / tp_size +}; + +class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear { +public: + GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool bias = false, + const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(), + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + + GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias, + const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(), + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + + std::tuple forward_split(infinicore::Tensor &input); + + infinicore::nn::Parameter get_gate_weight() const; + + infinicore::nn::Parameter get_gate_bias() const; + + infinicore::nn::Parameter get_up_weight() const; + + infinicore::nn::Parameter get_up_bias() const; + + bool has_gate_bias() const; + + bool has_up_bias() const; + +private: + bool gate_bias_; + bool up_bias_; +}; + +#define INFINILM_QKV_LINEAR_INIT(name, q_name, k_name, v_name, ...) \ + name##_ = std::make_shared(__VA_ARGS__); \ + this->register_parameter(std::string(q_name) + ".weight", name##_->get_q_weight()); \ + this->register_parameter(std::string(k_name) + ".weight", name##_->get_k_weight()); \ + this->register_parameter(std::string(v_name) + ".weight", name##_->get_v_weight()); \ + if (name##_->has_q_bias()) \ + this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias()); \ + if (name##_->has_k_bias()) \ + this->register_parameter(std::string(k_name) + ".bias", name##_->get_k_bias()); \ + if (name##_->has_v_bias()) \ + this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias()); + +#define INFINILM_GATE_UP_LINEAR_INIT(name, gate_name, up_name, ...) \ + name##_ = std::make_shared(__VA_ARGS__); \ + this->register_parameter(std::string(gate_name) + ".weight", name##_->get_gate_weight()); \ + this->register_parameter(std::string(up_name) + ".weight", name##_->get_up_weight()); \ + if (name##_->has_gate_bias()) \ + this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias()); \ + if (name##_->has_up_bias()) \ + this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias()); + +} // namespace infinilm::layers diff --git a/csrc/models/debug_utils/hooks.cpp b/csrc/models/debug_utils/hooks.cpp new file mode 100644 index 00000000..06846318 --- /dev/null +++ b/csrc/models/debug_utils/hooks.cpp @@ -0,0 +1,44 @@ +#include "hooks.hpp" +#include + +namespace infinilm::models::debug_utils { + +void HookRegistry::register_hook(const std::string &name, HookCallback callback) { + hooks_[name] = callback; + SPDLOG_DEBUG("HookRegistry: Registered hook '{}'", name); +} + +void HookRegistry::call_hook(const std::string &name, const infinicore::Tensor &tensor, int layer_idx) const { + // Try exact match first + auto it = hooks_.find(name); + if (it != hooks_.end()) { + try { + it->second(name, tensor, layer_idx); + } catch (const std::exception &e) { + SPDLOG_ERROR("HookRegistry: Error calling hook '{}': {}", name, e.what()); + } + return; + } + + // Try pattern matching (e.g., "layer0_*" matches "layer0_q_after_proj") + for (const auto &[pattern, callback] : hooks_) { + if (pattern.back() == '*' && name.size() >= pattern.size() - 1) { + std::string prefix = pattern.substr(0, pattern.size() - 1); + if (name.substr(0, prefix.size()) == prefix) { + try { + callback(name, tensor, layer_idx); + } catch (const std::exception &e) { + SPDLOG_ERROR("HookRegistry: Error calling hook pattern '{}' for '{}': {}", pattern, name, e.what()); + } + return; + } + } + } +} + +void HookRegistry::clear() { + hooks_.clear(); + SPDLOG_DEBUG("HookRegistry: Cleared all hooks"); +} + +} // namespace infinilm::models::debug_utils diff --git a/csrc/models/debug_utils/hooks.hpp b/csrc/models/debug_utils/hooks.hpp new file mode 100644 index 00000000..bb3460c7 --- /dev/null +++ b/csrc/models/debug_utils/hooks.hpp @@ -0,0 +1,186 @@ +#pragma once + +#include "infinicore/tensor.hpp" +#include +#include +#include +#include + +namespace infinilm::models::debug_utils { + +// TODO: move to InfiniCore as common utils in future work + +/** + * @brief Hook callback type for capturing intermediate values (DEBUG ONLY) + * + * Hook functions are called with: + * - name: Identifier for the intermediate value (e.g., "layer0_q_after_proj") + * - tensor: The intermediate tensor value + * - layer_idx: Layer index (for layer-specific hooks, -1 if not applicable) + * + * NOTE: This is a debug utility. Do not use in production code. + */ +using HookCallback = std::function; + +/** + * @brief Hook registry for managing hooks (DEBUG ONLY) + * + * NOTE: This is a debug utility for capturing intermediate tensor values + * during model execution. Do not use in production code. + */ +class HookRegistry { +public: + /** + * @brief Register a hook callback + * + * @param name Hook name (can be pattern like "layer0_*" or specific name) + * @param callback Hook callback function + */ + void register_hook(const std::string &name, HookCallback callback); + + /** + * @brief Call hook if registered + * + * @param name Full hook name + * @param tensor Tensor to pass to hook + * @param layer_idx Layer index (-1 if not applicable) + */ + void call_hook(const std::string &name, const infinicore::Tensor &tensor, int layer_idx = -1) const; + + /** + * @brief Clear all hooks + */ + void clear(); + + /** + * @brief Check if any hooks are registered + */ + bool has_hooks() const { return !hooks_.empty(); } + +private: + std::unordered_map hooks_; +}; + +/** + * @brief Macro to simplify hook registration (DEBUG ONLY) + * + * Usage: REGISTER_HOOK(registry, "hook_name", callback) + */ +#define REGISTER_HOOK(registry, name, callback) \ + (registry)->register_hook(name, callback) + +/** + * @brief Macro to simplify hook calls with automatic null and has_hooks checks (DEBUG ONLY) + * + * Usage: CALL_HOOK(registry, "hook_name", tensor) + * Note: layer_idx defaults to -1 + */ +#define CALL_HOOK(registry, name, tensor) \ + do { \ + if ((registry) && (registry)->has_hooks()) { \ + (registry)->call_hook(name, tensor, -1); \ + } \ + } while (0) + +/** + * @brief Macro to simplify hook calls with explicit layer index (DEBUG ONLY) + * + * Usage: CALL_HOOK_LAYER(registry, "hook_name", tensor, layer_idx) + */ +#define CALL_HOOK_LAYER(registry, name, tensor, layer_idx) \ + do { \ + if ((registry) && (registry)->has_hooks()) { \ + (registry)->call_hook(name, tensor, layer_idx); \ + } \ + } while (0) + +/** + * @brief Macros to simplify hook_registry and hook_prefix management in model classes + */ + +// Declare hook_registry and hook_prefix member variables +#define HOOK_REGISTRY_MEMBER() \ + std::shared_ptr hook_registry_; \ + std::string hook_prefix_; + +// Set hook_registry and hook_prefix (no forwarding to submodules) +#define SET_HOOK_REGISTRY_SIMPLE() \ + void set_hook_registry(const std::shared_ptr &hook_registry, const std::string &hook_prefix = "") { \ + hook_registry_ = hook_registry; \ + hook_prefix_ = hook_prefix; \ + } + +// Helper macro to build incremental hook prefix +#define BUILD_HOOK_PREFIX(prefix, name) \ + (prefix.empty() ? std::string(name) : prefix + "_" + std::string(name)) + +// Set hook_registry and hook_prefix and forward to one or more submodules +// Usage: SET_HOOK_REGISTRY(submodule1) or SET_HOOK_REGISTRY(submodule1, submodule2) +// The hook_prefix will be incremented for each submodule (e.g., "layer0" -> "layer0_attention") +// Note: Currently supports up to 2 submodules. For more, extend the pattern below. +#define SET_HOOK_REGISTRY(...) \ + SET_HOOK_REGISTRY_IMPL(__VA_ARGS__) + +// Helper to handle variable number of arguments using a reliable pattern +#define SET_HOOK_REGISTRY_IMPL(...) \ + SET_HOOK_REGISTRY_GET_NTH(__VA_ARGS__, SET_HOOK_REGISTRY_2, SET_HOOK_REGISTRY_1, SET_HOOK_REGISTRY_0,)(__VA_ARGS__) + +// Get the selector based on argument count +// Pattern: when we have N args, the (N+1)th parameter from the end is the selector +// For 0 args: _1=SET_HOOK_REGISTRY_2, _2=SET_HOOK_REGISTRY_1, _3=SET_HOOK_REGISTRY_0, N=(empty) → need to use _3 +// For 1 arg: _1=arg, _2=SET_HOOK_REGISTRY_2, _3=SET_HOOK_REGISTRY_1, N=SET_HOOK_REGISTRY_0 → wrong, need _3 +// For 2 args: _1=arg1, _2=arg2, _3=SET_HOOK_REGISTRY_2, N=SET_HOOK_REGISTRY_1 → wrong, need _3 + +// Use _3 as the selector (it's in the right position for all cases) +#define SET_HOOK_REGISTRY_GET_NTH(_1, _2, _3, N, ...) _3 + +// Implementation for 0 args (shouldn't be used, but handle gracefully) +#define SET_HOOK_REGISTRY_0() \ + void set_hook_registry(const std::shared_ptr &hook_registry, const std::string &hook_prefix = "") { \ + hook_registry_ = hook_registry; \ + hook_prefix_ = hook_prefix; \ + } + +// Implementation for 1 arg +#define SET_HOOK_REGISTRY_1(submodule) \ + void set_hook_registry(const std::shared_ptr &hook_registry, const std::string &hook_prefix = "") { \ + hook_registry_ = hook_registry; \ + hook_prefix_ = hook_prefix; \ + if (submodule##_) { \ + std::string submodule_prefix = BUILD_HOOK_PREFIX(hook_prefix, #submodule); \ + submodule##_->set_hook_registry(hook_registry, submodule_prefix); \ + } \ + } + +// Implementation for 2 args +#define SET_HOOK_REGISTRY_2(submodule1, submodule2) \ + void set_hook_registry(const std::shared_ptr &hook_registry, const std::string &hook_prefix = "") { \ + hook_registry_ = hook_registry; \ + hook_prefix_ = hook_prefix; \ + if (submodule1##_) { \ + std::string submodule1_prefix = BUILD_HOOK_PREFIX(hook_prefix, #submodule1); \ + submodule1##_->set_hook_registry(hook_registry, submodule1_prefix); \ + } \ + if (submodule2##_) { \ + std::string submodule2_prefix = BUILD_HOOK_PREFIX(hook_prefix, #submodule2); \ + submodule2##_->set_hook_registry(hook_registry, submodule2_prefix); \ + } \ + } + +// Set hook_registry and hook_prefix for a vector of submodules +// For vectors, the prefix is incremented with an index (e.g., "layer0", "layer1", ...) +// If parent has a prefix, it becomes "parent_layer0", "parent_layer1", etc. +#define SET_HOOK_REGISTRY_VEC(vec_name) \ + void set_hook_registry(const std::shared_ptr &hook_registry, const std::string &hook_prefix = "") { \ + hook_registry_ = hook_registry; \ + hook_prefix_ = hook_prefix; \ + for (size_t i = 0; i < vec_name##_.size(); ++i) { \ + if (vec_name##_[i]) { \ + std::string layer_name = "layer" + std::to_string(i); \ + std::string item_prefix = BUILD_HOOK_PREFIX(hook_prefix, layer_name); \ + vec_name##_[i]->set_hook_registry(hook_registry, item_prefix); \ + } \ + } \ + } + +} // namespace infinilm::models::debug_utils diff --git a/csrc/models/debug_utils/tensor_utils.hpp b/csrc/models/debug_utils/tensor_utils.hpp new file mode 100644 index 00000000..3cd0e219 --- /dev/null +++ b/csrc/models/debug_utils/tensor_utils.hpp @@ -0,0 +1,183 @@ +#pragma once + +#include "infinicore/tensor.hpp" +#include +#include +#include +#include +#include + +namespace infinilm::models::debug_utils { + +// Helper function to log tensor statistics and sample values +// This is useful for debugging intermediate values in model forward passes +// NOTE: This is a debug utility. Do not use in production code. +inline void log_tensor_stats(const infinicore::Tensor &tensor, const std::string &name, + bool log_samples = true, size_t max_samples = 10) { + auto shape = tensor->shape(); + auto dtype = tensor->dtype(); + auto device = tensor->device(); + + // Log basic info + std::string shape_str = "["; + for (size_t i = 0; i < shape.size(); ++i) { + if (i > 0) shape_str += ", "; + shape_str += std::to_string(shape[i]); + } + shape_str += "]"; + + SPDLOG_INFO(" {}: shape={}, dtype={}, device={}", name, shape_str, static_cast(dtype), device.toString()); + + // For F32, F16, and BF16 tensors, compute and log statistics + if (dtype == infinicore::DataType::F32 || + dtype == infinicore::DataType::F16 || + dtype == infinicore::DataType::BF16) { + // Copy to CPU if needed and compute stats + auto cpu_tensor = tensor->to(infinicore::Device(infinicore::Device::Type::CPU, 0)); + std::byte *raw_data = cpu_tensor->data(); + size_t numel = cpu_tensor->numel(); + + if (numel > 0) { + if (dtype == infinicore::DataType::F32) { + float *data = reinterpret_cast(raw_data); + float min_val = *std::min_element(data, data + numel); + float max_val = *std::max_element(data, data + numel); + float sum = std::accumulate(data, data + numel, 0.0f); + float mean_val = sum / static_cast(numel); + + SPDLOG_INFO(" Stats: min={:.6e}, max={:.6e}, mean={:.6e}, numel={}", + min_val, max_val, mean_val, numel); + + // Log sample values at specific positions + if (log_samples && numel > 0) { + size_t sample_count = std::min(max_samples, numel); + SPDLOG_INFO(" Sample values (first {}):", sample_count); + for (size_t i = 0; i < sample_count; ++i) { + SPDLOG_INFO(" [{}] = {:.6e}", i, data[i]); + } + } + } else if (dtype == infinicore::DataType::F16) { + // F16 is typically uint16_t, need to convert to float for logging + uint16_t *data = reinterpret_cast(raw_data); + std::vector float_data(numel); + for (size_t i = 0; i < numel; ++i) { + // Simple F16 to F32 conversion (approximate) + uint16_t h = data[i]; + uint32_t sign = (h >> 15) & 0x1; + uint32_t exp = (h >> 10) & 0x1F; + uint32_t mant = h & 0x3FF; + uint32_t f32 = (sign << 31) | ((exp + 112) << 23) | (mant << 13); + float_data[i] = *reinterpret_cast(&f32); + } + float min_val = *std::min_element(float_data.begin(), float_data.end()); + float max_val = *std::max_element(float_data.begin(), float_data.end()); + float sum = std::accumulate(float_data.begin(), float_data.end(), 0.0f); + float mean_val = sum / static_cast(numel); + + SPDLOG_INFO(" Stats (F16): min={:.6e}, max={:.6e}, mean={:.6e}, numel={}", + min_val, max_val, mean_val, numel); + + if (log_samples && numel > 0) { + size_t sample_count = std::min(max_samples, numel); + SPDLOG_INFO(" Sample values (first {}):", sample_count); + for (size_t i = 0; i < sample_count; ++i) { + SPDLOG_INFO(" [{}] = {:.6e}", i, float_data[i]); + } + } + } else if (dtype == infinicore::DataType::BF16) { + // BF16 is typically uint16_t, need to convert to float for logging + uint16_t *data = reinterpret_cast(raw_data); + std::vector float_data(numel); + for (size_t i = 0; i < numel; ++i) { + // BF16 to F32 conversion + uint16_t b = data[i]; + uint32_t f32 = (static_cast(b) << 16); + float_data[i] = *reinterpret_cast(&f32); + } + float min_val = *std::min_element(float_data.begin(), float_data.end()); + float max_val = *std::max_element(float_data.begin(), float_data.end()); + float sum = std::accumulate(float_data.begin(), float_data.end(), 0.0f); + float mean_val = sum / static_cast(numel); + + SPDLOG_INFO(" Stats (BF16): min={:.6e}, max={:.6e}, mean={:.6e}, numel={}", + min_val, max_val, mean_val, numel); + + if (log_samples && numel > 0) { + size_t sample_count = std::min(max_samples, numel); + SPDLOG_INFO(" Sample values (first {}):", sample_count); + for (size_t i = 0; i < sample_count; ++i) { + SPDLOG_INFO(" [{}] = {:.6e}", i, float_data[i]); + } + + // Also log last N values to see newly appended decode tokens + // This is critical for debugging precision issues at decode steps + if (numel > sample_count) { + SPDLOG_INFO(" Sample values (last {}):", sample_count); + for (size_t i = numel - sample_count; i < numel; ++i) { + SPDLOG_INFO(" [{}] = {:.6e}", i, float_data[i]); + } + } + } + } + } + } else { + SPDLOG_INFO(" {} (Stats computation skipped for unsupported dtype)", name); + } +} + +// Helper function to log specific tensor positions (for debugging) +// NOTE: This is a debug utility. Do not use in production code. +inline void log_tensor_positions(const infinicore::Tensor &tensor, const std::string &name, + const std::vector> &positions) { + auto shape = tensor->shape(); + auto dtype = tensor->dtype(); + + // Only log for F32 tensors (or copy to CPU) + if (dtype == infinicore::DataType::F32) { + auto cpu_tensor = tensor->to(infinicore::Device(infinicore::Device::Type::CPU, 0)); + std::byte *raw_data = cpu_tensor->data(); + float *data = reinterpret_cast(raw_data); + + SPDLOG_INFO(" {}: Logging specific positions:", name); + for (const auto &pos : positions) { + if (pos.size() != shape.size()) { + SPDLOG_INFO(" Position {}: dimension mismatch (expected {} dims, got {})", + pos.size(), shape.size()); + continue; + } + + // Calculate linear index + size_t idx = 0; + size_t stride = 1; + bool valid = true; + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + if (pos[i] >= shape[i]) { + valid = false; + break; + } + idx += pos[i] * stride; + stride *= shape[i]; + } + + if (valid && idx < cpu_tensor->numel()) { + std::string pos_str = "["; + for (size_t i = 0; i < pos.size(); ++i) { + if (i > 0) pos_str += ", "; + pos_str += std::to_string(pos[i]); + } + pos_str += "]"; + SPDLOG_INFO(" Position {}: value = {:.6e}", pos_str, data[idx]); + } else { + std::string pos_str = "["; + for (size_t i = 0; i < pos.size(); ++i) { + if (i > 0) pos_str += ", "; + pos_str += std::to_string(pos[i]); + } + pos_str += "]"; + SPDLOG_INFO(" Position {}: invalid (out of bounds)", pos_str); + } + } + } +} + +} // namespace infinilm::models::debug_utils diff --git a/csrc/models/infinilm_model.hpp b/csrc/models/infinilm_model.hpp new file mode 100644 index 00000000..95f5e712 --- /dev/null +++ b/csrc/models/infinilm_model.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "infinicore/nn/module.hpp" + +#include "../cache/cache.hpp" + +#include + +namespace infinilm { +class InfinilmModel : public infinicore::nn::Module { +public: + virtual ~InfinilmModel() = default; + virtual infinicore::Tensor forward(std::vector) const = 0; + // Optional: reset cache; default no-op for models without cache + virtual void reset_cache(size_t pos = 0) {} + virtual void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0) = 0; +}; +} // namespace infinilm diff --git a/csrc/models/llama/llama.hpp b/csrc/models/llama/llama.hpp new file mode 100644 index 00000000..fe554c32 --- /dev/null +++ b/csrc/models/llama/llama.hpp @@ -0,0 +1,24 @@ +#pragma once + +/** + * @file llama.hpp + * @brief Main header file for Llama model architecture + * + * This header includes all components of the Llama model architecture + * built using InfiniCore::nn::Module pattern. + * + * Components: + * - LlamaConfig: Model configuration structure + * - LlamaAttention: Multi-head self-attention module + * - LlamaMLP: Feed-forward network module + * - LlamaDecoderLayer: Single transformer decoder layer + * - LlamaModel: Core transformer model (without LM head) + * - LlamaForCausalLM: Complete model with language modeling head + */ + +#include "llama_config.hpp" +#include "llama_attention.hpp" +#include "llama_mlp.hpp" +#include "llama_decoder_layer.hpp" +#include "llama_model.hpp" +#include "llama_for_causal_lm.hpp" diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp new file mode 100644 index 00000000..8d951f3f --- /dev/null +++ b/csrc/models/llama/llama_attention.cpp @@ -0,0 +1,143 @@ +#include "llama_attention.hpp" +#include "infinicore/nn/linear.hpp" +#include "infinicore/nn/rope.hpp" +#include "infinicore/ops.hpp" +#include "infinicore/ops/mul.hpp" +#include +#include +#include +#include +#include +#include +#include + +namespace infinilm::models::llama { + +LlamaAttention::LlamaAttention(const LlamaConfig &config, + const infinicore::Device &device, + size_t layer_idx, + infinicore::DataType dtype, + engine::distributed::RankInfo rank_info) + : layer_idx_(layer_idx), + hidden_size_(config.hidden_size), + num_attention_heads_(config.num_attention_heads), + num_key_value_heads_(config.num_key_value_heads), + head_dim_(config.head_dim), + kv_dim_(config.kv_dim()), + use_bias_(config.attention_bias), + use_output_bias_(config.attention_output_bias), + max_position_embeddings_(config.max_position_embeddings), rank_info_(rank_info) { + + int tp_rank = rank_info.tp_rank; + int tp_size = rank_info.tp_size; + + int num_attention_heads = config.num_attention_heads; + int num_key_value_heads = config.num_key_value_heads; + + if ((num_key_value_heads >= tp_size) && (0 == (num_key_value_heads % tp_size))) { + this->num_attention_heads_ = num_attention_heads / tp_size; + this->num_key_value_heads_ = num_key_value_heads / tp_size; + } else { + throw std::runtime_error("num_attention_heads / tp_size error."); + } + + // Initialize projection layers + INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, config.num_attention_heads, config.num_key_value_heads, use_bias_, + dtype, device, rank_info); + // Output projection uses attention_output_bias (can be different from qkv) + INFINICORE_NN_MODULE_INIT(o_proj, hidden_size_, hidden_size_, use_output_bias_, + dtype, device, tp_rank, tp_size, rank_info.comm); +} + +infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + void *kv_cache) const { + if (!rotary_emb_) { + throw std::runtime_error("LlamaAttention: rotary_emb not configured"); + } + // Input shape: [batch, seq_len, hidden_size] + auto hidden_states_mutable = hidden_states; + auto shape = hidden_states->shape(); + size_t batch_size = shape[0]; + size_t seq_len = shape[1]; + + // 1. Project Q, K, V + auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + + // 2. Reshape for multi-head attention + + // Reshape Q, K, V to include batch dimension + // Python: query_states = self.q_proj(hidden_states).view(querys_shape) + // The view operation requires the tensor to be contiguous in the required dimensions + auto q_reshaped = q->view({batch_size, seq_len, num_attention_heads_, head_dim_}); + auto k_reshaped = k->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); + auto v_reshaped = v->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); + + // 3. Prepare position_ids for RoPE - align with Python pattern + // Python: bs, num = pos_ids.shape; pos_ids = pos_ids.view((bs * num,)) + auto pos_shape = position_ids->shape(); + infinicore::Tensor pos_ids_for_rope = position_ids; + if (pos_shape.size() == 2) { + auto pos_narrowed = position_ids->narrow({{0, 0, 1}}); + pos_ids_for_rope = pos_narrowed->contiguous()->view({pos_shape[1]}); + } else if (pos_shape.size() == 1) { + pos_ids_for_rope = position_ids->contiguous(); + } else { + throw std::runtime_error("Unexpected position_ids shape"); + } + + // 4. Apply RoPE to Q and K + auto q_rope = infinicore::Tensor::empty({batch_size, num_attention_heads_, seq_len, head_dim_}, q_reshaped->dtype(), q_reshaped->device())->permute({0, 2, 1, 3}); + rotary_emb_->forward(q_rope, q_reshaped, pos_ids_for_rope); // [bs, seq_len, n_q_head, head_dim] + rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim] + + // 5. Prepare KV caches + // Convert to [batch, n_head, seq_len, head_dim] for cache + // Ensure contiguous after permute for F16 compatibility with cache operations + q_reshaped = q_rope->permute({0, 2, 1, 3}); // [bs, n_q_head, seq_len, head_dim] + auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] + auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] + infinilm::cache::DynamicCache *external_cache = static_cast(kv_cache); + infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim] + infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim] + if (external_cache != nullptr) { + auto [k_total_tmp, v_total_tmp] = external_cache->update(layer_idx_, k_permuted, v_permuted); + k_total = k_total_tmp; + v_total = v_total_tmp; + } else { + // No external cache - this shouldn't happen in normal operation, but handle gracefully + throw std::runtime_error("LlamaAttention: kv_cache is required but nullptr provided"); + } + auto total_seq_len = k_total->shape()[2]; + + // 6. Compute attention + size_t ngroup = num_attention_heads_ / num_key_value_heads_; + auto Q = q_reshaped->view({batch_size * num_key_value_heads_, ngroup * seq_len, head_dim_}); + auto K = k_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_}); + auto V = v_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_}); + + auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len] + + float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); + auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling); // [bs * n_kv_head, ng * seq_len, total_seq_len] + + auto attn_weight_softmax = attn_weight->view({batch_size * num_attention_heads_, seq_len, total_seq_len}); + infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax); + + auto out = infinicore::op::matmul(attn_weight, V); // [bs * n_kv_head, ng * seq_len, head_dim] + + auto attn_output = out->view({batch_size, num_attention_heads_, seq_len, head_dim_}) + ->permute({0, 2, 1, 3}) + ->contiguous() + ->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim] + + auto output = o_proj_->forward(attn_output); + + return output; +} + +void LlamaAttention::set_rotary_emb(const std::shared_ptr &rotary_emb) { + rotary_emb_ = rotary_emb; +} + +} // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_attention.hpp b/csrc/models/llama/llama_attention.hpp new file mode 100644 index 00000000..bdc98d23 --- /dev/null +++ b/csrc/models/llama/llama_attention.hpp @@ -0,0 +1,94 @@ +#pragma once + +#include "../../cache/kv_cache.hpp" +#include "../../engine/distributed/distributed.hpp" +#include "../../layers/fused_linear.hpp" +#include "llama_config.hpp" + +#include "infinicore/nn/linear.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/nn/rope.hpp" +#include "infinicore/tensor.hpp" +#include "llama_config.hpp" +#include +#include +#include + +namespace infinilm::models::llama { + +/** + * @brief Multi-head self-attention module for Llama + * + * Implements the attention mechanism with: + * - Query, Key, Value projections + * - Output projection + * - Rotary Position Embeddings (RoPE) applied to Q and K + * - Support for Grouped Query Attention (GQA) + */ +class LlamaAttention : public infinicore::nn::Module { +public: + /** + * @brief Construct LlamaAttention module + * + * @param config Model configuration + * @param device Device to create tensors on + * @param layer_idx Layer index for cache access + * @param dtype Optional data type for model parameters (defaults to F32) + */ + LlamaAttention(const LlamaConfig &config, + const infinicore::Device &device, + size_t layer_idx, + infinicore::DataType dtype = infinicore::DataType::F32, + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + + /** + * @brief Forward pass: compute attention + * + * @param hidden_states Input tensor of shape [batch, seq_len, hidden_size] + * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len] + * @param kv_cache Optional model-level KV cache for incremental decoding + * @return Output tensor of shape [batch, seq_len, hidden_size] + */ + infinicore::Tensor forward(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + void *kv_cache = nullptr) const; + + /** + * @brief Get the layer index + */ + size_t layer_idx() const { return layer_idx_; } + + /** + * @brief Provide shared RoPE module from parent model. + */ + void set_rotary_emb(const std::shared_ptr &rotary_emb); + + // Module information + size_t num_heads() const { return num_attention_heads_; } + size_t num_kv_heads() const { return num_key_value_heads_; } + size_t head_dim() const { return head_dim_; } + size_t hidden_size() const { return hidden_size_; } + +protected: + // Projection layers + INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj); + INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, o_proj); + + engine::distributed::RankInfo rank_info_; + + // Shared Rotary Position Embeddings (RoPE) + std::shared_ptr rotary_emb_; + +private: + size_t layer_idx_; // Layer index for cache access + size_t hidden_size_; + size_t num_attention_heads_; + size_t num_key_value_heads_; + size_t head_dim_; + size_t kv_dim_; + bool use_bias_; // Bias for Q/K/V projections + bool use_output_bias_; // Bias for output projection (o_proj) + size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility) +}; + +} // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_config.hpp b/csrc/models/llama/llama_config.hpp new file mode 100644 index 00000000..9fb14ce8 --- /dev/null +++ b/csrc/models/llama/llama_config.hpp @@ -0,0 +1,85 @@ +#pragma once + +#include +#include +#include +#include + +namespace infinilm::models::llama { + +/** + * @brief Configuration structure for Llama model architecture + * + * This struct holds all hyperparameters needed to construct a Llama model. + * It follows the same structure as HuggingFace's LlamaConfig. + */ +struct LlamaConfig { + // Vocabulary and embedding + size_t vocab_size = 32000; // Vocabulary size + size_t hidden_size = 4096; // Hidden dimension size + size_t intermediate_size = 11008; // MLP intermediate dimension + + // Architecture + size_t num_hidden_layers = 32; // Number of decoder layers + size_t num_attention_heads = 32; // Number of attention heads + size_t num_key_value_heads = 32; // Number of key-value heads (for GQA) + size_t head_dim = 128; // Attention head dimension (hidden_size / num_attention_heads) + + // Position embeddings + size_t max_position_embeddings = 2048; // Maximum sequence length + double rope_theta = 10000.0; // RoPE base frequency + + // Normalization + double rms_norm_eps = 1e-6; // RMSNorm epsilon + + // Activation + std::string hidden_act = "silu"; // Activation function (typically "silu") + std::string model_type = "llama"; // Model type identifier (matches HF configs) + + // Optional features + bool use_cache = true; // Whether to use KV cache + bool attention_bias = true; // Whether to use bias in Q/K/V projections (default true for 9G7B compatibility) + bool attention_output_bias = false; // Whether to use bias in output projection (o_proj) + bool mlp_bias = false; // Whether to use bias in MLP projections + bool tie_word_embeddings = false; // Whether to tie input/output embeddings + + // Training/initialization parameters + double attention_dropout = 0.0; // Dropout ratio for attention probabilities + double initializer_range = 0.02; // Standard deviation for weight initialization + size_t pretraining_tp = 1; // Tensor parallelism rank used during pretraining + + // Model metadata + std::string name_or_path = ""; // Model name or path identifier + + // Token IDs + int64_t pad_token_id = -1; // Padding token ID (optional) + std::vector bos_token_id = {1}; // Beginning of sequence token ID(s) + std::vector eos_token_id = {2}; // End of sequence token ID(s) + + /** + * @brief Compute key-value dimension for Grouped Query Attention (GQA) + * @return The dimension for key/value projections + */ + size_t kv_dim() const { + return hidden_size * num_key_value_heads / num_attention_heads; + } + + /** + * @brief Validate configuration parameters + * @return true if configuration is valid + */ + bool validate() const { + if (hidden_size % num_attention_heads != 0) { + return false; + } + if (num_attention_heads % num_key_value_heads != 0) { + return false; + } + if (head_dim != hidden_size / num_attention_heads) { + return false; + } + return true; + } +}; + +} // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_decoder_layer.cpp b/csrc/models/llama/llama_decoder_layer.cpp new file mode 100644 index 00000000..f8c900c9 --- /dev/null +++ b/csrc/models/llama/llama_decoder_layer.cpp @@ -0,0 +1,52 @@ +#include "llama_decoder_layer.hpp" +#include "infinicore/nn/rmsnorm.hpp" +#include "infinicore/ops.hpp" + +namespace infinilm::models::llama { + +LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, + const infinicore::Device &device, + size_t layer_idx, + infinicore::DataType dtype, + engine::distributed::RankInfo rank_info) : layer_idx_(layer_idx) , rank_info_(rank_info){ + // Initialize layer normalization layers + INFINICORE_NN_MODULE_INIT(input_layernorm, config.hidden_size, config.rms_norm_eps, + dtype, device); + INFINICORE_NN_MODULE_INIT(post_attention_layernorm, config.hidden_size, config.rms_norm_eps, + dtype, device); + + // Initialize attention and MLP modules + INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, dtype, rank_info_); + INFINICORE_NN_MODULE_INIT(mlp, config, device, dtype, rank_info_); +} + +infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + void *kv_cache) const { + // Save residual for attention + auto residual = hidden_states; + + // 1. Pre-attention layer normalization + auto normed_states = input_layernorm_->forward(hidden_states); + + // 2. Self-attention with residual connection + auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache); + + // Add residual: hidden_states = hidden_states + attn_output + auto output = infinicore::op::add(residual, attn_output); + // Save residual for MLP + residual = output; + + // 3. Post-attention layer normalization + normed_states = post_attention_layernorm_->forward(output); + + // 4. MLP with residual connection + auto mlp_output = mlp_->forward(normed_states); + + // Add residual: output = output + mlp_output + output = infinicore::op::add(residual, mlp_output); + + return output; +} + +} // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_decoder_layer.hpp b/csrc/models/llama/llama_decoder_layer.hpp new file mode 100644 index 00000000..d6b25f9c --- /dev/null +++ b/csrc/models/llama/llama_decoder_layer.hpp @@ -0,0 +1,79 @@ +#pragma once + +#include "infinicore/device.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/nn/rmsnorm.hpp" +#include "infinicore/tensor.hpp" +#include "llama_attention.hpp" +#include "llama_config.hpp" +#include "llama_mlp.hpp" + +#include "../../engine/distributed/distributed.hpp" + +namespace infinilm::models::llama { + +/** + * @brief Single decoder layer (transformer block) for Llama + * + * Each decoder layer consists of: + * - Input layer normalization (RMSNorm) + * - Self-attention mechanism + * - Post-attention layer normalization (RMSNorm) + * - MLP feed-forward network + * + * Residual connections are applied around both attention and MLP blocks. + */ +class LlamaDecoderLayer : public infinicore::nn::Module { +public: + /** + * @brief Construct LlamaDecoderLayer module + * + * @param config Model configuration + * @param device Device to create tensors on + * @param layer_idx Layer index for cache management and debugging + * @param dtype Optional data type for model parameters (defaults to F32) + */ + LlamaDecoderLayer(const LlamaConfig &config, + const infinicore::Device &device, + size_t layer_idx, + infinicore::DataType dtype = infinicore::DataType::F32, + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + + /** + * @brief Forward pass: process one decoder layer + * + * @param hidden_states Input tensor of shape [batch, seq_len, hidden_size] + * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len] + * @param kv_cache Optional KV cache for incremental decoding + * @return Output tensor of shape [batch, seq_len, hidden_size] + */ + infinicore::Tensor forward(const infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + void *kv_cache = nullptr) const; + + /** + * @brief Get the layer index + */ + size_t layer_idx() const { return layer_idx_; } + + void set_rotary_emb(const std::shared_ptr &rotary_emb) { + if (self_attn_) { + self_attn_->set_rotary_emb(rotary_emb); + } + } + +protected: + // Layer normalization + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_attention_layernorm); + + // Attention and MLP + INFINICORE_NN_MODULE(LlamaAttention, self_attn); + INFINICORE_NN_MODULE(LlamaMLP, mlp); + engine::distributed::RankInfo rank_info_; + +private: + size_t layer_idx_; // Layer index for cache management and debugging +}; + +} // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_for_causal_lm.cpp b/csrc/models/llama/llama_for_causal_lm.cpp new file mode 100644 index 00000000..95a3bdf1 --- /dev/null +++ b/csrc/models/llama/llama_for_causal_lm.cpp @@ -0,0 +1,73 @@ +#include "llama_for_causal_lm.hpp" +#include "infinicore/context/context.hpp" +#include "infinicore/nn/linear.hpp" +#include "infinicore/ops.hpp" +#include + +namespace infinilm::models::llama { + +LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, + const infinicore::Device &device, + infinicore::DataType dtype, + engine::distributed::RankInfo rank_info) { + + // Initialize module's device_ member + device_ = device; + + // Initialize base model + INFINICORE_NN_MODULE_INIT(model, config, device, dtype, rank_info); + + // Initialize language modeling head + // Note: If tie_word_embeddings is true, we would share weights with embed_tokens + // For now, we create a separate linear layer + INFINICORE_NN_MODULE_INIT(lm_head, config.hidden_size, config.vocab_size, false, + dtype, device); +} + +infinicore::Tensor LlamaForCausalLM::forward(const infinicore::Tensor &input_ids, + const infinicore::Tensor &position_ids, + void *kv_cache) const { + // 1. Forward through base model to get hidden states + auto position_ids_device = position_ids->to(device_); + auto hidden_states = model_->forward(input_ids, position_ids_device, kv_cache); + + // 2. Apply language modeling head to get logits + auto logits = lm_head_->forward(hidden_states); + + // 3. CRITICAL: Synchronize the C++ backend's context after forward pass + // This ensures all C++ backend operations complete before returning to Python + if (device_.getType() != infinicore::Device::Type::CPU) { + infinicore::context::setDevice(device_, false); + infinicore::context::syncStream(); + } + + return logits; +} + +infinicore::Tensor LlamaForCausalLM::forward(std::vector args) const { + if (args.size() < 2) { + throw std::invalid_argument("LlamaForCausalLM::forward requires at least 2 arguments: input_ids and position_ids"); + } + + // Extract input tensors from args + const auto &input_ids = std::any_cast(args[0]); + const auto &position_ids = std::any_cast(args[1]); + + // Optional KV caches + std::vector *kv_caches = nullptr; + if (args.size() >= 3) { + kv_caches = std::any_cast *>(args[2]); + } + + return forward(input_ids, position_ids, kv_caches); +} + +void LlamaForCausalLM::reset_cache(size_t pos) { + model_->reset_cache(pos); +} + +void LlamaForCausalLM::reset_cache(const cache::CacheConfig &new_config, size_t pos) { + model_->reset_cache(new_config, pos); +} + +} // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_for_causal_lm.hpp b/csrc/models/llama/llama_for_causal_lm.hpp new file mode 100644 index 00000000..c78811b5 --- /dev/null +++ b/csrc/models/llama/llama_for_causal_lm.hpp @@ -0,0 +1,68 @@ +#pragma once + +#include "../infinilm_model.hpp" +#include "llama_model.hpp" + +#include "infinicore/device.hpp" +#include "infinicore/nn/linear.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/tensor.hpp" + +#include "../../engine/distributed/distributed.hpp" + +namespace infinilm::models::llama { + +/** + * @brief Llama model for Causal Language Modeling + * + * Extends LlamaModel by adding a language modeling head (lm_head) that + * projects hidden states to vocabulary logits. + * + * This matches the structure of HuggingFace's LlamaForCausalLM. + */ +class LlamaForCausalLM : public InfinilmModel { +public: + /** + * @brief Construct LlamaForCausalLM module + * + * @param config Model configuration + * @param device Device to create tensors on + * @param dtype Optional data type for model parameters (defaults to BF16) + */ + LlamaForCausalLM(const LlamaConfig &config, + const infinicore::Device &device, + infinicore::DataType dtype = infinicore::DataType::BF16, + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + + /** + * @brief Forward pass: compute language modeling logits + * + * @param input_ids Token IDs tensor of shape [batch, seq_len] + * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len] + * @param kv_cache Optional model-level KV cache for incremental decoding + * @return Logits tensor of shape [batch, seq_len, vocab_size] + */ + infinicore::Tensor forward(const infinicore::Tensor &input_ids, + const infinicore::Tensor &position_ids, + void *kv_cache = nullptr) const; + + infinicore::Tensor forward(std::vector args) const override; + + // Reset internal cache position + void reset_cache(size_t pos = 0) override; + void reset_cache(const cache::CacheConfig &new_config, size_t pos) override; + + // Module information + const LlamaConfig &config() const { return model_->config(); } + LlamaModel &model() { return *model_; } + const LlamaModel &model() const { return *model_; } + +protected: + // Base model + INFINICORE_NN_MODULE(LlamaModel, model); + + // Language modeling head + INFINICORE_NN_MODULE(infinicore::nn::Linear, lm_head); +}; + +} // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_mlp.cpp b/csrc/models/llama/llama_mlp.cpp new file mode 100644 index 00000000..1a9a3619 --- /dev/null +++ b/csrc/models/llama/llama_mlp.cpp @@ -0,0 +1,41 @@ +#include "llama_mlp.hpp" +#include "infinicore/nn/linear.hpp" +#include "infinicore/ops.hpp" + +namespace infinilm::models::llama { + +LlamaMLP::LlamaMLP(const LlamaConfig &config, + const infinicore::Device &device, + infinicore::DataType dtype, + engine::distributed::RankInfo rank_info) + : hidden_size_(config.hidden_size), + intermediate_size_(config.intermediate_size), + use_bias_(config.mlp_bias), rank_info_(rank_info) { + + int tp_rank = rank_info.tp_rank; + int tp_size = rank_info.tp_size; + + // Initialize projection layers + INFINILM_GATE_UP_LINEAR_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, use_bias_, + dtype, device, rank_info_); + INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, use_bias_, + dtype, device, tp_rank, tp_size, rank_info.comm); +} + +infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const { + // 1. Project to gate and up + auto hidden_states_mutable = hidden_states; + auto [gate, up] = gate_up_proj_->forward_split(hidden_states_mutable); + + // 2. Apply SwiGLU: silu(gate) * up + // Note: swiglu kernel expects (up, gate) and computes gate * sigmoid(gate) * up + // So we pass (up, gate) to get the correct result: gate * sigmoid(gate) * up + auto intermediate = infinicore::op::swiglu(up, gate); + + // 3. Project down + auto output = down_proj_->forward(intermediate); + + return output; +} + +} // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_mlp.hpp b/csrc/models/llama/llama_mlp.hpp new file mode 100644 index 00000000..79eff520 --- /dev/null +++ b/csrc/models/llama/llama_mlp.hpp @@ -0,0 +1,63 @@ +#pragma once + +#include "../../layers/fused_linear.hpp" +#include "llama_config.hpp" + +#include "infinicore/device.hpp" +#include "infinicore/nn/linear.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/tensor.hpp" +#include "llama_config.hpp" + +#include "../../engine/distributed/distributed.hpp" + +namespace infinilm::models::llama { + +/** + * @brief MLP (Feed-Forward Network) module for Llama + * + * Implements the MLP block with: + * - Gate projection + * - Up projection + * - Down projection + * - SiLU activation function + * + * Formula: down_proj(SiLU(gate_proj(x)) * up_proj(x)) + */ +class LlamaMLP : public infinicore::nn::Module { +public: + /** + * @brief Construct LlamaMLP module + * + * @param config Model configuration + * @param device Device to create tensors on + * @param dtype Optional data type for model parameters (defaults to F32) + */ + LlamaMLP(const LlamaConfig &config, + const infinicore::Device &device, + infinicore::DataType dtype = infinicore::DataType::F32, + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + + /** + * @brief Forward pass: compute MLP output + * + * @param hidden_states Input tensor of shape [batch, seq_len, hidden_size] + * @return Output tensor of shape [batch, seq_len, hidden_size] + */ + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + + // Module information + size_t hidden_size() const { return hidden_size_; } + size_t intermediate_size() const { return intermediate_size_; } + +protected: + INFINICORE_NN_MODULE(layers::GateUpParallelLinear, gate_up_proj); + INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, down_proj); + + engine::distributed::RankInfo rank_info_; + size_t hidden_size_; + size_t intermediate_size_; + bool use_bias_; +}; + +} // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_model.cpp b/csrc/models/llama/llama_model.cpp new file mode 100644 index 00000000..c0dc020d --- /dev/null +++ b/csrc/models/llama/llama_model.cpp @@ -0,0 +1,117 @@ +#include "llama_model.hpp" +#include "infinicore/nn/embedding.hpp" +#include "infinicore/nn/rmsnorm.hpp" +#include "infinicore/nn/rope.hpp" +#include "infinicore/ops.hpp" +#include + +namespace infinilm::models::llama { + +LlamaModel::LlamaModel(const LlamaConfig &config, + const infinicore::Device &device, + infinicore::DataType dtype, + engine::distributed::RankInfo rank_info) + : config_(config) { + // Initialize token embeddings + INFINICORE_NN_MODULE_INIT(embed_tokens, config.vocab_size, config.hidden_size, + std::nullopt, dtype, device); + + // Initialize decoder layers with layer indices + // TODO: Update INFINICORE_NN_MODULE_VEC_INIT macro to support per-layer constructor arguments + // (e.g., via a factory function or lambda that receives the layer index) + // Currently, we can't use the macro because each layer needs a different layer_idx + layers_.reserve(config.num_hidden_layers); + for (size_t i = 0; i < config.num_hidden_layers; ++i) { + layers_.push_back(this->register_module( + "layers." + std::to_string(i), config, device, i, dtype, rank_info)); + } + + // Initialize final layer normalization + INFINICORE_NN_MODULE_INIT(norm, config.hidden_size, config.rms_norm_eps, + dtype, device); + + // Initialize Rotary Position Embeddings (shared across all layers) + // Use GPT-J-style inverse frequencies (default) and GPT_NEOX rotation pairing + INFINICORE_NN_MODULE_INIT(rotary_emb, config.head_dim, config.max_position_embeddings, + config.rope_theta, infinicore::nn::RoPE::Algo::GPT_NEOX, + dtype, device); + + for (auto &layer : layers_) { + if (layer) { + layer->set_rotary_emb(rotary_emb_); + } + } +} + +infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, + const infinicore::Tensor &position_ids, + void *kv_cache) const { + // Use persistent internal cache if no external cache is provided + // This matches Python backend behavior: if use_cache and past_key_values is None, create DynamicCache + // The cache persists across forward calls to enable incremental decoding + void *cache_to_use = kv_cache; + + if (cache_to_use == nullptr) { + // Create or reuse persistent internal cache at model level + // This ensures the cache persists across multiple forward calls (prefill -> decode -> decode...) + if (external_cache_ != nullptr) { + cache_to_use = external_cache_; + } else { + // Fall back to internal cache + if (!internal_cache_) { + internal_cache_ = std::make_unique( + config_.num_hidden_layers, + config_.max_position_embeddings); + } + cache_to_use = internal_cache_.get(); + } + } + + // 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size] + auto hidden_states = embed_tokens_->forward(input_ids); + + // 2. Process through all decoder layers + size_t num_layers = layers_.size(); + for (size_t i = 0; i < num_layers; ++i) { + // Pass model-level cache (layer index is now a property of the layer) + hidden_states = layers_.at(i)->forward(hidden_states, position_ids, cache_to_use); + + // DEBUG: Disabled previous final layer logging + // Logging moved to decoder layer for post-attention normalization + } + + // 3. Apply final layer normalization to last token only (aligns with transformers) + + // Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size] + auto shape = hidden_states->shape(); + size_t seq_len = shape[1]; + auto last_token = hidden_states->narrow({{1, seq_len - 1, 1}}); + + // DEBUG: Disabled previous final layer normalization logging + // Normalize only the last token (matches Python backend) + auto normalized_last_token = norm_->forward(last_token); + + return normalized_last_token; +} + +void LlamaModel::reset_cache(size_t pos) const { + if (internal_cache_) { + internal_cache_->reset(pos); + } + if (external_cache_) { + external_cache_->reset(pos); + } +} + +void LlamaModel::reset_cache(const cache::CacheConfig &new_config, size_t pos) const { + if (internal_cache_) { + internal_cache_->update_config(new_config); + internal_cache_->reset(pos); + } + if (external_cache_) { + external_cache_->update_config(new_config); + external_cache_->reset(pos); + } +} + +} // namespace infinilm::models::llama diff --git a/csrc/models/llama/llama_model.hpp b/csrc/models/llama/llama_model.hpp new file mode 100644 index 00000000..f02a9f7f --- /dev/null +++ b/csrc/models/llama/llama_model.hpp @@ -0,0 +1,108 @@ +#pragma once + +#include "../../cache/kv_cache.hpp" +#include "llama_config.hpp" +#include "llama_decoder_layer.hpp" + +#include "infinicore/nn/embedding.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/nn/rmsnorm.hpp" +#include "infinicore/nn/rope.hpp" +#include "infinicore/tensor.hpp" +#include "llama_config.hpp" +#include "llama_decoder_layer.hpp" +#include +#include + +#include "../../engine/distributed/distributed.hpp" + +namespace infinilm::models::llama { + +/** + * @brief Main Llama model architecture (without language modeling head) + * + * This is the core transformer model consisting of: + * - Token embeddings (embed_tokens) + * - Multiple decoder layers (layers) + * - Final layer normalization (norm) + * - Rotary Position Embeddings (rotary_emb) + * + * This matches the structure of HuggingFace's LlamaModel. + */ +class LlamaModel : public infinicore::nn::Module { +public: + /** + * @brief Construct LlamaModel module + * + * @param config Model configuration + * @param device Device to create tensors on + * @param dtype Optional data type for model parameters (defaults to F32) + */ + LlamaModel(const LlamaConfig &config, + const infinicore::Device &device, + infinicore::DataType dtype = infinicore::DataType::F32, + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + + /** + * @brief Forward pass: process input through the model + * + * @param input_ids Token IDs tensor of shape [batch, seq_len] + * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len] + * @param kv_cache Optional model-level KV cache for incremental decoding + * @return Output tensor of shape [batch, seq_len, hidden_size] + */ + infinicore::Tensor forward(const infinicore::Tensor &input_ids, + const infinicore::Tensor &position_ids, + void *kv_cache = nullptr) const; + + // Module information + const LlamaConfig &config() const { return config_; } + size_t num_layers() const { return config_.num_hidden_layers; } + + /** + * @brief Reset the internal cache to a specific position + * This should be called when starting a new generation sequence to prevent state + * from persisting between different questions/prompts + * @param pos Position to reset to (defaults to 0) + */ + void reset_cache(size_t pos = 0) const; + + /** + * @brief Reset the internal cache with a new configuration and position + * This should be called when changing cache parameters (e.g., initial capacity) + * @param new_config New cache configuration + * @param pos Position to reset to + */ + void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0) const; + + /** + * @brief Set external cache for the model + * @param cache Pointer to external cache (managed by CacheManager) + */ + void set_external_cache(std::shared_ptr cache) { + external_cache_ = cache.get(); + } + +protected: + // Token embeddings + INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens); + + // Decoder layers + INFINICORE_NN_MODULE_VEC(LlamaDecoderLayer, layers); + + // Final normalization + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, norm); + + // Rotary Position Embeddings (shared across all layers) + INFINICORE_NN_MODULE(infinicore::nn::RoPE, rotary_emb); + +private: + LlamaConfig config_; + // Persistent cache for when no external cache is provided + // Mutable because it's not part of the model's learned parameters, + // but needs to persist across forward calls for incremental decoding + mutable std::unique_ptr internal_cache_; + cache::DynamicCache *external_cache_ = nullptr; +}; + +} // namespace infinilm::models::llama diff --git a/csrc/models/model_factory.cpp b/csrc/models/model_factory.cpp new file mode 100644 index 00000000..f65315ac --- /dev/null +++ b/csrc/models/model_factory.cpp @@ -0,0 +1,24 @@ +#include "model_factory.hpp" +#include "llama/llama.hpp" + +namespace infinilm { +std::shared_ptr InfinilmModelFactory::createModel( + const std::any &config, + engine::distributed::RankInfo rank_info, + std::shared_ptr cache_ptr) { + + if (config.type() == typeid(models::llama::LlamaConfig)) { + const auto &llama_config = std::any_cast(config); + auto model = std::make_shared( + llama_config, rank_info.device, infinicore::DataType::BF16, rank_info); + + if (cache_ptr != nullptr) { + model->model().set_external_cache(cache_ptr); + } + + return model; + } else { + throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type"); + } +} +} // namespace infinilm diff --git a/csrc/models/model_factory.hpp b/csrc/models/model_factory.hpp new file mode 100644 index 00000000..029c33b5 --- /dev/null +++ b/csrc/models/model_factory.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include "infinilm_model.hpp" + +#include "../engine/distributed/distributed.hpp" + +namespace infinilm { +class InfinilmModelFactory { +public: + static std::shared_ptr createModel(const std::any &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), std::shared_ptr cache_ptr = nullptr); +}; +} // namespace infinilm diff --git a/csrc/pybind11/bindings.cc b/csrc/pybind11/bindings.cc new file mode 100644 index 00000000..b77a521c --- /dev/null +++ b/csrc/pybind11/bindings.cc @@ -0,0 +1,16 @@ +#include + +#include "models/llama.hpp" +#include "engine.hpp" + +namespace py = pybind11; + +PYBIND11_MODULE(_infinilm, m) { + m.doc() = "InfiniLM Llama model Python bindings"; + + infinilm::cache::bind_cache_config(m); + + infinilm::models::llama::bind_llama(m); + infinilm::engine::distributed::bind_dist_config(m); + infinilm::engine::bind_infer_engine(m); +} diff --git a/csrc/pybind11/engine.hpp b/csrc/pybind11/engine.hpp new file mode 100644 index 00000000..28995d0b --- /dev/null +++ b/csrc/pybind11/engine.hpp @@ -0,0 +1,125 @@ +#include "../cache/cache_config.hpp" +#include "../engine/infer_engine.hpp" +#include "infinicore/tensor.hpp" +#include +#include + +namespace py = pybind11; + +namespace infinilm::cache { + +inline void bind_cache_config(py::module &m) { + // First bind the CacheType enum + py::enum_(m, "CacheType") + .value("DYNAMIC", CacheType::DYNAMIC) + .value("PAGED", CacheType::PAGED) + .export_values(); + + // Then bind the CacheResetMode enum + py::enum_(m, "CacheResetMode") + .value("PRESERVE", CacheResetMode::PRESERVE) + .value("RECREATE", CacheResetMode::RECREATE) + .export_values(); + + // Finally bind the CacheConfig struct + py::class_(m, "CacheConfig") + .def(py::init<>(), "Default constructor") + .def(py::init(), + py::arg("type") = CacheType::DYNAMIC, + py::arg("num_layers") = 32, + py::arg("max_kv_cache_length") = 4096, + "Constructor with parameters") + .def_readwrite("type", &CacheConfig::type, "Cache type") + .def_readwrite("num_layers", &CacheConfig::num_layers, "Number of layers") + .def_readwrite("max_kv_cache_length", &CacheConfig::max_kv_cache_length, + "Maximum KV cache length") + .def_readwrite("initial_capacity", &CacheConfig::initial_capacity, + "Initial cache capacity in tokens") + .def_readwrite("initial_batch_size", &CacheConfig::initial_batch_size, + "Initial batch size for cache allocation") + .def_readwrite("growth_factor", &CacheConfig::growth_factor, + "Cache growth factor when resizing (e.g., 2.0 for doubling)") + .def_readwrite("allow_expand", &CacheConfig::allow_expand, + "Whether to allow cache expansion") + .def_readwrite("reset_mode", &CacheConfig::reset_mode, + "Cache reset mode") + .def("__eq__", &CacheConfig::operator==, py::is_operator(), + "Check if two CacheConfig objects are equal") + .def("__ne__", &CacheConfig::operator!=, py::is_operator(), + "Check if two CacheConfig objects are not equal") + .def("__repr__", [](const CacheConfig &cfg) { + return fmt::format("CacheConfig(type={}, num_layers={}, max_kv_cache_length={}, " + "initial_capacity={}, initial_batch_size={}, growth_factor={}, " + "allow_expand={}, reset_mode={})", + static_cast(cfg.type), cfg.num_layers, + cfg.max_kv_cache_length, cfg.initial_capacity, + cfg.initial_batch_size, cfg.growth_factor, + cfg.allow_expand, static_cast(cfg.reset_mode)); + }); +} + +} // namespace infinilm::cache + +namespace infinilm::engine::distributed { + +inline void bind_dist_config(py::module &m) { + py::class_(m, "DistConfig") + .def(py::init<>(), "Default constructor, empty device list") + .def(py::init(), py::arg("tp_size"), + "Constructor with tensor parallel size, auto-assigns device IDs 0..tp_size-1") + .def(py::init &>(), py::arg("tp_device_ids"), + "Constructor with explicit device IDs") + .def_readwrite("tp_device_ids", &DistConfig::tp_device_ids, + "List of device IDs used in tensor parallelism") + .def("__repr__", [](const DistConfig &cfg) { + return std::string(cfg); + }) + .def("__str__", [](const DistConfig &cfg) { + return std::string(cfg); + }); +} + +} // namespace infinilm::engine::distributed + +namespace infinilm::engine { + +inline void bind_infer_engine(py::module &m) { + py::class_>(m, "InferEngine") + .def(py::init([](const infinilm::models::llama::LlamaConfig &cfg, + const infinilm::engine::distributed::DistConfig &dist, + infinicore::Device::Type dev, + const infinilm::cache::CacheConfig &cache_config) { + return new InferEngine(std::any(cfg), dist, dev, cache_config); + }), + py::arg("config"), + py::arg("distributed_config") = distributed::DistConfig(), + py::arg("device_type") = infinicore::context::getDevice().getType(), + py::arg("cache_config") = infinilm::cache::CacheConfig()) + .def("load_param", &InferEngine::load_param, + py::arg("name"), py::arg("param"), + "Load a parameter tensor into all workers (each worker picks its shard)") + .def("state_dict", [](InferEngine &self) { + py::list state_dict_tp_all; + for (const auto &state_dict_tp : self.state_dict()) { + py::dict result; + for (const auto &[name, param] : state_dict_tp) { + result[py::cast(name)] = infinicore::Tensor(param); + } + state_dict_tp_all.append(result); + } + return state_dict_tp_all; + }) + .def( + "generate", [](InferEngine &self, py::object input_ids, py::object position_ids) -> infinicore::Tensor { + return self.generate(input_ids.cast(), position_ids.cast()); + }, + "Run inference on all ranks with arbitrary arguments") + .def("reset_cache", py::overload_cast(&InferEngine::reset_cache), py::arg("pos") = 0, "Reset the internal cache in all workers to a specific position") + .def("reset_cache", py::overload_cast(&InferEngine::reset_cache), py::arg("cache_config"), py::arg("pos") = 0, "Reset cache with new KV configuration") + .def("get_cache_config", &InferEngine::get_cache_config, "Get current KV configuration") + .def("__repr__", [](const InferEngine &self) { + return ""; + }); +} + +} // namespace infinilm::engine diff --git a/csrc/pybind11/models/llama.hpp b/csrc/pybind11/models/llama.hpp new file mode 100644 index 00000000..3ce39ebe --- /dev/null +++ b/csrc/pybind11/models/llama.hpp @@ -0,0 +1,201 @@ +#pragma once + +#include "../../cache/kv_cache.hpp" +#include "../../models/debug_utils/hooks.hpp" +#include "../../models/llama/llama.hpp" +#include "../../models/llama/llama_attention.hpp" +#include "infinicore/device.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/tensor.hpp" +#include +#include +#include + +namespace py = pybind11; +using infinicore::Device; +using infinilm::models::debug_utils::HookRegistry; + +namespace infinilm::models::llama { + +inline void bind_llama(py::module &m) { + // TODO: HookRegistry should be moved out from Llama-specific bindings to InfiniCore as common utils in future work + // Bind HookRegistry + py::class_>(m, "HookRegistry") + .def(py::init<>()) + .def( + "register_hook", [](HookRegistry &self, const std::string &name, py::object callback) { + // Convert Python callable to C++ function + self.register_hook(name, [callback](const std::string &hook_name, const infinicore::Tensor &tensor, int layer_idx) { + try { + // Call Python callback with hook name, tensor, and layer index + callback(hook_name, tensor, layer_idx); + } catch (const py::error_already_set &e) { + // Re-raise Python exception + throw; + } + }); + }, + py::arg("name"), py::arg("callback")) + .def("clear", &HookRegistry::clear) + .def("has_hooks", &HookRegistry::has_hooks); + + // Bind LlamaConfig + py::class_ config(m, "LlamaConfig"); + config + .def(py::init<>()) + .def_readwrite("vocab_size", &LlamaConfig::vocab_size) + .def_readwrite("hidden_size", &LlamaConfig::hidden_size) + .def_readwrite("intermediate_size", &LlamaConfig::intermediate_size) + .def_readwrite("num_hidden_layers", &LlamaConfig::num_hidden_layers) + .def_readwrite("num_attention_heads", &LlamaConfig::num_attention_heads) + .def_readwrite("num_key_value_heads", &LlamaConfig::num_key_value_heads) + .def_readwrite("head_dim", &LlamaConfig::head_dim) + .def_readwrite("max_position_embeddings", &LlamaConfig::max_position_embeddings) + .def_readwrite("rms_norm_eps", &LlamaConfig::rms_norm_eps) + .def_readwrite("hidden_act", &LlamaConfig::hidden_act) + .def_readwrite("model_type", &LlamaConfig::model_type) + .def_readwrite("rope_theta", &LlamaConfig::rope_theta) + .def_readwrite("attention_bias", &LlamaConfig::attention_bias) + .def_readwrite("attention_output_bias", &LlamaConfig::attention_output_bias) + .def_readwrite("mlp_bias", &LlamaConfig::mlp_bias) + .def_readwrite("tie_word_embeddings", &LlamaConfig::tie_word_embeddings) + .def_readwrite("use_cache", &LlamaConfig::use_cache) + .def_readwrite("attention_dropout", &LlamaConfig::attention_dropout) + .def_readwrite("initializer_range", &LlamaConfig::initializer_range) + .def_readwrite("pretraining_tp", &LlamaConfig::pretraining_tp) + .def_readwrite("name_or_path", &LlamaConfig::name_or_path) + .def_readwrite("pad_token_id", &LlamaConfig::pad_token_id) + .def_property("bos_token_id", [](const LlamaConfig &self) { + // Always return as list to match Python config format + return py::cast(self.bos_token_id); }, [](LlamaConfig &self, py::object value) { + // Accept both single int and list + if (py::isinstance(value)) { + self.bos_token_id = {value.cast()}; + } else if (py::isinstance(value) || py::isinstance(value)) { + self.bos_token_id = value.cast>(); + } else { + throw py::type_error("bos_token_id must be int or list of ints"); + } }) + .def_property("eos_token_id", [](const LlamaConfig &self) { + // Always return as list to match Python config format + return py::cast(self.eos_token_id); }, [](LlamaConfig &self, py::object value) { + // Accept both single int and list + if (py::isinstance(value)) { + self.eos_token_id = {value.cast()}; + } else if (py::isinstance(value) || py::isinstance(value)) { + self.eos_token_id = value.cast>(); + } else { + throw py::type_error("eos_token_id must be int or list of ints"); + } }) + .def("validate", &LlamaConfig::validate) + .def("kv_dim", &LlamaConfig::kv_dim) + // Add __dir__ to make attributes discoverable via dir() in Python + .def("__dir__", [](const LlamaConfig &self) { + py::list dir_list; + dir_list.append("vocab_size"); + dir_list.append("hidden_size"); + dir_list.append("intermediate_size"); + dir_list.append("num_hidden_layers"); + dir_list.append("num_attention_heads"); + dir_list.append("num_key_value_heads"); + dir_list.append("head_dim"); + dir_list.append("max_position_embeddings"); + dir_list.append("rms_norm_eps"); + dir_list.append("hidden_act"); + dir_list.append("model_type"); + dir_list.append("rope_theta"); + dir_list.append("attention_bias"); + dir_list.append("attention_output_bias"); + dir_list.append("mlp_bias"); + dir_list.append("tie_word_embeddings"); + dir_list.append("use_cache"); + dir_list.append("attention_dropout"); + dir_list.append("initializer_range"); + dir_list.append("pretraining_tp"); + dir_list.append("name_or_path"); + dir_list.append("pad_token_id"); + dir_list.append("bos_token_id"); + dir_list.append("eos_token_id"); + dir_list.append("validate"); + dir_list.append("kv_dim"); + return dir_list; }); + + // Note: Device is already bound in InfiniCore bindings, so we don't need to bind it here + + // Bind LlamaForCausalLM + py::class_>(m, "LlamaForCausalLM") + .def("state_dict", [](const LlamaForCausalLM &model) { + // Return a dictionary containing references to the whole state of the module. + auto state_dict = model.state_dict(); + py::dict result; + for (const auto &[name, param] : state_dict) { + result[py::cast(name)] = infinicore::Tensor(param); + } + return result; + }) + .def("get_parameter", [](const LlamaForCausalLM &model, const std::string &name) { + // Get actual tensor parameter by name + auto state_dict = model.state_dict(); + auto it = state_dict.find(name); + if (it != state_dict.end()) { + // Parameter inherits from Tensor, cast to Tensor for pybind11 + const infinicore::Tensor &tensor = it->second; + return tensor; + } + throw std::runtime_error("Parameter '" + name + "' not found in model"); }, py::arg("name")) + .def("load_state_dict", [](LlamaForCausalLM &model, py::dict state_dict) { + // Convert Python dict to C++ state_dict + std::unordered_map cpp_state_dict; + for (auto item : state_dict) { + std::string key = item.first.cast(); + py::object value = item.second.cast(); + // Extract InfiniCore tensor from Python object + infinicore::Tensor tensor; + if (py::hasattr(value, "_underlying")) { + tensor = value.attr("_underlying").cast(); + } else { + tensor = value.cast(); + } + cpp_state_dict.emplace(key, tensor); + } + model.load_state_dict(cpp_state_dict); }, py::arg("state_dict")) + .def("config", &LlamaForCausalLM::config, py::return_value_policy::reference_internal) + .def("reset_cache", [](const LlamaForCausalLM &model, size_t pos = 0) { + // Reset the internal cache to prevent state from persisting between generations + model.model().reset_cache(pos); }, py::arg("pos") = 0, "Reset the internal cache to a specific position (clears state between generations)") + .def("forward", [](const LlamaForCausalLM &model, py::object input_ids, py::object position_ids, py::object kv_cache = py::none()) { + // Helper to extract C++ tensor from Python InfiniCore tensor + auto get_tensor = [](py::object obj) -> infinicore::Tensor { + // If it's already a Python InfiniCore tensor wrapper, extract underlying + if (py::hasattr(obj, "_underlying")) { + return obj.attr("_underlying").cast(); + } + // Try direct cast (in case it's already a C++ tensor) + return obj.cast(); + }; + + // Extract InfiniCore tensors from Python objects + auto infini_input_ids = get_tensor(input_ids); + auto infini_position_ids = get_tensor(position_ids); + + // Handle kv_cache if provided (model-level DynamicCache) + void *kv_cache_ptr = nullptr; + if (!kv_cache.is_none()) { + // Try to extract DynamicCache from Python object + if (py::hasattr(kv_cache, "_underlying")) { + kv_cache_ptr = kv_cache.attr("_underlying").cast(); + } else { + // Try direct cast + try { + kv_cache_ptr = kv_cache.cast(); + } catch (...) { + // If conversion fails, pass nullptr (cache will be ignored) + kv_cache_ptr = nullptr; + } + } + } + + return model.forward(infini_input_ids, infini_position_ids, kv_cache_ptr); }, py::arg("input_ids"), py::arg("position_ids"), py::arg("kv_caches") = py::none()); +} + +} // namespace infinilm::models::llama diff --git a/csrc/utils.hpp b/csrc/utils.hpp new file mode 100644 index 00000000..e8540940 --- /dev/null +++ b/csrc/utils.hpp @@ -0,0 +1,124 @@ +#pragma once +#include + +#include +#include +#include +#include + +inline void assertTrue(int expr, const char *msg, const char *file, int line) { + if (!expr) { + fprintf(stderr, "\033[31mAssertion failed:\033[0m %s at file %s, line %d\n", msg, file, line); + exit(EXIT_FAILURE); + } +} + +#define ASSERT(expr) assertTrue((expr), #expr " is false", __FILE__, __LINE__) +#define ASSERT_EQ(a, b) assertTrue((a) == (b), #a " != " #b, __FILE__, __LINE__) +#define ASSERT_VALID_PTR(a) assertTrue((a) != nullptr, #a " is nullptr", __FILE__, __LINE__) + +#define PANIC(EXPR) \ + printf("Error at %s:%d - %s\n", __FILE__, __LINE__, #EXPR); \ + exit(EXIT_FAILURE) + +#define RUN_INFINI(API) \ + do { \ + auto api_result_ = (API); \ + if (api_result_ != INFINI_STATUS_SUCCESS) { \ + std::cerr << "Error Code " << api_result_ << " in `" << #API << "`" \ + << " from " << __func__ \ + << " at " << __FILE__ << ":" << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +inline float f16_to_f32(uint16_t h) { + uint32_t sign = (h & 0x8000) << 16; // Extract the sign bit + int32_t exponent = (h >> 10) & 0x1F; // Extract the exponent + uint32_t mantissa = h & 0x3FF; // Extract the mantissa (fraction part) + + if (exponent == 31) { // Special case for Inf and NaN + if (mantissa != 0) { + // NaN: Set float32 NaN + uint32_t f32 = sign | 0x7F800000 | (mantissa << 13); + return *(float *)&f32; + } else { + // Infinity + uint32_t f32 = sign | 0x7F800000; + return *(float *)&f32; + } + } else if (exponent == 0) { // Subnormal float16 or zero + if (mantissa == 0) { + // Zero (positive or negative) + uint32_t f32 = sign; // Just return signed zero + return *(float *)&f32; + } else { + // Subnormal: Convert to normalized float32 + exponent = -14; // Set exponent for subnormal numbers + while ((mantissa & 0x400) == 0) { // Normalize mantissa + mantissa <<= 1; + exponent--; + } + mantissa &= 0x3FF; // Clear the leading 1 bit + uint32_t f32 = sign | ((exponent + 127) << 23) | (mantissa << 13); + return *(float *)&f32; + } + } else { + // Normalized float16 + uint32_t f32 = sign | ((exponent + 127 - 15) << 23) | (mantissa << 13); + return *(float *)&f32; + } +} + +inline uint16_t f32_to_f16(float val) { + uint32_t f32; + memcpy(&f32, &val, sizeof(f32)); // Read the bits of the float32 + uint16_t sign = (f32 >> 16) & 0x8000; // Extract the sign bit + int32_t exponent = ((f32 >> 23) & 0xFF) - 127; // Extract and de-bias the exponent + uint32_t mantissa = f32 & 0x7FFFFF; // Extract the mantissa (fraction part) + + if (exponent >= 31) { // Special cases for Inf and NaN + // NaN + if (exponent == 128 && mantissa != 0) { + return static_cast(sign | 0x7E00); + } + // Infinity + return static_cast(sign | 0x7C00); + } else if (exponent >= -14) { // Normalized case + return (uint16_t)(sign | ((exponent + 15) << 10) | (mantissa >> 13)); + } else if (exponent >= -24) { + mantissa |= 0x800000; // Add implicit leading 1 + mantissa >>= (-14 - exponent); + return (uint16_t)(sign | (mantissa >> 13)); + } else { + // Too small for subnormal: return signed zero + return (uint16_t)sign; + } +} + +inline float bf16_to_f32(uint16_t val) { + // 只需把 bf16 放到 float32 高 16 bit,其余 16 位置 0。 + uint32_t bits32 = static_cast(val) << 16; + + float out; + std::memcpy(&out, &bits32, sizeof(out)); + return out; +} + +inline uint16_t f32_to_bf16(float val) { + uint32_t bits32; + std::memcpy(&bits32, &val, sizeof(bits32)); + + // 截断前先加 0x7FFF,再根据第 16 位(有效位的最低位)的奇偶做 round-to-nearest-even + const uint32_t rounding_bias = 0x00007FFF + // 0111 1111 1111 1111 + ((bits32 >> 16) & 1); // 尾数的有效位的最低位奇数时 +1,即实现舍入偶数 + + uint16_t bf16_bits = static_cast((bits32 + rounding_bias) >> 16); + + return bf16_bits; +} + +// Hash combine utility (similar to boost::hash_combine) +inline void hash_combine(size_t &seed, size_t value) { + seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} diff --git a/helper.py b/helper.py new file mode 100644 index 00000000..af9ac43e --- /dev/null +++ b/helper.py @@ -0,0 +1,384 @@ + auto nlayer = meta.nlayer; + auto nkvh = meta.nkvh / ndev; // 每个设备的KV头数 (GQA分组查询) + auto nh = meta.nh / ndev; // 每个设备的注意力头数 + auto ngroup = nh / nkvh; // 每个设备的组数 (多少个Q头共享一个K/V头) + auto dctx = meta.dctx; // 最大上下文长度 + auto dh = meta.dh; // 每个头的维度 + auto d = meta.d; // 模型隐藏层维度 + auto dt_logits = meta.dt_logits; // 输出数据类型 + auto di_dense = meta.di_dense / ndev; // 每个设备的密集FFN中间维度 + auto di_expert = meta.di_expert / ndev; // 每个设备的专家FFN中间维度 + auto dvoc = meta.dvoc; // 词汇表大小 + auto stream = rsrc.stream; + bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0; // 是否有QKV偏置 + bool has_qk_norm = rsrc.w_attn_q_norm.size() > 0 && rsrc.w_attn_k_norm.size() > 0; // 是否有QK归一化 + + //Allocate Buffer + std::shared_ptr o_buf = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + std::cout << "Allocating Buffer " << std::endl; + auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + auto logits_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + auto qkv_buf = Tensor::buffer(dt_logits, {ntok, (nh + nkvh * 2) * dh}, rsrc.memory_pool); + auto gate_up_buf = Tensor::buffer(dt_logits, {ntok, 2 * di_expert}, rsrc.memory_pool); // ? + auto o_buf = Tensor::buffer(dt_logits, {ntok, nh * dh}, rsrc.memory_pool); + auto prob_buf = Tensor::buffer(dt_logits, {nreq, dvoc}, rsrc.memory_pool); + auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool); + auto result_cpu = std::vector(nreq); + auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh}); + auto q_buf = qkv_rope->slice(1, 0, nh); + auto k_buf = qkv_rope->slice(1, nh, nkvh); + + // ==================================================================== + // 1. 输入预处理和Token嵌入 + // ==================================================================== + + // 1.1 位置编码准备 - 生成序列中每个token的位置ID + // requests = [ + // "Hello world", # req_idx=0, 长度=2, 起始位置=0 + // "How are you", # req_idx=1, 长度=3, 起始位置=2 + // "Fine thanks", # req_idx=2, 长度=2, 起始位置=5 + // ] + + // # 输入参数: + // req_pos = [0, 2, 5] # 每个请求的起始位置 + // req_lens = [2, 3, 2] # 每个请求的长度 + // nreq = 3 # 3个请求 + // ntok = 7 # 总共7个token + + // # 执行后结果: + // batch_pos_ids = [0, 1, 2, 3, 4, 5, 6] + auto batch_pos_ids = std::vector(ntok); + // 根据请求信息填充位置ID (req_pos指示每个请求的起始位置) + std::cout << "PreFill Id " << "req len is " << nreq << std::endl; + for (uint32_t req_idx = 0; req_idx < nreq; ++req_idx) { + uint32_t start_pos = req_pos[req_idx]; + uint32_t req_len = req_lens[req_idx]; + for (uint32_t i = 0; i < req_len; ++i) { + batch_pos_ids[start_pos + i] = start_pos + i; + } + } + + // 1.2 将位置ID复制到GPU (如果是GPU设备) + std::shared_ptr pos_ids_buf; + if (rsrc.device == INFINI_DEVICE_CPU) { + pos_ids_buf = Tensor::weight(batch_pos_ids.data(), INFINI_DTYPE_U32, {ntok}); + } else { // GPU + pos_ids_buf = Tensor::buffer(INFINI_DTYPE_U32, {ntok}, rsrc.memory_pool); + RUN_INFINI(infinirtMemcpyAsync(pos_ids_buf->data(), batch_pos_ids.data(), sizeof(uint32_t) * ntok, + INFINIRT_MEMCPY_H2D, stream)); + } + + // 1.3 Token嵌入查找 - 将token ID转换为对应的嵌入向量 + // 输入: tokens[token_count] -> 输出: [token_count, hidden_dim] + std::shared_ptr hidden_states = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + for (uint32_t i = 0; i < ntok; i++) { + // 查找词嵌入表: [vocab_size, d] -> 输出token i的嵌入向量 + RUN_INFINI(infinirtMemcpyAsync(hidden_states->data(i * d), + rsrc.w_in_embd->data(tokens[i] * d), + dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream)); + } + + // ==================================================================== + // 2. 位置编码应用 - RoPE (Rotary Position Embedding) + // ==================================================================== + + // 2.1 从预计算的RoPE表中获取当前位置的正弦余弦值 + // sin_table: [max_seq_len, head_dim/2], cos_table: [max_seq_len, head_dim/2] + std::shared_ptr sin_cos_for_pos = Tensor::buffer(dt_logits, {ntok, dh}, rsrc.memory_pool); + // TODO: 根据batch_pos_ids从sin_table和cos_table中提取对应的正弦余弦值 + + // ==================================================================== + // 3. Transformer层处理循环 + // ==================================================================== + + // 设置推理上下文 - 必须在调用rmsnorm之前 + + CacheManager cache_manager(100); + InferenceContext ctx(rsrc.handle, rsrc.memory_pool, &cache_manager, rsrc.stream); + setInferenceContext(&ctx); + std::cout << "[DEBUG] InferenceContext created and set successfully" << std::endl; + + // 保存每层的输入,用于残差连接 + std::shared_ptr layer_input = hidden_states; + + for (uint32_t layer_idx = 0; layer_idx < nlayer; ++layer_idx) { + std::cout << "Processing layer " << layer_idx << std::endl; + + // ==================================================================== + // 3.1 注意力前归一化 (Input Layernorm) + // ==================================================================== + std::cout << "Input LayerNorm" << std::endl; + // RMSNorm: hidden_states = hidden_states * rms_norm_weight / sqrt(mean(hidden_states^2) + eps) + std::shared_ptr attn_input = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + // 实现RMSNorm归一化 - 参考jiuge.cpp:218 + std::cout << "Layer Idx " << layer_idx << " Do rmsnorm" << std::endl; + + // 检查权重是否为空 + if (!rsrc.w_attn_norm[layer_idx]) { + std::cerr << "ERROR: rsrc.w_attn_norm[" << layer_idx << "] is null!" << std::endl; + return; + } + if (!layer_input) { + std::cerr << "ERROR: layer_input is null!" << std::endl; + return; + } + if (!attn_input) { + std::cerr << "ERROR: attn_input is null!" << std::endl; + return; + } + if (!rsrc.w_attn_norm[layer_idx]->data()) { + std::cerr << "ERROR: rsrc.w_attn_norm[" << layer_idx << "] data pointer is null!" << std::endl; + return; + } + + std::cout << "All pointers are valid, calling rmsnorm..." << std::endl; + rmsnorm(attn_input, layer_input, rsrc.w_attn_norm[layer_idx], meta.epsilon); + + // ==================================================================== + // 3.2 QKV线性投影 + // ==================================================================== + // 输入: [ntok, d] -> Q: [ntok, nh, dh], K: [ntok, nkvh, dh], V: [ntok, nkvh, dh] + std::cout << "Initlilizing qkv buffer " << std::endl; + std::shared_ptr qkv_buf = Tensor::buffer(dt_logits, {ntok, (nh + nkvh * 2) * dh}, rsrc.memory_pool); + + + // 执行线性投影: [ntok, d] × [d, (nh + 2*nkvh)*dh] = [ntok, (nh + 2*nkvh)*dh] + // TODO: 实现QKV矩阵乘法 + // RUN_INFINI(infiniopGemm(..., attn_input, rsrc.w_attn_qkv[layer_idx], qkv_buf, ...)); + std::cout << "Linear project qkv matrix" << std::endl; + linear(qkv_buf, attn_input, rsrc.w_attn_qkv[layer_idx], 1.0, 0.0, nullptr, nullptr); + // 如果有QKV偏置,加上偏置 + // if (has_qkv_bias) { + // // TODO: 实现偏置加法 + // // RUN_INFINI(infiniopAdd(..., qkv_buf, rsrc.b_attn_qkv[layer_idx], qkv_buf, ...)); + // } + // ==================================================================== + // 3.3 QKV重塑和分离 + // ==================================================================== + std::cout << "Initlilizing qkv split " << std::endl; + auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh}); + auto q_buf = qkv_rope->slice(1, 0, nh); // [ntok, nh, dh] - Query部分 + auto k_buf = qkv_rope->slice(1, nh, nkvh); // [ntok, nkvh, dh] - Key部分 + auto v_buf = qkv_rope->slice(1, nh+nkvh, nkvh); // [ntok, nkvh, dh] - Value部分 + + // ==================================================================== + // 3.4 QK归一化 (可选) + // ==================================================================== + std::cout << "QK Normalization" << std::endl; + if (has_qk_norm) { + // 对Q和K分别应用RMSNorm + // TODO: 实现QK归一化 + // RUN_INFINI(infiniopRMSNorm(..., q_buf, rsrc.w_attn_q_norm[layer_idx], q_buf, ...)); + // RUN_INFINI(infiniopRMSNorm(..., k_buf, rsrc.w_attn_k_norm[layer_idx], k_buf, ...)); + rmsnorm(q_buf, q_buf, rsrc.w_attn_q_norm[layer_idx], meta.epsilon); + rmsnorm(k_buf, k_buf, rsrc.w_attn_k_norm[layer_idx], meta.epsilon); + } + + // ==================================================================== + // 3.5 RoPE位置编码应用 + // ==================================================================== + // 将正弦余弦编码应用到Q和K张量 + // 输入: q_buf: [ntok, nh, dh], k_buf: [ntok, nkvh, dh] + // 位置编码: sin_cos_for_pos: [ntok, dh] + // TODO: 实现RoPE旋转位置编码 + std::cout << "Position Embedding" << std::endl; + rope(q_buf, q_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table); + rope(k_buf, k_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table); + // 注意:rope函数会同时处理q和k,所以第二个调用可能不需要 buffer created + // ==================================================================== + // 3.6 注意力计算 + // ==================================================================== + size_t token_offset = 0; + for(uint32_t req = 0; req < nreq; req ++){ + auto past_len = req_pos[req]; + auto seq_len = req_lens[req]; + auto total = past_len + seq_len; + auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}); + auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); + + // self attention + // concat + rearrange(kv_caches[req]->k[idev][layer_idx]->slice(0, past_len, seq_len), k); + rearrange(kv_caches[req]->v[idev][layer_idx]->slice(0, past_len, seq_len), v); + // qk + rearrange(q_rearrange->slice(2, 0, seq_len), q); + auto qk_gemm = qk_buf->slice(0, 0, nh * seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len}); + auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}); + linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); + // softmax + auto qk_softmax = qk_gemm->view({nh, seq_len, total_len}); + causalSoftmax(qk_softmax, qk_softmax); + auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}); + linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr); + // rearrange attn val + rearrange(o, attn_val_gemm->slice(2, 0, seq_len)); + + token_offset += seq_len; + } + // 3.6.1 KV缓存处理 + // 将当前的K和V写入到KV缓存中,用于自回归生成 + // TODO: 实现KV缓存写入 + // update_kv_cache(kv_caches, layer_idx, k_buf, v_buf, ...); + + // 3.6.2 读取历史KV (如果有缓存) + // TODO: 从KV缓存中读取历史K和V + // [seq_len_cached + ntok, nkvh, dh] + + // 3.6.3 GQA处理 - 重复KV以匹配Q的维度 + // 将[n_tok, nkvh, dh]的K和V重复ngroup次变为[n_tok, nh, dh] + std::shared_ptr k_repeated = Tensor::buffer(dt_logits, {ntok, nh, dh}, rsrc.memory_pool); + std::shared_ptr v_repeated = Tensor::buffer(dt_logits, {ntok, nh, dh}, rsrc.memory_pool); + // TODO: 实现KV重复操作 + // repeat_kv(k_buf, k_repeated, ngroup); + // repeat_kv(v_buf, v_repeated, ngroup); + + // 3.6.4 注意力权重计算 + // Q @ K^T / sqrt(dh) -> [ntok, nh, seq_len] + std::shared_ptr attn_weights = Tensor::buffer(dt_logits, {ntok, nh, ntok}, rsrc.memory_pool); + // TODO: 实现注意力权重计算 + // RUN_INFINI(infiniopBatchedGemm(..., q_buf, k_repeated, attn_weights, ...)); + // attn_weights = attn_weights / sqrt(dh); + + // 3.6.5 注意力归一化 + // TODO: 实现Softmax归一化 + // RUN_INFINI(infiniopSoftmax(..., attn_weights, attn_weights, ...)); + + // 3.6.6 注意力输出计算 + // attn_weights @ V -> [ntok, nh, dh] + std::shared_ptr attn_output = Tensor::buffer(dt_logits, {ntok, nh, dh}, rsrc.memory_pool); + // TODO: 实现注意力输出计算 + // RUN_INFINI(infiniopBatchedGemm(..., attn_weights, v_repeated, attn_output, ...)); + + // 3.6.7 输出投影 + // 将多头输出合并: [ntok, nh, dh] -> [ntok, d] + // 然后通过输出投影层: [ntok, d] × [d, d] = [ntok, d] + std::shared_ptr attn_proj = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + // TODO: 实现注意力输出投影 + // RUN_INFINI(infiniopGemm(..., attn_output, rsrc.w_attn_out[layer_idx], attn_proj, ...)); + + // ==================================================================== + // 3.7 残差连接和后注意力归一化 + // ==================================================================== + // 残差连接: layer_input + attn_proj + std::shared_ptr residual1 = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + // TODO: 实现残差加法 + // RUN_INFINI(infiniopAdd(..., layer_input, attn_proj, residual1, ...)); + + // 后注意力归一化 + std::shared_ptr mlp_input = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + // TODO: 实现RMSNorm + // RUN_INFINI(infiniopRMSNorm(..., residual1, rsrc.w_ffn_norm[layer_idx], mlp_input, ...)); + + // ==================================================================== + // 3.8 专家混合(MoE)处理 + // ==================================================================== + + // 3.8.1 路由器计算 + // 输入: [ntok, d] -> 输出: [ntok, num_experts] (每个token对每个专家的权重) + std::shared_ptr router_logits = Tensor::buffer(dt_logits, {ntok, meta.num_experts}, rsrc.memory_pool); + // TODO: 实现路由器线性层 + // RUN_INFINI(infiniopGemm(..., mlp_input, w_router, router_logits, ...)); + + // 3.8.2 TopK专家选择 + // 对每个token选择top_k个专家 + // TODO: 实现TopK选择和Softmax归一化 + // top_k_experts, top_k_weights = topk_softmax(router_logits, k=2); + + // 3.8.3 专家计算 + // 初始化MoE输出为0 + std::shared_ptr moe_output = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + // RUN_INFINI(infinirtMemsetAsync(moe_output->data(), 0, ...)); + + // 对每个专家进行计算 + for (uint32_t expert_idx = 0; expert_idx < meta.num_experts; ++expert_idx) { + // 找到被路由到当前专家的token + // TODO: 找到被路由到当前专家的token索引 + + // 3.8.3.1 专家FFN计算 + // Gate + Up投影: [token_count, d] × [d, 2*di_expert] = [token_count, 2*di_expert] + // std::shared_ptr gate_up_output = Tensor::buffer(dt_logits, {1, 2 * di_expert}, rsrc.memory_pool); + // TODO: 实现Gate+Up投影 + // RUN_INFINI(infiniopGemm(..., expert_tokens, w_gate_up[layer_idx][expert_idx], gate_up_output, ...)); + + // 激活函数: 通常使用SiLU (Swish) + // TODO: 实现SiLU激活函数 + // RUN_INFINI(infiniopSilu(..., gate_up_output, gate_up_output, ...)); + + // Down投影: [token_count, 2*di_expert] × [2*di_expert, d] = [token_count, d] + // std::shared_ptr expert_output = Tensor::buffer(dt_logits, {expert_token_count, d}, rsrc.memory_pool); + // TODO: 实现Down投影 + // RUN_INFINI(infiniopGemm(..., gate_up_output, w_down[layer_idx][expert_idx], expert_output, ...)); + + // 3.8.3.2 加权求和 + // 将专家输出按路由权重加权加到总输出中 + // TODO: 实现加权求和 + // moe_output += expert_output * expert_weights + } + + // ==================================================================== + // 3.9 第二个残差连接 + // ==================================================================== + // residual1 + moe_output -> next_layer_input + // std::shared_ptr next_layer_input = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + // TODO: 实现残差加法 + // RUN_INFINI(infiniopAdd(..., residual1, moe_output, next_layer_input, ...)); + + // 更新下一层的输入 + // layer_input = next_layer_input; + + // 同步当前层的计算 + // RUN_INFINI(infinirtStreamSynchronize(stream)); + } + + // ==================================================================== + // 4. 最终输出层 + // ==================================================================== + + // 4.1 最终归一化 + // std::shared_ptr final_norm = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + // TODO: 实现最终RMSNorm + // RUN_INFINI(infiniopRMSNorm(..., layer_input, rsrc.w_out_norm, final_norm, ...)); + + // 4.2 输出投影到词汇表 + // 输入: [ntok, d] × 输出权重: [d, dvoc] = [ntok, dvoc] + // std::shared_ptr logits = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); + // TODO: 实现输出投影 + // RUN_INFINI(infiniopGemm(..., final_norm, rsrc.w_out_embd, logits, ...)); + + // ==================================================================== + // 5. 采样/解码 + // ==================================================================== + + // // 5.1 如果是推理模式,对每个序列的最后一个token进行采样 + // for (uint32_t req_idx = 0; req_idx < nreq; ++req_idx) { + // uint32_t last_token_pos = req_pos[req_idx] + req_lens[req_idx] - 1; + + // // 获取最后一个token的logits + // std::shared_ptr last_token_logits = logits->slice(0, last_token_pos, last_token_pos + 1); // [1, dvoc] + + // // 5.2 应用温度 + // if (temperature[req_idx] != 0.0) { + // // TODO: 实现温度缩放: logits = logits / temperature + // } + + // // 5.3 TopK/TopP采样 + // if (topk[req_idx] > 1) { + // // TODO: 实现TopK采样 + // } else if (topp[req_idx] < 1.0) { + // // TODO: 实现TopP (Nucleus) 采样 + // } else { + // // 直接取argmax + // // TODO: 实现Argmax + // // RUN_INFINI(infiniopArgmax(..., last_token_logits, &output[req_idx], ...)); + // } + // } + + // 同步所有计算 + RUN_INFINI(infinirtStreamSynchronize(stream)); + + // 清理推理上下文 + setInferenceContext(nullptr); + + std::cout << "InferDeviceBatch completed" << std::endl; \ No newline at end of file diff --git a/include/infinicore_infer/models/deepseek.h b/include/infinicore_infer/models/deepseek.h index 3924c5fe..635e9ec2 100644 --- a/include/infinicore_infer/models/deepseek.h +++ b/include/infinicore_infer/models/deepseek.h @@ -107,7 +107,7 @@ __C __export struct DeepSeekV3Model * createDeepSeekV3Model(const DeepSeekV3Meta *, const DeepSeekV3Weights *); -__C DeepSeekV3Weights * +__C __export DeepSeekV3Weights * createDeepSeekV3Weights(const DeepSeekV3Meta *meta, infiniDevice_t device, int ndev, diff --git a/include/infinicore_infer/models/jiuge.h b/include/infinicore_infer/models/jiuge.h index ee0a78c0..536e8e9d 100644 --- a/include/infinicore_infer/models/jiuge.h +++ b/include/infinicore_infer/models/jiuge.h @@ -1,14 +1,12 @@ #ifndef MODEL_JIUGE_H #define MODEL_JIUGE_H - #include #include #include - #include - struct JiugeModel; + typedef struct { infiniDtype_t dt_logits; @@ -19,6 +17,15 @@ typedef struct typedef struct { + + // ​**d**: 隐藏层维度(hidden dimension) + // ​**dvoc**: 词表大小(vocabulary dimension) + // ​**nlayer**: Transformer 层数 + // ​**nh**: 注意力头数(number of heads) + // ​**nkvh**: KV 头数(用于分组查询注意力) + // ​**dh**: 每个注意力头的维度(d_head = d / nh) + // ​**di**: FFN 中间层维度(通常 d_i = 4*d)QKVO + // ​**ndev**: 设备数量(用于模型并行) size_t nlayer; infiniDtype_t dt_norm, dt_mat; // 0 if linear weights are passed as W, any other value if passed as W^T (default format in pytorch) @@ -31,7 +38,7 @@ typedef struct const void *output_embd; // nlayer * [d] const void *const *attn_norm; - // nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh, d] + // nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh, d] each devide deal with equal head const void *const *attn_qkv; // nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh] const void *const *attn_qkv_b; diff --git a/include/infinicore_infer/models/llada.h b/include/infinicore_infer/models/llada.h new file mode 100644 index 00000000..c700c59d --- /dev/null +++ b/include/infinicore_infer/models/llada.h @@ -0,0 +1,91 @@ +#ifndef MODEL_LLADAMOE_H +#define MODEL_LLADAMOE_H + +#include +#include +#include + +#include +#include + +struct LLaDAModel; + +typedef struct +{ + infiniDtype_t dt_logits; + size_t nlayer, d, nh, nkvh, dh, di_dense, di_expert, dctx, dvoc; + float epsilon, theta; + uint32_t end_token; + size_t num_experts; +} LLaDAMeta; + + +typedef struct{ + size_t nlayer; + infiniDtype_t dt_norm, dt_mat; + // 0 if linear weights are passed as W, any other value if passed as W^T (default format in pytorch) + int transpose_linear_weights; + // [dvoc, d] + const void *input_embd; + // [d] + const void *output_norm; + // [dvoc, d] + const void *output_embd; + // nlayer * [d] + const void *const *attn_norm; + // nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh, d] each devide deal with equal head + const void *const *attn_qkv; + // nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh] + const void *const *attn_qkv_b; + // nlayer * [dh] + const void *const *attn_q_norm; + // nlayer * [dh] + const void *const *attn_k_norm; + // nlayer * [ndev, d, nkvh / ndev * dh] + const void *const *attn_o; + // nlayer * [d] + const void *const *ffn_norm; + // nlayer * [ndev, 2 * di / ndev, d] + + const void *const *expert_gate; + + const void *const *expert_up; + + const void *const *expert_down; + + const void *const *expert_router; +} LLaDAWeights; + +// 改进后的权重加载器结构体 +typedef struct { + +} LLaDAWeightLoader; // # TODO: + + + +__C __export struct LLaDAModel * +createLLaDAModel(const LLaDAMeta *, + const LLaDAWeights *, + infiniDevice_t device, + int ndev, + const int *dev_ids); + +__C __export void +destroyLLaDAModel(); + +__C __export void +inferBatchLLaDA(struct LLaDAModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output); + +__C __export void +forwardBatchLLaDA(struct LLaDAModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + void *logits); + +#endif \ No newline at end of file diff --git a/jiuge.sh b/jiuge.sh new file mode 100644 index 00000000..c186fc7c --- /dev/null +++ b/jiuge.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +# Jiuge模型运行脚本 +# 使用NVIDIA显卡运行9G4B模型 + +set -e # 遇到错误立即退出 + +echo "==========================================" +echo "🚀 启动 Jiuge 模型 (9G4B) - NVIDIA版本" +echo "==========================================" +export INFINI_ROOT=/home/featurize/.infini +# 设置参数 +MODEL_DIR="/home/featurize/work/InfiniFamily/9G4B" +DEVICE="--nvidia" +N_DEVICE=1 +SCRIPT_PATH="python scripts/jiuge.py" + +# 检查模型目录是否存在 +if [ ! -d "$MODEL_DIR" ]; then + echo "❌ 错误: 模型目录不存在: $MODEL_DIR" + echo "请检查路径是否正确" + exit 1 +fi + +# 检查Python脚本是否存在 +if [ ! -f "scripts/jiuge.py" ]; then + echo "❌ 错误: 未找到jiuge.py脚本: scripts/jiuge.py" + echo "请确保在当前目录下运行此脚本" + exit 1 +fi + +echo "📁 模型路径: $MODEL_DIR" +echo "🎯 设备类型: NVIDIA GPU" +echo "💻 设备数量: $N_DEVICE" +echo "" + +# 运行模型 +echo "🔄 启动模型..." +$SCRIPT_PATH $DEVICE $MODEL_DIR $N_DEVICE + +echo "" +echo "==========================================" +echo "✅ 模型运行完成" +echo "==========================================" \ No newline at end of file diff --git a/llada.sh b/llada.sh new file mode 100644 index 00000000..4ca95dbe --- /dev/null +++ b/llada.sh @@ -0,0 +1,2 @@ +export INFINI_ROOT=/home/featurize/.infini +python /home/featurize/work/My_InfiniLM/scripts/llada.py \ No newline at end of file diff --git a/llada_torch/LLaDA b/llada_torch/LLaDA new file mode 160000 index 00000000..570f2903 --- /dev/null +++ b/llada_torch/LLaDA @@ -0,0 +1 @@ +Subproject commit 570f29032d6824ea14977c89a8eb402e6eb25f96 diff --git a/modeling_lladamoe.py b/modeling_lladamoe.py new file mode 100644 index 00000000..868ee084 --- /dev/null +++ b/modeling_lladamoe.py @@ -0,0 +1,1186 @@ +"""LLaDA MoE model pytorch implementation""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +from .configuration_lladamoe import LLaDAConfig + + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LLaDAConfig" + + +# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func +def load_balancing_loss_func( + gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None +) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + attention_mask (`torch.Tensor`, *optional*): + For diffusion language model, attention_mask is set to None by default. + If you pass an attention mask and expect the model to use it for computing other attention mechanisms, + it may lead to logits and aux_loss returned by the model being inconsistent with your expectations. + num_experts (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + +# copied from transformers.models.olmoe.modeling_olmoe.OlmoeRMSNorm -> LLaDAMoERMSNorm +class LLaDAMoERMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + """ + LLaDAMoERMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(LLaDAMoERMSNorm) + +# copied from transformers.models.olmoe.modeling_olmoe.OlmoeRotaryEmbedding -> LLaDAMoERotaryEmbedding +class LLaDAMoERotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LLaDAConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LLaDAMoERotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# copied from transformers.models.olmoe.modeling_olmoe.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# copied from transformers.models.olmoe.modeling_olmoe.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + rotary_dim = cos.shape[-1] + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rot = q[..., :rotary_dim] + q_pass = q[..., rotary_dim:] + + k_rot = k[..., :rotary_dim] + k_pass = k[..., rotary_dim:] + + q_rotated = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_rotated = (k_rot * cos) + (rotate_half(k_rot) * sin) + + q_final = torch.cat((q_rotated, q_pass), dim=-1) + k_final = torch.cat((k_rotated, k_pass), dim=-1) + + return q_final, k_final + + +# copied from transformers.models.olmoe.modeling_olmoe.OlmoeMLP with OlmoeMLP->LLaDAMoEMLP +class LLaDAMoEMLP(nn.Module): + def __init__(self, config, mlp_type): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + if mlp_type == 'dense': + self.intermediate_size = config.dense_intermediate_size + elif mlp_type == 'expert': + self.intermediate_size = config.expert_intermediate_size + elif mlp_type == 'shared_expert': + self.intermediate_size = config.shared_expert_intermediate_size + else: + assert False, "unknown mlp type" + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] # "hidden_act": "silu", + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# copied from transformers.models.olmoe.modeling_olmoe.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# copied from transformers.models.olmoe.modeling_olmoe.OlmoeAttention with OlmoeAttention->LLaDAMoEAttention +class LLaDAMoEAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LLaDAConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # **For diffusion language model, we set is_causal to False by default.** + self.is_causal = False + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + if config.qk_layernorm: + self.q_norm = LLaDAMoERMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = LLaDAMoERMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + if 'q_norm' in dir(self): + query_states = self.q_norm(query_states.reshape(-1, self.head_dim)).reshape(bsz, q_len, -1) + key_states = self.k_norm(key_states.reshape(-1, self.head_dim)).reshape(bsz, q_len, -1) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + # **For diffusion language model, attention_mask is set to None(full attention) by default.** + attention_mask = None + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# copied from transformers.models.olmoe.modeling_olmoe.FlashAttention2 with OlmoeFlashAttention2->LLaDAMoEFlashAttention2 +class LLaDAMoEFlashAttention2(LLaDAMoEAttention): + """ + LLaDAMoE flash attention module. This module inherits from `LLaDAMoEAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # copied from transformers.models.olmoe.modeling_olmoe.OlmoeFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + if 'q_norm' in dir(self): + query_states = self.q_norm(query_states.reshape(-1, self.head_dim)).reshape(bsz, q_len, -1) + key_states = self.k_norm(key_states.reshape(-1, self.head_dim)).reshape(bsz, q_len, -1) + value_states = self.v_proj(hidden_states) + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LLaDAMoERMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # **For diffusion language model, attention_mask is set to None(full attention) by default.** + attention_mask = None + self.is_causal = False + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# copied from transformers.models.olmoe.modeling_olmoe.OlmoeSdpaAttention with OlmoeSdpaAttention->LLaDAMoESdpaAttention +class LLaDAMoESdpaAttention(LLaDAMoEAttention): + """ + LLaDAMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LLaDAMoEAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + logger.warning_once( + "LLaDAModel is using LLaDAMoESdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + if 'q_norm' in dir(self): + query_states = self.q_norm(query_states.reshape(-1, self.head_dim)).reshape(bsz, q_len, -1) + key_states = self.k_norm(key_states.reshape(-1, self.head_dim)).reshape(bsz, q_len, -1) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + # if attention_mask is not None and cache_position is not None: + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + # **For diffusion language model, attention_mask is set to None(full attention) by default.** + is_causal = False + causal_mask = None + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLADAMOE_ATTENTION_CLASSES = { + "eager": LLaDAMoEAttention, + "flash_attention_2": LLaDAMoEFlashAttention2, + "sdpa": LLaDAMoESdpaAttention, +} + + +# copied from transformers.models.olmoe.modeling_olmoe.OlmoeSparseMoeBlock with OlmoeSparseMoeBlock->LLaDAMoESparseMoeBlock +class LLaDAMoESparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = False + self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False) + self.experts = nn.ModuleList([LLaDAMoEMLP(config, 'expert') for _ in range(self.num_experts)]) + self.score_func = config.moe_router_score_function + if config.moe_router_enable_expert_bias: + self.register_buffer("expert_bias", torch.zeros(self.num_experts)) + else: + self.expert_bias = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) # TODO: + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + + if self.expert_bias is not None: + routing_weights += self.expert_bias + + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be selected + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + + +class LLaDAMoEDecoderLayer(nn.Module): + def __init__(self, config: LLaDAConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.mlp_type = 'dense' if config.moe_layer_freq[layer_idx] == 0 else 'moe' + + self.self_attn = LLADAMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LLaDAMoESparseMoeBlock(config) if self.mlp_type == 'moe' else LLaDAMoEMLP(config, 'dense') + self.input_layernorm = LLaDAMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LLaDAMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if config.shared_expert_intermediate_size is not None and self.mlp_type == 'moe': + self.shared_expert = LLaDAMoEMLP(config, 'shared_expert') + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + For diffusion language model, attention_mask is set to None(full attention). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + For diffusion language model, use_cache is set to False by default. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): + For diffusion language model, past_key_value is set to None by default. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + For diffusion language model, cache_position is set to None by default. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # **For diffusion language model, attention_mask is set to None(full attention) by default.** + use_cache = False + attention_mask = None + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + shared_expert_states = hidden_states + + hidden_states = self.mlp(hidden_states) + + if hasattr(self, "shared_expert"): + hidden_states = hidden_states + self.shared_expert(shared_expert_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLADAMOE_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LLaDAConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaDAMoE Model outputting raw hidden-states without any specific head on top.", + LLADAMOE_START_DOCSTRING, +) +# copied from transformers.models.olmoe.modeling_olmoe.OlmoeModel with OlmoePreTrainedModel->LLaDAMoEPreTrainedModel +class LLaDAMoEPreTrainedModel(PreTrainedModel): + config_class = LLaDAConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LLaDAMoEDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLADAMOE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. + **For diffusion language model, attention_mask is set to None(full attention) by default.** + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + **For diffusion language model, past_key_values can not be applied by default.** + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + For diffusion languagem model, the use_cache and past_key_values can not be enabled for default setting. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + **For diffusion language model, cache_position can not be applied by default.** +""" + + +@add_start_docstrings( + "The bare LLaDAMoE Model outputting raw hidden-states without any specific head on top.", + LLADAMOE_START_DOCSTRING, +) +# copied from transformers.models.olmoe.modeling_olmoe.OlmoeModel with OlmoeModel->LLaDAMoEModel +class LLaDAMoEModel(LLaDAMoEPreTrainedModel): + """ + Transformer encoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LLaDAMoEDecoderLayer`] + + Args: + config: LLaDAConfig + """ + + def __init__(self, config: LLaDAConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LLaDAMoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LLaDAMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LLaDAMoERotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLADAMOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + assert (not use_cache and past_key_values is None and cache_position is None), "The cache mechanism is not suppotred for LLaDA MoE by default." + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = None #TODO: + logger.warning_once( + f"Please note that, unlike autoregressive models, LLaDA MoE employs a bidirectional attention mechanism. " + f"In the forward code in modeling_lladamoe.py, we set both attention_mask and causal_mask to None, " + f"which affects the default causal attention and causes the input attention_mask parameter to become ineffective. " + f"If you pass an attention mask and expect the model to use it for computing other attention mechanisms, " + f"it may lead to logits and aux_loss returned by the model being inconsistent with your expectations. " + ) + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits and layer_outputs[-1] is not None: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +class LLaDAMoEModelLM(LLaDAMoEPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LLaDAMoEModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLADAMOE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + For the current inference code of the diffusion language model, passing the parameters `labels` and `num_logits_to_keep` to compute loss is not supported. + Please note that for the diffusion language model, you cannot use model.generate() to generate responses. Please use the provided sampling code to generate model outputs. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModel + + >>> model = AutoModel.from_pretrained("path/to/LLaDAMoE") + >>> tokenizer = AutoTokenizer.from_pretrained("path/to/LLaDAMoE") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = generate() # Please use the customized generate method instead of model.generate(). + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m' + ``` + """ + assert (labels is None and num_logits_to_keep == 0), "LLaDAMoE model does not support calculate loss in the forward pass." + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) diff --git a/rope.py b/rope.py new file mode 100644 index 00000000..eccf8ca0 --- /dev/null +++ b/rope.py @@ -0,0 +1,291 @@ +import torch +import torch.nn as nn +import math +from typing import Optional + +class RotaryPositionEmbeddingSimple(nn.Module): + """ + 修复bug后的直观RoPE实现 + """ + + def __init__(self, dim: int, base: int = 10000): + super().__init__() + + assert dim % 2 == 0, f"维度必须是偶数,当前dim={dim}" + + self.dim = dim + self.base = base + + print(f"🔧 RoPE初始化: dim={dim}, base={base}") + + def forward(self, x: torch.Tensor, positions: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + 修复bug的RoPE前向传播 + """ + print(f"\n🔄 开始RoPE计算...") + print(f" 输入形状: {x.shape}") + + # 保存原始形状 + original_shape = x.shape + + # 获取输入信息 + if x.dim() == 3: # [batch, seq_len, dim] + batch_size, seq_len, dim = x.shape + num_heads = 1 + elif x.dim() == 4: # [batch, heads, seq_len, dim] + batch_size, num_heads, seq_len, dim = x.shape + else: + raise ValueError(f"不支持的输入维度: {x.dim()}") + + # 处理位置编码 + if positions is None: + positions = torch.arange(seq_len, device=x.device) + print(f" 自动生成位置: {positions.tolist()}") + else: + print(f" 使用自定义位置: {positions.tolist()}") + + # 1. 计算频率向量 + print(f"\n📊 步骤1: 计算频率向量") + indices = torch.arange(0, dim, 2).float().to(x.device) # [0, 2, 4, ..., dim-2] + inv_freq = 1.0 / (self.base ** (indices / dim)) + print(f" 频率向量: {inv_freq.tolist()}") + + # 2. 计算角度矩阵 + print(f"\n📊 步骤2: 计算角度矩阵") + # positions: [seq_len] -> [seq_len, 1] + # inv_freq: [dim/2] -> [1, dim/2] + positions_expanded = positions.unsqueeze(-1) # [seq_len, 1] + inv_freq_expanded = inv_freq.unsqueeze(0) # [1, dim/2] + + angles = positions_expanded * inv_freq_expanded # [seq_len, dim/2] + print(f" 角度矩阵形状: {angles.shape}") + + # 3. 扩展角度到每个维度 + angles_expanded = angles.repeat_interleave(2, dim=-1) # [seq_len, dim] + print(f" 扩展后角度形状: {angles_expanded.shape}") + + # 4. 计算正弦余弦 + sin = torch.sin(angles_expanded) # [seq_len, dim] + cos = torch.cos(angles_expanded) # [seq_len, dim] + print(f" 正弦形状: {sin.shape}, 余弦形状: {cos.shape}") + + # 5. 关键修复:正确调整形状以匹配输入 + print(f"\n📊 步骤3: 调整形状匹配输入") + if x.dim() == 3: # [batch, seq_len, dim] + # 扩展维度: [seq_len, dim] -> [batch, seq_len, dim] + sin = sin.unsqueeze(0).expand(batch_size, -1, -1) # 使用expand而不是repeat + cos = cos.unsqueeze(0).expand(batch_size, -1, -1) + elif x.dim() == 4: # [batch, heads, seq_len, dim] + # 扩展维度: [seq_len, dim] -> [batch, heads, seq_len, dim] + sin = sin.unsqueeze(0).unsqueeze(0).expand(batch_size, num_heads, -1, -1) + cos = cos.unsqueeze(0).unsqueeze(0).expand(batch_size, num_heads, -1, -1) + + print(f" 调整后正弦形状: {sin.shape}") + print(f" 调整后余弦形状: {cos.shape}") + print(f" 输入x形状: {x.shape}") + + # 6. 应用旋转 + result = self._apply_rotation_detailed(x, sin, cos) + + print(f"✅ RoPE计算完成") + print(f" 输入: {original_shape} -> 输出: {result.shape}") + + return result + + def _apply_rotation_detailed(self, x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor: + """修复后的旋转操作""" + print(f"\n📊 步骤4: 应用旋转操作") + + # 检查形状是否匹配 + assert x.shape == sin.shape == cos.shape, f"形状不匹配: x{x.shape}, sin{sin.shape}, cos{cos.shape}" + + # 分割输入张量 + x1 = x[..., 0::2] # 所有偶数索引维度 + x2 = x[..., 1::2] # 所有奇数索引维度 + + print(f" x1形状 (偶数维度): {x1.shape}") + print(f" x2形状 (奇数维度): {x2.shape}") + + # 分割正弦余弦(确保形状匹配) + sin1 = sin[..., 0::2] # 对应x1的正弦 + cos1 = cos[..., 0::2] # 对应x1的余弦 + sin2 = sin[..., 1::2] # 对应x2的正弦 + cos2 = cos[..., 1::2] # 对应x2的余弦 + + print(f" sin1形状: {sin1.shape}, cos1形状: {cos1.shape}") + print(f" sin2形状: {sin2.shape}, cos2形状: {cos2.shape}") + + # 应用旋转公式(确保广播正确) + rotated_x1 = x1 * cos1 - x2 * sin2 + rotated_x2 = x1 * sin1 + x2 * cos2 + + print(f" rotated_x1形状: {rotated_x1.shape}") + print(f" rotated_x2形状: {rotated_x2.shape}") + + # 重新组合 + result = torch.stack([rotated_x1, rotated_x2], dim=-1) + result = result.flatten(start_dim=-2) + + print(f" 最终输出形状: {result.shape}") + return result + + +def test_fixed_version(): + """测试修复后的版本""" + print("=" * 60) + print("🧪 测试修复后的版本") + print("=" * 60) + + # 测试1: 3D输入 + print("测试1: 3D输入 [batch, seq_len, dim]") + dim = 6 + rope = RotaryPositionEmbeddingSimple(dim) + + x_3d = torch.randn(2, 3, dim) # [batch=2, seq_len=3, dim=6] + positions = torch.tensor([0, 1, 2]) + + try: + result_3d = rope(x_3d, positions) + print("✅ 3D输入测试通过") + except Exception as e: + print(f"❌ 3D输入测试失败: {e}") + + # 测试2: 4D输入 + print("\n测试2: 4D输入 [batch, heads, seq_len, dim]") + x_4d = torch.randn(2, 4, 3, dim) # [batch=2, heads=4, seq_len=3, dim=6] + + try: + result_4d = rope(x_4d, positions) + print("✅ 4D输入测试通过") + except Exception as e: + print(f"❌ 4D输入测试失败: {e}") + + +def debug_shape_issue(): + """调试原始的形状问题""" + print("\n" + "=" * 60) + print("🐛 调试原始的形状问题") + print("=" * 60) + + dim = 4 + batch_size, seq_len = 2, 3 + + # 创建测试数据 + x = torch.randn(batch_size, seq_len, dim) + positions = torch.arange(seq_len) + + print("原始问题分析:") + print(f"输入x形状: {x.shape}") # [2, 3, 4] + + # 计算正弦余弦(错误的方式) + indices = torch.arange(0, dim, 2).float() + inv_freq = 1.0 / (10000 ** (indices / dim)) + + angles = positions.unsqueeze(-1) * inv_freq.unsqueeze(0) # [3, 2] + angles_expanded = angles.repeat_interleave(2, dim=-1) # [3, 4] + + sin = torch.sin(angles_expanded) # [3, 4] + cos = torch.cos(angles_expanded) # [3, 4] + + print(f"计算出的sin形状: {sin.shape}") # [3, 4] + print(f"计算出的cos形状: {cos.shape}") # [3, 4] + + # 错误:直接使用会导致形状不匹配 + print(f"❌ 问题: sin{sin.shape} 与 x{x.shape} 形状不匹配") + print(f"❌ 需要将sin从[3,4]扩展到[2,3,4]") + + # 正确的方式 + sin_correct = sin.unsqueeze(0).expand(batch_size, -1, -1) # [2, 3, 4] + print(f"✅ 正确扩展后: {sin_correct.shape}") + + +def simple_demo(): + """简单的演示""" + print("\n" + "=" * 60) + print("🎯 简单演示") + print("=" * 60) + + # 使用更小的维度便于观察 + dim = 4 + rope = RotaryPositionEmbeddingSimple(dim, base=100) + + # 创建简单的测试数据 + x = torch.tensor([ + [[1.0, 0.0, 0.5, 0.5], # 第一个序列 + [0.0, 1.0, 0.3, 0.7]], + + [[0.5, 0.5, 1.0, 0.0], # 第二个序列 + [0.7, 0.3, 0.0, 1.0]] + ]) # [batch=2, seq_len=2, dim=4] + + print("输入数据:") + print(f"批次0, token0: {x[0,0].tolist()}") + print(f"批次0, token1: {x[0,1].tolist()}") + print(f"批次1, token0: {x[1,0].tolist()}") + + # 应用RoPE + result = rope(x) + + print("\n旋转后数据:") + print(f"批次0, token0: {result[0,0].tolist()}") + print(f"批次0, token1: {result[0,1].tolist()}") + print(f"批次1, token0: {result[1,0].tolist()}") + + +def verify_calculation(): + """验证计算正确性""" + print("\n" + "=" * 60) + print("✅ 验证计算正确性") + print("=" * 60) + + # 使用2维向量手动验证 + dim = 2 + rope = RotaryPositionEmbeddingSimple(dim, base=10000) + + # 创建简单的测试向量 + x = torch.tensor([[[1.0, 0.0]]]) # [1, 1, 2] + positions = torch.tensor([1]) + + # 手动计算期望结果 + # 对于2维向量,只有一个频率θ + theta = 1.0 / (10000 ** (0 / 2)) # i=0, θ=1.0 + angle = 1 * theta # 位置1,角度=1弧度 + + # 手动旋转计算 + x_manual = torch.tensor([ + [1.0 * math.cos(angle) - 0.0 * math.sin(angle), + 1.0 * math.sin(angle) + 0.0 * math.cos(angle)] + ]) + + # RoPE计算 + x_rope = rope(x, positions) + + print(f"手动计算: {x_manual.tolist()}") + print(f"RoPE计算: {x_rope[0,0].tolist()}") + + # 检查是否一致 + diff = torch.abs(x_manual - x_rope[0,0]).max().item() + if diff < 1e-6: + print("✅ 计算正确性验证通过") + else: + print(f"❌ 计算有差异: {diff}") + + +def main(): + """运行所有测试""" + print("🚀 开始修复后的RoPE测试") + + # 运行测试 + test_fixed_version() + debug_shape_issue() + simple_demo() + verify_calculation() + + print("\n" + "=" * 60) + print("🎉 所有测试完成!") + print("=" * 60) + + +if __name__ == "__main__": + torch.manual_seed(42) + main() \ No newline at end of file diff --git a/scripts/infer_task.py b/scripts/infer_task.py index 0d1231b7..7267fb01 100644 --- a/scripts/infer_task.py +++ b/scripts/infer_task.py @@ -3,11 +3,11 @@ def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): self.id = id self.finish_reason = None self.tokens = tokens - self.max_tokens = max_tokens + self.max_tokens = max_tokens # length ? self.temperature = temperature self.topk = topk self.topp = topp - self.end_tokens = end_tokens + self.end_tokens = end_tokens # tow type end symbol self._kv_cache = None self.pos = 0 @@ -38,7 +38,7 @@ def next(self, out_token): class KVCache: def __init__(self, model): - self._kvcache = model.create_kv_cache() + self._kvcache = model.create_kv_cache() # in c library self.tokens = [0 for _ in range(model.max_context_len())] def data(self): diff --git a/scripts/jiuge.py b/scripts/jiuge.py index 7c31baf8..6a96c7a5 100644 --- a/scripts/jiuge.py +++ b/scripts/jiuge.py @@ -874,4 +874,4 @@ def test(): if __name__ == "__main__": - test() + test() \ No newline at end of file diff --git a/scripts/libinfinicore_infer/__init__.py b/scripts/libinfinicore_infer/__init__.py index 8fc5f4db..638d60df 100644 --- a/scripts/libinfinicore_infer/__init__.py +++ b/scripts/libinfinicore_infer/__init__.py @@ -1,5 +1,6 @@ from .base import DataType, DeviceType, KVCacheCStruct from .jiuge import JiugeModel, JiugeMetaCStruct, JiugeWeightsCStruct +from .llada import LLaDAModel, LLaDAMetaCStruct, LLaDAWeightsCStruct from .jiuge_awq import JiugeAWQModel, JiugeAWQMetaCStruct, ModelWeightsCStruct from .deepseek_v3 import ( DeepSeekV3Model, @@ -24,4 +25,7 @@ "DeepSeekV3WeightsCStruct", "DeepSeekV3WeightLoaderCStruct", "ModelRegister", + "LLaDAModel", + "LLaDAMetaCStruct", + "LLaDAWeightsCStruct" ] diff --git a/scripts/libinfinicore_infer/base.py b/scripts/libinfinicore_infer/base.py index bed65b2e..6a4384ab 100644 --- a/scripts/libinfinicore_infer/base.py +++ b/scripts/libinfinicore_infer/base.py @@ -64,6 +64,7 @@ def __init__(self): register_lib_functions(self.lib) def _load_library(self): + lib_path = os.path.join( os.environ.get("INFINI_ROOT"), "lib", "libinfinicore_infer.so" ) diff --git a/scripts/libinfinicore_infer/jiuge.py b/scripts/libinfinicore_infer/jiuge.py index fe2abf10..a1a583f0 100644 --- a/scripts/libinfinicore_infer/jiuge.py +++ b/scripts/libinfinicore_infer/jiuge.py @@ -148,4 +148,4 @@ def forward_batch( ): self.lib.forwardBatchJiuge( model, tokens, ntok, req_lens, nreq, req_pos, kv_caches, logits - ) + ) \ No newline at end of file diff --git a/scripts/libinfinicore_infer/llada.py b/scripts/libinfinicore_infer/llada.py new file mode 100644 index 00000000..866a7532 --- /dev/null +++ b/scripts/libinfinicore_infer/llada.py @@ -0,0 +1,164 @@ +from .base import BaseModel, DataType, DeviceType, KVCacheCStruct, register_model +from ctypes import c_size_t, c_uint, c_int, c_float, c_void_p, POINTER, Structure, byref + + +class LLaDAMetaCStruct(Structure): # from config file + _fields_ = [ # (name, size) + ("dt_logits", DataType), # 4 bytes + ("_pad0", c_uint), # 填充 4 bytes,使下一个 c_size_t 对齐 + ("nlayer", c_size_t), # 8 bytes + ("d", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di_dense", c_size_t), + ("di_expert", c_size_t), + ("dctx", c_size_t), + ("dvoc", c_size_t), + ("epsilon", c_float), # 4 bytes + ("theta", c_float), + ("end_token", c_uint), # 4 bytes + ("_pad1", c_uint), # 填充 4 bytes,使下一个 size_t 对齐 + ("num_experts", c_size_t), + ] # equal to c structure in c + + +class LLaDAWeightsCStruct(Structure): + _fields_ = [ + ("nlayer", c_size_t), + ("dt_norm", DataType), + ("dt_mat", DataType), + ("transpose_linear_weights", c_int), + ("input_embd", c_void_p), + ("output_norm", c_void_p), + ("output_embd", c_void_p), + ("attn_norm", POINTER(c_void_p)), + ("attn_qkv", POINTER(c_void_p)), + ("attn_qkv_b", POINTER(c_void_p)), + ("attn_q_norm", POINTER(c_void_p)), + ("attn_k_norm", POINTER(c_void_p)), + ("attn_o", POINTER(c_void_p)), + ("ffn_norm", POINTER(c_void_p)), + + ("expert_gate", POINTER(c_void_p)), + ("expert_up", POINTER(c_void_p)), + ("expert_down", POINTER(c_void_p)), + ("router", POINTER(c_void_p)), + ] + +class LLaDAModelCStruct(Structure): + pass + + +@register_model +class LLaDAModel(BaseModel): + @classmethod + def register_lib(cls, lib): + # 此处实现参数列表的对应 + # TODO: + # createJiugeModel C++ interface + # destoryJIugeModel C++ interface + # inferBatchJiuge C++ interface + # forwardBatchJiuge C++ interface + # TODO: 根据最后4个的实现完善参数列表 + lib.createLLaDAModel.restype = POINTER(LLaDAModelCStruct) # OK + lib.createLLaDAModel.argtypes = [ + POINTER(LLaDAMetaCStruct), + POINTER(LLaDAWeightsCStruct), + DeviceType, + c_int, + POINTER(c_int), # const --> Pointer + ] # OK + + # lib.destroyJiugeModel.argtypes = [POINTER(LLaDAModelCStruct)] + + lib.createKVCache.argtypes = [ + c_size_t, + c_size_t, + c_size_t, + c_size_t, + c_size_t, + DataType, + DeviceType, + POINTER(c_int), + c_size_t, + ] + lib.createKVCache.restype = POINTER(KVCacheCStruct) + + lib.dropKVCache.argtypes = [POINTER(KVCacheCStruct)] + + lib.inferBatchLLaDA.argtypes = [ + POINTER(LLaDAModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), + POINTER(POINTER(KVCacheCStruct)), + POINTER(c_float), + POINTER(c_uint), + POINTER(c_float), + POINTER(c_uint), + ] + + lib.forwardBatchLLaDA.argtypes = [ + POINTER(LLaDAModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), + POINTER(POINTER(KVCacheCStruct)), + c_void_p, + ] + def create_model(self, meta, weights, device_type, ndev, dev_ids): + # TODO: + return self.lib.createLLaDAModel(meta, weights, device_type, ndev, dev_ids) + + def destroy_model(self, model): + pass + + def create_kv_cache( + self, nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev + ): + return self.lib.createKVCache( + nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev + ) + + def drop_kv_cache(self, kv_cache): + self.lib.dropKVCache(kv_cache) + + def infer_batch( + self, + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + temperature, + topk, + topp, + output, + ): + self.lib.inferBatchLLaDA( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + temperature, + topk, + topp, + output, + ) + + def forward_batch( + self, model, tokens, ntok, req_lens, nreq, req_pos, kv_caches, logits + ): + self.lib.forwardBatchLLaDA( + model, tokens, ntok, req_lens, nreq, req_pos, kv_caches, logits + ) diff --git a/scripts/llada.py b/scripts/llada.py new file mode 100644 index 00000000..19509965 --- /dev/null +++ b/scripts/llada.py @@ -0,0 +1,748 @@ +from typing import List, Sequence +import math +import os +from pathlib import Path +import safetensors +import sys +import time +import json +import torch +import transformers +from infer_task import InferTask, KVCache + +from libinfinicore_infer import ( + DeviceType, + KVCacheCStruct, + DataType +) +from libinfinicore_infer.llada import LLaDAModel, LLaDAMetaCStruct, LLaDAWeightsCStruct + +from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref +import torch.nn.functional as F +import numpy as np + +class LLaDAWeifghtsNaming: + def input_embd(self): + return "model.embed_tokens.weight" + + def output_norm(self): + return "model.norm.weight" + + def output_embd(self): + return "lm_head.weight" + + def attn_norm(self, i): + return f"model.layers.{i}.input_layernorm.weight" + + def attn_q(self, i): + return f"model.layers.{i}.self_attn.q_proj.weight" + + def attn_k(self, i): + return f"model.layers.{i}.self_attn.k_proj.weight" + + def attn_v(self, i): + return f"model.layers.{i}.self_attn.v_proj.weight" + + def attn_o(self, i): + return f"model.layers.{i}.self_attn.o_proj.weight" + + def attn_q_b(self, i): + return f"model.layers.{i}.self_attn.q_proj.bias" + + def attn_k_b(self, i): + return f"model.layers.{i}.self_attn.k_proj.bias" + + def attn_v_b(self, i): + return f"model.layers.{i}.self_attn.v_proj.bias" + + def attn_q_norm(self, i): + return f"model.layers.{i}.self_attn.q_norm.weight" + + def attn_k_norm(self, i): + return f"model.layers.{i}.self_attn.k_norm.weight" + + def ffn_norm(self, i): + return f"model.layers.{i}.post_attention_layernorm.weight" + + def router(self, i): + return f"model.layers.{i}.mlp.gate.weight" + + def expert_gate(self, i, j): + return f"model.layers.{i}.mlp.experts.{j}.gate_proj.weight" + + def expert_up(self, i, j): + return f"model.layers.{i}.mlp.experts.{j}.up_proj.weight" + + def down(self, i, j): + return f"model.layers.{i}.mlp.experts.{j}.down_proj.weight" + + + +class LLaDAMetaFromLlama(LLaDAMetaCStruct): # model specific data: heads num .... + def __init__(self, config, dtype=torch.float16, max_tokens = None): + if dtype == torch.float16: + dt_ = DataType.INFINI_DTYPE_F16 + elif dtype == torch.float32: + dt_ = DataType.INFINI_DTYPE_F32 + elif dtype == torch.bfloat16: + dt_ = DataType.INFINI_DTYPE_BF16 + else: + dt_ = DataType.INFINI_DTYPE_F16 + + self.scale_input = 1.0 + self.scale_output = 1.0 + self.scale_o = 1.0 + self.scale_down = 1.0 + if ( + config["model_type"] in ["fm9g", "minicpm"] + and "scale_emb" in config + and "scale_depth" in config + and "dim_model_base" in config + ): + self.scale_input = config["scale_emb"] + self.scale_output = config["hidden_size"] //config["hidden_size"] + self.scale_o = config["scale_depth"] / math.sqrt( + config["num_hidden_layers"] + ) + self.scale_down = config["scale_depth"] / math.sqrt( + config["num_hidden_layers"] + ) + + super().__init__( + dt_logits=dt_, + _pad0 = 0, + nlayer=config["num_hidden_layers"], + d=config["hidden_size"], + nh=config["num_attention_heads"], + nkvh=( + config["num_key_value_heads"] + if "num_key_value_heads" in config + else config["num_attention_heads"] + ), + dh=( + config["head_dim"] + if "head_dim" in config + else config["hidden_size"] // config["num_attention_heads"] + ), + di_dense = config["dense_intermediate_size"], + di_expert = config["expert_intermediate_size"], + dctx=( + config["max_position_embeddings"] if max_tokens is None else max_tokens + ), + dvoc=config["vocab_size"], + epsilon=config["rms_norm_eps"], + theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), + end_token=2, + _pad1=0, + num_experts=config["num_experts"] + ) + self.torch_dtype_logits = dtype + + +class LLaDAWeightsImpl(LLaDAWeightsCStruct): + def __init__(self, meta, naming, + state_dict, # 权重 + torch_dt_mat=torch.float16, + torch_dt_norm=torch.float32, + transpose_weight = None, + ndev=1, + ): + nlayer = meta.nlayer + nh = meta.nh + di_expert = meta.di_expert + di_dense = meta.di_dense + nkvh = meta.nkvh + dh = meta.dh + d = meta.d + num_experts = meta.num_experts + scale_input = meta.scale_input + scale_output = meta.scale_output + scale_o = meta.scale_o + scale_down = meta.scale_down + assert nh % nkvh == 0 + assert nh % ndev == 0 + assert nkvh % ndev == 0 + + torch_dt_logits = meta.torch_dtype_logits + if torch_dt_mat == torch.float16: + self.dt_mat = DataType.INFINI_DTYPE_F16 + elif torch_dt_mat == torch.float32: + self.dt_mat = DataType.INFINI_DTYPE_F32 + elif torch_dt_mat == torch.bfloat16: + self.dt_mat = DataType.INFINI_DTYPE_BF16 + else: + raise ValueError("Unsupported proj weight data type") + if torch_dt_norm == torch.float16: + self.dt_norm = DataType.INFINI_DTYPE_F16 + elif torch_dt_norm == torch.float32: + self.dt_norm = DataType.INFINI_DTYPE_F32 + elif torch_dt_norm == torch.bfloat16: + self.dt_norm = DataType.INFINI_DTYPE_BF16 + else: + raise ValueError("Unsupported norm weight data type") + + input_embd_naming = ( + naming.input_embd() + if naming.input_embd() in state_dict + else naming.output_embd() + ) + output_embd_naming = ( + naming.output_embd() + if naming.output_embd() in state_dict + else naming.input_embd() + ) + self.transpose_linear_weights = 1 if transpose_weight else 0 + self.nlayer = nlayer + self.input_embd_tensor = ( + state_dict[input_embd_naming].to(torch_dt_logits) + ) + self.input_embd = self.input_embd_tensor.data_ptr() + self.output_norm_tensor = ( + state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output + ) + self.output_norm = self.output_norm_tensor.data_ptr() + self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) + if not transpose_weight: + self.output_embd_tensor = self.output_embd_tensor.transpose( + 0, 1 + ).contiguous() + self.output_embd = self.output_embd_tensor.data_ptr() + + self.attn_norm_tensors = [ + state_dict[naming.attn_norm(i)].to(torch_dt_norm) for i in range(nlayer) # each layer's weight + ] + self.attn_norm_ptrs = [ + self.attn_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_norm = (c_void_p * nlayer)(*self.attn_norm_ptrs) + + def qkv_slices(_i): + _Q = ( + state_dict[naming.attn_q(_i)] + .reshape([nh, 2, dh // 2, d]) + .transpose(1, 2) + ) # For RoPE + _K = ( + state_dict[naming.attn_k(_i)] + .reshape([nkvh, 2, dh // 2, d]) + .transpose(1, 2) + ) + _V = state_dict[naming.attn_v(_i)].reshape([nkvh, dh // 2, 2, d]) + _result = [] + _nh = nh // ndev + _nkvh = nkvh // ndev + for _idev in range(ndev): + _result.append(_Q[_idev * _nh : (_idev + 1) * _nh, :, :, :]) + _result.append(_K[_idev * _nkvh : (_idev + 1) * _nkvh, :, :, :]) + _result.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :]) + return _result + + self.qkv_tensor = [ + torch.concat(qkv_slices(i)).to(torch_dt_mat) for i in range(nlayer) + ] + if not transpose_weight: + for i in range(nlayer): + self.qkv_tensor[i] = ( + self.qkv_tensor[i] + .reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d) + .transpose(1, 2) + .contiguous() + ) + self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)] + self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs) + + + if naming.attn_q_norm(0) in state_dict: + print("have norm") + self.attn_q_norm_tensors = [ + state_dict[naming.attn_q_norm(i)] + .reshape([2, dh // 2]) + .transpose(0, 1) + .contiguous() + .to(torch_dt_norm) + for i in range(nlayer) + ] + self.attn_q_norm_ptrs = [ + self.attn_q_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_q_norm = (c_void_p * nlayer)(*self.attn_q_norm_ptrs) + self.attn_k_norm_tensors = [ + state_dict[naming.attn_k_norm(i)] + .reshape([2, dh // 2]) + .transpose(0, 1) + .contiguous() + .to(torch_dt_norm) + for i in range(nlayer) + ] + self.attn_k_norm_ptrs = [ + self.attn_k_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_k_norm = (c_void_p * nlayer)(*self.attn_k_norm_ptrs) + else: + self.attn_q_norm = None + self.attn_k_norm = None + + self.attn_o_tensor = [ + ( + state_dict[naming.attn_o(i)] + .to(torch_dt_mat) + .reshape([d, ndev, nh // ndev * dh]) + .transpose(0, 1) + .contiguous() + if transpose_weight + else state_dict[naming.attn_o(i)] + .transpose(0, 1) + .to(torch_dt_mat) + .contiguous() + ) + * scale_o + for i in range(nlayer) + ] + self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)] + self.attn_o = (c_void_p * nlayer)(*self.attn_o_ptrs) + + self.ffn_norm_tensors = [ + state_dict[naming.ffn_norm(i)].to(torch_dt_norm) for i in range(nlayer) + ] + self.ffn_norm_ptrs = [ + self.ffn_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.ffn_norm = (c_void_p * nlayer)(*self.ffn_norm_ptrs) + + def expert_gate_slices(layer_id, num_experts): + """ + Extract expert gate and up weights for one layer. + Compatible with keys like: + model.layers.{i}.mlp.experts.{e}.gate_proj.weight + model.layers.{i}.mlp.experts.{e}.up_proj.weight + """ + gate_list = [] + + for e in range(num_experts): + gate_key = naming.expert_gate(layer_id, e) + gate_w = state_dict[gate_key] # shape: [1024, 2048] + gate_list.append(gate_w) + return gate_list # list of num_experts tensors + + def expert_up_slices(layer_id, num_experts): + """ + Extract expert gate and up weights for one layer. + Compatible with keys like: + model.layers.{i}.mlp.experts.{e}.gate_proj.weight + model.layers.{i}.mlp.experts.{e}.up_proj.weight + """ + up_list = [] + + for e in range(num_experts): + up_key = naming.expert_up + up_key = f"model.layers.{layer_id}.mlp.experts.{e}.up_proj.weight" + up_w = state_dict[up_key] # shape: [1024, 2048] + up_list.append(up_w) + return up_list # list of num_experts tensors + + # memory: [gate_layer0_expert_gate0]...[gate_layer0_expert_gate63]......[gate_layer15_expert_gate63] + self.expert_gate_tensors = [ + torch.concat(expert_gate_slices(i, num_experts), dim=0).to(torch_dt_mat) + for i in range(nlayer) + ] + + # memory: [gate_layer0_expert_up0]...[gate_layer0_expert_up63]......[gate_layer15_expert_up63] + self.expert_up_tensors = [ + torch.concat(expert_up_slices(i, num_experts), dim=0).to(torch_dt_mat) + for i in range(nlayer) + ] + + self.expert_gate_ptrs = [self.expert_gate_tensors[i].data_ptr() for i in range(nlayer)] + self.expert_gate = (c_void_p * nlayer)(*self.expert_gate_ptrs) + + self.expert_up_ptrs = [self.expert_up_tensors[i].data_ptr() for i in range(nlayer)] + self.expert_up = (c_void_p * nlayer)(*self.expert_up_ptrs) + + def expert_down_slices(layer_id, num_experts): + """ + Extract expert gate and up weights for one layer. + Compatible with keys like: + model.layers.{i}.mlp.experts.{e}.gate_proj.weight + model.layers.{i}.mlp.experts.{e}.up_proj.weight + """ + down_list = [] + + for e in range(num_experts): + down_key = naming.down(layer_id, e) + down_w = state_dict[down_key] # shape: [1024, 2048] + # concat gate + up along dim 0 → shape: [2048, 2048] + down_list.append(down_w) + return down_list # list of num_experts tensors + + # memory: [gate_layer0_expert_down0]...[gate_layer0_expert_down63]......[gate_layer15_expert_down63] + self.expert_down_tensor = [ + torch.concat(expert_down_slices(i, num_experts), dim=0).to(torch_dt_mat) + for i in range(nlayer) + ] + self.expert_down_ptrs = [self.expert_down_tensor[i].data_ptr() for i in range(nlayer)] + self.expert_down = (c_void_p * nlayer)(*self.expert_down_ptrs) + + # Impl Python gate + def router_slices(): + """ + Extract expert gate and up weights for one layer. + Compatible with keys like: + model.layers.{i}.mlp.experts.{e}.gate_proj.weight + model.layers.{i}.mlp.experts.{e}.up_proj.weight + """ + router_list = [] + for i in range(nlayer): + gate_weight = state_dict[naming.router(i)].to(torch_dt_mat) + + router_list.append(gate_weight) + + return router_list # list of num_experts tensors + + + self.router_gate_tensor = router_slices() + # memory: [gate_layer0_router]......[gate_layer15_router] + self.router_ptrs = [self.router_gate_tensor[i].data_ptr() for i in range(nlayer)] + self.router = (c_void_p * nlayer)(*self.router_ptrs) + + + +class LLaDABatchedTask: + """ + Batch task handler for LLaDA model inference. + Similar to JiugeBatchedTask but adapted for LLaDA requirements. + """ + def __init__(self, tasks: List[InferTask]): + self.tasks = tasks + self.nreq = len(tasks) + + # Precompute fields + token_lists = [t.tokens for t in tasks] + self.req_lens_list = [len(toks) for toks in token_lists] + self.req_pos_list = [t.pos for t in tasks] + self.kv_cache_ptrs = [t.kvcache().data() for t in tasks] + self.temperaturas_list = [t.temperature for t in tasks] + self.topks_list = [t.topk for t in tasks] + self.topps_list = [t.topp for t in tasks] + + + + flat_tokens = [] + for toks in token_lists: + if isinstance(toks, (list, tuple)): + flat_tokens.extend(toks) + else: + flat_tokens.append(toks) + + # Convert all tokens to int + flat_tokens = [int(tok) for tok in flat_tokens] + + self.ntok = len(flat_tokens) + print(f"Torch : flat_tokens : {flat_tokens}") + + # Convert to ctypes arrays in one pass + self.tokens = (c_uint * self.ntok)(*flat_tokens) + self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) + self.req_pos = (c_uint * self.nreq)(*self.req_pos_list) + self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs) + self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) + self.topks = (c_uint * self.nreq)(*self.topks_list) + self.topps = (c_float * self.nreq)(*self.topps_list) + + def input_args(self): + return ( + self.tokens, + self.ntok, + self.req_lens, + self.nreq, + self.req_pos, + self.kv_caches, + self.temperaturas, + self.topks, + self.topps, + ) + + +class LLaDAForCauslLM: + def __init__( + self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None + ): + def load_all_safetensors_from_dir(dir_path_: str): #TODO: Load. Accelerate By Page Cache + tensors_ = {} + dir_path_ = Path(dir_path_) + print(f"load Dir path {dir_path_}") + for file in sorted(dir_path_.glob("*.safetensors")): + data_ = safetensors.safe_open(file, "pt") + for name_ in data_.keys(): + # print("Tensor Name ") + # print(name_) + tensors_[name_] = data_.get_tensor(name_) + return tensors_ + + print("Loading model weights to host...") + load_start_time = time.time() + + with open(os.path.join(model_dir_path, "config.json"), "r") as f: + config = json.load(f) + self.config = config + eos_token_id = self.config["eos_token_id"] + self.eos_token_id = ( + [eos_token_id] if type(eos_token_id) == int else eos_token_id + ) + + self.llada_model = LLaDAModel() # TODO: 实现LLaDAModel + + state_dict = load_all_safetensors_from_dir(model_dir_path) + # C Structure Meta and weights + self.meta = LLaDAMetaFromLlama(config, dtype=torch.bfloat16, max_tokens=max_tokens) + self.weights = LLaDAWeightsImpl( + self.meta, + LLaDAWeifghtsNaming(), + state_dict, + ndev=ndev, + transpose_weight=None, + ) # bottleneck + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_dir_path, trust_remote_code=True + ) # bottleneck + load_end_time = time.time() + print(f"Time used: {load_end_time - load_start_time:.3f}s") + print(f"Creating model on {ndev} devices...") + load_start_time = time.time() + self.dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) + self.ndev = ndev + self.device = device + print("--- start create model ---") + self.model_instance = self.llada_model.create_model( + byref(self.meta), + byref(self.weights), + device, + ndev, + self.dev_ids, + ) + self.model_ptr = self.model_instance + load_end_time = time.time() + print(f"Time used: {load_end_time - load_start_time:.3f}s") + + + # <-------------------------------------- Infer PipeLine ------------------------------------------------> + def max_context_len(self): + return self.meta.dctx + + def create_kv_cache(self): + """Create KV cache for the model""" + return self.llada_model.create_kv_cache( + self.meta.nlayer, + self.meta.dctx, + self.meta.nkvh, + self.meta.dh, + self.meta.dh, + self.meta.dt_logits, + self.device, + self.dev_ids, + self.ndev, + ) + + def drop_kv_cache(self, kv_cache): + """Drop KV cache""" + self.llada_model.drop_kv_cache(kv_cache) + + def batch_infer_one_round(self, tasks: List[InferTask]): + """ + Perform one round of batch inference using LLaDA model. + + Args: + tasks: List of InferTask objects containing input sequences and parameters + + Returns: + List of generated token IDs + """ + output = (c_uint * len(tasks))() + batch_inputs = LLaDABatchedTask(tasks) + self.llada_model.infer_batch( + self.model_instance, + *(batch_inputs.input_args()), + output, + ) + return list(output) + + def generate( + self, + prompts: str, + max_steps: int = 128, + gen_length: int = 128, + block_length: int = 128, + temperature_: float = 0., + cfg_scale: float = 0., + remasking: str = 'low_confidence', + mask_id: int = 126336, + logits_eos_inf: bool = False, + confidence_eos_eot_inf: bool = False, + verbose: bool = False, + topp_ = 1.0, + topk_ = 1 + ): + if isinstance(prompts, str): + prompts = [prompts] + + # Apply chat template and tokenize + print("Staring generate prepare") + messages = [{"role": "user", "content": prompt} for prompt in prompts] + formatted_prompts = [self.tokenizer.apply_chat_template([message], add_generation_prompt=True, tokenize=False) for message in messages] + encoded_outputs = self.tokenizer.batch_encode_plus( + formatted_prompts, + add_special_tokens=False, + padding=True, + return_tensors="pt" + ) + # Extract input_ids from batch encoding + input_ids = encoded_outputs['input_ids'] + + # For single prompt, get the first sequence + if len(prompts) == 1: + tokens = input_ids[0].tolist() + else: + # For batch, handle each sequence separately + tokens = [seq.tolist() for seq in input_ids] + print(len(tokens)) + print(f"Pytho Side Tokens type: {type(tokens)}, content: {tokens if isinstance(tokens, list) else 'Not a list'}") + + infer_task = InferTask( + 0, + tokens, + self.max_context_len(), + temperature_, + topk_, + topp_, + self.eos_token_id, + ) + + # Bind KV cache + kv_cache = KVCache(self) + infer_task.bind_kvcache(kv_cache) + print("Staring Infering") + output_tokens = self.batch_infer_one_round([infer_task]) + + + + + + def forward_logits_batch(self, input_ids_tensor, attention_mask_tensor=None): + """ + Forward pass to get logits for a batch of sequences using C++ model. + + Args: + input_ids_tensor: Tensor of shape (batch_size, seq_len) with token IDs + attention_mask_tensor: Tensor of shape (batch_size, seq_len) with attention mask + + Returns: + logits: Tensor of shape (batch_size, seq_len, vocab_size) + """ + batch_size, seq_len = input_ids_tensor.shape + + # Create InferTask objects for each sequence in the batch + tasks = [] + for i in range(batch_size): + # Extract tokens for this sequence + seq_tokens = input_ids_tensor[i].tolist() + # Create KVCache for this sequence + kv_cache = KVCache( + self.meta.nlayer, + self.meta.dctx, + self.meta.nkvh, + self.meta.dh, + self.meta.dh, + self.meta.dt_logits, + self.device, + self.dev_ids, + self.ndev, + ) + # Create InferTask + task = InferTask( + tokens=seq_tokens, + pos=0, # Start position + temperature=0.0, # Will be handled by sampling logic + topk=1, # Will be handled by sampling logic + topp=1.0, # Will be handled by sampling logic + kvcache=kv_cache, + ) + tasks.append(task) + + # Create batched task + batch_inputs = LLaDABatchedTask(tasks) + + # Prepare output tensor for logits + vocab_size = self.config.get("vocab_size", 150528) + logits_tensor = torch.zeros( + batch_inputs.ntok, vocab_size, + dtype=self.meta.torch_dtype_logits, + device=torch.device("cpu") + ) + + # Call C++ forward_batch + self.llada_model.forward_batch( + self.model_instance, + batch_inputs.tokens, + batch_inputs.ntok, + batch_inputs.req_lens, + batch_inputs.nreq, + batch_inputs.req_pos, + batch_inputs.kv_caches, + logits_tensor.data_ptr(), + ) + + # Reshape logits to (batch_size, seq_len, vocab_size) + # Note: This requires careful handling of the flattened output + logits_reshaped = torch.zeros(batch_size, seq_len, vocab_size, dtype=logits_tensor.dtype, device=logits_tensor.device) + + # Copy logits back to batch format + token_offset = 0 + for req_idx, req_len in enumerate(batch_inputs.req_lens_list): + # Extract logits for this request + req_logits = logits_tensor[token_offset:token_offset + req_len] + logits_reshaped[req_idx, :req_len] = req_logits + token_offset += req_len + + # Clean up KV caches + for task in tasks: + task.kvcache().drop() + + return logits_reshaped + + + + +def test(): + + model_path = "/home/featurize/work/InfiniFamily/cache/models--inclusionAI--LLaDA-MoE-7B-A1B-Instruct/snapshots/783d3467f108d28ac0a78d3e41af16ab05cabd8d" + device_type = DeviceType.DEVICE_TYPE_NVIDIA + verbose = True + + # Number of devices + ndev = 1 + + print("Loading LLaDA model...") + model = LLaDAForCauslLM(model_path, device_type, ndev) + + # Test prompts + test_prompts = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour." + + print("\n=== Testing C++ Model Integration Function ===") + + result = model.generate( + prompts=test_prompts, + max_steps=16, # Reduced for faster testing + gen_length=32, # Shorter generation for testing + block_length=16, + temperature_=0.0, # Deterministic for testing + verbose=verbose + ) + # # print(f"Result: {result}") + # # except Exception as e: + # # print(f"Error in C++ model generation: {e}") + + + +if __name__ == "__main__": + import os + print(os.getpid()) + test() \ No newline at end of file diff --git a/setup.sh b/setup.sh new file mode 100644 index 00000000..95c7a451 --- /dev/null +++ b/setup.sh @@ -0,0 +1,16 @@ +sudo apt-get update +sudo apt-get install cmake +sudo apt-get install gdb + +wget https://xmake.io/shget.text -O - | bash +source ~/.xmake/profile +sudo sh cuda_12.0.0_525.60.13_linux.run +sudo cp /environment/miniconda3/lib/python3.11/site-packages/nvidia/cudnn/include/cudnn*.h /usr/include +sudo cp /environment/miniconda3/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn* /usr/lib +sudo ln -s /usr/lib/libcudnn.so.8 /usr/lib/libcudnn.so +python InfiniCore/scripts/install.py --nv-gpu=y +source ./start.sh +export LD_LIBRARY_PATH=/usr/local/cuda-12.0/lib64:$LD_LIBRARY_PATH +CUDA12.0 in local environment +export INFINI_ROOT=/home/featurize/.infini/ +/home/featurize/work/InfiniRefresh/InfiniCore what you need \ No newline at end of file diff --git a/src/cache_manager/opcache_manager.hpp b/src/cache_manager/opcache_manager.hpp index 4c49e961..c90df9e7 100644 --- a/src/cache_manager/opcache_manager.hpp +++ b/src/cache_manager/opcache_manager.hpp @@ -140,6 +140,17 @@ class LRUDescriptorCache { // Helper macro to generate the destroy function name #define DESTROY_FUNC(OpType) infiniopDestroy##OpType##Descriptor +// 宏展开后的结果 +// LRUDescriptorCache Add_cache; + +// bool getAddDescriptor(size_t key, infiniopAddDescriptor_t &desc) { +// return Add_cache.get(key, desc); +// } + +// void putAddDescriptor(size_t key, const infiniopAddDescriptor_t &desc) { +// Add_cache.put(key, desc); +// } + // Declare cache and access functions #define DECLARE_OP_CACHE(OpType) \ LRUDescriptorCache OpType##_cache; \ @@ -162,6 +173,10 @@ class CacheManager { DECLARE_OP_CACHE(SwiGLU) DECLARE_OP_CACHE(RandomSample) DECLARE_OP_CACHE(DequantizeAWQ) + DECLARE_OP_CACHE(Softmax) // 新增 + DECLARE_OP_CACHE(BiAttention) + DECLARE_OP_CACHE(Topksoftmax) // 新增 topksoftmax + CacheManager(size_t capacity = 100) : Add_cache(capacity, DESTROY_FUNC(Add)), @@ -173,7 +188,10 @@ class CacheManager { Topkrouter_cache(capacity, DESTROY_FUNC(Topkrouter)), SwiGLU_cache(capacity, DESTROY_FUNC(SwiGLU)), RandomSample_cache(capacity, DESTROY_FUNC(RandomSample)), - DequantizeAWQ_cache(capacity, DESTROY_FUNC(DequantizeAWQ)) {} + DequantizeAWQ_cache(capacity, DESTROY_FUNC(DequantizeAWQ)), + Softmax_cache(capacity, DESTROY_FUNC(Softmax)), + BiAttention_cache(capacity, DESTROY_FUNC(BiAttention)), + Topksoftmax_cache(capacity, DESTROY_FUNC(Topksoftmax)){} template static size_t createDescriptorKey(Tensors... tensors) { diff --git a/src/models/deepseek_v3/deepseek_v3_weight.cpp b/src/models/deepseek_v3/deepseek_v3_weight.cpp index 846af633..906b76bf 100644 --- a/src/models/deepseek_v3/deepseek_v3_weight.cpp +++ b/src/models/deepseek_v3/deepseek_v3_weight.cpp @@ -436,7 +436,7 @@ static DeepSeekV3WeightLoader weight_loader = { .load_mlp_experts = load_mlp_experts, }; -__C DeepSeekV3Weights * +__C __export DeepSeekV3Weights * createDeepSeekV3Weights(const DeepSeekV3Meta *meta, infiniDevice_t device, int ndev, diff --git a/src/models/inference_context.cpp b/src/models/inference_context.cpp index db5fda11..5d56966f 100644 --- a/src/models/inference_context.cpp +++ b/src/models/inference_context.cpp @@ -143,6 +143,77 @@ void InferenceContext::causalSoftmax(std::shared_ptr y, y->data(), x->data(), stream)); } +void InferenceContext::BiAttention(std::shared_ptr out, + std::shared_ptr q, + std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr k_cache, + std::shared_ptr v_cache, + int pos){ + size_t key = CacheManager::createDescriptorKey(out, q, k, v, k_cache, v_cache); + infiniopBiAttentionDescriptor_t desc; + if (!cache_manager->getBiAttentionDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateBiAttentionDescriptor( + op_handle, &desc, out->desc(), q->desc(), k->desc(), v->desc(), k_cache->desc(), v_cache->desc(), pos)); + cache_manager->putBiAttentionDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetBiAttentionWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopBiAttention( desc, workspace, workspace_size, + out->data(), q->data(), k->data(), v->data(), k_cache->data(), v_cache->data(), stream )); + +} + +void InferenceContext::softmax(std::shared_ptr y, + std::shared_ptr x, int dim) { + size_t key = CacheManager::createDescriptorKey(y, x); + + infiniopSoftmaxDescriptor_t desc; + if (!cache_manager->getSoftmaxDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateSoftmaxDescriptor( + op_handle, &desc, y->desc(), x->desc(), dim)); + cache_manager->putSoftmaxDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetSoftmaxWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopSoftmax(desc, workspace, workspace_size, + y->data(), x->data(), stream)); +} + +void InferenceContext::topksoftmax( + std::shared_ptr values, // F32 + std::shared_ptr indices, // I32 + std::shared_ptr x, + size_t topk, + bool norm + ){ + size_t key = CacheManager::createDescriptorKey(values, indices, x); + hash_combine(key, std::hash()(topk)); + hash_combine(key, std::hash()(norm)); + + infiniopTopksoftmaxDescriptor_t desc; + if (!cache_manager->getTopksoftmaxDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateTopksoftmaxDescriptor(op_handle, &desc, x->desc())); + cache_manager->putTopksoftmaxDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetTopksoftmaxWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopTopksoftmax(desc, workspace, workspace_size, + values->data(), indices->data(), x->data(), topk, norm, stream)); +} + void InferenceContext::topkrouter(std::shared_ptr values, // F32 std::shared_ptr indices, // I32 std::shared_ptr x, @@ -280,4 +351,4 @@ void InferenceContext::dequant(std::shared_ptr weight, RUN_INFINI(infiniopDequantizeAWQ( desc, workspace, workspace_size, weight->data(), in_w->data(), in_s->data(), in_z->data(), stream)); -} +} \ No newline at end of file diff --git a/src/models/inference_context.hpp b/src/models/inference_context.hpp index 0cf93f6f..57d59497 100644 --- a/src/models/inference_context.hpp +++ b/src/models/inference_context.hpp @@ -16,6 +16,15 @@ struct InferenceContext { void ensure_workspace(size_t required_size); + + void topksoftmax( + std::shared_ptr values, // F32 + std::shared_ptr indices, // I32 + std::shared_ptr x, + size_t topk, + bool norm + ); + void add(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b); @@ -37,6 +46,17 @@ struct InferenceContext { infiniopRoPEAlgo_t algo); void causalSoftmax(std::shared_ptr y, std::shared_ptr x); + + void BiAttention( std::shared_ptr out, + std::shared_ptr q, + std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr k_cache, + std::shared_ptr v_cache, + int pos); + + void softmax(std::shared_ptr y, + std::shared_ptr x, int dim); //新增 void topkrouter(std::shared_ptr values, // F32 std::shared_ptr indices, // I32 @@ -111,6 +131,20 @@ inline void causalSoftmax(std::shared_ptr y, std::shared_ptr x) getInferenceContext().causalSoftmax(y, x); } +inline void BiAttention(std::shared_ptr out, + std::shared_ptr q, + std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr k_cache, + std::shared_ptr v_cache, + int pos){ + getInferenceContext().BiAttention(out, q, k, v, k_cache, v_cache, pos); +} + +inline void softmax(std::shared_ptr y, std::shared_ptr x, int dim) { + getInferenceContext().softmax(y, x, dim); +} + inline void topkrouter(std::shared_ptr values, // F32 std::shared_ptr indices, // I32 std::shared_ptr x, @@ -126,6 +160,18 @@ inline void topkrouter(std::shared_ptr values, // F32 topk); } +inline void topksoftmax(std::shared_ptr values, // F32 + std::shared_ptr indices, // I32 + std::shared_ptr x, + size_t topk, + bool norm) { + getInferenceContext().topksoftmax(values, + indices, + x, + topk, + norm); +} + inline void swiglu(std::shared_ptr out, std::shared_ptr up, std::shared_ptr gate) { getInferenceContext().swiglu(out, up, gate); @@ -148,4 +194,4 @@ inline void dequant_linear(std::shared_ptr out, std::shared_ptr auto w = Tensor::buffer(x->dtype(), {x->shape()[1], out->shape()[1]}, getInferenceContext().memory_pool); getInferenceContext().dequant(w, w_w, w_s, w_z); getInferenceContext().linear(out, x, w, alpha, beta, residual, bias); -} +} \ No newline at end of file diff --git a/src/models/jiuge/jiuge.cpp b/src/models/jiuge/jiuge.cpp index 41f8e5ea..6ed207fd 100644 --- a/src/models/jiuge/jiuge.cpp +++ b/src/models/jiuge/jiuge.cpp @@ -15,6 +15,7 @@ void createDeviceResource(JiugeDeviceResource *rsrc, const JiugeMeta *meta, infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) { + std::cout << "Set Device" << std::endl; RUN_INFINI(infinirtSetDevice(device, dev_id)); infiniopHandle_t handle; infiniopCreateHandle(&handle); @@ -127,6 +128,10 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, struct KVCache **kv_caches, const float *temperature, const uint32_t *topk, const float *topp, uint32_t *output, void *last_logits) { + std::cout << "entering infer batch" << std::endl; + + std::cout << "Calucute Hyper Parameter" << std::endl; + auto nlayer = meta.nlayer; auto nkvh = meta.nkvh / ndev; auto nh = meta.nh / ndev; @@ -142,6 +147,7 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, bool has_qk_norm = rsrc.w_attn_q_norm.size() > 0 && rsrc.w_attn_k_norm.size() > 0; // Allocate buffers + std::cout << "Allocating Buffer " << std::endl; auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); auto logits_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); auto qkv_buf = Tensor::buffer(dt_logits, {ntok, (nh + nkvh * 2) * dh}, rsrc.memory_pool); @@ -156,6 +162,7 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, auto k_buf = qkv_rope->slice(1, nh, nkvh); // Prepare inputs + std::cout << "Preparing Input" << std::endl; auto batch_pos_ids = std::vector(ntok); size_t req_start = 0; for (uint32_t req = 0; req < nreq; req++) { @@ -203,12 +210,15 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, auto gate_buf = gate_up_buf->slice(1, 0, di); auto up_buf = gate_up_buf->slice(1, di, di); + std::cout << "Transformer Layer Stream" << std::endl; // Compute for (uint32_t layer = 0; layer < nlayer; layer++) { // 1. Attention // rms norm + std::cout << "Are you OK" << std::endl; rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon); // qkv_proj + std::cout << "rmsnorm is OK" << std::endl; linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, nullptr, has_qkv_bias ? rsrc.b_attn_qkv[layer] : nullptr); if (has_qk_norm) { @@ -216,7 +226,8 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, rmsnorm(k_buf, k_buf, rsrc.w_attn_k_norm[layer], meta.epsilon); } - // rope + // rope + std::cout << "Position Embedding" << std::endl; rope(q_buf, q_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table); rope(k_buf, k_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table); @@ -275,6 +286,7 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, } } // Sample and Output + std::cout << "starting reduce result" << std::endl; if (idev == 0) { if (last_logits != nullptr) { rmsnorm(logits_out, logits_in, rsrc.w_out_norm, meta.epsilon); @@ -380,12 +392,16 @@ forwardBatchJiuge(struct JiugeModel *model, } } -void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, JiugeDeviceResource *rsrc, InferState &state, InferRequest &req, +void launchDevice(const JiugeMeta & meta, const JiugeWeights *weights, JiugeDeviceResource *rsrc, InferState &state, InferRequest &req, infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) { + std::cout << "launch device" << std::endl; // Create Device Resource createDeviceResource(rsrc, &meta, weights, device, idev, ndev, dev_id, comm); + + std::cout << "Cache Manager initing ..." << std::endl; + CacheManager cache_manager(100); - CacheManager cache_manager(100); + std::cout << "Context Initing" << std::endl; InferenceContext ctx(rsrc->handle, rsrc->memory_pool, &cache_manager, rsrc->stream); // Set the inference context for this thread @@ -406,7 +422,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, JiugeDevic if (state.exit_flag) { break; } - + std::cout << "Infering Device Batch" << std::endl; inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok, req.req_lens, req.nreq, req.req_pos, req.kv_caches, req.temperature, req.topk, req.topp, req.output, req.logits); @@ -417,12 +433,15 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, JiugeDevic } // Clean-Up + std::cout << "Clearing Context" << std::endl; releaseDeviceResource(*rsrc); setInferenceContext(nullptr); // Clear the context when done } JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infiniDevice_t device_, std::vector device_ids) : meta(*_meta) { + std::cout << "Starting Distri Deploy Model" << std::endl; int ndev = int(device_ids.size()); + std::cout << "The nums of dev is " << ndev << std::endl; device = device_; dev_ids = device_ids; dev_resources = std::vector(ndev); @@ -435,6 +454,7 @@ JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infi } for (int i = 0; i < ndev; i++) { + std::cout << "Launch Device " << i << " Thread" << std::endl; threads[i] = std::thread(launchDevice, std::cref(meta), weights, &dev_resources[i], std::ref(states[i]), std::ref(req), device, i, ndev, dev_ids[i], comms[i]); } for (int i = 0; i < ndev; i++) { @@ -450,6 +470,7 @@ createJiugeModel(const JiugeMeta *meta, infiniDevice_t device, int ndev, const int *dev_ids) { + std::cout << "Start Create Model" << std::endl; std::vector device_ids(ndev); std::copy(dev_ids, dev_ids + ndev, device_ids.begin()); JiugeModel *model = new JiugeModel(meta, weights, device, device_ids); diff --git a/src/models/llada/llada.cpp b/src/models/llada/llada.cpp new file mode 100644 index 00000000..cfeb5545 --- /dev/null +++ b/src/models/llada/llada.cpp @@ -0,0 +1,618 @@ +#include "../../tensor.hpp" +#include "../../utils.hpp" +#include "../inference_context.hpp" +#include "../../cache.hpp" +#include "infinicore_infer.h" + +#include "llada_impl.hpp" +#include "llada_weight.hpp" +#include +#include +#include +#include +#include +#include +#include +#include // for memcpy + + +// #TODO:这个是草稿版 +void createDeviceResource(LLaDADeviceResource *rsrc, const LLaDAMeta * meta, + const LLaDAWeights *weights, infiniDevice_t device, int idev, + int ndev, int dev_id, + infinicclComm_t comm){ + std::cout << "Set Device" << std::endl; + //Print(meta); + RUN_INFINI(infinirtSetDevice(device, dev_id)); + infiniopHandle_t handle; + infiniopCreateHandle(&handle); + infinirtStream_t stream; + infinirtStreamCreate(&stream); + + std::cout << "Set Weight" << std::endl; // 逐层获取权重 + std::vector> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm, w_attn_out, + w_ffn_norm, w_ffn_gate_up, w_ffn_down, w_expert_router, w_expert_gate, w_expert_up, w_expert_down; + for (size_t layer = 0; layer < meta->nlayer; layer++) { + w_attn_norm.push_back( + getAttnNorm(meta, weights, layer)); + w_attn_qkv.push_back( + getAttnQKV(meta, weights, layer, idev, ndev)); + + if (weights->attn_q_norm != nullptr) { + w_attn_q_norm.push_back( + getAttnQNorm(meta, weights, layer)); + w_attn_k_norm.push_back( + getAttnKNorm(meta, weights, layer)); + } + w_attn_out.push_back( + getAttnO(meta, weights, layer, idev, ndev) + ); + w_ffn_norm.push_back( + getFFNNorm(meta, weights, layer) + ); + w_expert_router.push_back( + getExpertRouter(meta, weights, layer, idev, ndev, meta->num_experts) + ); + + w_expert_gate.push_back( + getExpertGate(meta, weights, layer, idev, ndev) + ); + w_expert_down.push_back( + getExpertDown(meta, weights, layer, idev, ndev) + ); + w_expert_up.push_back( + getExpertUp(meta, weights, layer, idev, ndev) + ); + + // w_ffn_down.push_back( + // getFFNDown(meta, weights, layer, idev, ndev)); + } + std::cout << "Check out expert router size " << "Routers have " << w_expert_router.size() << " Shape is " << w_expert_router[0]->info() << std::endl; + + + + + std::cout << "Set Memory Pool" << std::endl; + auto memory_pool = std::make_shared(128 * 1024 * 1024); + + std::cout << "Set LLaDADeviceResource" << std::endl; + *rsrc = LLaDADeviceResource{ + .device = device, + .device_id = dev_id, + .handle = handle, + + .w_in_embd = getInEmbd(meta, weights), + .w_out_norm = getOutNorm(meta, weights), + .w_out_embd = getOutEmbd(meta, weights), + .sin_table = getSinTable(meta), + .cos_table = getCosTable(meta), + + .w_attn_norm = w_attn_norm, + .w_attn_qkv = w_attn_qkv, + .w_attn_q_norm = w_attn_q_norm, + .w_attn_k_norm = w_attn_k_norm, + .w_attn_out = w_attn_out, + .w_ffn_norm = w_ffn_norm, + .w_expert_gate = w_expert_gate, + .w_expert_up = w_expert_up, + .w_expert_down = w_expert_down, + .w_expert_router = w_expert_router, + .stream = stream, + .comm = comm, + .memory_pool = memory_pool, + }; + std::cout << "Over LLaDADeviceResource" << std::endl; + RUN_INFINI(infinirtDeviceSynchronize()); +} + +void releaseDeviceResource(LLaDADeviceResource &rsrc){ + std::cout << "Release" << std::endl; +} + +__C void +inferBatchLLaDA(struct LLaDAModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output) { + std::cout << "[DEBUG] inferBatchLLaDA called with single-threaded mode" << std::endl; + + // Set request data + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_caches = kv_caches; + model->req.output = output; + model->req.logits = nullptr; + model->req.temperature = temperature; + model->req.topk = topk; + model->req.topp = topp; + + // Single-threaded implementation - directly call inference + if (model->dev_ids.empty()) { + std::cerr << "[ERROR] No devices available!" << std::endl; + return; + } + + // Use the first device (single-threaded) + int idev = 0; + int ndev = 1; + int dev_id = model->dev_ids[idev]; + + std::cout << "[DEBUG] Using device " << dev_id << " for inference" << std::endl; + + // Create device resource (temporary for single-threaded call) + LLaDADeviceResource rsrc; + + // Create communication handle (single device, no comm) + infinicclComm_t comm = nullptr; + + try { + // Direct call to inference function using model's device + launchDevice(model->meta, model->weights, &rsrc, model->states[idev], model->req, + model->device, idev, ndev, dev_id, comm); + + std::cout << "[DEBUG] Inference completed successfully" << std::endl; + + } catch (const std::exception& e) { + std::cerr << "[ERROR] Inference failed: " << e.what() << std::endl; + } +} + +void inferDeviceBatch(const LLaDAMeta &meta, LLaDADeviceResource &rsrc, + uint32_t idev, uint32_t ndev, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output, void *last_logits){ + CacheManager cache_manager(100); + InferenceContext ctx(rsrc.handle, rsrc.memory_pool, &cache_manager, rsrc.stream); + setInferenceContext(&ctx); + + auto nlayer = meta.nlayer; // 16 + auto nkvh = meta.nkvh / ndev; // 每个设备的KV头数 (GQA分组查询) 16 + auto nh = meta.nh / ndev; // 每个设备的注意力头数 16 + auto ngroup = nh / nkvh; // 每个设备的组数 (多少个Q头共享一个K/V头) 1 + auto dctx = meta.dctx; // 最大上下文长度 8192 + auto dh = meta.dh; // 每个头的维度. 128 + auto d = meta.d; // 模型隐藏层维度 2048(157184 * 2048词表) + auto dt_logits = meta.dt_logits; // 输出数据类型 19 + auto di_dense = meta.di_dense / ndev; // 每个设备的密集FFN中间维度 8192 + auto di_expert = meta.di_expert / ndev; // 每个设备的专家FFN中间维度 1024 + auto dvoc = meta.dvoc; // 词汇表大小 157184 + auto stream = rsrc.stream; + bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0; // 是否有QKV偏置 + bool has_qk_norm = rsrc.w_attn_q_norm.size() > 0 && rsrc.w_attn_k_norm.size() > 0; // 是否有QK归一化 + + auto nexperts = meta.num_experts; + + // Allocate buffers + + auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + auto logits_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + auto qkv_buf = Tensor::buffer(dt_logits, {ntok, (nh + nkvh * 2) * dh}, rsrc.memory_pool); + auto gate_up_buf = Tensor::buffer(dt_logits, {ntok, 2 * di_dense}, rsrc.memory_pool); + auto o_buf = Tensor::buffer(dt_logits, {ntok, nh * dh}, rsrc.memory_pool); + auto prob_buf = Tensor::buffer(dt_logits, {nreq, dvoc}, rsrc.memory_pool); + auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool); + auto result_cpu = std::vector(nreq); + auto attention_output = Tensor::buffer(dt_logits, {ntok, nh, dh}); + auto router_logits_buf = Tensor::buffer(dt_logits, {ntok, nexperts}, rsrc.memory_pool); + + + auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh}); + auto shape_qkv = qkv_rope->shape(); + auto q_buf = qkv_rope->slice(1, 0, nh); // 0 ---> nh q + auto k_buf = qkv_rope->slice(1, nh, nkvh); //nh ---> nh+nkvh k_buf + auto v_buf = qkv_rope->slice(1, nh + nkvh, nkvh); //nh+nkvh ---> nh+nkvh*2 v_buf + + + // Prepare inputs + std::cout << "Preparing Input" << std::endl; + auto batch_pos_ids = std::vector(ntok); + size_t req_start = 0; + for (uint32_t req = 0; req < nreq; req++) { + for (uint32_t i = 0; i < req_lens[req]; i++) { + batch_pos_ids[req_start + i] = req_pos[req] + i; + } + req_start += req_lens[req]; + } + std::shared_ptr pos_ids_buf; + pos_ids_buf = Tensor::buffer(INFINI_DTYPE_U32, {ntok}, rsrc.memory_pool); + RUN_INFINI(infinirtMemcpyAsync(pos_ids_buf->data(), batch_pos_ids.data(), sizeof(uint32_t) * ntok, + INFINIRT_MEMCPY_H2D, stream)); + // std::shared_ptr hidden_states = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + // for (uint32_t i = 0; i < ntok; i++) { + // RUN_INFINI(infinirtMemcpyAsync(hidden_states->data(i * d), + // rsrc.w_in_embd->data(tokens[i] * d), + // dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream)); + // } // embedding weight and python slide is same + for (uint32_t i = 0; i < ntok; i++) { + RUN_INFINI(infinirtMemcpyAsync(logits_in->data(i * d), + rsrc.w_in_embd->data(tokens[i] * d), + dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream)); + } // embedding weight and python slide is same + + for(uint32_t layer = 0; layer < nlayer; layer ++){ + // 1. Before Attention + // rms norm + rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon); + // qkv_proj + std::cout << "QKV BUF IS " << qkv_buf->info() << std::endl; + std::cout << "Logits OUT BUF IS " << logits_out->info() << std::endl; + std::cout << "Attentioin BUF IS " << rsrc.w_attn_qkv[layer]->info() << std::endl; + linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, nullptr, has_qkv_bias ? rsrc.b_attn_qkv[layer] : nullptr); + + if (has_qk_norm) { + rmsnorm(q_buf, q_buf, rsrc.w_attn_q_norm[layer], meta.epsilon); + rmsnorm(k_buf, k_buf, rsrc.w_attn_k_norm[layer], meta.epsilon); + } + // rope + std::cout << "Position Embedding" << std::endl; + rope(q_buf, q_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table); + rope(k_buf, k_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table); // llada modeling_lladamoe.py:390 + + std::cout << "q buf info " << q_buf->info() << std::endl; // [54, 16, 128] [ntok, nh, dh] + std::cout << "k buf info " << k_buf->info() << std::endl; // [54, 16, 128] [ntok, nh, dh] + std::cout << "v buf info " << v_buf->info() << std::endl; // [54, 16, 128] [ntok, nh, dh] + std::cout << "cache info " << std::endl; + std::cout << "req pos is " << req_pos[0] << std::endl; + std::cout << "req len is " << req_lens[0] << std::endl; + std::cout << "KV cache info "; + std::cout << kv_caches[0]->k[0][0]->permute({1, 0, 2})->info() << std::endl; // [16, 54, 128] + std::cout << "output info " << attention_output->info() << std::endl; + BiAttention(attention_output, q_buf->permute({1, 0, 2}), k_buf->permute({1, 0, 2}), v_buf->permute({1, 0, 2}), kv_caches[0]->k[0][0]->permute({1, 0, 2}), kv_caches[0]->v[0][0]->permute({1, 0, 2}), 0); + + + // 创建新张量来存储 dimMerge 的结果 + auto o_buf = Tensor::buffer(meta.dt_logits, {ntok, nh * dh}, rsrc.memory_pool); + rearrange(o_buf, attention_output->dimMerge(1, 2)); + + std::cout << logits_in->info() << std::endl; + std::cout << o_buf->info() << std::endl; + std::cout << rsrc.w_attn_out[layer]->info() << std::endl; + std::cout << "logits_in contiguous: " << logits_in->isContigous() << std::endl; + std::cout << "o_buf contiguous: " << o_buf->isContigous() << std::endl; + std::cout << "weight contiguous: " << rsrc.w_attn_out[layer]->isContigous() << std::endl; + linear(logits_in, o_buf, rsrc.w_attn_out[layer], 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); //self.o_proj(attn_output) + + rmsnorm(logits_out, logits_in, rsrc.w_ffn_norm[layer], meta.epsilon); + + std::cout << "Starting MoE layer " << layer << std::endl; + std::cout << " logits_out shape: " << logits_out->info() << std::endl; + std::cout << " expert_router weight shape: " << rsrc.w_expert_router[layer]->info() << std::endl; + + linear(router_logits_buf, logits_out, rsrc.w_expert_router[layer], 1.0, 0.0, nullptr, nullptr); + + std::cout << "router_logits_buf info:" << router_logits_buf->info() << std::endl; + std::cout << "router_logits_buf contiguous: " << router_logits_buf->isContigous() << std::endl; + router_logits_buf->debug(); + + // ==================== MoE 路由和专家计算 ==================== + // 参考 modeling_lladamoe.py: LLaDAMoESparseMoeBlock::forward + // 路由: hidden_states -> gate -> router_logits -> topkrouter -> expert_indices, expert_weights + // 专家: hidden_states -> expert[gate][up] -> swiglu -> expert[down] -> weighted sum + + // 创建路由输出缓冲区: [ntok, 8] + auto router_values = Tensor::buffer(meta.dt_logits, {ntok, 8}, rsrc.memory_pool); + auto router_indices = Tensor::buffer(INFINI_DTYPE_I32, {ntok, 8}, rsrc.memory_pool); + + // 创建 correction_bias,暂时使用零偏置 + auto correction_bias = Tensor::buffer(meta.dt_logits, {nexperts}, rsrc.memory_pool); + std::vector correction_bias_cpu(nexperts, 0.0f); + RUN_INFINI(infinirtMemcpy(correction_bias->data(), correction_bias_cpu.data(), + nexperts * sizeof(float), INFINIRT_MEMCPY_H2D)); + + // 调用 topkrouter 进行路由 + topkrouter(router_values, router_indices, router_logits_buf, + correction_bias, 1.0f, 8); // topk=8, routed_scaling_factor=2.5 + + // 复制路由结果回CPU,用于按 token 顺序遍历专家 + std::vector router_values_cpu(ntok * 8); + std::vector router_indices_cpu(ntok * 8); + RUN_INFINI(infinirtMemcpy(router_values_cpu.data(), router_values->data(), + router_values_cpu.size() * sizeof(float), INFINIRT_MEMCPY_D2H)); + RUN_INFINI(infinirtMemcpy(router_indices_cpu.data(), router_indices->data(), + router_indices_cpu.size() * sizeof(int), INFINIRT_MEMCPY_D2H)); + + std::cout << "=== MoE Routing Info ===" << std::endl; + std::cout << " ntok: " << ntok << ", nexperts: " << nexperts << ", topk: 8" << std::endl; + std::cout << " router_values shape: " << router_values->info() << std::endl; + std::cout << " router_indices shape: " << router_indices->info() << std::endl; + + // 创建 MoE 输出缓冲区: [ntok, d] + auto moe_output = Tensor::buffer(meta.dt_logits, {ntok, d}, rsrc.memory_pool); + + // 创建临时缓冲区用于专家计算 + auto expert_gate_buf = Tensor::buffer(meta.dt_logits, {1, di_expert}, rsrc.memory_pool); // [1, 64] + auto expert_up_buf = Tensor::buffer(meta.dt_logits, {1, di_expert}, rsrc.memory_pool); // [1, 64] + auto expert_down_buf = Tensor::buffer(meta.dt_logits, {1, d}, rsrc.memory_pool); // [1, 2048] + + // 每个 token 通过 top-k 专家,加权求和 + // 参考 modeling_lladamoe.py:698-710 的实现 + for (size_t itok = 0; itok < ntok; ++itok) { + // 获取当前 token 的隐藏状态: [1, d] + auto hidden_states_i = logits_out->slice(0, itok, 1); //[1, 2048] + // 获取当前 token 的 MoE 输出位置: [1, d] + auto moe_output_i = moe_output->slice(0, itok, 1); //[1, 2048] + // 获取临时缓冲区: [1, di_expert], [1, d] + auto expert_gate_buf_i = expert_gate_buf->slice(0, 0, 1); // + auto expert_up_buf_i = expert_up_buf->slice(0, 0, 1); + auto expert_down_buf_i = expert_down_buf->slice(0, 0, 1); + + // 遍历 top-k 专家,加权累加 + for (size_t k = 0; k < 8; ++k) { + // 获取专家索引和权重 + int expert_idx = router_indices_cpu[itok * 8 + k]; + float expert_weight = router_values_cpu[itok * 8 + k]; + + // 验证 expert_idx 是否在有效范围内 + if (expert_idx < 0 || expert_idx >= (int)nexperts) { + std::cerr << "ERROR: Invalid expert_idx=" << expert_idx + << " for token " << itok << ", expert " << k + << ", nexperts=" << nexperts << std::endl; + continue; + } + + // Debug: 打印权重 shape 信息(第一个 token, 第一个专家) + if (itok == 0 && k == 0) { + std::cout << " expert_gate weight shape: " << rsrc.w_expert_gate[layer]->info() << std::endl; + std::cout << " expert_up weight shape: " << rsrc.w_expert_up[layer]->info() << std::endl; + std::cout << " expert_down weight shape: " << rsrc.w_expert_down[layer]->info() << std::endl; + std::cout << " expert_gate_w after slice: " << rsrc.w_expert_gate[layer]->slice(0, expert_idx, 1)->info() << std::endl; + std::cout << " expert_up_w after slice: " << rsrc.w_expert_up[layer]->slice(0, expert_idx, 1)->info() << std::endl; + std::cout << " expert_down_w after slice: " << rsrc.w_expert_down[layer]->slice(0, expert_idx, 1)->info() << std::endl; + } + + // 计算专家输出: hidden_states @ expert_gate -> silu * expert_up @ expert_down + // gate_proj: [d, di_expert] (从权重中切片并降维) + // 切片后是 [1, d, di_expert],需要降维为 [d, di_expert] + auto expert_gate_w = rsrc.w_expert_gate[layer]->slice(0, expert_idx, 1)->view({d, di_expert}); + // up_proj: [d, di_expert] (从权重中切片并降维) + auto expert_up_w = rsrc.w_expert_up[layer]->slice(0, expert_idx, 1)->view({d, di_expert}); + // down_proj: [di_expert, d] (从权重中切片并降维) + auto expert_down_w = rsrc.w_expert_down[layer]->slice(0, expert_idx, 1)->view({di_expert, d}); + + // gate_output = hidden_states @ expert_gate_w: [1, di_expert] + // 注意:这里使用 linear 会自动处理 dtype 转换 + std::cout << "=== Before GEMM Call ===" << std::endl; + std::cout << "expert_gate_buf_i (output): " << expert_gate_buf_i->info() << std::endl; + std::cout << "expert_gate_buf_i contiguous: " << expert_gate_buf_i->isContigous() << std::endl; + std::cout << "hidden_states_i (input): " << hidden_states_i->info() << std::endl; + std::cout << "hidden_states_i contiguous: " << hidden_states_i->isContigous() << std::endl; + std::cout << "expert_gate_w (weight): " << expert_gate_w->info() << std::endl; + std::cout << "expert_gate_w contiguous: " << expert_gate_w->isContigous() << std::endl; + + std::cout << "1111" << std::endl; + linear(expert_gate_buf_i, hidden_states_i, expert_gate_w, 1.0, 0.0, nullptr, nullptr); + // up_output = hidden_states @ expert_up_w: [1, di_expert] + + std::cout << "1111" << std::endl; + linear(expert_up_buf_i, hidden_states_i, expert_up_w, 1.0, 0.0, nullptr, nullptr); + // swiglu = silu(gate_output) * up_output: [1, di_expert] + + std::cout << "1111" << std::endl; + swiglu(expert_gate_buf_i, expert_up_buf_i, expert_gate_buf_i); + + std::cout << "1111" << std::endl; + // expert_output = swiglu @ expert_down_w: [1, d] (带权重) + // 当 k==0 时,直接写入 moe_output_i;当 k>0 时,累加到 moe_output_i + if (k == 0) { + std::cout << "1111" << std::endl; + std::cout << "Calling linear with moe_output_i, k=" << k << ", itok=" << itok << std::endl; + std::cout << " moe_output_i: " << moe_output_i->info() << std::endl; + std::cout << " expert_gate_buf_i: " << expert_gate_buf_i->info() << std::endl; + std::cout << " expert_down_w: " << expert_down_w->info() << std::endl; + // 第一个专家:直接写入,不累加 + linear(moe_output_i, expert_gate_buf_i, expert_down_w, expert_weight, 0.0, nullptr, nullptr); + std::cout << "1111" << std::endl; + } else { + std::cout << "1111" << std::endl; + // 后续专家:累加到 moe_output_i + // 先计算 expert_down_buf_i = expert_gate_buf_i @ expert_down_w * expert_weight + linear(expert_down_buf_i, expert_gate_buf_i, expert_down_w, expert_weight, 0.0, nullptr, nullptr); + // 再累加到 moe_output_i + add(moe_output_i, moe_output_i, expert_down_buf_i); + } + } + } + std::cout << "2222" << std::endl; + // 残差连接: logits_in = logits_in + moe_output + // 直接使用 add 函数进行加法 + add(logits_in, logits_in, moe_output); + + std::cout << "=== MoE Computation Completed ===" << std::endl; + std::cout << " moe_output shape: " << moe_output->info() << std::endl; + std::cout << " Final logits_in shape: " << logits_in->info() << std::endl; + // ===================================================== + } + + + // 复制最终的 logits 到 host 内存(如果提供了 last_logits) + if (last_logits != nullptr) { + RUN_INFINI(infinirtMemcpy(last_logits, logits_in->data(), + logits_in->shape()[0] * logits_in->shape()[1] * dsize(dt_logits), + INFINIRT_MEMCPY_D2H)); + } + + RUN_INFINI(infinirtStreamSynchronize(stream)); + // 清理推理上下文 + setInferenceContext(nullptr); + std::cout << "InferDeviceBatch completed" << std::endl; +} +__C void +forwardBatchLLaDA(struct LLaDAModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + void *logits){ + std::cout << "[DEBUG] forwardBatchLLaDA called with single-threaded mode" << std::endl; + + // Set request data + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_caches = kv_caches; + model->req.output = nullptr; + model->req.logits = logits; + model->req.temperature = nullptr; + model->req.topk = nullptr; + model->req.topp = nullptr; + + // Single-threaded implementation - directly call inference + if (model->dev_ids.empty()) { + std::cerr << "[ERROR] No devices available!" << std::endl; + return; + } + + // Use the first device (single-threaded) + int idev = 0; + int ndev = 1; + int dev_id = model->dev_ids[idev]; + + std::cout << "[DEBUG] Using device " << dev_id << " for forward pass" << std::endl; + + // Create device resource (temporary for single-threaded call) + LLaDADeviceResource rsrc; + + // Create communication handle (single device, no comm) + infinicclComm_t comm = nullptr; + + try { + // Direct call to inference function using model's device + launchDevice(model->meta, model->weights, &rsrc, model->states[idev], model->req, + model->device, idev, ndev, dev_id, comm); + + std::cout << "[DEBUG] Forward pass completed successfully" << std::endl; + + } catch (const std::exception& e) { + std::cerr << "[ERROR] Forward pass failed: " << e.what() << std::endl; + } +} + +// 目前实现了资源的分配 +void launchDevice(const LLaDAMeta & meta, const LLaDAWeights *weights, LLaDADeviceResource *rsrc, InferState &state, InferRequest &req, + infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm){ + std::cout << "launch device" << std::endl; + // Create Device Resource + createDeviceResource(rsrc, &meta, weights, device, idev, ndev, dev_id, comm); + + std::cout << "Cache Manager initing ..." << std::endl; + CacheManager cache_manager(100); + + std::cout << "Context Initing" << std::endl; + InferenceContext ctx(rsrc->handle, rsrc->memory_pool, &cache_manager, rsrc->stream); + + // Set the inference context for this thread + setInferenceContext(&ctx); + + // while (true) { + // std::unique_lock lock(state.mtx); + // state.cv_start.wait(lock, [&] { return state.proceed || state.exit_flag; }); + // // quit if exit_flag is set + // if (state.exit_flag) { + // break; + // } + // std::cout << "Infering Device Batch" << std::endl; + // inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok, + // req.req_lens, req.nreq, req.req_pos, req.kv_caches, + // req.temperature, req.topk, req.topp, req.output, req.logits); + + // state.proceed = false; + // lock.unlock(); + // state.cv_done.notify_one(); + // } + + // Check for potential out-of-bounds issues before accessing + std::cout << "[DEBUG] Checking for potential issues:" << std::endl; + for (uint32_t i = 0; i < req.nreq; ++i) { + std::cout << "[DEBUG] Request " << i << ": pos=" << req.req_pos[i] + << ", len=" << req.req_lens[i]; + if (req.req_pos[i] < 0 || req.req_pos[i] >= req.ntok) { + std::cout << " [ERROR: Invalid position! " << req.req_pos[i] << " >= " << req.ntok << "]"; + } + if (req.req_lens[i] < 0 || req.req_pos[i] + req.req_lens[i] > req.ntok) { + std::cout << " [ERROR: Length exceeds total tokens! " << req.req_pos[i] << " + " << req.req_lens[i] << " > " << req.ntok << "]"; + } + std::cout << std::endl; + } + + // Only output tokens if it's safe + if (req.ntok > 0) { + // Output first few tokens for each request + for (uint32_t i = 0; i < req.nreq; ++i) { + if (req.req_pos[i] >= 0 && req.req_pos[i] < req.ntok) { + uint32_t start_idx = req.req_pos[i]; + uint32_t available_tokens = req.ntok - start_idx; + uint32_t tokens_to_show = std::min(static_cast(10), std::min(available_tokens, req.req_lens[i])); + for (uint32_t j = start_idx; j < start_idx + tokens_to_show; ++j) { + std::cout << req.tokens[j]; + if (j < start_idx + tokens_to_show - 1) std::cout << ", "; + } + if (req.req_lens[i] > tokens_to_show) std::cout << "..."; + std::cout << std::endl; + } + } + } + + inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok, + req.req_lens, req.nreq, req.req_pos, req.kv_caches, + req.temperature, req.topk, req.topp, req.output, req.logits); + + // Clean-Up + std::cout << "Clearing Context" << std::endl; + releaseDeviceResource(*rsrc); + setInferenceContext(nullptr); // Clear the context when done + +} + + +// TODO: not void just for tmp +LLaDAModel::LLaDAModel(const LLaDAMeta *_meta, const LLaDAWeights *weights, infiniDevice_t device_, std::vector device_ids) { + std::cout << "Initing LLaDA model in Cpp side " << std::endl; + int ndev = int(device_ids.size()); + meta = *_meta; // Copy meta data + this->weights = weights; // Store weights pointer + device = device_; + dev_ids = device_ids; + dev_resources = std::vector(ndev); + states = std::vector(ndev); + threads.resize(ndev); + RUN_INFINI(infinirtInit()); + auto comms = std::vector(ndev, nullptr); + if (ndev > 1) { + RUN_INFINI(infinicclCommInitAll(device, comms.data(), ndev, dev_ids.data())); + } + std::cout << "OK LET's US ROCK! " << std::endl; + // for (int i = 0; i < ndev; i++) { + // std::cout << "Launch Device " << i << " Thread" << std::endl; + // threads[i] = std::thread(launchDevice, std::cref(*_meta), weights, &dev_resources[i], std::ref(states[i]), std::ref(req), device, i, ndev, dev_ids[i], comms[i]); + // } + //launchDevice(std::cref(*_meta), weights, &dev_resources[0], std::ref(states[0]), std::ref(req), device, 0, ndev, dev_ids[0], comms[0]); + +} + +// 和 Pythoh 交互的C++接口 +__C struct LLaDAModel * createLLaDAModel(const LLaDAMeta * meta, + const LLaDAWeights * weights, + infiniDevice_t device, + int ndev, + const int *dev_ids) { + std::vector device_ids(ndev); + std::copy(dev_ids, dev_ids + ndev, device_ids.begin()); + + LLaDAModel *model = new LLaDAModel(meta, weights, device, device_ids); // 测试代码编写在该函数体内部 + + return model; +} + + +__C void destroyLlaDAMoEModel(){ + +} \ No newline at end of file diff --git a/src/models/llada/llada_impl.hpp b/src/models/llada/llada_impl.hpp new file mode 100644 index 00000000..e9cbb432 --- /dev/null +++ b/src/models/llada/llada_impl.hpp @@ -0,0 +1,79 @@ +#ifndef LLADAMOE_IMPL_H +#define LLADAMOE_IMPL_H + +#include "infinicore_infer.h" +#include "../../../include/infinicore_infer/models/llada.h" // 先包含meta定义 +#include "../../allocator.hpp" +#include "../../tensor.hpp" + +#include +#include +#include +#include +#include +// #TODO: 模型运行所需资源 +struct LLaDADeviceResource { + // Device + infiniDevice_t device; + int device_id; + infiniopHandle_t handle; + // Weights + std::shared_ptr w_in_embd, w_out_norm, w_out_embd, sin_table, + cos_table; + std::vector> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm,w_attn_out, + w_ffn_norm; + std::vector> w_expert_gate; + std::vector> w_expert_up ; + std::vector> w_expert_down; + std::vector> w_expert_router; + + // Streams + infinirtStream_t stream; + // Communicator + infinicclComm_t comm; + + std::shared_ptr memory_pool; +}; + +struct InferState { + std::mutex mtx; + std::condition_variable cv_load, cv_start, cv_done; + bool loaded = false; + bool proceed = false; + bool exit_flag = false; +}; + +struct InferRequest { + const uint32_t *tokens; + uint32_t ntok; + const uint32_t *req_lens; + uint32_t nreq; + const uint32_t *req_pos; + struct KVCache **kv_caches; + const float *temperature; + const uint32_t *topk; + const float *topp; + uint32_t *output; + void *logits; +}; + +struct LLaDAModel { + LLaDAMeta meta; + const LLaDAWeights *weights; // Add weights pointer + infiniDevice_t device; + std::vector dev_ids; + std::vector dev_resources; + std::vector states; + std::vector threads; + InferRequest req; + + LLaDAModel(const LLaDAMeta *, const LLaDAWeights *, infiniDevice_t device, std::vector device_ids); +}; + +// Function declarations +void launchDevice(const LLaDAMeta &meta, const LLaDAWeights *weights, LLaDADeviceResource *rsrc, + InferState &state, InferRequest &req, infiniDevice_t device, int idev, int ndev, + int dev_id, infinicclComm_t comm); + + +#endif \ No newline at end of file diff --git a/src/models/llada/llada_weight.hpp b/src/models/llada/llada_weight.hpp new file mode 100644 index 00000000..d9435b6c --- /dev/null +++ b/src/models/llada/llada_weight.hpp @@ -0,0 +1,297 @@ +#ifndef LLADAMOE_WEIGHT_HPP +#define LLADAMOE_WEIGHT_HPP + + +#include +void debugPrint(LLaDAMeta const *meta) { + if (!meta) { + std::cout << "LLaDAMeta pointer = NULL" << std::endl; + return; + } + + std::cout << "===== LLaDAMeta DEBUG =====" << std::endl; + std::cout << "meta pointer = " << meta << std::endl; + + std::cout << "dt_logits = " << (int)meta->dt_logits << std::endl; + std::cout << "nlayer = " << meta->nlayer << std::endl; + std::cout << "d = " << meta->d << std::endl; + std::cout << "nh = " << meta->nh << std::endl; + std::cout << "nkvh = " << meta->nkvh << std::endl; + std::cout << "dh = " << meta->dh << std::endl; + + std::cout << "di_dense = " << meta->di_dense << std::endl; + std::cout << "di_expert = " << meta->di_expert << std::endl; + + std::cout << "dctx = " << meta->dctx << std::endl; + std::cout << "dvoc = " << meta->dvoc << std::endl; + + std::cout << "epsilon = " << meta->epsilon << std::endl; + std::cout << "theta = " << meta->theta << std::endl; + + std::cout << "end_token = " << meta->end_token << std::endl; + std::cout << "num_experts = " << meta->num_experts << std::endl; + + std::cout << "===========================" << std::endl; +} + +inline std::shared_ptr getInEmbd( + LLaDAMeta const * meta, + LLaDAWeights const * w) { + auto shape = std::vector({meta->dvoc, meta->d}); + return Tensor::weight((char *)w->input_embd, meta->dt_logits, shape); +} + +inline std::shared_ptr getOutNorm( + LLaDAMeta const * meta, + LLaDAWeights const * w){ + + std::cout << "Get In Embd112" << std::endl; + auto shape = std::vector({meta->d}); //TODO: + return Tensor::weight((char *)w->output_norm, w->dt_norm, shape); +} + +inline std::shared_ptr getOutEmbd( + LLaDAMeta const *meta, + LLaDAWeights const *w) { + std::cout << "Out Embd sd" << std::endl; + if (w->transpose_linear_weights != 0) { + auto shape = std::vector({meta->dvoc, meta->d}); + return Tensor::weight((char *)w->output_embd, meta->dt_logits, shape) + ->permute({1, 0}); + } else { + auto shape = std::vector({meta->d, meta->dvoc}); + return Tensor::weight((char *)w->output_embd, meta->dt_logits, shape); + } +} + +inline std::shared_ptr getAttnNorm( + LLaDAMeta const *meta, + LLaDAWeights const *w, + size_t layer) { + auto shape = std::vector({meta->d}); + return Tensor::weight((char *)(w->attn_norm[layer]), w->dt_norm, shape); +} + +inline std::shared_ptr getAttnQKV( + LLaDAMeta const *meta, + LLaDAWeights const *w, + size_t layer, size_t idev, size_t ndev) { + auto nkvh = meta->nkvh; + auto nh = meta->nh; + auto dh = meta->dh; + auto d = meta->d; + size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * d * dsize(w->dt_mat); + if (w->transpose_linear_weights != 0) { + auto shape = std::vector({(nh + 2 * nkvh) / ndev * dh, d}); + return Tensor::weight((char *)(w->attn_qkv[layer]) + offset, w->dt_mat, shape) + ->permute({1, 0}); + } else { + auto shape = std::vector({d, (nh + 2 * nkvh) / ndev * dh}); + return Tensor::weight((char *)(w->attn_qkv[layer]) + offset, w->dt_mat, shape); + } +} + + + +inline std::shared_ptr getAttnQNorm( + LLaDAMeta const *meta, + LLaDAWeights const *w, + size_t layer) { + auto shape = std::vector({meta->dh}); + return Tensor::weight((char *)(w->attn_q_norm[layer]), w->dt_norm, shape); +} + +inline std::shared_ptr getAttnKNorm( + LLaDAMeta const *meta, + LLaDAWeights const *w, + size_t layer) { + auto shape = std::vector({meta->dh}); + return Tensor::weight((char *)(w->attn_k_norm[layer]), w->dt_norm, shape); +} + +inline std::shared_ptr getAttnO(LLaDAMeta const *meta, + LLaDAWeights const *w, size_t layer, + size_t idev, size_t ndev) { + auto nh = meta->nh; + auto dh = meta->dh; + auto d = meta->d; + size_t offset = idev * d * (nh / ndev * dh) * dsize(w->dt_mat); + if (w->transpose_linear_weights != 0) { + auto shape = std::vector({d, nh / ndev * dh}); + return Tensor::weight((char *)(w->attn_o[layer]) + offset, w->dt_mat, shape) + ->permute({1, 0}); + } else { + auto shape = std::vector({nh / ndev * dh, d}); + return Tensor::weight((char *)(w->attn_o[layer]) + offset, w->dt_mat, shape); + } +} + +inline std::shared_ptr getFFNNorm( + LLaDAMeta const *meta, + LLaDAWeights const *w, + size_t layer) { + auto shape = std::vector({meta->d}); + return Tensor::weight((char *)(w->ffn_norm[layer]), w->dt_norm, shape); +} + +// inline std::shared_ptr getFFNGateUp( +// LLaDAMeta const *meta, +// LLaDAWeights const *w, +// size_t layer, size_t idev, size_t ndev) { +// auto di = meta->di_expert; // TODO: 具体di还要区分 +// auto d = meta->d; +// size_t offset = idev * (2 * di / ndev) * d * dsize(w->dt_mat); +// if (w->transpose_linear_weights != 0) { +// auto shape = std::vector({2 * di / ndev, d}); +// return Tensor::weight((char *)(w->ffn_gate_up[layer]) + offset, +// w->dt_mat, shape) +// ->permute({1, 0}); +// } else { +// auto shape = std::vector({d, 2 * di / ndev}); +// return Tensor::weight((char *)(w->ffn_gate_up[layer]) + offset, +// w->dt_mat, shape); +// } +// } +// +inline std::shared_ptr getExpertRouter( + LLaDAMeta const *meta, + LLaDAWeights const *w, + size_t layer, size_t idev, size_t ndev, size_t num_experts) { + auto shape = std::vector({meta->d}); + auto di = meta->di_expert; // TODO: 具体di还要区分 + auto d = meta->d; + size_t offset = 0; + if (w->transpose_linear_weights != 0) { + auto shape = std::vector({num_experts, d}); + return Tensor::weight((char *)(w->expert_router[layer]) + offset, + w->dt_mat, shape) + ->permute({1, 0}); + } else { + auto shape = std::vector({num_experts, d}); + return Tensor::weight((char *)(w->expert_router[layer]) + offset, + w->dt_mat, shape)->permute({1, 0}); + } +} + +inline std::shared_ptr getExpertGate( + LLaDAMeta const *meta, + LLaDAWeights const *w, + size_t layer, size_t idev, size_t ndev){ + auto shape = std::vector({meta->d}); + auto di = meta->di_expert; // TODO: 具体di还要区分 + auto d = meta->d; + size_t offset = 0; + if (w->transpose_linear_weights != 0) { + auto shape = std::vector({meta->num_experts, di, d}); + return Tensor::weight((char *)(w->expert_gate[layer]) + offset, + w->dt_mat, shape) + ->permute({1, 0}); + } else { + auto shape = std::vector({meta->num_experts, di, d}); + return Tensor::weight((char *)(w->expert_gate[layer]) + offset, + w->dt_mat, shape); + } +} + +inline std::shared_ptr getExpertUp( + LLaDAMeta const *meta, + LLaDAWeights const *w, + size_t layer, size_t idev, size_t ndev) { + auto shape = std::vector({meta->d}); + auto di = meta->di_expert; // TODO: 具体di还要区分 + auto d = meta->d; + size_t offset = 0; + if (w->transpose_linear_weights != 0) { + auto shape = std::vector({meta->num_experts, di, d}); + return Tensor::weight((char *)(w->expert_up[layer]) + offset, + w->dt_mat, shape) + ->permute({1, 0}); + } else { + auto shape = std::vector({meta->num_experts, di, d}); + return Tensor::weight((char *)(w->expert_up[layer]) + offset, + w->dt_mat, shape); + } +} + + +inline std::shared_ptr getExpertDown( + LLaDAMeta const *meta, + LLaDAWeights const *w, + size_t layer, size_t idev, size_t ndev) { + auto shape = std::vector({meta->d}); + auto di = meta->di_expert; // TODO: 具体di还要区分 + auto d = meta->d; + size_t offset = 0; + if (w->transpose_linear_weights != 0) { + auto shape = std::vector({meta->num_experts, di, d}); + return Tensor::weight((char *)(w->expert_down[layer]) + offset, + w->dt_mat, shape) + ->permute({1, 0}); + } else { + auto shape = std::vector({meta->num_experts, di, d}); + return Tensor::weight((char *)(w->expert_down[layer]) + offset, + w->dt_mat, shape); + } +} + + + + +inline std::shared_ptr getSinTable(LLaDAMeta const *meta) { + std::cout << "Get Sin Table" << std::endl; + auto half_dh = meta->dh / 2; + auto unit = dsize(meta->dt_logits); + void *table = std::malloc(meta->dctx * half_dh * unit); + + for (size_t i = 0; i < meta->dctx; i++) { + for (size_t j = 0; j < half_dh; j++) { + float _sin = std::sin( + static_cast(i) / std::pow(meta->theta, static_cast(j) / half_dh)); + if (meta->dt_logits == INFINI_DTYPE_F16) { + ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_sin); + } else if (meta->dt_logits == INFINI_DTYPE_BF16) { + ((uint16_t *)table)[i * half_dh + j] = f32_to_bf16(_sin); + } else if (meta->dt_logits == INFINI_DTYPE_F32) { + ((float *)table)[i * half_dh + j] = _sin; + } else { + std::cout << "unsupported data type" << std::endl; + std::cout << meta->dt_logits << std::endl; + exit(1); + } + } + } + auto shape = std::vector({meta->dctx, half_dh}); + auto tensor = Tensor::weight(table, meta->dt_logits, shape); + std::free(table); + std::cout << "Sin Table Initing Over" << std::endl; + return tensor; +} + +inline std::shared_ptr getCosTable(LLaDAMeta const *meta) { + auto half_dh = meta->dh / 2; + auto unit = dsize(meta->dt_logits); + void *table = std::malloc(meta->dctx * half_dh * unit); + + for (size_t i = 0; i < meta->dctx; i++) { + for (size_t j = 0; j < half_dh; j++) { + float _cos = std::cos( + static_cast(i) / std::pow(meta->theta, static_cast(j) / half_dh)); + if (meta->dt_logits == INFINI_DTYPE_F16) { + ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_cos); + } else if (meta->dt_logits == INFINI_DTYPE_BF16) { + ((uint16_t *)table)[i * half_dh + j] = f32_to_bf16(_cos); + } else if (meta->dt_logits == INFINI_DTYPE_F32) { + ((float *)table)[i * half_dh + j] = _cos; + } else { + std::cout << "unsupported data type" << std::endl; + exit(1); + } + } + } + auto shape = std::vector({meta->dctx, half_dh}); + auto tensor = Tensor::weight(table, meta->dt_logits, shape); + std::free(table); + return tensor; +} + +#endif diff --git a/xmake.lua b/xmake.lua index 598ac534..85a95e40 100644 --- a/xmake.lua +++ b/xmake.lua @@ -1,5 +1,12 @@ +add_requires("pybind11") local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini") +set_toolchains("gcc") + +-- Add spdlog from third_party directory +add_includedirs("third_party/spdlog/include") +add_cxxflags("-Wno-unused-variable") + target("infinicore_infer") set_kind("shared") @@ -8,10 +15,8 @@ target("infinicore_infer") add_linkdirs(INFINI_ROOT.."/lib") add_links("infiniop", "infinirt", "infiniccl") - set_languages("cxx17") set_warnings("all", "error") - add_files("src/models/*.cpp") add_files("src/models/*/*.cpp") add_files("src/tensor/*.cpp") @@ -19,8 +24,32 @@ target("infinicore_infer") add_files("src/dataloader/*.cpp") add_files("src/cache_manager/*.cpp") add_includedirs("include") - set_installdir(INFINI_ROOT) add_installfiles("include/infinicore_infer.h", {prefixdir = "include"}) add_installfiles("include/infinicore_infer/models/*.h", {prefixdir = "include/infinicore_infer/models"}) target_end() + +target("_infinilm") + add_packages("pybind11") + set_default(false) + add_rules("python.module", {soabi = true}) + set_languages("cxx17") + set_kind("shared") + + local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini") + + -- add_includedirs("csrc", { public = false }) + -- add_includedirs("csrc/pybind11", { public = false }) + add_includedirs(INFINI_ROOT.."/include", { public = true }) + add_includedirs("include", { public = false }) + -- spdlog is already included globally via add_includedirs at the top + + add_linkdirs(INFINI_ROOT.."/lib") + add_links("infinicore_cpp_api", "infiniop", "infinirt", "infiniccl") + + -- Add src files + add_files("csrc/**.cpp") + add_files("csrc/**.cc") + + set_installdir("python/infinilm") +target_end() diff --git a/xmake.lua.bak b/xmake.lua.bak new file mode 100644 index 00000000..1892b7b5 --- /dev/null +++ b/xmake.lua.bak @@ -0,0 +1,38 @@ +local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini") + +target("infinicore_infer") + set_kind("shared") + -- debug/release 设置 + if is_mode("debug") then + set_symbols("debug") + set_optimize("none") + add_defines("DEBUG") + elseif is_mode("release") then + set_symbols("hidden") + set_optimize("fast") + add_defines("NDEBUG") + end + add_cxxflags("-Wno-unused-variable") + add_includedirs("include", { public = false }) + add_includedirs(INFINI_ROOT.."/include", { public = true }) + + add_linkdirs(INFINI_ROOT.."/lib") + add_links("infiniop", "infinirt", "infiniccl") + + set_languages("cxx17") + set_warnings("all", "error") + + + + add_files("src/models/*.cpp") + add_files("src/models/*/*.cpp") + add_files("src/tensor/*.cpp") + add_files("src/allocator/*.cpp") + add_files("src/dataloader/*.cpp") + add_files("src/cache_manager/*.cpp") + add_includedirs("include") + + set_installdir(INFINI_ROOT) + add_installfiles("include/infinicore_infer.h", {prefixdir = "include"}) + add_installfiles("include/infinicore_infer/models/*.h", {prefixdir = "include/infinicore_infer/models"}) +target_end()