From ff5cc753de0a6f1ee4bac4007efc20fc12039f81 Mon Sep 17 00:00:00 2001 From: scbz4learning Date: Sun, 11 Jan 2026 11:29:44 +0800 Subject: [PATCH] Add support for Qwen2.5 GPTQ 4 bit Model --- include/infinicore_infer/models/jiuge_gptq.h | 79 ++++ scripts/jiuge_gptq.py | 369 +++++++++++++++++ scripts/launch_server.py | 11 + scripts/libinfinicore_infer/__init__.py | 11 +- scripts/libinfinicore_infer/jiuge_awq.py | 4 +- scripts/libinfinicore_infer/jiuge_gptq.py | 167 ++++++++ scripts/test_ceval.py | 85 +++- src/cache_manager/opcache_manager.hpp | 4 +- src/models/inference_context.cpp | 49 ++- src/models/inference_context.hpp | 14 +- src/models/jiuge_awq/jiuge_awq.cpp | 36 +- src/models/jiuge_awq/jiuge_awq.hpp | 4 +- src/models/jiuge_awq/jiuge_awq_weight.cpp | 6 + src/models/jiuge_gptq/jiuge_gptq.cpp | 405 +++++++++++++++++++ src/models/jiuge_gptq/jiuge_gptq.hpp | 82 ++++ src/models/jiuge_gptq/jiuge_gptq_weight.cpp | 137 +++++++ 16 files changed, 1404 insertions(+), 59 deletions(-) create mode 100644 include/infinicore_infer/models/jiuge_gptq.h create mode 100644 scripts/jiuge_gptq.py create mode 100644 scripts/libinfinicore_infer/jiuge_gptq.py create mode 100644 src/models/jiuge_gptq/jiuge_gptq.cpp create mode 100644 src/models/jiuge_gptq/jiuge_gptq.hpp create mode 100644 src/models/jiuge_gptq/jiuge_gptq_weight.cpp diff --git a/include/infinicore_infer/models/jiuge_gptq.h b/include/infinicore_infer/models/jiuge_gptq.h new file mode 100644 index 00000000..2a71df2c --- /dev/null +++ b/include/infinicore_infer/models/jiuge_gptq.h @@ -0,0 +1,79 @@ +#ifndef MODEL_JIUGE_GPTQ_H +#define MODEL_JIUGE_GPTQ_H + +#include +#include +#include + +#include + +#include "../weights_loader.h" + +struct JiugeGPTQModel; + +typedef struct +{ + infiniDtype_t dt_logits; + infiniDtype_t dt_linear_w; + infiniDtype_t dt_norm_w; + size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc; + float epsilon, theta; + uint32_t end_token; + size_t nbit; + size_t quant_group_size; + char has_qkv_bias; +} JiugeGPTQMeta; + +//////////////////// APIs /////////////////////// +__C __export struct ModelWeights * +createJiugeGPTQWeights(const JiugeGPTQMeta *, + infiniDevice_t device, + int ndev, + const int *dev_ids); +/// @brief 创建模型 +/// @param device 协处理器种类 +/// @param ndev 协处理器数量 +/// @param dev_ids 协处理器编号,长度为 ndev +__C __export struct JiugeGPTQModel * +createJiugeGPTQModel(const JiugeGPTQMeta *, + const ModelWeights *); + +/// @brief 销毁模型 +__C __export void +destroyJiugeGPTQModel(struct JiugeGPTQModel *); + +/// @brief 批次推理一轮,并采样出新的 token +/// @param tokens 输入 token 地址 +/// @param ntok 输入 token 数量 +/// @param nreq 请求数量 +/// @param req_lens 每个请求的 token 数量 +/// @param req_pos 每个请求的起始位置 +/// @param kv_caches 每个请求的 KV Cache +/// @param temperature 采样温度(0. 表示贪心采样) +/// @param topk 采样 topk(1 表示贪心采样) +/// @param topp 采样 topp +/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq +__C __export void +inferBatchJiugeGPTQ(struct JiugeGPTQModel *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output); + +/// @brief 批次推理一轮,输出 output embedding 后的 logits +/// @param tokens 输入 token 地址 +/// @param ntok 输入 token 数量 +/// @param nreq 请求数量 +/// @param req_lens 每个请求的 token 数量 +/// @param req_pos 每个请求的起始位置 +/// @param kv_caches 每个请求的 KV Cache +/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq +__C __export void +forwardBatchJiugeGPTQ(struct JiugeGPTQModel *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + void *logits); + +#endif diff --git a/scripts/jiuge_gptq.py b/scripts/jiuge_gptq.py new file mode 100644 index 00000000..e510847c --- /dev/null +++ b/scripts/jiuge_gptq.py @@ -0,0 +1,369 @@ +from typing import List, Sequence +import math +import os +from pathlib import Path +import safetensors +import sys +import time +import json +import torch +import transformers + +from libinfinicore_infer import ( + JiugeGPTQModel, + JiugeGPTQMetaCStruct, + DataType, + DeviceType, + KVCacheCStruct, +) +from infer_task import InferTask, KVCache + +from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref + +torch.set_default_device("cpu") + + +class JiugeGPTQMetaFromConfig(JiugeGPTQMetaCStruct): + def __init__(self, config, dtype=torch.float16, max_tokens=None): + if config["torch_dtype"] == "float16": + dt_ = DataType.INFINI_DTYPE_F16 + elif config["torch_dtype"] == "float32": + dt_ = DataType.INFINI_DTYPE_F32 + elif config["torch_dtype"] == "bfloat16": + dt_ = DataType.INFINI_DTYPE_BF16 + else: + dt_ = DataType.INFINI_DTYPE_F16 + + # GPTQ 没有 scale 字段 + + has_qkv_bias = 1 if ("attention_bias" in config and config["attention_bias"]) else 0 + if config["model_type"] in ["qwen2", "qwen3"]: + has_qkv_bias = 1 + eos_token_id = ( + config["eos_token_id"][0] + if type(config["eos_token_id"]) == list + else config["eos_token_id"] + ) + + super().__init__( + dt_logits=dt_, + dt_linear_w=DataType.INFINI_DTYPE_I32, + dt_norm_w=dt_, + nlayer=config["num_hidden_layers"], + d=config["hidden_size"], + nh=config["num_attention_heads"], + nkvh=( + config["num_key_value_heads"] + if "num_key_value_heads" in config + else config["num_attention_heads"] + ), + dh=config["hidden_size"] // config["num_attention_heads"], + di=config["intermediate_size"], + dctx=( + config["max_position_embeddings"] if max_tokens is None else max_tokens + ), + dvoc=config["vocab_size"], + epsilon=config["rms_norm_eps"], + theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), + end_token=eos_token_id, + nbit=config["quantization_config"]["bits"], + quant_group_size=config["quantization_config"]["group_size"], + has_qkv_bias=has_qkv_bias, + ) + self.torch_dtype_logits = dtype + + +class JiugeGPTQBatchedTask: + def __init__(self, tasks: List[InferTask]): + self.tasks = tasks + self.nreq = len(tasks) + + # Precompute fields + token_lists = [t.tokens for t in tasks] + self.req_lens_list = [len(toks) for toks in token_lists] + self.req_pos_list = [t.pos for t in tasks] + self.kv_cache_ptrs = [t.kvcache().data() for t in tasks] + self.temperaturas_list = [t.temperature for t in tasks] + self.topks_list = [t.topk for t in tasks] + self.topps_list = [t.topp for t in tasks] + + # Flatten token lists + flat_tokens = [tok for toks in token_lists for tok in toks] + self.ntok = len(flat_tokens) + + # Convert to ctypes arrays in one pass + self.tokens = (c_uint * self.ntok)(*flat_tokens) + self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) + self.req_pos = (c_uint * self.nreq)(*self.req_pos_list) + self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs) + self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) + self.topks = (c_uint * self.nreq)(*self.topks_list) + self.topps = (c_float * self.nreq)(*self.topps_list) + + def input_args(self): + return ( + self.tokens, + self.ntok, + self.req_lens, + self.nreq, + self.req_pos, + self.kv_caches, + self.temperaturas, + self.topks, + self.topps, + ) + + +class JiugeGPTQForCausalLM: + def __init__( + self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None + ): + + load_start_time = time.time() + print(f"Creating model on {ndev} devices...") + with open(os.path.join(model_dir_path, "config.json"), "r") as f: + config = json.load(f) + self.config = config + eos_token_id = self.config["eos_token_id"] + self.eos_token_id = ( + [eos_token_id] if type(eos_token_id) == int else eos_token_id + ) + self.dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) + self.ndev = ndev + self.device = device + self.meta = JiugeGPTQMetaFromConfig(config, max_tokens=max_tokens) + + self.jiuge_gptq_model = JiugeGPTQModel() + + self.weights = self.jiuge_gptq_model.create_weights( + byref(self.meta), + self.device, + ndev, + self.dev_ids, + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_dir_path, trust_remote_code=True + ) + + load_end_time = time.time() + print(f"Time used: {load_end_time - load_start_time:.3f}s") + + load_start_time = time.time() + print("Loading model weights to host...") + + self.load_all_safetensors_from_dir(os.path.join(model_dir_path)) + + self.model_instance = self.jiuge_gptq_model.create_model( + byref(self.meta), + self.weights, + ) + load_end_time = time.time() + print(f"Time used: {load_end_time - load_start_time:.3f}s") + + def load_all_safetensors_from_dir(self, dir_path_: str): + dir_path_ = Path(dir_path_) + lm_head_loaded = False + embed_tokens_tensor = None + for file in sorted(dir_path_.glob("*.safetensors")): + with safetensors.safe_open(file, framework="pt", device="cpu") as f: + for key in f.keys(): + tensor = f.get_tensor(key) + if key == "lm_head.weight": + lm_head_loaded = True + elif key == "model.embed_tokens.weight": + embed_tokens_tensor = tensor + # Some GPTQ exports (e.g. Qwen2.5 GPTQ with desc_act=False) + # store a dummy g_idx tensor (often all zeros). Our backend + # still consumes g_idx, so we convert the dummy to identity + # to keep dequant correct. + tensor_ptr = tensor.data_ptr() + if key.endswith(".g_idx") and tensor.dtype == torch.int32 and tensor.numel() > 1: + try: + if int(tensor.max().item()) == 0: + if not hasattr(self, "_gidx_fix_buffers"): + self._gidx_fix_buffers = [] + fixed = torch.arange(tensor.numel(), dtype=torch.int32) + self._gidx_fix_buffers.append(fixed) + tensor_ptr = fixed.data_ptr() + except Exception: + pass + + self.jiuge_gptq_model.load_weight( + self.weights, key, tensor_ptr + ) + if not lm_head_loaded and embed_tokens_tensor is not None: + print("lm_head.weight missing, tying to embed_tokens.weight") + self.jiuge_gptq_model.load_weight( + self.weights, "lm_head.weight", embed_tokens_tensor.data_ptr() + ) + elif not lm_head_loaded: + raise RuntimeError("lm_head.weight missing and embed_tokens.weight not found") + + def max_context_len(self): + return self.meta.dctx + + def create_kv_cache(self): + return self.jiuge_gptq_model.create_kv_cache( + self.meta.nlayer, + self.meta.dctx, + self.meta.nkvh, + self.meta.dh, + self.meta.dh, + self.meta.dt_logits, + self.device, + self.dev_ids, + self.ndev, + ) + + def drop_kv_cache(self, kv_cache): + self.jiuge_gptq_model.drop_kv_cache(kv_cache) + + def batch_infer_one_round(self, tasks: List[InferTask]): + output = (c_uint * len(tasks))() + batch_inputs = JiugeGPTQBatchedTask(tasks) + self.jiuge_gptq_model.infer_batch( + self.model_instance, + *(batch_inputs.input_args()), + output, + ) + return list(output) + + def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0): + input_content = self.tokenizer.apply_chat_template( + conversation=[{"role": "user", "content": input_content}], + add_generation_prompt=True, + tokenize=False, + ) + print(input_content, end="", flush=True) + tokens = self.tokenizer.encode(input_content) + infer_task = InferTask( + 0, + tokens, + self.max_context_len(), + temperature_, + topk_, + topp_, + self.eos_token_id, + ) + infer_task.bind_kvcache(KVCache(self)) + + steps = 0 + total_time = 0 + output_content = "" + + for step_i in range(max_steps): + start_time = time.time() + output_tokens = self.batch_infer_one_round([infer_task]) + end_time = time.time() + steps += 1 + output_str = self.tokenizer.decode(output_tokens[0]) + output_content += output_str + print(output_str, end="", flush=True) + if output_tokens[0] in self.eos_token_id: + break + infer_task.next(output_tokens[0]) + + if step_i > 0: + total_time += end_time - start_time + + print("\n") + avg_time = total_time * 1000 / (steps - 1) + print(f"Time per step: {avg_time:.3f}ms") + + infer_task._kv_cache.drop(self) + return output_content, avg_time + + def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10): + tasks = [ + InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id) + for i in range(batch_size) + ] + kv_caches = [KVCache(self) for _ in range(batch_size)] + + nll = 0.0 + total_len = 0 + + for i in range(0, len(test_sequences), batch_size): + batch_id = 0 + true_tokens = [] + while batch_id < batch_size and batch_id + i < len(test_sequences): + input_tokens = test_sequences[i + batch_id][:-1] + true_tokens.extend(test_sequences[i + batch_id][1:]) + tasks[batch_id].tokens = input_tokens + tasks[batch_id].bind_kvcache(kv_caches[batch_id]) + batch_id += 1 + + batch_inputs = JiugeGPTQBatchedTask(tasks[:batch_id]) + logits = torch.zeros( + (batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits + ) + self.jiuge_gptq_model.forward_batch( + self.model_instance, + batch_inputs.tokens, + batch_inputs.ntok, + batch_inputs.req_lens, + batch_inputs.nreq, + batch_inputs.req_pos, + batch_inputs.kv_caches, + logits.data_ptr(), + ) + + logits = logits.float() + token_ids = torch.tensor(true_tokens, dtype=torch.int64) # [ntok,] + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab) + token_logprobs = log_probs[ + torch.arange(batch_inputs.ntok), token_ids + ] # (ntok,) + + start = 0 + for l in batch_inputs.req_lens_list: + nll += -token_logprobs[start : start + l].sum().item() + start += l + total_len += token_logprobs.numel() + + for task in tasks: + task.release_kvcache() + + return math.exp(nll / total_len) + + def destroy_model_instance(self): + self.jiuge_gptq_model.destroy_model(self.model_instance) + print("Model destroyed") + + +def test(): + if len(sys.argv) < 3: + print( + "Usage: python jiuge_gptq.py [--cpu | --nvidia | --cambricon | --ascend | --metax | --moore | --iluvatar] [n_device]" + ) + sys.exit(1) + model_path = sys.argv[2] + device_type = DeviceType.DEVICE_TYPE_CPU + if sys.argv[1] == "--cpu": + device_type = DeviceType.DEVICE_TYPE_CPU + elif sys.argv[1] == "--nvidia": + device_type = DeviceType.DEVICE_TYPE_NVIDIA + elif sys.argv[1] == "--cambricon": + device_type = DeviceType.DEVICE_TYPE_CAMBRICON + elif sys.argv[1] == "--ascend": + device_type = DeviceType.DEVICE_TYPE_ASCEND + elif sys.argv[1] == "--metax": + device_type = DeviceType.DEVICE_TYPE_METAX + elif sys.argv[1] == "--moore": + device_type = DeviceType.DEVICE_TYPE_MOORE + elif sys.argv[1] == "--iluvatar": + device_type = DeviceType.DEVICE_TYPE_ILUVATAR + else: + print( + "Usage: python jiuge_gptq.py [--cpu | --nvidia | --cambricon | --ascend | --metax | --moore | --iluvatar] [n_device]" + ) + sys.exit(1) + + ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 + model = JiugeGPTQForCausalLM(model_path, device_type, ndev) + model.generate("山东最高的山是?", 500) + model.destroy_model_instance() + + +if __name__ == "__main__": + test() diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 659163c6..d04d4f69 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -69,6 +69,12 @@ def parse_args(): action="store_true", help="Whether to use AWQ quantized model (default: False)", ) + # Add support for GPTQ + parser.add_argument( + "--gptq", + action="store_true", + help="Whether to use GPTQ quantized model (default: False)", + ) return parser.parse_args() @@ -78,6 +84,7 @@ def parse_args(): ndev = args.ndev max_tokens = args.max_tokens USE_AWQ = args.awq +USE_GPTQ = args.gptq MAX_BATCH = args.max_batch print( f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs." @@ -127,6 +134,10 @@ async def lifespan(app: FastAPI): app.state.model = JiugeAWQForCausalLM( model_path, device_type, ndev, max_tokens=max_tokens ) + elif USE_GPTQ: + app.state.model = JiugeGPTQForCausalLM( + model_path, device_type, ndev, max_tokens=max_tokens + ) else: app.state.model = JiugeForCauslLM( model_path, device_type, ndev, max_tokens=max_tokens diff --git a/scripts/libinfinicore_infer/__init__.py b/scripts/libinfinicore_infer/__init__.py index 8fc5f4db..84db059a 100644 --- a/scripts/libinfinicore_infer/__init__.py +++ b/scripts/libinfinicore_infer/__init__.py @@ -1,6 +1,9 @@ from .base import DataType, DeviceType, KVCacheCStruct from .jiuge import JiugeModel, JiugeMetaCStruct, JiugeWeightsCStruct -from .jiuge_awq import JiugeAWQModel, JiugeAWQMetaCStruct, ModelWeightsCStruct +# 为了区分, 给 ModelWeightsCStruct 别名 +from .jiuge_awq import JiugeAWQModel, JiugeAWQMetaCStruct, ModelWeightsCStruct as AWQModelWeightsCStruct +# 添加 GPTQ 模块 +from .jiuge_gptq import JiugeGPTQModel, JiugeGPTQMetaCStruct, ModelWeightsCStruct as GPTQModelWeightsCStruct from .deepseek_v3 import ( DeepSeekV3Model, DeepSeekV3MetaCStruct, @@ -18,7 +21,11 @@ "JiugeWeightsCStruct", "JiugeAWQModel", "JiugeAWQMetaCStruct", - "ModelWeightsCStruct", + "AWQModelWeightsCStruct", + # Add GPTQ module + "JiugeGPTQModel", + "JiugeGPTQMetaCStruct", + "GPTQModelWeightsCStruct", "DeepSeekV3Model", "DeepSeekV3MetaCStruct", "DeepSeekV3WeightsCStruct", diff --git a/scripts/libinfinicore_infer/jiuge_awq.py b/scripts/libinfinicore_infer/jiuge_awq.py index 2f47ca8c..9d587395 100644 --- a/scripts/libinfinicore_infer/jiuge_awq.py +++ b/scripts/libinfinicore_infer/jiuge_awq.py @@ -103,7 +103,7 @@ def register_lib(cls, lib): c_void_p, ] - lib.loadModelWeight.argtypes = [ + lib.JiugeAWQLoadWeight.argtypes = [ POINTER(ModelWeightsCStruct), c_char_p, c_void_p, @@ -129,7 +129,7 @@ def drop_kv_cache(self, kv_cache): self.lib.dropKVCache(kv_cache) def load_weight(self, weights, name, data): - self.lib.loadModelWeight(weights, name.encode("utf-8"), data) + self.lib.JiugeAWQLoadWeight(weights, name.encode("utf-8"), data) def infer_batch( self, diff --git a/scripts/libinfinicore_infer/jiuge_gptq.py b/scripts/libinfinicore_infer/jiuge_gptq.py new file mode 100644 index 00000000..59d9af45 --- /dev/null +++ b/scripts/libinfinicore_infer/jiuge_gptq.py @@ -0,0 +1,167 @@ +from .base import BaseModel, DataType, DeviceType, KVCacheCStruct, register_model +from ctypes import ( + c_size_t, + c_uint, + c_int, + c_float, + c_void_p, + POINTER, + Structure, + c_char, + c_char_p, +) + + +class JiugeGPTQMetaCStruct(Structure): + _fields_ = [ + ("dt_logits", DataType), + ("dt_linear_w", DataType), + ("dt_norm_w", DataType), + ("nlayer", c_size_t), + ("d", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di", c_size_t), + ("dctx", c_size_t), + ("dvoc", c_size_t), + ("epsilon", c_float), + ("theta", c_float), + ("end_token", c_uint), + ("nbit", c_size_t), + ("quant_group_size", c_size_t), + ("has_qkv_bias", c_char), + ] + + +class ModelWeightsCStruct(Structure): + pass + + +class JiugeGPTQModelCStruct(Structure): + pass + + +@register_model +class JiugeGPTQModel(BaseModel): + @classmethod + def register_lib(cls, lib): + """Register JiugeGPTQ model functions with the library""" + lib.createJiugeGPTQWeights.restype = POINTER(ModelWeightsCStruct) + lib.createJiugeGPTQWeights.argtypes = [ + POINTER(JiugeGPTQMetaCStruct), + DeviceType, + c_int, + POINTER(c_int), + ] + + lib.createJiugeGPTQModel.restype = POINTER(JiugeGPTQModelCStruct) + lib.createJiugeGPTQModel.argtypes = [ + POINTER(JiugeGPTQMetaCStruct), + POINTER(ModelWeightsCStruct), + ] + + lib.destroyJiugeGPTQModel.argtypes = [POINTER(JiugeGPTQModelCStruct)] + + lib.createKVCache.argtypes = [ + c_size_t, + c_size_t, + c_size_t, + c_size_t, + c_size_t, + DataType, + DeviceType, + POINTER(c_int), + c_size_t, + ] + lib.createKVCache.restype = POINTER(KVCacheCStruct) + + lib.dropKVCache.argtypes = [POINTER(KVCacheCStruct)] + + lib.inferBatchJiugeGPTQ.argtypes = [ + POINTER(JiugeGPTQModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), + POINTER(POINTER(KVCacheCStruct)), + POINTER(c_float), + POINTER(c_uint), + POINTER(c_float), + POINTER(c_uint), + ] + + lib.forwardBatchJiugeGPTQ.argtypes = [ + POINTER(JiugeGPTQModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), + POINTER(POINTER(KVCacheCStruct)), + c_void_p, + ] + + lib.JiugeGPTQLoadWeight.argtypes = [ + POINTER(ModelWeightsCStruct), + c_char_p, + c_void_p, + ] + + def create_weights(self, meta, device_type, ndev, dev_ids): + return self.lib.createJiugeGPTQWeights(meta, device_type, ndev, dev_ids) + + def create_model(self, meta, weights): + return self.lib.createJiugeGPTQModel(meta, weights) + + def destroy_model(self, model): + self.lib.destroyJiugeGPTQModel(model) + + def create_kv_cache( + self, nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev + ): + return self.lib.createKVCache( + nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev + ) + + def drop_kv_cache(self, kv_cache): + self.lib.dropKVCache(kv_cache) + + def load_weight(self, weights, name, data): + self.lib.JiugeGPTQLoadWeight(weights, name.encode("utf-8"), data) + + def infer_batch( + self, + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + temperature, + topk, + topp, + output, + ): + self.lib.inferBatchJiugeGPTQ( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + temperature, + topk, + topp, + output, + ) + + def forward_batch( + self, model, tokens, ntok, req_lens, nreq, req_pos, kv_caches, logits + ): + self.lib.forwardBatchJiugeGPTQ( + model, tokens, ntok, req_lens, nreq, req_pos, kv_caches, logits + ) diff --git a/scripts/test_ceval.py b/scripts/test_ceval.py index 83365f4a..0b05dacf 100644 --- a/scripts/test_ceval.py +++ b/scripts/test_ceval.py @@ -2,6 +2,18 @@ from jiuge import * from datasets import load_dataset +# Support AWQ and GPTQ +from libinfinicore_infer import ( + JiugeAWQModel, + JiugeAWQMetaCStruct, + JiugeGPTQModel, + JiugeGPTQMetaCStruct, + DeviceType, + KVCacheCStruct, +) + +# Import missing classes +from jiuge import KVCache, InferTask class JiugeForCeval(JiugeForCauslLM): def __init__( @@ -68,25 +80,50 @@ def test(): ) sys.exit(1) - model_path = sys.argv[2] + # ----------------------------- + # 支持 AWQ/GPTQ 参数 + # ----------------------------- + use_awq = False + use_gptq = False + device_arg = None + model_path = None + ndev = 1 + + # 解析 sys.argv + for arg in sys.argv[1:]: + if arg.lower() == "--awq": + use_awq = True + elif arg.lower() == "--gptq": + use_gptq = True + elif arg.startswith("--"): + device_arg = arg.lower() + else: + if model_path is None: + model_path = arg + else: + ndev = int(arg) + + # ----------------------------- + # 设置设备类型 + # ----------------------------- device_type = DeviceType.DEVICE_TYPE_CPU - if sys.argv[1] == "--cpu": + if device_arg == "--cpu": device_type = DeviceType.DEVICE_TYPE_CPU - elif sys.argv[1] == "--nvidia": + elif device_arg == "--nvidia": device_type = DeviceType.DEVICE_TYPE_NVIDIA - elif sys.argv[1] == "--cambricon": + elif device_arg == "--cambricon": device_type = DeviceType.DEVICE_TYPE_CAMBRICON - elif sys.argv[1] == "--ascend": + elif device_arg == "--ascend": device_type = DeviceType.DEVICE_TYPE_ASCEND - elif sys.argv[1] == "--metax": + elif device_arg == "--metax": device_type = DeviceType.DEVICE_TYPE_METAX - elif sys.argv[1] == "--moore": + elif device_arg == "--moore": device_type = DeviceType.DEVICE_TYPE_MOORE - elif sys.argv[1] == "--iluvatar": + elif device_arg == "--iluvatar": device_type = DeviceType.DEVICE_TYPE_ILUVATAR - elif sys.argv[1] == "--kunlun": + elif device_arg == "--kunlun": device_type = DeviceType.DEVICE_TYPE_KUNLUN - elif sys.argv[1] == "--hygon": + elif device_arg == "--hygon": device_type = DeviceType.DEVICE_TYPE_HYGON else: print( @@ -94,8 +131,9 @@ def test(): ) sys.exit(1) - # https://huggingface.co/datasets/ceval/ceval-exam/tree/main/middle_school_geography - + # ----------------------------- + # 加载 CEval 数据集 + # ----------------------------- dataset = load_dataset(r"ceval/ceval-exam", name="middle_school_mathematics") # dataset = load_dataset(r"ceval/ceval-exam", name="high_school_history") # dataset = load_dataset(r"ceval/ceval-exam", name="high_school_chinese") @@ -104,8 +142,16 @@ def test(): # dataset = load_dataset(r"ceval/ceval-exam", name="middle_school_physics") samples = dataset["val"] - ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 - model = JiugeForCeval(model_path, device_type, ndev) + + # ----------------------------- + # 初始化模型 + # ----------------------------- + if use_awq: + model = JiugeAWQModel(model_path, device_type, ndev) + elif use_gptq: + model = JiugeGPTQModel(model_path, device_type, ndev) + else: + model = JiugeForCeval(model_path, device_type, ndev) answers_list = [] for sample in samples: @@ -131,6 +177,10 @@ def test(): print("-------------------------------------------------------------") + # ----------------------------- + # 计算正确率 + # ----------------------------- + import re true_num = 0 all_num = 0 for cont in answers_list: @@ -139,9 +189,10 @@ def test(): answer = cont["answer"] all_num = all_num + 1 - position = 0 - ABCD = output[position : position + 2] - if answer in ABCD: + # 提取模型输出的第一个选项 A/B/C/D + match = re.search(r"[A-D]", output) + model_answer = match.group(0) if match else "" + if model_answer == answer: true_num = true_num + 1 print(f"id {id} : ", "正确") else: diff --git a/src/cache_manager/opcache_manager.hpp b/src/cache_manager/opcache_manager.hpp index 4c49e961..fecaeb5a 100644 --- a/src/cache_manager/opcache_manager.hpp +++ b/src/cache_manager/opcache_manager.hpp @@ -162,6 +162,7 @@ class CacheManager { DECLARE_OP_CACHE(SwiGLU) DECLARE_OP_CACHE(RandomSample) DECLARE_OP_CACHE(DequantizeAWQ) + DECLARE_OP_CACHE(DequantizeGPTQ) CacheManager(size_t capacity = 100) : Add_cache(capacity, DESTROY_FUNC(Add)), @@ -173,7 +174,8 @@ class CacheManager { Topkrouter_cache(capacity, DESTROY_FUNC(Topkrouter)), SwiGLU_cache(capacity, DESTROY_FUNC(SwiGLU)), RandomSample_cache(capacity, DESTROY_FUNC(RandomSample)), - DequantizeAWQ_cache(capacity, DESTROY_FUNC(DequantizeAWQ)) {} + DequantizeAWQ_cache(capacity, DESTROY_FUNC(DequantizeAWQ)), + DequantizeGPTQ_cache(capacity, DESTROY_FUNC(DequantizeGPTQ)) {} template static size_t createDescriptorKey(Tensors... tensors) { diff --git a/src/models/inference_context.cpp b/src/models/inference_context.cpp index db5fda11..7aca3fcb 100644 --- a/src/models/inference_context.cpp +++ b/src/models/inference_context.cpp @@ -262,22 +262,35 @@ void InferenceContext::linear(std::shared_ptr c, void InferenceContext::dequant(std::shared_ptr weight, std::shared_ptr in_w, std::shared_ptr in_s, - std::shared_ptr in_z) { - - size_t key = CacheManager::createDescriptorKey(weight, in_w, in_s, in_z); - - infiniopDequantizeAWQDescriptor_t desc; - if (!cache_manager->getDequantizeAWQDescriptor(key, desc)) { - RUN_INFINI(infiniopCreateDequantizeAWQDescriptor(op_handle, &desc, weight->desc(), in_w->desc(), in_s->desc(), in_z->desc())); - cache_manager->putDequantizeAWQDescriptor(key, desc); + std::shared_ptr in_z, + QuantType type, + std::shared_ptr in_g_idx) { + size_t key = CacheManager::createDescriptorKey(weight, in_w, in_s, in_z, in_g_idx); + if (type == QuantType::AWQ) { + // unchanged + infiniopDequantizeAWQDescriptor_t desc; + if (!cache_manager->getDequantizeAWQDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateDequantizeAWQDescriptor(op_handle, &desc, + weight->desc(), in_w->desc(), in_s->desc(), in_z->desc())); + cache_manager->putDequantizeAWQDescriptor(key, desc); + } + size_t workspace_size = 0; + RUN_INFINI(infiniopGetDequantizeAWQWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + RUN_INFINI(infiniopDequantizeAWQ(desc, workspace_storage->memory(), workspace_size, + weight->data(), in_w->data(), in_s->data(), in_z->data(), stream)); + } else if (type == QuantType::GPTQ) { + ASSERT(in_g_idx && "GPTQ dequant requires g_idx"); + infiniopDequantizeGPTQDescriptor_t desc; + if (!cache_manager->getDequantizeGPTQDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateDequantizeGPTQDescriptor(op_handle, &desc, + weight->desc(), in_w->desc(), in_s->desc(), in_z->desc(), in_g_idx->desc())); + cache_manager->putDequantizeGPTQDescriptor(key, desc); + } + size_t workspace_size = 0; + RUN_INFINI(infiniopGetDequantizeGPTQWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + RUN_INFINI(infiniopDequantizeGPTQ(desc, workspace_storage->memory(), workspace_size, + weight->data(), in_w->data(), in_s->data(), in_z->data(), in_g_idx->data(), stream)); } - - size_t workspace_size = 0; - RUN_INFINI(infiniopGetDequantizeAWQWorkspaceSize(desc, &workspace_size)); - ensure_workspace(workspace_size); - void *workspace = workspace_storage->memory(); - - RUN_INFINI(infiniopDequantizeAWQ( - desc, workspace, workspace_size, - weight->data(), in_w->data(), in_s->data(), in_z->data(), stream)); -} +} \ No newline at end of file diff --git a/src/models/inference_context.hpp b/src/models/inference_context.hpp index 0cf93f6f..4f307026 100644 --- a/src/models/inference_context.hpp +++ b/src/models/inference_context.hpp @@ -4,6 +4,11 @@ #include +enum class QuantType { + AWQ, + GPTQ +}; + struct InferenceContext { infiniopHandle_t op_handle; std::shared_ptr memory_pool; @@ -61,7 +66,9 @@ struct InferenceContext { void dequant(std::shared_ptr weight, std::shared_ptr in_w, std::shared_ptr in_s, - std::shared_ptr in_z); + std::shared_ptr in_z, + QuantType type, + std::shared_ptr in_g_idx = nullptr); }; namespace { @@ -144,8 +151,9 @@ inline void linear(std::shared_ptr c, std::shared_ptr a, inline void dequant_linear(std::shared_ptr out, std::shared_ptr x, std::shared_ptr w_w, std::shared_ptr w_s, std::shared_ptr w_z, - float alpha, float beta, std::shared_ptr residual, std::shared_ptr bias) { + float alpha, float beta, std::shared_ptr residual, std::shared_ptr bias, + QuantType type = QuantType::AWQ, std::shared_ptr w_g_idx = nullptr) { auto w = Tensor::buffer(x->dtype(), {x->shape()[1], out->shape()[1]}, getInferenceContext().memory_pool); - getInferenceContext().dequant(w, w_w, w_s, w_z); + getInferenceContext().dequant(w, w_w, w_s, w_z, type, w_g_idx); getInferenceContext().linear(out, x, w, alpha, beta, residual, bias); } diff --git a/src/models/jiuge_awq/jiuge_awq.cpp b/src/models/jiuge_awq/jiuge_awq.cpp index 4452c400..3b095be4 100644 --- a/src/models/jiuge_awq/jiuge_awq.cpp +++ b/src/models/jiuge_awq/jiuge_awq.cpp @@ -8,7 +8,7 @@ #include #include -void createDeviceResource(DeviceResource *rsrc, const JiugeAWQMeta *meta, +void createDeviceResource(AWQDeviceResource *rsrc, const JiugeAWQMeta *meta, std::shared_ptr weights, infiniDevice_t device, int idev, int ndev, int dev_id, @@ -21,7 +21,7 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeAWQMeta *meta, auto memory_pool = std::make_shared(128 * 1024 * 1024); - *rsrc = DeviceResource{ + *rsrc = AWQDeviceResource{ device, dev_id, handle, @@ -33,7 +33,7 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeAWQMeta *meta, RUN_INFINI(infinirtDeviceSynchronize()); } -void releaseDeviceResource(DeviceResource &res) { +void releaseDeviceResource(AWQDeviceResource &res) { infinirtDeviceSynchronize(); // Release individual Tensors @@ -45,7 +45,7 @@ void releaseDeviceResource(DeviceResource &res) { res.comm = nullptr; } -void inferDeviceBatch(const JiugeAWQMeta *meta, DeviceResource &rsrc, +void inferDeviceBatch(const JiugeAWQMeta *meta, AWQDeviceResource &rsrc, uint32_t idev, uint32_t ndev, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, @@ -132,13 +132,16 @@ void inferDeviceBatch(const JiugeAWQMeta *meta, DeviceResource &rsrc, // qkv_proj dequant_linear(q_buf, logits_out, weight->w_attn_q[layer]->w, weight->w_attn_q[layer]->s, weight->w_attn_q[layer]->z, - 1.0, 0.0, nullptr, has_qkv_bias ? weight->b_attn_q[layer] : nullptr); + 1.0, 0.0, nullptr, has_qkv_bias ? weight->b_attn_q[layer] : nullptr, + QuantType::AWQ); dequant_linear(k_buf, logits_out, weight->w_attn_k[layer]->w, weight->w_attn_k[layer]->s, weight->w_attn_k[layer]->z, - 1.0, 0.0, nullptr, has_qkv_bias ? weight->b_attn_k[layer] : nullptr); + 1.0, 0.0, nullptr, has_qkv_bias ? weight->b_attn_k[layer] : nullptr, + QuantType::AWQ); dequant_linear(v_buf, logits_out, weight->w_attn_v[layer]->w, weight->w_attn_v[layer]->s, weight->w_attn_v[layer]->z, - 1.0, 0.0, nullptr, has_qkv_bias ? weight->b_attn_v[layer] : nullptr); + 1.0, 0.0, nullptr, has_qkv_bias ? weight->b_attn_v[layer] : nullptr, + QuantType::AWQ); // rope rope_v2(q_buf->view({ntok, nh, dh}), q_buf->view({ntok, nh, dh}), pos_ids_buf, weight->sin_table, weight->cos_table); rope_v2(k_buf->view({ntok, nkvh, dh}), k_buf->view({ntok, nkvh, dh}), pos_ids_buf, weight->sin_table, weight->cos_table); @@ -172,8 +175,10 @@ void inferDeviceBatch(const JiugeAWQMeta *meta, DeviceResource &rsrc, token_offset += seq_len; } // o_proj - dequant_linear(logits_in, o_buf, weight->w_attn_out[layer]->w, weight->w_attn_out[layer]->s, weight->w_attn_out[layer]->z, - 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); // only rank 0 adds residual + dequant_linear(logits_in, o_buf, + weight->w_attn_out[layer]->w, weight->w_attn_out[layer]->s, weight->w_attn_out[layer]->z, + 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr, + QuantType::AWQ); // All_reduce if distributed if (rsrc.comm != nullptr) { RUN_INFINI(infinicclAllReduce( @@ -185,14 +190,17 @@ void inferDeviceBatch(const JiugeAWQMeta *meta, DeviceResource &rsrc, rmsnorm(logits_out, logits_in, weight->w_ffn_norm[layer], meta->epsilon); dequant_linear(gate_buf, logits_out, weight->w_ffn_gate[layer]->w, weight->w_ffn_gate[layer]->s, weight->w_ffn_gate[layer]->z, - 1.0, 0.0, nullptr, nullptr); + 1.0, 0.0, nullptr, nullptr, + QuantType::AWQ); dequant_linear(up_buf, logits_out, weight->w_ffn_up[layer]->w, weight->w_ffn_up[layer]->s, weight->w_ffn_up[layer]->z, - 1.0, 0.0, nullptr, nullptr); + 1.0, 0.0, nullptr, nullptr, + QuantType::AWQ); swiglu(gate_buf, up_buf, gate_buf); dequant_linear(logits_in, gate_buf, weight->w_ffn_down[layer]->w, weight->w_ffn_down[layer]->s, weight->w_ffn_down[layer]->z, - 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); // only rank 0 adds residual + 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr, + QuantType::AWQ); // only rank 0 adds residual // All_reduce if distributed if (rsrc.comm != nullptr) { RUN_INFINI(infinicclAllReduce( @@ -307,7 +315,7 @@ forwardBatchJiugeAWQ(struct JiugeAWQModel *model, } } -void launchDevice(const JiugeAWQMeta *meta, std::shared_ptr weights, DeviceResource *rsrc, InferState &state, InferRequest &req, +void launchDevice(const JiugeAWQMeta *meta, std::shared_ptr weights, AWQDeviceResource *rsrc, InferState &state, InferRequest &req, infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) { // Create Device Resource createDeviceResource(rsrc, meta, weights, device, idev, ndev, dev_id, comm); @@ -353,7 +361,7 @@ JiugeAWQModel::JiugeAWQModel(const JiugeAWQMeta *meta, const ModelWeights *weigh device = weights->device(); dev_ids = weights->devIds(); int ndev = int(dev_ids.size()); - dev_resources = std::vector(ndev); + dev_resources = std::vector(ndev); states = std::vector(ndev); threads.resize(ndev); diff --git a/src/models/jiuge_awq/jiuge_awq.hpp b/src/models/jiuge_awq/jiuge_awq.hpp index 9a06f790..7a1f05d9 100644 --- a/src/models/jiuge_awq/jiuge_awq.hpp +++ b/src/models/jiuge_awq/jiuge_awq.hpp @@ -32,7 +32,7 @@ class JiugeAWQWeights : public infinicore::weights::Loader { } }; -struct DeviceResource { +struct AWQDeviceResource { // Device infiniDevice_t device; int device_id; @@ -73,7 +73,7 @@ struct JiugeAWQModel { JiugeAWQMeta meta; infiniDevice_t device; std::vector dev_ids; - std::vector dev_resources; + std::vector dev_resources; std::vector states; std::vector threads; InferRequest req; diff --git a/src/models/jiuge_awq/jiuge_awq_weight.cpp b/src/models/jiuge_awq/jiuge_awq_weight.cpp index b01735d0..38bdf882 100644 --- a/src/models/jiuge_awq/jiuge_awq_weight.cpp +++ b/src/models/jiuge_awq/jiuge_awq_weight.cpp @@ -126,3 +126,9 @@ createJiugeAWQWeights(const JiugeAWQMeta *meta, JiugeAWQWeights *weights = new JiugeAWQWeights(meta, device, std::vector(dev_ids, dev_ids + ndev)); return (struct ModelWeights *)weights; } + +// 建立外部别名, 防止python端冲突 +__C void JiugeAWQLoadWeight(struct ModelWeights *weights, const char *name, void *data) { + // 直接转发调用通用的 loadModelWeight + loadModelWeight(weights, name, data); +} diff --git a/src/models/jiuge_gptq/jiuge_gptq.cpp b/src/models/jiuge_gptq/jiuge_gptq.cpp new file mode 100644 index 00000000..e631beed --- /dev/null +++ b/src/models/jiuge_gptq/jiuge_gptq.cpp @@ -0,0 +1,405 @@ +#include "jiuge_gptq.hpp" + +#include "../../tensor.hpp" +#include "../../utils.hpp" +#include "../inference_context.hpp" + +#include +#include +#include + +void createDeviceResource(GPTQDeviceResource *rsrc, const JiugeGPTQMeta *meta, + std::shared_ptr weights, + infiniDevice_t device, int idev, + int ndev, int dev_id, + infinicclComm_t comm) { + RUN_INFINI(infinirtSetDevice(device, dev_id)); + infiniopHandle_t handle; + infiniopCreateHandle(&handle); + infinirtStream_t stream; + infinirtStreamCreate(&stream); + + auto memory_pool = std::make_shared(128 * 1024 * 1024); + + *rsrc = GPTQDeviceResource{ + device, + dev_id, + handle, + weights, + stream, + comm, + memory_pool, + }; + RUN_INFINI(infinirtDeviceSynchronize()); +} + +void releaseDeviceResource(GPTQDeviceResource &res) { + infinirtDeviceSynchronize(); + // Release individual Tensors + + infiniopDestroyHandle(res.handle); + res.handle = nullptr; + infinirtStreamDestroy(res.stream); + res.stream = nullptr; + infinicclCommDestroy(res.comm); + res.comm = nullptr; +} + +void inferDeviceBatch(const JiugeGPTQMeta *meta, GPTQDeviceResource &rsrc, + uint32_t idev, uint32_t ndev, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output, void *last_logits) { + auto nlayer = meta->nlayer; + auto nkvh = meta->nkvh / ndev; + auto nh = meta->nh / ndev; + auto ngroup = nh / nkvh; + // auto dctx = meta.dctx; + auto dh = meta->dh; + auto d = meta->d; + auto dt_logits = meta->dt_logits; + auto di = meta->di / ndev; + auto dvoc = meta->dvoc; + auto stream = rsrc.stream; + auto weight = rsrc.weights; + bool has_qkv_bias = meta->has_qkv_bias; + + // Allocate buffers + auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + auto logits_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); + auto q_buf = Tensor::buffer(dt_logits, {ntok, nh * dh}, rsrc.memory_pool); + auto k_buf = Tensor::buffer(dt_logits, {ntok, nkvh * dh}, rsrc.memory_pool); + auto v_buf = Tensor::buffer(dt_logits, {ntok, nkvh * dh}, rsrc.memory_pool); + + auto gate_buf = Tensor::buffer(dt_logits, {ntok, di}, rsrc.memory_pool); + auto up_buf = Tensor::buffer(dt_logits, {ntok, di}, rsrc.memory_pool); + + auto o_buf = Tensor::buffer(dt_logits, {ntok, nh * dh}, rsrc.memory_pool); + auto prob_buf = Tensor::buffer(dt_logits, {nreq, dvoc}, rsrc.memory_pool); + auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool); + auto result_cpu = std::vector(nreq); + + // Prepare inputs + auto batch_pos_ids = std::vector(ntok); + size_t req_start = 0; + for (uint32_t req = 0; req < nreq; req++) { + for (uint32_t i = 0; i < req_lens[req]; i++) { + batch_pos_ids[req_start + i] = req_pos[req] + i; + } + req_start += req_lens[req]; + } + + std::shared_ptr pos_ids_buf; + if (rsrc.device == INFINI_DEVICE_CPU) { + pos_ids_buf = Tensor::weight(batch_pos_ids.data(), INFINI_DTYPE_U32, {ntok}); + } else { + pos_ids_buf = Tensor::buffer(INFINI_DTYPE_U32, {ntok}, rsrc.memory_pool); + RUN_INFINI(infinirtMemcpyAsync(pos_ids_buf->data(), batch_pos_ids.data(), sizeof(uint32_t) * ntok, + INFINIRT_MEMCPY_H2D, stream)); + } + for (uint32_t i = 0; i < ntok; i++) { + RUN_INFINI(infinirtMemcpyAsync(logits_in->data(i * d), + weight->w_in_embd->data(tokens[i] * d), + dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream)); + } + // Attention + // attention inner + size_t max_qk_size = 0; + size_t max_seq_len = 0; + + for (uint32_t req = 0; req < nreq; req++) { + auto past_len = req_pos[req]; + auto seq_len = req_lens[req]; + auto total_len = past_len + seq_len; + + max_qk_size = std::max(max_qk_size, size_t(seq_len * total_len)); + max_seq_len = std::max(max_seq_len, size_t(seq_len)); + } + + auto qk_buf = Tensor::buffer(dt_logits, {nh * max_qk_size}, rsrc.memory_pool); + auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); + auto q_rearrange = rearrange_q_buf->view({nkvh, ngroup, max_seq_len, dh}); + auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); + auto attn_val_gemm = attn_val_buf->view({nkvh, ngroup, max_seq_len, dh}); + + // Compute + for (uint32_t layer = 0; layer < nlayer; layer++) { + // 1. Attention + // rms norm + rmsnorm(logits_out, logits_in, weight->w_attn_norm[layer], meta->epsilon); + // qkv_proj + dequant_linear(q_buf, logits_out, + weight->w_attn_q[layer]->w, weight->w_attn_q[layer]->s, weight->w_attn_q[layer]->z, + 1.0, 0.0, nullptr, has_qkv_bias ? weight->b_attn_q[layer] : nullptr, + QuantType::GPTQ, weight->w_attn_q[layer]->g_idx); + dequant_linear(k_buf, logits_out, + weight->w_attn_k[layer]->w, weight->w_attn_k[layer]->s, weight->w_attn_k[layer]->z, + 1.0, 0.0, nullptr, has_qkv_bias ? weight->b_attn_k[layer] : nullptr, + QuantType::GPTQ, weight->w_attn_k[layer]->g_idx); + dequant_linear(v_buf, logits_out, + weight->w_attn_v[layer]->w, weight->w_attn_v[layer]->s, weight->w_attn_v[layer]->z, + 1.0, 0.0, nullptr, has_qkv_bias ? weight->b_attn_v[layer] : nullptr, + QuantType::GPTQ, weight->w_attn_v[layer]->g_idx); + // rope + rope_v2(q_buf->view({ntok, nh, dh}), q_buf->view({ntok, nh, dh}), pos_ids_buf, weight->sin_table, weight->cos_table); + rope_v2(k_buf->view({ntok, nkvh, dh}), k_buf->view({ntok, nkvh, dh}), pos_ids_buf, weight->sin_table, weight->cos_table); + size_t token_offset = 0; + for (uint32_t req = 0; req < nreq; req++) { + auto past_len = req_pos[req]; + auto seq_len = req_lens[req]; + auto total_len = past_len + seq_len; + auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + auto q = q_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + auto k = k_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, dh}); + auto v = v_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, dh}); + + // self attention + // concat + rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k); + rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v); + // qk + rearrange(q_rearrange->slice(2, 0, seq_len), q); + auto qk_gemm = qk_buf->slice(0, 0, nh * seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len}); + auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}); + linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); + // softmax + auto qk_softmax = qk_gemm->view({nh, seq_len, total_len}); + causalSoftmax(qk_softmax, qk_softmax); + auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}); + linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr); + // rearrange attn val + rearrange(o, attn_val_gemm->slice(2, 0, seq_len)); + + token_offset += seq_len; + } + // o_proj + dequant_linear(logits_in, o_buf, + weight->w_attn_out[layer]->w, weight->w_attn_out[layer]->s, weight->w_attn_out[layer]->z, + 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr, + QuantType::GPTQ, weight->w_attn_out[layer]->g_idx); + // All_reduce if distributed + if (rsrc.comm != nullptr) { + RUN_INFINI(infinicclAllReduce( + logits_in->data(), logits_in->data(), ntok * d, dt_logits, + INFINICCL_SUM, rsrc.comm, stream)); + RUN_INFINI(infinirtStreamSynchronize(stream)); + } + // 2. FFN + rmsnorm(logits_out, logits_in, weight->w_ffn_norm[layer], meta->epsilon); + dequant_linear(gate_buf, logits_out, + weight->w_ffn_gate[layer]->w, weight->w_ffn_gate[layer]->s, weight->w_ffn_gate[layer]->z, + 1.0, 0.0, nullptr, nullptr, + QuantType::GPTQ, weight->w_ffn_gate[layer]->g_idx); + dequant_linear(up_buf, logits_out, + weight->w_ffn_up[layer]->w, weight->w_ffn_up[layer]->s, weight->w_ffn_up[layer]->z, + 1.0, 0.0, nullptr, nullptr, + QuantType::GPTQ, weight->w_ffn_up[layer]->g_idx); + swiglu(gate_buf, up_buf, gate_buf); + dequant_linear(logits_in, gate_buf, + weight->w_ffn_down[layer]->w, weight->w_ffn_down[layer]->s, weight->w_ffn_down[layer]->z, + 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr, + QuantType::GPTQ, weight->w_ffn_down[layer]->g_idx); + // All_reduce if distributed + if (rsrc.comm != nullptr) { + RUN_INFINI(infinicclAllReduce( + logits_in->data(), logits_in->data(), ntok * d, dt_logits, + INFINICCL_SUM, rsrc.comm, stream)); + RUN_INFINI(infinirtStreamSynchronize(stream)); + } + } + // Sample and Output + if (idev == 0) { + if (last_logits != nullptr) { + rmsnorm(logits_out, logits_in, weight->w_out_norm, meta->epsilon); + auto last_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); + linear(last_logits_buf, logits_out, weight->w_out_embd, 1.0, 0.0, nullptr, nullptr); + RUN_INFINI(infinirtStreamSynchronize(stream)); + RUN_INFINI(infinirtMemcpy(last_logits, last_logits_buf->data(), dsize(dt_logits) * ntok * dvoc, INFINIRT_MEMCPY_D2H)); + } + if (output != nullptr) { + size_t token_offset = 0; + for (uint32_t req = 0; req < nreq; req++) { + auto seq_len = req_lens[req]; + token_offset += seq_len; + rmsnorm(logits_out->slice(0, req, 1), + logits_in->slice(0, token_offset - 1, 1), + weight->w_out_norm, + meta->epsilon); + } + linear(prob_buf, logits_out->slice(0, 0, nreq), weight->w_out_embd, 1.0, 0.0, nullptr, nullptr); + std::random_device _rd; + std::mt19937 gen(_rd()); + token_offset = 0; + for (uint32_t req = 0; req < nreq; req++) { + auto seq_len = req_lens[req]; + float random_val = std::uniform_real_distribution(0, 1)(gen); + randomSample(result_buf->slice(0, req, 1)->view_as({}, {}), + prob_buf->slice(0, req, 1)->view_as({dvoc}, {1}), + random_val, topp[req], topk[req], temperature[req]); + token_offset += seq_len; + } + RUN_INFINI(infinirtStreamSynchronize(stream)); + RUN_INFINI(infinirtMemcpy(result_cpu.data(), result_buf->data(), + sizeof(int64_t) * nreq, INFINIRT_MEMCPY_D2H)); + for (uint32_t req = 0; req < nreq; req++) { + output[req] = uint32_t(result_cpu[req]); + } + } + } +} + +__C void +inferBatchJiugeGPTQ(struct JiugeGPTQModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output) { + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_caches = kv_caches; + model->req.output = output; + model->req.logits = nullptr; + model->req.temperature = temperature; + model->req.topk = topk; + model->req.topp = topp; + + for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; + lock.unlock(); + model->states[idev].cv_start.notify_one(); + } + for (size_t i = model->dev_ids.size(); i > 0; i--) { + auto idev = i - 1; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); + lock.unlock(); + } +} + +__C void +forwardBatchJiugeGPTQ(struct JiugeGPTQModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + void *logits) { + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_caches = kv_caches; + model->req.output = nullptr; + model->req.logits = logits; + model->req.temperature = nullptr; + model->req.topk = nullptr; + model->req.topp = nullptr; + + for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; + lock.unlock(); + model->states[idev].cv_start.notify_one(); + } + for (size_t i = model->dev_ids.size(); i > 0; i--) { + auto idev = i - 1; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); + lock.unlock(); + } +} + +void launchDevice(const JiugeGPTQMeta *meta, std::shared_ptr weights, GPTQDeviceResource *rsrc, InferState &state, InferRequest &req, + infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) { + // Create Device Resource + createDeviceResource(rsrc, meta, weights, device, idev, ndev, dev_id, comm); + + CacheManager cache_manager(100); + InferenceContext ctx(rsrc->handle, rsrc->memory_pool, &cache_manager, rsrc->stream); + + // Set the inference context for this thread + setInferenceContext(&ctx); + + { + std::unique_lock lock(state.mtx); + state.loaded = true; + lock.unlock(); + state.cv_load.notify_one(); + } + + // Infer Loop + while (true) { + std::unique_lock lock(state.mtx); + state.cv_start.wait(lock, [&] { return state.proceed || state.exit_flag; }); + // quit if exit_flag is set + if (state.exit_flag) { + break; + } + + inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok, + req.req_lens, req.nreq, req.req_pos, req.kv_caches, + req.temperature, req.topk, req.topp, req.output, req.logits); + + state.proceed = false; + lock.unlock(); + state.cv_done.notify_one(); + } + + // Clean-Up + releaseDeviceResource(*rsrc); + setInferenceContext(nullptr); // Clear the context when done +} + +JiugeGPTQModel::JiugeGPTQModel(const JiugeGPTQMeta *meta, const ModelWeights *weights_) { + auto weights = (JiugeGPTQWeights *)(weights_); + device = weights->device(); + dev_ids = weights->devIds(); + int ndev = int(dev_ids.size()); + dev_resources = std::vector(ndev); + states = std::vector(ndev); + threads.resize(ndev); + + auto comms = std::vector(ndev, nullptr); + if (ndev > 1) { + RUN_INFINI(infinicclCommInitAll(device, comms.data(), ndev, dev_ids.data())); + } + + for (int i = 0; i < ndev; i++) { + threads[i] = std::thread(launchDevice, meta, weights->device_weights()[i], &dev_resources[i], std::ref(states[i]), std::ref(req), device, i, ndev, dev_ids[i], comms[i]); + } + for (int i = 0; i < ndev; i++) { + std::unique_lock lock(states[i].mtx); + states[i].cv_load.wait(lock, [&] { return states[i].loaded; }); + lock.unlock(); + } +} + +__C struct JiugeGPTQModel * +createJiugeGPTQModel(const JiugeGPTQMeta *meta, + const ModelWeights *weights) { + JiugeGPTQModel *model = new JiugeGPTQModel(meta, weights); + return model; +} + +__C void destroyJiugeGPTQModel(struct JiugeGPTQModel *model) { + auto ndev = model->dev_resources.size(); + + for (size_t idev = 0; idev < ndev; idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].exit_flag = true; + lock.unlock(); + model->states[idev].cv_start.notify_one(); + } + + for (size_t idev = 0; idev < ndev; idev++) { + model->threads[idev].join(); + } + + delete model; +} diff --git a/src/models/jiuge_gptq/jiuge_gptq.hpp b/src/models/jiuge_gptq/jiuge_gptq.hpp new file mode 100644 index 00000000..77dd9ead --- /dev/null +++ b/src/models/jiuge_gptq/jiuge_gptq.hpp @@ -0,0 +1,82 @@ +#pragma once +#include "infinicore_infer/models/jiuge_gptq.h" + +#include "../../cache.hpp" +#include "../../dataloader/weights_loader.hpp" + +#include +#include +#include + +struct QuantInt4Weight { + std::shared_ptr w, s, z, g_idx; // add g_idx +}; + +struct JiugeGPTQDeviceWeight { + std::shared_ptr w_in_embd, w_out_norm, w_out_embd, sin_table, + cos_table; + std::vector> w_attn_norm, b_attn_q, b_attn_k, b_attn_v, w_ffn_norm; + std::vector> w_attn_q, w_attn_k, w_attn_v, w_attn_out, w_ffn_gate, w_ffn_up, w_ffn_down; +}; + +class JiugeGPTQWeights : public infinicore::weights::Loader { +private: + std::vector> _device_weights; + +public: + JiugeGPTQWeights(const JiugeGPTQMeta *meta, + infiniDevice_t device, + const std::vector &dev_ids); + std::vector> &device_weights() { + return _device_weights; + } +}; + +struct GPTQDeviceResource { + // Device + infiniDevice_t device; + int device_id; + infiniopHandle_t handle; + // Weights + std::shared_ptr weights; + // Streams + infinirtStream_t stream; + // Communicator + infinicclComm_t comm; + + std::shared_ptr memory_pool; +}; + +struct InferRequest { + const uint32_t *tokens; + uint32_t ntok; + const uint32_t *req_lens; + uint32_t nreq; + const uint32_t *req_pos; + struct KVCache **kv_caches; + const float *temperature; + const uint32_t *topk; + const float *topp; + uint32_t *output; + void *logits; +}; + +struct InferState { + std::mutex mtx; + std::condition_variable cv_load, cv_start, cv_done; + bool loaded = false; + bool proceed = false; + bool exit_flag = false; +}; + +struct JiugeGPTQModel { + JiugeGPTQMeta meta; + infiniDevice_t device; + std::vector dev_ids; + std::vector dev_resources; + std::vector states; + std::vector threads; + InferRequest req; + + JiugeGPTQModel(const JiugeGPTQMeta *, const ModelWeights *); +}; \ No newline at end of file diff --git a/src/models/jiuge_gptq/jiuge_gptq_weight.cpp b/src/models/jiuge_gptq/jiuge_gptq_weight.cpp new file mode 100644 index 00000000..84abd23e --- /dev/null +++ b/src/models/jiuge_gptq/jiuge_gptq_weight.cpp @@ -0,0 +1,137 @@ +#include "jiuge_gptq.hpp" + +#include + +inline std::shared_ptr getSinTable(size_t dctx, size_t dh, float theta) { + auto half_dh = dh / 2; + auto unit = dsize(INFINI_DTYPE_F16); + void *table = std::malloc(dctx * half_dh * unit); + + for (size_t i = 0; i < dctx; i++) { + for (size_t j = 0; j < half_dh; j++) { + float _sin = std::sin( + static_cast(i) / std::pow(theta, static_cast(j) / half_dh)); + + ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_sin); + } + } + auto shape = std::vector({dctx, half_dh}); + auto tensor = Tensor::weight(table, INFINI_DTYPE_F16, shape); + std::free(table); + return tensor; +} + +inline std::shared_ptr getCosTable(size_t dctx, size_t dh, float theta) { + auto half_dh = dh / 2; + auto unit = dsize(INFINI_DTYPE_F16); + void *table = std::malloc(dctx * half_dh * unit); + + for (size_t i = 0; i < dctx; i++) { + for (size_t j = 0; j < half_dh; j++) { + float _cos = std::cos( + static_cast(i) / std::pow(theta, static_cast(j) / half_dh)); + + ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_cos); + } + } + auto shape = std::vector({dctx, half_dh}); + auto tensor = Tensor::weight(table, INFINI_DTYPE_F16, shape); + std::free(table); + return tensor; +} + +JiugeGPTQWeights::JiugeGPTQWeights( + const JiugeGPTQMeta *meta, + infiniDevice_t device, + const std::vector &dev_ids) : infinicore::weights::Loader(device, dev_ids) { + auto ndev = dev_ids.size(); + _device_weights.resize(ndev); + infiniDtype_t dt_logits = meta->dt_logits; + infiniDtype_t dt_norm_w = meta->dt_norm_w; + size_t nlayer = meta->nlayer; + size_t d = meta->d; + size_t nh = meta->nh / ndev; + size_t nkvh = meta->nkvh / ndev; + size_t dh = meta->dh; + size_t di = meta->di / ndev; + size_t dctx = meta->dctx; + size_t dvoc = meta->dvoc; + size_t nbit = meta->nbit; + size_t quant_group_size = meta->quant_group_size; + + for (size_t i = 0; i < ndev; i++) { + RUN_INFINI(infinirtSetDevice(device, dev_ids[i])); + + auto weight = std::make_shared(); + _device_weights[i] = weight; + + auto w_in_embd = Tensor::weight(nullptr, dt_logits, {dvoc, d}); + this->register_weight("model.embed_tokens.weight", w_in_embd, i); + weight->w_in_embd = w_in_embd; + + auto w_out_norm = Tensor::weight(nullptr, dt_norm_w, {d}); + this->register_weight("model.norm.weight", w_out_norm, i); + weight->w_out_norm = w_out_norm; + + auto w_out_embd = Tensor::weight(nullptr, dt_logits, {dvoc, d})->permute({1, 0}); + this->register_weight("lm_head.weight", w_out_embd, i); + weight->w_out_embd = w_out_embd; + + weight->sin_table = getSinTable(dctx, dh, meta->theta); + weight->cos_table = getCosTable(dctx, dh, meta->theta); + + for (size_t layer = 0; layer < nlayer; layer++) { + +#define RIGISTER_LAYER_WEIGHT(W_NAME, W_VAR, W_SHAPE, W_DTYPE, W_DIST_TYPE) \ + auto W_VAR = Tensor::weight(nullptr, W_DTYPE, W_SHAPE); \ + this->register_weight(W_NAME, W_VAR, i, infinicore::weights::DistributionType::W_DIST_TYPE); \ + weight->W_VAR.push_back(W_VAR); + + RIGISTER_LAYER_WEIGHT("model.layers." + std::to_string(layer) + ".input_layernorm.weight", w_attn_norm, {d}, dt_norm_w, FULL); + +#define REGISTER_LAYER_QUANT_WEIGHT(W_NAME, W_VAR, W_IN, W_OUT, W_DIST_TYPE) \ + auto W_VAR = std::make_shared(); \ + /* GPTQ layout: qweight[in_packed=W_IN/8, out_features=W_OUT]; zeros/scales grouped by input */ \ + W_VAR->w = Tensor::weight(nullptr, INFINI_DTYPE_I32, {(W_IN)*nbit / 32, (W_OUT)}); \ + this->register_weight(W_NAME + ".qweight", W_VAR->w, i, infinicore::weights::DistributionType::W_DIST_TYPE); \ + W_VAR->s = Tensor::weight(nullptr, INFINI_DTYPE_F16, {(W_IN) / quant_group_size, (W_OUT)}); \ + this->register_weight(W_NAME + ".scales", W_VAR->s, i, infinicore::weights::DistributionType::W_DIST_TYPE); \ + W_VAR->z = Tensor::weight(nullptr, INFINI_DTYPE_I32, {(W_IN) / quant_group_size, (W_OUT)*nbit / 32}); \ + this->register_weight(W_NAME + ".qzeros", W_VAR->z, i, infinicore::weights::DistributionType::W_DIST_TYPE); \ + W_VAR->g_idx = Tensor::weight(nullptr, INFINI_DTYPE_I32, {(W_IN)}); \ + auto W_VAR##_gidx_dist = (infinicore::weights::DistributionType::W_DIST_TYPE == infinicore::weights::DistributionType::ROW) ? infinicore::weights::DistributionType::ROW : infinicore::weights::DistributionType::FULL; \ + this->register_weight(W_NAME + ".g_idx", W_VAR->g_idx, i, W_VAR##_gidx_dist); \ + weight->W_VAR.push_back(W_VAR); + + REGISTER_LAYER_QUANT_WEIGHT("model.layers." + std::to_string(layer) + ".self_attn.q_proj", w_attn_q, d, nh * dh, COLUMN); + REGISTER_LAYER_QUANT_WEIGHT("model.layers." + std::to_string(layer) + ".self_attn.k_proj", w_attn_k, d, nkvh * dh, COLUMN); + REGISTER_LAYER_QUANT_WEIGHT("model.layers." + std::to_string(layer) + ".self_attn.v_proj", w_attn_v, d, nkvh * dh, COLUMN); + RIGISTER_LAYER_WEIGHT("model.layers." + std::to_string(layer) + ".self_attn.q_proj.bias", b_attn_q, {nh * dh}, INFINI_DTYPE_F16, COLUMN); + RIGISTER_LAYER_WEIGHT("model.layers." + std::to_string(layer) + ".self_attn.k_proj.bias", b_attn_k, {nkvh * dh}, INFINI_DTYPE_F16, COLUMN); + RIGISTER_LAYER_WEIGHT("model.layers." + std::to_string(layer) + ".self_attn.v_proj.bias", b_attn_v, {nkvh * dh}, INFINI_DTYPE_F16, COLUMN); + REGISTER_LAYER_QUANT_WEIGHT("model.layers." + std::to_string(layer) + ".self_attn.o_proj", w_attn_out, nh * dh, d, ROW); + + RIGISTER_LAYER_WEIGHT("model.layers." + std::to_string(layer) + ".post_attention_layernorm.weight", w_ffn_norm, {d}, dt_norm_w, FULL); + REGISTER_LAYER_QUANT_WEIGHT("model.layers." + std::to_string(layer) + ".mlp.gate_proj", w_ffn_gate, d, di, COLUMN); + REGISTER_LAYER_QUANT_WEIGHT("model.layers." + std::to_string(layer) + ".mlp.up_proj", w_ffn_up, d, di, COLUMN); + REGISTER_LAYER_QUANT_WEIGHT("model.layers." + std::to_string(layer) + ".mlp.down_proj", w_ffn_down, di, d, ROW); + } + } + + #undef RIGISTER_LAYER_WEIGHT +#undef REGISTER_LAYER_QUANT_WEIGHT +} + +__C struct ModelWeights * +createJiugeGPTQWeights(const JiugeGPTQMeta *meta, + infiniDevice_t device, + int ndev, + const int *dev_ids) { + JiugeGPTQWeights *weights = new JiugeGPTQWeights(meta, device, std::vector(dev_ids, dev_ids + ndev)); + return (struct ModelWeights *)weights; +} + +// 建立外部别名, 防止python端冲突 +__C void JiugeGPTQLoadWeight(struct ModelWeights *weights, const char *name, void *data) { + loadModelWeight(weights, name, data); +}