Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions include/infinicore_infer/models/jiuge.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ struct JiugeModel;
typedef struct
{
infiniDtype_t dt_logits;
size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc, kvcache_block_size;
size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc, kvcache_block_size, dim_model_base;
float epsilon, theta;
uint32_t end_token;
// Longrope support
uint32_t rope_type; // 0 = standard, 1 = longrope
size_t original_max_position_embeddings;
const float *short_factor; // Array of dh/2 floats, nullptr if not longrope
const float *long_factor; // Array of dh/2 floats, nullptr if not longrope
} JiugeMeta;

typedef struct
Expand Down Expand Up @@ -101,15 +106,19 @@ __C __export struct KVCache *createPagedKVCache(
/// @param temperature 采样温度(0. 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
/// @param is_prefill 是否按 prefill 流程处理,0 表示 decode,1 表示 prefill
/// @param enable_paged_attn 是否启用 paged attention
/// @param repetition_penalty 重复惩罚系数(1.0 表示无惩罚)
/// @param previous_tokens_per_req 每个请求的唯一 token ID 数组指针(vLLM-style,用于高效重复惩罚)
/// @param previous_tokens_len_per_req 每个请求的唯一 token 数量
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
inferBatchJiuge(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
const float *temperature, const uint32_t *topk, const float *topp,
const float *repetition_penalty,
const uint32_t *const *previous_tokens_per_req,
const uint32_t *previous_tokens_len_per_req,
uint32_t *output);

__C __export void
Expand All @@ -120,6 +129,9 @@ inferBatch(struct JiugeModel *,
const int32_t *block_tables,
const int32_t *slot_mapping,
const float *temperature, const uint32_t *topk, const float *topp,
const float *repetition_penalty,
const uint32_t *const *previous_tokens_per_req,
const uint32_t *previous_tokens_len_per_req,
const uint32_t is_prefill, const bool enable_paged_attn,
uint32_t *output);

Expand Down
5 changes: 5 additions & 0 deletions python/icinfer/engine/libinfinicore_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,14 @@ class JiugeMetaCStruct(ctypes.Structure):
("dctx", c_size_t),
("dvoc", c_size_t),
("kvcache_block_size", c_size_t),
("dim_model_base", c_size_t),
("epsilon", c_float),
("theta", c_float),
("end_token", c_uint),
("rope_type", c_uint),
("original_max_position_embeddings", c_size_t),
("short_factor", POINTER(c_float)),
("long_factor", POINTER(c_float)),
]


Expand Down
49 changes: 47 additions & 2 deletions python/icinfer/models/jiuge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sympy import true

from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref
import ctypes
import os
from pathlib import Path
import safetensors
Expand Down Expand Up @@ -122,6 +123,45 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None):
config["num_hidden_layers"]
)

dim_model_base = (
config["dim_model_base"] if "dim_model_base" in config else config["hidden_size"]
)

# Load longrope configuration
rope_type = 0 # 0 = standard, 1 = longrope
original_max_position_embeddings = 0
short_factor_ptr = None
long_factor_ptr = None
self._short_factor_array = None # Keep reference to prevent GC
self._long_factor_array = None # Keep reference to prevent GC

rope_scaling = config.get("rope_scaling", {})
if isinstance(rope_scaling, dict):
rope_scaling_type = rope_scaling.get("rope_type") or rope_scaling.get("type", "")
if rope_scaling_type == "longrope":
rope_type = 1
original_max_position_embeddings = rope_scaling.get(
"original_max_position_embeddings",
config.get("original_max_position_embeddings", 0)
)

short_factor_list = rope_scaling.get("short_factor", [])
long_factor_list = rope_scaling.get("long_factor", [])

if short_factor_list and long_factor_list:
# Convert to ctypes arrays
half_dh = (config["hidden_size"] // config["num_attention_heads"]) // 2
if len(short_factor_list) == half_dh and len(long_factor_list) == half_dh:
self._short_factor_array = (c_float * half_dh)(*short_factor_list)
self._long_factor_array = (c_float * half_dh)(*long_factor_list)
short_factor_ptr = ctypes.cast(self._short_factor_array, POINTER(c_float))
long_factor_ptr = ctypes.cast(self._long_factor_array, POINTER(c_float))
else:
logger.warning(
f"Longrope factor arrays have wrong length: "
f"short={len(short_factor_list)}, long={len(long_factor_list)}, expected={half_dh}"
)

super().__init__(
dt_logits=dt_,
nlayer=config["num_hidden_layers"],
Expand All @@ -138,10 +178,15 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None):
config["max_position_embeddings"] if max_tokens is None else max_tokens
),
dvoc=config["vocab_size"],
block_size=config["block_size"],
kvcache_block_size=config["block_size"],
dim_model_base=dim_model_base,
epsilon=config["rms_norm_eps"],
theta=(config["rope_theta"] if "rope_theta" in config else 100000.0),
end_token=2,
rope_type=rope_type,
original_max_position_embeddings=original_max_position_embeddings,
short_factor=short_factor_ptr,
long_factor=long_factor_ptr,
)
self.torch_dtype_logits = dtype

Expand Down Expand Up @@ -206,7 +251,7 @@ def __init__(
)
self.input_embd = self.input_embd_tensor.data_ptr()
self.output_norm_tensor = (
state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output
state_dict[naming.output_norm()].to(torch_dt_norm)
)
self.output_norm = self.output_norm_tensor.data_ptr()
self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat)
Expand Down
56 changes: 55 additions & 1 deletion python/icinfer/utils/jiuge_weights_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Tuple
import math
from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref
import ctypes
import os
from pathlib import Path
import safetensors
Expand Down Expand Up @@ -119,6 +120,54 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None):
self.scale_o = config.scale_depth / math.sqrt(config.num_hidden_layers)
self.scale_down = config.scale_depth / math.sqrt(config.num_hidden_layers)

dim_model_base = (
config.dim_model_base if hasattr(config, "dim_model_base") else config.hidden_size
)

# Load longrope configuration
rope_type = 0 # 0 = standard, 1 = longrope
original_max_position_embeddings = 0
short_factor_ptr = None
long_factor_ptr = None
self._short_factor_array = None # Keep reference to prevent GC
self._long_factor_array = None # Keep reference to prevent GC

# Handle both dict and object config
if hasattr(config, "rope_scaling"):
rope_scaling = config.rope_scaling
elif isinstance(config, dict) and "rope_scaling" in config:
rope_scaling = config["rope_scaling"]
else:
rope_scaling = {}

if isinstance(rope_scaling, dict):
rope_scaling_type = rope_scaling.get("rope_type") or rope_scaling.get("type", "")
if rope_scaling_type == "longrope":
rope_type = 1
original_max_position_embeddings = rope_scaling.get(
"original_max_position_embeddings",
getattr(config, "original_max_position_embeddings", 0) if not isinstance(config, dict) else config.get("original_max_position_embeddings", 0)
)

short_factor_list = rope_scaling.get("short_factor", [])
long_factor_list = rope_scaling.get("long_factor", [])

if short_factor_list and long_factor_list:
# Convert to ctypes arrays
half_dh = (config.hidden_size // config.num_attention_heads) // 2
if len(short_factor_list) == half_dh and len(long_factor_list) == half_dh:
self._short_factor_array = (c_float * half_dh)(*short_factor_list)
self._long_factor_array = (c_float * half_dh)(*long_factor_list)
short_factor_ptr = ctypes.cast(self._short_factor_array, POINTER(c_float))
long_factor_ptr = ctypes.cast(self._long_factor_array, POINTER(c_float))
else:
import logging
logger = logging.getLogger(__name__)
logger.warning(
f"Longrope factor arrays have wrong length: "
f"short={len(short_factor_list)}, long={len(long_factor_list)}, expected={half_dh}"
)

super().__init__(
dt_logits=dt_,
nlayer=config.num_hidden_layers,
Expand All @@ -134,9 +183,14 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None):
dctx=(config.max_position_embeddings if max_tokens is None else max_tokens),
dvoc=config.vocab_size,
kvcache_block_size=config.kvcache_block_size,
dim_model_base=dim_model_base,
epsilon=config.rms_norm_eps,
theta=(config.rope_theta if hasattr(config, "rope_theta") else 100000.0),
end_token=2,
rope_type=rope_type,
original_max_position_embeddings=original_max_position_embeddings,
short_factor=short_factor_ptr,
long_factor=long_factor_ptr,
)
self.torch_dtype_logits = dtype

Expand Down Expand Up @@ -201,7 +255,7 @@ def __init__(
)
self.input_embd = self.input_embd_tensor.data_ptr()
self.output_norm_tensor = (
state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output
state_dict[naming.output_norm()].to(torch_dt_norm)
)
self.output_norm = self.output_norm_tensor.data_ptr()
self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat)
Expand Down
36 changes: 34 additions & 2 deletions scripts/infer_task.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,34 @@
class InferTask:
def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens):
def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens, repetition_penalty=1.0):
self.id = id
self.finish_reason = None
self.tokens = tokens
self.max_tokens = max_tokens
self.temperature = temperature
self.topk = topk
self.topp = topp
self.repetition_penalty = repetition_penalty
self.end_tokens = end_tokens
self._kv_cache = None
self.pos = 0

# vLLM-style unique token tracking for efficient repetition penalty
# Track unique token IDs that have been generated (not the full sequence)
# Initialize with prompt tokens so they are also penalized
self._unique_generated_tokens = set(tokens) # Initialize with prompt tokens!
self._unique_tokens_array = sorted(self._unique_generated_tokens) # Pre-sort for efficiency
self._unique_tokens_dirty = False # Already initialized, no need to rebuild

def bind_kvcache(self, kv_cache, pos=0):
self._kv_cache = kv_cache
self.pos = pos
self.tokens = self.tokens[pos:]
# Update tokens and add any new tokens to unique set
remaining_tokens = self.tokens[pos:]
for token in remaining_tokens:
if token not in self._unique_generated_tokens:
self._unique_generated_tokens.add(token)
self._unique_tokens_dirty = True
self.tokens = remaining_tokens

def release_kvcache(self):
cache = self._kv_cache
Expand All @@ -34,6 +48,24 @@ def next(self, out_token):
self.finish_reason = "length"
else:
self.tokens = [out_token]
# Incrementally update unique token set (vLLM-style)
# Only add if it's a new token (O(1) average)
if out_token not in self._unique_generated_tokens:
self._unique_generated_tokens.add(out_token)
self._unique_tokens_dirty = True

def get_unique_previous_tokens(self):
"""
Returns a sorted list of unique token IDs that have been generated.
This is the vLLM-style "seen tokens" list for efficient repetition penalty.

Returns:
tuple: (array, length) where array is sorted list of unique token IDs
"""
if self._unique_tokens_dirty:
self._unique_tokens_array = sorted(self._unique_generated_tokens)
self._unique_tokens_dirty = False
return self._unique_tokens_array, len(self._unique_tokens_array)


class KVCache:
Expand Down
Loading