From 92545b5054ed1d07f6fbad6294df6bfff7db0239 Mon Sep 17 00:00:00 2001 From: Ceng23333 <441651826@qq.com> Date: Thu, 25 Dec 2025 09:42:10 +0800 Subject: [PATCH 01/12] =?UTF-8?q?=E6=9C=8D=E5=8A=A1=E7=AB=AF=E6=94=AF?= =?UTF-8?q?=E6=8C=81repetition=5Fpenalty?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ceng23333 <441651826@qq.com> --- include/infinicore_infer/models/jiuge.h | 11 +- scripts/infer_task.py | 36 ++++++- scripts/jiuge.py | 68 +++++++++++- scripts/launch_server.py | 138 +++++++++++++++++++----- scripts/libinfinicore_infer/jiuge.py | 9 ++ src/models/inference_context.cpp | 22 ++-- src/models/inference_context.hpp | 17 +-- src/models/jiuge/jiuge.cpp | 75 ++++++++++--- src/models/jiuge/jiuge_impl.hpp | 3 + 9 files changed, 309 insertions(+), 70 deletions(-) diff --git a/include/infinicore_infer/models/jiuge.h b/include/infinicore_infer/models/jiuge.h index 8b09e1ac..26a970e3 100644 --- a/include/infinicore_infer/models/jiuge.h +++ b/include/infinicore_infer/models/jiuge.h @@ -101,8 +101,9 @@ __C __export struct KVCache *createPagedKVCache( /// @param temperature 采样温度(0. 表示贪心采样) /// @param topk 采样 topk(1 表示贪心采样) /// @param topp 采样 topp -/// @param is_prefill 是否按 prefill 流程处理,0 表示 decode,1 表示 prefill -/// @param enable_paged_attn 是否启用 paged attention +/// @param repetition_penalty 重复惩罚系数(1.0 表示无惩罚) +/// @param previous_tokens_per_req 每个请求的唯一 token ID 数组指针(vLLM-style,用于高效重复惩罚) +/// @param previous_tokens_len_per_req 每个请求的唯一 token 数量 /// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq __C __export void inferBatchJiuge(struct JiugeModel *, @@ -110,6 +111,9 @@ inferBatchJiuge(struct JiugeModel *, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct KVCache **kv_caches, const float *temperature, const uint32_t *topk, const float *topp, + const float *repetition_penalty, + const uint32_t *const *previous_tokens_per_req, + const uint32_t *previous_tokens_len_per_req, uint32_t *output); __C __export void @@ -120,6 +124,9 @@ inferBatch(struct JiugeModel *, const int32_t *block_tables, const int32_t *slot_mapping, const float *temperature, const uint32_t *topk, const float *topp, + const float *repetition_penalty, + const uint32_t *const *previous_tokens_per_req, + const uint32_t *previous_tokens_len_per_req, const uint32_t is_prefill, const bool enable_paged_attn, uint32_t *output); diff --git a/scripts/infer_task.py b/scripts/infer_task.py index 0d1231b7..eb0137d5 100644 --- a/scripts/infer_task.py +++ b/scripts/infer_task.py @@ -1,5 +1,5 @@ class InferTask: - def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): + def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens, repetition_penalty=1.0): self.id = id self.finish_reason = None self.tokens = tokens @@ -7,14 +7,28 @@ def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): self.temperature = temperature self.topk = topk self.topp = topp + self.repetition_penalty = repetition_penalty self.end_tokens = end_tokens self._kv_cache = None self.pos = 0 + # vLLM-style unique token tracking for efficient repetition penalty + # Track unique token IDs that have been generated (not the full sequence) + # Initialize with prompt tokens so they are also penalized + self._unique_generated_tokens = set(tokens) # Initialize with prompt tokens! + self._unique_tokens_array = sorted(self._unique_generated_tokens) # Pre-sort for efficiency + self._unique_tokens_dirty = False # Already initialized, no need to rebuild + def bind_kvcache(self, kv_cache, pos=0): self._kv_cache = kv_cache self.pos = pos - self.tokens = self.tokens[pos:] + # Update tokens and add any new tokens to unique set + remaining_tokens = self.tokens[pos:] + for token in remaining_tokens: + if token not in self._unique_generated_tokens: + self._unique_generated_tokens.add(token) + self._unique_tokens_dirty = True + self.tokens = remaining_tokens def release_kvcache(self): cache = self._kv_cache @@ -34,6 +48,24 @@ def next(self, out_token): self.finish_reason = "length" else: self.tokens = [out_token] + # Incrementally update unique token set (vLLM-style) + # Only add if it's a new token (O(1) average) + if out_token not in self._unique_generated_tokens: + self._unique_generated_tokens.add(out_token) + self._unique_tokens_dirty = True + + def get_unique_previous_tokens(self): + """ + Returns a sorted list of unique token IDs that have been generated. + This is the vLLM-style "seen tokens" list for efficient repetition penalty. + + Returns: + tuple: (array, length) where array is sorted list of unique token IDs + """ + if self._unique_tokens_dirty: + self._unique_tokens_array = sorted(self._unique_generated_tokens) + self._unique_tokens_dirty = False + return self._unique_tokens_array, len(self._unique_tokens_array) class KVCache: diff --git a/scripts/jiuge.py b/scripts/jiuge.py index 676d7de7..6b7e4530 100644 --- a/scripts/jiuge.py +++ b/scripts/jiuge.py @@ -20,6 +20,7 @@ from infer_task import InferTask, KVCache from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref +import ctypes torch.set_default_device("cpu") @@ -395,11 +396,55 @@ def __init__(self, tasks: List[InferTask]): self.temperaturas_list = [t.temperature for t in tasks] self.topks_list = [t.topk for t in tasks] self.topps_list = [t.topp for t in tasks] + self.repetition_penalties_list = [t.repetition_penalty for t in tasks] # Flatten token lists flat_tokens = [tok for toks in token_lists for tok in toks] self.ntok = len(flat_tokens) + # Collect unique tokens per request (vLLM-style for efficient repetition penalty) + # Each request has its own list of unique token IDs + self.unique_tokens_arrays = [] # List of arrays, one per request + self.unique_tokens_lens = [] # List of lengths, one per request + self.unique_tokens_flat = [] # Flattened array for C API + self.unique_tokens_offsets = [0] # Offsets into flat array + + total_unique_tokens = 0 + for task in tasks: + tokens_array, tokens_len = task.get_unique_previous_tokens() + self.unique_tokens_arrays.append(tokens_array) + self.unique_tokens_lens.append(tokens_len) + self.unique_tokens_flat.extend(tokens_array) + total_unique_tokens += tokens_len + self.unique_tokens_offsets.append(total_unique_tokens) + + # Convert to C-compatible arrays + if total_unique_tokens > 0: + self.unique_tokens_c = (c_uint * total_unique_tokens)(*self.unique_tokens_flat) + # Create array of pointers, one per request + self.unique_tokens_ptrs = [] + for req_idx in range(self.nreq): + offset = self.unique_tokens_offsets[req_idx] + length = self.unique_tokens_lens[req_idx] + if length > 0: + # Create pointer to the start of this request's tokens in the flat array + ptr = ctypes.cast( + ctypes.addressof(self.unique_tokens_c) + offset * ctypes.sizeof(c_uint), + POINTER(c_uint) + ) + else: + ptr = None + self.unique_tokens_ptrs.append(ptr) + # Create array of pointers (use None for empty requests) + self.unique_tokens_ptrs_array = (POINTER(c_uint) * self.nreq)(*self.unique_tokens_ptrs) + else: + self.unique_tokens_c = None + # All requests have no previous tokens + self.unique_tokens_ptrs_array = (POINTER(c_uint) * self.nreq)(*[None] * self.nreq) + + # Array of lengths per request + self.unique_tokens_lens_array = (c_uint * self.nreq)(*self.unique_tokens_lens) + # Convert to ctypes arrays in one pass self.tokens = (c_uint * self.ntok)(*flat_tokens) self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) @@ -408,6 +453,7 @@ def __init__(self, tasks: List[InferTask]): self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) self.topks = (c_uint * self.nreq)(*self.topks_list) self.topps = (c_float * self.nreq)(*self.topps_list) + self.repetition_penalties = (c_float * self.nreq)(*self.repetition_penalties_list) def input_args(self): return ( @@ -420,6 +466,9 @@ def input_args(self): self.temperaturas, self.topks, self.topps, + self.repetition_penalties, + self.unique_tokens_ptrs_array, # Array of pointers to unique tokens per request + self.unique_tokens_lens_array, # Array of lengths per request ) @@ -534,7 +583,7 @@ def load_all_safetensors_from_dir(dir_path_: str): else: raise ValueError("Unsupported model architecture") - + if "llama" == config["model_type"]: from tokenizers import decoders as _dec backend = getattr(self.tokenizer, "backend_tokenizer", None) @@ -593,9 +642,21 @@ def drop_kv_cache(self, kv_cache): def batch_infer_one_round(self, tasks: List[InferTask]): output = (c_uint * len(tasks))() batch_inputs = JiugeBatchedTask(tasks) + args = batch_inputs.input_args() self.jiuge_model.infer_batch( self.model_instance, - *(batch_inputs.input_args()), + args[0], # tokens + args[1], # ntok + args[2], # req_lens + args[3], # nreq + args[4], # req_pos + args[5], # kv_caches + args[6], # temperature + args[7], # topk + args[8], # topp + args[9], # repetition_penalty + args[10], # previous_tokens_per_req + args[11], # previous_tokens_len_per_req output, ) return list(output) @@ -616,6 +677,7 @@ def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1. topk_, topp_, self.eos_token_id, + 1.0, # repetition_penalty default ) infer_task.bind_kvcache(KVCache(self)) @@ -648,7 +710,7 @@ def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1. def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10): tasks = [ - InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id) + InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id, 1.0) for i in range(batch_size) ] kv_caches = [KVCache(self) for _ in range(batch_size)] diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 115fbd0a..4b918b6d 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -69,6 +69,18 @@ def parse_args(): action="store_true", help="Whether to use AWQ quantized model (default: False)", ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to run the server on (default: 8000)", + ) + parser.add_argument( + "--host", + type=str, + default="0.0.0.0", + help="Host to bind the server to (default: 0.0.0.0)", + ) return parser.parse_args() @@ -79,6 +91,8 @@ def parse_args(): max_tokens = args.max_tokens USE_AWQ = args.awq MAX_BATCH = args.max_batch +SERVER_PORT = args.port +SERVER_HOST = args.host print( f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs." ) @@ -110,8 +124,8 @@ def chunk_json(id_, content=None, role=None, finish_reason=None): # A wrapper for InferTask that supports async output queue class AsyncInferTask(InferTask): - def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): - super().__init__(id, tokens, max_tokens, temperature, topk, topp, end_tokens) + def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens, repetition_penalty=1.0): + super().__init__(id, tokens, max_tokens, temperature, topk, topp, end_tokens, repetition_penalty) self.output_queue = janus.Queue() print(f"[INFO] Create InferTask {self.id}") @@ -185,11 +199,16 @@ def build_task(id_, request_data, request: Request): if "messages" in request_data: # Chat format messages = request_data.get("messages", []) - input_content = request.app.state.model.tokenizer.apply_chat_template( - conversation=messages, - add_generation_prompt=True, - tokenize=False, - ) + # Get chat_template_kwargs from request, default to empty dict + chat_template_kwargs = request_data.get("chat_template_kwargs", {}) + # Merge with default parameters, allowing chat_template_kwargs to override + template_params = { + "conversation": messages, + "add_generation_prompt": True, + "tokenize": False, + **chat_template_kwargs # Allow override of defaults + } + input_content = request.app.state.model.tokenizer.apply_chat_template(**template_params) tokens = request.app.state.model.tokenizer.encode(input_content) max_tokens = request_data.get("max_tokens", request.app.state.model.max_context_len()) else: @@ -197,15 +216,16 @@ def build_task(id_, request_data, request: Request): prompt = request_data.get("prompt", "") tokens = request.app.state.model.tokenizer.encode(prompt) max_tokens = request_data.get("max_tokens", 0) - + return AsyncInferTask( id_, tokens, max_tokens, request_data.get("temperature", 1.0), - request_data.get("top_k", 1), + request_data.get("top_k", 0), # Default to 0 (disabled) to consider all tokens, matching vLLM behavior request_data.get("top_p", 1.0), request.app.state.model.eos_token_id, + request_data.get("repetition_penalty", 1.0), ) @@ -318,16 +338,16 @@ async def completion(id_, request_data, request: Request): max_tokens = request_data.get("max_tokens", 0) if max_tokens > 0: return JSONResponse( - content={"error": "max_tokens > 0 is not supported yet. Please use max_tokens=0 for logprobs calculation."}, + content={"error": "max_tokens > 0 is not supported yet. Please use max_tokens=0 for logprobs calculation."}, status_code=400 ) - + infer_task = build_task(id_, request_data, request) await request.app.state.kv_cache_pool.acquire(infer_task) - + output = [] logprobs = [] - + # Handle echo and logprobs calculation echo = request_data.get("echo", False) if echo: @@ -340,12 +360,12 @@ async def completion(id_, request_data, request: Request): .replace("<0x0A>", "\n") ) output.append(content) - + # Calculate logprobs for input tokens from jiuge import JiugeBatchedTask batch_inputs = JiugeBatchedTask([infer_task]) log_probs = torch.zeros( - (batch_inputs.ntok, request.app.state.model.meta.dvoc), + (batch_inputs.ntok, request.app.state.model.meta.dvoc), dtype=request.app.state.model.meta.torch_dtype_logits ) request.app.state.model.jiuge_model.forward_batch( @@ -358,39 +378,39 @@ async def completion(id_, request_data, request: Request): batch_inputs.kv_caches, log_probs.data_ptr(), ) - + log_probs = log_probs.float() - + # Calculate correct logprobs for input tokens token_logprobs = [] for i in range(len(infer_task.tokens) - 1): # Only up to second-to-last token next_token = infer_task.tokens[i+1] # Next token to predict logprob = log_probs[i, next_token].item() # Use position i logits to predict position i+1 token token_logprobs.append(logprob) - + # First token has no context, so logprob is None logprobs = [None] + token_logprobs else: # echo=false: don't calculate logprobs since user can't see input text logprobs = [] - + # For max_tokens=0, we need to manually release the KV cache since we don't go through worker await request.app.state.kv_cache_pool.release(infer_task) print(f"[DEBUG] {id_} Released KV cache for max_tokens=0") output_text = "".join(output).strip() - + # Prepare tokens list for logprobs tokens_list = [] text_offset_list = [] current_offset = 0 - + # Build tokens list and text offsets for i, content in enumerate(output): tokens_list.append(content) text_offset_list.append(current_offset) current_offset += len(content) - + # Build response according to DeepSeek API completion format response = { "id": id_, @@ -440,15 +460,56 @@ async def completions(request: Request): id_ = f"cmpl-{uuid.uuid4().hex}" response = await completion(id_, data, request) - + # Check if response is already a JSONResponse (error case) if isinstance(response, JSONResponse): return response else: return JSONResponse(content=response) + +@App.get("/models") +async def list_models(request: Request): + """ + OpenAI-compatible /models endpoint. + Returns a list of available models. + """ + try: + # Get model information from app state + model = request.app.state.model + model_id = "jiuge" # Default model ID + + # Try to get model name from config if available + if hasattr(model, 'config') and model.config: + # Try model_type first + model_id = model.config.get("model_type", "jiuge") + # If model_type is not informative, try architectures + if model_id == "jiuge" and "architectures" in model.config: + architectures = model.config.get("architectures", []) + if architectures: + model_id = architectures[0].lower() + + return JSONResponse(content={ + "object": "list", + "data": [ + { + "id": model_id, + "object": "model", + "created": int(time.time()), + "owned_by": "infini", + "permission": [], + "root": model_id, + "parent": None + } + ] + }) + except Exception as e: + print(f"[Error] Exception in /models: {e}") + return JSONResponse(content={"error": str(e)}, status_code=500) + + if __name__ == "__main__": - uvicorn.run(App, host="0.0.0.0", port=8000) + uvicorn.run(App, host=SERVER_HOST, port=SERVER_PORT) """ curl -N -H "Content-Type: application/json" \ @@ -456,12 +517,31 @@ async def completions(request: Request): -d '{ "model": "jiuge", "messages": [ - {"role": "user", "content": "山东最高的山是?"} + {"role": "user", "content": "介绍你自己"} + ], + "temperature": 0.7, + "top_p": 0.7, + "repetition_penalty": 1.02, + "stream": false, + "chat_template_kwargs": {"enable_thinking": false} + }' + + +curl -N -H "Content-Type: application/json" \ + -X POST http://127.0.0.1:8000/chat/completions \ + -d '{ + "model": "jiuge", + "messages": [ + {"role": "system", "content": "你是一个由启元实验室开发的九格助手,你擅长中英文对话,能够理解并处理各种问题,提供安全、有帮助、准确的回答。当前时间:2025-12-24#注意:回复之前注意结合上下文和工具返回内容进行回复"}, + {"role": "user", "content": "怎么看待台海局势"} ], - "temperature": 1.0, - "top_k": 50, - "top_p": 0.8, + "temperature": 0.7, + "top_p": 0.7, "max_tokens": 512, - "stream": true + "repetition_penalty": 1.1, + "stream": false, + "chat_template_kwargs": {"enable_thinking": false} }' + + """ diff --git a/scripts/libinfinicore_infer/jiuge.py b/scripts/libinfinicore_infer/jiuge.py index 89553041..6338f2f9 100644 --- a/scripts/libinfinicore_infer/jiuge.py +++ b/scripts/libinfinicore_infer/jiuge.py @@ -86,6 +86,9 @@ def register_lib(cls, lib): POINTER(c_float), POINTER(c_uint), POINTER(c_float), + POINTER(c_float), + POINTER(POINTER(c_uint)), # previous_tokens_per_req: array of pointers + POINTER(c_uint), # previous_tokens_len_per_req: array of lengths POINTER(c_uint), ] @@ -128,6 +131,9 @@ def infer_batch( temperature, topk, topp, + repetition_penalty, + previous_tokens_per_req, + previous_tokens_len_per_req, output, ): self.lib.inferBatchJiuge( @@ -141,6 +147,9 @@ def infer_batch( temperature, topk, topp, + repetition_penalty, + previous_tokens_per_req, + previous_tokens_len_per_req, output, ) diff --git a/src/models/inference_context.cpp b/src/models/inference_context.cpp index 1cbc267a..1e5bd9bc 100644 --- a/src/models/inference_context.cpp +++ b/src/models/inference_context.cpp @@ -272,8 +272,11 @@ void InferenceContext::swiglu(std::shared_ptr out, } void InferenceContext::randomSample(std::shared_ptr out, - std::shared_ptr prob, - float random_val, float top_p, uint32_t top_k, float temperature) { + std::shared_ptr prob, + float random_val, float top_p, uint32_t top_k, float temperature, + float repetition_penalty, + const uint32_t *previous_tokens, + size_t previous_tokens_len) { size_t key = CacheManager::createDescriptorKey(out, prob); infiniopRandomSampleDescriptor_t desc; @@ -288,10 +291,12 @@ void InferenceContext::randomSample(std::shared_ptr out, ensure_workspace(workspace_size); void *workspace = workspace_storage->memory(); + RUN_INFINI(infiniopRandomSample( desc, workspace, workspace_size, out->data(), prob->data(), - random_val, top_p, top_k, temperature, + random_val, top_p, top_k, temperature, repetition_penalty, + previous_tokens, previous_tokens_len, stream)); } @@ -374,7 +379,7 @@ void InferenceContext::pagedCaching(std::shared_ptr k, infiniopPagedCachingDescriptor_t desc; if (!cache_manager->getPagedCachingDescriptor(key, desc)) { RUN_INFINI(infiniopCreatePagedCachingDescriptor( - op_handle, &desc, k->desc(), v->desc(), + op_handle, &desc, k->desc(), v->desc(), k_cache->desc(), v_cache->desc(), slot_mapping->desc())); cache_manager->putPagedCachingDescriptor(key, desc); } @@ -416,7 +421,7 @@ void InferenceContext::pagedAttention(std::shared_ptr out, RUN_INFINI(infiniopGetPagedAttentionWorkspaceSize(desc, &workspace_size)); ensure_workspace(workspace_size); void *workspace = workspace_storage->memory(); - + const void* alibi_data = alibi_slopes ? alibi_slopes->data() : nullptr; RUN_INFINI(infiniopPagedAttention( desc, workspace, workspace_size, @@ -424,10 +429,3 @@ void InferenceContext::pagedAttention(std::shared_ptr out, block_tables->data(), seq_lens->data(), alibi_data, stream)); } - - - - - - - diff --git a/src/models/inference_context.hpp b/src/models/inference_context.hpp index 96d8601c..12709095 100644 --- a/src/models/inference_context.hpp +++ b/src/models/inference_context.hpp @@ -61,7 +61,10 @@ struct InferenceContext { std::shared_ptr gate); void randomSample(std::shared_ptr out, std::shared_ptr prob, - float random_val, float top_p, uint32_t top_k, float temperature); + float random_val, float top_p, uint32_t top_k, float temperature, + float repetition_penalty = 1.0f, + const uint32_t *previous_tokens = nullptr, + size_t previous_tokens_len = 0); void linear(std::shared_ptr c, std::shared_ptr a, @@ -79,7 +82,7 @@ struct InferenceContext { std::shared_ptr k_cache, std::shared_ptr v_cache, std::shared_ptr slot_mapping); - + void pagedAttention(std::shared_ptr out, std::shared_ptr q, std::shared_ptr k_cache, @@ -180,8 +183,12 @@ inline void swiglu(std::shared_ptr out, std::shared_ptr up, } inline void randomSample(std::shared_ptr out, std::shared_ptr prob, - float random_val, float top_p, uint32_t top_k, float temperature) { - getInferenceContext().randomSample(out, prob, random_val, top_p, top_k, temperature); + float random_val, float top_p, uint32_t top_k, float temperature, + float repetition_penalty = 1.0f, + const uint32_t *previous_tokens = nullptr, + size_t previous_tokens_len = 0) { + getInferenceContext().randomSample(out, prob, random_val, top_p, top_k, temperature, repetition_penalty, + previous_tokens, previous_tokens_len); } inline void linear(std::shared_ptr c, std::shared_ptr a, @@ -211,5 +218,3 @@ inline void pagedAttention(std::shared_ptr out, std::shared_ptr std::shared_ptr alibi_slopes, float scale) { getInferenceContext().pagedAttention(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale); } - - diff --git a/src/models/jiuge/jiuge.cpp b/src/models/jiuge/jiuge.cpp index 8432bfee..9c5f23ab 100644 --- a/src/models/jiuge/jiuge.cpp +++ b/src/models/jiuge/jiuge.cpp @@ -126,6 +126,9 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct KVCache **kv_caches, const float *temperature, const uint32_t *topk, const float *topp, + const float *repetition_penalty, + const uint32_t *const *previous_tokens_per_req, + const uint32_t *previous_tokens_len_per_req, uint32_t *output, void *last_logits) { auto nlayer = meta.nlayer; auto nkvh = meta.nkvh / ndev; @@ -280,10 +283,10 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, rmsnorm(logits_out, logits_in, rsrc.w_out_norm, meta.epsilon); auto last_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); linear(last_logits_buf, logits_out, rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr); - + auto log_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); logSoftmax(log_logits_buf, last_logits_buf); - + RUN_INFINI(infinirtStreamSynchronize(stream)); RUN_INFINI(infinirtMemcpy(last_logits, log_logits_buf->data(), dsize(dt_logits) * ntok * dvoc, INFINIRT_MEMCPY_D2H)); } @@ -304,9 +307,19 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, for (uint32_t req = 0; req < nreq; req++) { auto seq_len = req_lens[req]; float random_val = std::uniform_real_distribution(0, 1)(gen); + float rep_penalty_val = (repetition_penalty != nullptr) ? repetition_penalty[req] : 1.0f; + // Get unique tokens for this request (vLLM-style) + const uint32_t *prev_tokens = nullptr; + size_t prev_tokens_len = 0; + if (previous_tokens_per_req != nullptr && previous_tokens_per_req[req] != nullptr) { + prev_tokens = previous_tokens_per_req[req]; + prev_tokens_len = (previous_tokens_len_per_req != nullptr) ? previous_tokens_len_per_req[req] : 0; + } randomSample(result_buf->slice(0, req, 1)->view_as({}, {}), prob_buf->slice(0, req, 1)->view_as({dvoc}, {1}), - random_val, topp[req], topk[req], temperature[req]); + random_val, topp[req], topk[req], temperature[req], + rep_penalty_val, + prev_tokens, prev_tokens_len); token_offset += seq_len; } RUN_INFINI(infinirtStreamSynchronize(stream)); @@ -327,6 +340,9 @@ void inferDeviceBatchPaged(const JiugeMeta &meta, JiugeDeviceResource &rsrc, const int32_t *block_tables, const int32_t *slot_mapping, const float *temperature, const uint32_t *topk, const float *topp, + const float *repetition_penalty, + const uint32_t *const *previous_tokens_per_req, + const uint32_t *previous_tokens_len_per_req, const uint32_t is_prefill, const bool enable_paged_attn, uint32_t *output, void *last_logits) { auto nlayer = meta.nlayer; @@ -425,7 +441,7 @@ void inferDeviceBatchPaged(const JiugeMeta &meta, JiugeDeviceResource &rsrc, auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); auto attn_val_gemm = attn_val_buf->view({nkvh, ngroup, max_seq_len, dh}); - + // MLP buffers auto gate_buf = gate_up_buf->slice(1, 0, di); auto up_buf = gate_up_buf->slice(1, di, di); @@ -451,7 +467,7 @@ void inferDeviceBatchPaged(const JiugeMeta &meta, JiugeDeviceResource &rsrc, auto k = qkv_rope->slice({ {0, 0, ntok}, {1, nh, nkvh} }); auto v = qkv_rope->slice({ {0, 0, ntok}, {1, nh + nkvh, nkvh} }); - auto k_cache_pool = kv_caches[0]->k[idev][layer]; + auto k_cache_pool = kv_caches[0]->k[idev][layer]; auto v_cache_pool = kv_caches[0]->v[idev][layer]; pagedCaching(k, v, k_cache_pool, v_cache_pool, slot_mapping_buf); @@ -481,10 +497,10 @@ void inferDeviceBatchPaged(const JiugeMeta &meta, JiugeDeviceResource &rsrc, auto o = o_buf->slice({{0, 0, ntok}})->view({ntok, nh, dh}); auto q_batch = qkv_rope->slice({ {0, 0, ntok}, {1, 0, nh} })->view({ntok, nh, dh}); float scale = 1.f / float(sqrt(dh)); - pagedAttention(o, q_batch, k_cache_pool, v_cache_pool, + pagedAttention(o, q_batch, k_cache_pool, v_cache_pool, block_tables_buf, seq_lens_buf, nullptr /* alibi_slopes */, scale); - - + + } } else { @@ -550,10 +566,10 @@ void inferDeviceBatchPaged(const JiugeMeta &meta, JiugeDeviceResource &rsrc, rmsnorm(logits_out, logits_in, rsrc.w_out_norm, meta.epsilon); auto last_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); linear(last_logits_buf, logits_out, rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr); - + auto log_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); logSoftmax(log_logits_buf, last_logits_buf); - + RUN_INFINI(infinirtStreamSynchronize(stream)); RUN_INFINI(infinirtMemcpy(last_logits, log_logits_buf->data(), dsize(dt_logits) * ntok * dvoc, INFINIRT_MEMCPY_D2H)); } @@ -574,9 +590,19 @@ void inferDeviceBatchPaged(const JiugeMeta &meta, JiugeDeviceResource &rsrc, for (uint32_t req = 0; req < nreq; req++) { auto seq_len = req_lens[req]; float random_val = std::uniform_real_distribution(0, 1)(gen); + float rep_penalty_val = (repetition_penalty != nullptr) ? repetition_penalty[req] : 1.0f; + // Get unique tokens for this request (vLLM-style) + const uint32_t *prev_tokens = nullptr; + size_t prev_tokens_len = 0; + if (previous_tokens_per_req != nullptr && previous_tokens_per_req[req] != nullptr) { + prev_tokens = previous_tokens_per_req[req]; + prev_tokens_len = (previous_tokens_len_per_req != nullptr) ? previous_tokens_len_per_req[req] : 0; + } randomSample(result_buf->slice(0, req, 1)->view_as({}, {}), prob_buf->slice(0, req, 1)->view_as({dvoc}, {1}), - random_val, topp[req], topk[req], temperature[req]); + random_val, topp[req], topk[req], temperature[req], + rep_penalty_val, + prev_tokens, prev_tokens_len); token_offset += seq_len; } RUN_INFINI(infinirtStreamSynchronize(stream)); @@ -595,6 +621,9 @@ inferBatchJiuge(struct JiugeModel *model, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct KVCache **kv_caches, const float *temperature, const uint32_t *topk, const float *topp, + const float *repetition_penalty, + const uint32_t *const *previous_tokens_per_req, + const uint32_t *previous_tokens_len_per_req, uint32_t *output) { model->req.tokens = tokens; model->req.ntok = ntok; @@ -607,6 +636,9 @@ inferBatchJiuge(struct JiugeModel *model, model->req.temperature = temperature; model->req.topk = topk; model->req.topp = topp; + model->req.repetition_penalty = repetition_penalty; + model->req.previous_tokens_per_req = previous_tokens_per_req; + model->req.previous_tokens_len_per_req = previous_tokens_len_per_req; for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { std::unique_lock lock(model->states[idev].mtx); @@ -639,6 +671,7 @@ forwardBatchJiuge(struct JiugeModel *model, model->req.temperature = nullptr; model->req.topk = nullptr; model->req.topp = nullptr; + model->req.repetition_penalty = nullptr; for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { std::unique_lock lock(model->states[idev].mtx); @@ -658,10 +691,13 @@ __C void inferBatch(struct JiugeModel *model, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, - struct KVCache **kv_caches, + struct KVCache **kv_caches, const int32_t *block_tables, const int32_t *slot_mapping, const float *temperature, const uint32_t *topk, const float *topp, + const float *repetition_penalty, + const uint32_t *const *previous_tokens_per_req, + const uint32_t *previous_tokens_len_per_req, const uint32_t is_prefill, const bool enable_paged_attn, uint32_t *output) { model->req.tokens = tokens; @@ -677,6 +713,9 @@ inferBatch(struct JiugeModel *model, model->req.temperature = temperature; model->req.topk = topk; model->req.topp = topp; + model->req.repetition_penalty = repetition_penalty; + model->req.previous_tokens_per_req = previous_tokens_per_req; + model->req.previous_tokens_len_per_req = previous_tokens_len_per_req; model->req.is_prefill = is_prefill; model->req.enable_paged_attn = enable_paged_attn; @@ -716,6 +755,7 @@ forwardBatch(struct JiugeModel *model, model->req.temperature = nullptr; model->req.topk = nullptr; model->req.topp = nullptr; + model->req.repetition_penalty = nullptr; model->req.is_prefill = is_prefill; model->req.enable_paged_attn = enable_paged_attn; @@ -764,15 +804,18 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, JiugeDevic if (enable_paged){ inferDeviceBatchPaged(meta, *rsrc, idev, ndev, req.tokens, req.ntok, req.req_lens, req.nreq, req.req_pos, req.kv_caches, - req.block_tables, req.slot_mapping, - req.temperature, req.topk, req.topp, + req.block_tables, req.slot_mapping, + req.temperature, req.topk, req.topp, req.repetition_penalty, + req.previous_tokens_per_req, req.previous_tokens_len_per_req, req.is_prefill, req.enable_paged_attn, req.output, req.logits); } else{ inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok, req.req_lens, req.nreq, req.req_pos, req.kv_caches, - req.temperature, req.topk, req.topp, req.output, req.logits); + req.temperature, req.topk, req.topp, req.repetition_penalty, + req.previous_tokens_per_req, req.previous_tokens_len_per_req, + req.output, req.logits); } @@ -800,7 +843,7 @@ JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infi } for (int i = 0; i < ndev; i++) { - + threads[i] = std::thread(launchDevice, std::cref(meta), weights, &dev_resources[i], std::ref(states[i]), std::ref(req), device, i, ndev, dev_ids[i], comms[i]); } for (int i = 0; i < ndev; i++) { diff --git a/src/models/jiuge/jiuge_impl.hpp b/src/models/jiuge/jiuge_impl.hpp index 5b1fd2f8..022e34da 100644 --- a/src/models/jiuge/jiuge_impl.hpp +++ b/src/models/jiuge/jiuge_impl.hpp @@ -50,6 +50,9 @@ struct InferRequest { const float *temperature; const uint32_t *topk; const float *topp; + const float *repetition_penalty; + const uint32_t *const *previous_tokens_per_req; // Array of pointers to unique tokens per request (vLLM-style) + const uint32_t *previous_tokens_len_per_req; // Array of lengths per request uint32_t *output; uint32_t is_prefill; bool enable_paged_attn; From 85077bddbd61ac069c072a06f6954d777ec83870 Mon Sep 17 00:00:00 2001 From: Ceng23333 <441651826@qq.com> Date: Sun, 4 Jan 2026 15:37:02 +0800 Subject: [PATCH 02/12] fix scaling Signed-off-by: Ceng23333 <441651826@qq.com> --- include/infinicore_infer/models/jiuge.h | 2 +- python/icinfer/engine/libinfinicore_infer.py | 1 + python/icinfer/models/jiuge.py | 8 ++++++-- python/icinfer/utils/jiuge_weights_loader.py | 6 +++++- scripts/jiuge.py | 8 ++++++-- scripts/libinfinicore_infer/jiuge.py | 1 + src/models/jiuge/jiuge.cpp | 18 ++++++++++++++---- 7 files changed, 34 insertions(+), 10 deletions(-) diff --git a/include/infinicore_infer/models/jiuge.h b/include/infinicore_infer/models/jiuge.h index 26a970e3..4df1d3b3 100644 --- a/include/infinicore_infer/models/jiuge.h +++ b/include/infinicore_infer/models/jiuge.h @@ -12,7 +12,7 @@ struct JiugeModel; typedef struct { infiniDtype_t dt_logits; - size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc, kvcache_block_size; + size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc, kvcache_block_size, dim_model_base; float epsilon, theta; uint32_t end_token; } JiugeMeta; diff --git a/python/icinfer/engine/libinfinicore_infer.py b/python/icinfer/engine/libinfinicore_infer.py index 75f6e025..d7e4dca7 100644 --- a/python/icinfer/engine/libinfinicore_infer.py +++ b/python/icinfer/engine/libinfinicore_infer.py @@ -48,6 +48,7 @@ class JiugeMetaCStruct(ctypes.Structure): ("dctx", c_size_t), ("dvoc", c_size_t), ("kvcache_block_size", c_size_t), + ("dim_model_base", c_size_t), ("epsilon", c_float), ("theta", c_float), ("end_token", c_uint), diff --git a/python/icinfer/models/jiuge.py b/python/icinfer/models/jiuge.py index 34c7f60f..277a2db3 100644 --- a/python/icinfer/models/jiuge.py +++ b/python/icinfer/models/jiuge.py @@ -122,6 +122,9 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): config["num_hidden_layers"] ) + dim_model_base = ( + config["dim_model_base"] if "dim_model_base" in config else config["hidden_size"] + ) super().__init__( dt_logits=dt_, nlayer=config["num_hidden_layers"], @@ -138,7 +141,8 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): config["max_position_embeddings"] if max_tokens is None else max_tokens ), dvoc=config["vocab_size"], - block_size=config["block_size"], + kvcache_block_size=config["block_size"], + dim_model_base=dim_model_base, epsilon=config["rms_norm_eps"], theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), end_token=2, @@ -206,7 +210,7 @@ def __init__( ) self.input_embd = self.input_embd_tensor.data_ptr() self.output_norm_tensor = ( - state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output + state_dict[naming.output_norm()].to(torch_dt_norm) ) self.output_norm = self.output_norm_tensor.data_ptr() self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) diff --git a/python/icinfer/utils/jiuge_weights_loader.py b/python/icinfer/utils/jiuge_weights_loader.py index 7cff7bf2..c872a3ca 100644 --- a/python/icinfer/utils/jiuge_weights_loader.py +++ b/python/icinfer/utils/jiuge_weights_loader.py @@ -119,6 +119,9 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): self.scale_o = config.scale_depth / math.sqrt(config.num_hidden_layers) self.scale_down = config.scale_depth / math.sqrt(config.num_hidden_layers) + dim_model_base = ( + config.dim_model_base if hasattr(config, "dim_model_base") else config.hidden_size + ) super().__init__( dt_logits=dt_, nlayer=config.num_hidden_layers, @@ -134,6 +137,7 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): dctx=(config.max_position_embeddings if max_tokens is None else max_tokens), dvoc=config.vocab_size, kvcache_block_size=config.kvcache_block_size, + dim_model_base=dim_model_base, epsilon=config.rms_norm_eps, theta=(config.rope_theta if hasattr(config, "rope_theta") else 100000.0), end_token=2, @@ -201,7 +205,7 @@ def __init__( ) self.input_embd = self.input_embd_tensor.data_ptr() self.output_norm_tensor = ( - state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output + state_dict[naming.output_norm()].to(torch_dt_norm) ) self.output_norm = self.output_norm_tensor.data_ptr() self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) diff --git a/scripts/jiuge.py b/scripts/jiuge.py index 6b7e4530..2cf3c433 100644 --- a/scripts/jiuge.py +++ b/scripts/jiuge.py @@ -114,6 +114,9 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): config["num_hidden_layers"] ) + dim_model_base = ( + config["dim_model_base"] if "dim_model_base" in config else config["hidden_size"] + ) super().__init__( dt_logits=dt_, nlayer=config["num_hidden_layers"], @@ -130,7 +133,8 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): config["max_position_embeddings"] if max_tokens is None else max_tokens ), dvoc=config["vocab_size"], - kvcache_block_size=0, + kvcache_block_size=config.get("block_size", 0), + dim_model_base=dim_model_base, epsilon=config["rms_norm_eps"], theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), end_token=2, @@ -198,7 +202,7 @@ def __init__( ) self.input_embd = self.input_embd_tensor.data_ptr() self.output_norm_tensor = ( - state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output + state_dict[naming.output_norm()].to(torch_dt_norm) ) self.output_norm = self.output_norm_tensor.data_ptr() self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) diff --git a/scripts/libinfinicore_infer/jiuge.py b/scripts/libinfinicore_infer/jiuge.py index 6338f2f9..ec01b610 100644 --- a/scripts/libinfinicore_infer/jiuge.py +++ b/scripts/libinfinicore_infer/jiuge.py @@ -14,6 +14,7 @@ class JiugeMetaCStruct(Structure): ("dctx", c_size_t), ("dvoc", c_size_t), ("kvcache_block_size", c_size_t), + ("dim_model_base", c_size_t), ("epsilon", c_float), ("theta", c_float), ("end_token", c_uint), diff --git a/src/models/jiuge/jiuge.cpp b/src/models/jiuge/jiuge.cpp index 9c5f23ab..e4764030 100644 --- a/src/models/jiuge/jiuge.cpp +++ b/src/models/jiuge/jiuge.cpp @@ -279,10 +279,15 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, } // Sample and Output if (idev == 0) { + // Calculate output scale: if dim_model_base != d, scale by dim_model_base/d, else no scaling + float output_scale = 1.0f; + if (meta.dim_model_base > 0 && meta.dim_model_base != meta.d && meta.d > 0) { + output_scale = meta.dim_model_base / float(meta.d); + } if (last_logits != nullptr) { rmsnorm(logits_out, logits_in, rsrc.w_out_norm, meta.epsilon); auto last_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); - linear(last_logits_buf, logits_out, rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr); + linear(last_logits_buf, logits_out, rsrc.w_out_embd, output_scale, 0.0, nullptr, nullptr); auto log_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); logSoftmax(log_logits_buf, last_logits_buf); @@ -300,7 +305,7 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, rsrc.w_out_norm, meta.epsilon); } - linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr); + linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, output_scale, 0.0, nullptr, nullptr); std::random_device _rd; std::mt19937 gen(_rd()); token_offset = 0; @@ -562,10 +567,15 @@ void inferDeviceBatchPaged(const JiugeMeta &meta, JiugeDeviceResource &rsrc, // Sample and Output if (idev == 0) { + // Calculate output scale: if dim_model_base != d, scale by dim_model_base/d, else no scaling + float output_scale = 1.0f; + if (meta.dim_model_base > 0 && meta.dim_model_base != meta.d && meta.d > 0) { + output_scale = meta.dim_model_base / float(meta.d); + } if (last_logits != nullptr) { rmsnorm(logits_out, logits_in, rsrc.w_out_norm, meta.epsilon); auto last_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); - linear(last_logits_buf, logits_out, rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr); + linear(last_logits_buf, logits_out, rsrc.w_out_embd, output_scale, 0.0, nullptr, nullptr); auto log_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); logSoftmax(log_logits_buf, last_logits_buf); @@ -583,7 +593,7 @@ void inferDeviceBatchPaged(const JiugeMeta &meta, JiugeDeviceResource &rsrc, rsrc.w_out_norm, meta.epsilon); } - linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr); + linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, output_scale, 0.0, nullptr, nullptr); std::random_device _rd; std::mt19937 gen(_rd()); token_offset = 0; From 9e26503c8ba8e63775928c919fc07e655e20a905 Mon Sep 17 00:00:00 2001 From: Ceng23333 <441651826@qq.com> Date: Mon, 5 Jan 2026 11:29:57 +0800 Subject: [PATCH 03/12] support longrope fractor Signed-off-by: Ceng23333 <441651826@qq.com> --- include/infinicore_infer/models/jiuge.h | 5 ++ python/icinfer/engine/libinfinicore_infer.py | 4 ++ python/icinfer/models/jiuge.py | 41 ++++++++++++++++ python/icinfer/utils/jiuge_weights_loader.py | 50 ++++++++++++++++++++ scripts/jiuge.py | 41 ++++++++++++++++ scripts/libinfinicore_infer/jiuge.py | 4 ++ src/models/jiuge/jiuge_weight.hpp | 48 +++++++++++++++++-- 7 files changed, 189 insertions(+), 4 deletions(-) diff --git a/include/infinicore_infer/models/jiuge.h b/include/infinicore_infer/models/jiuge.h index 4df1d3b3..b7bd3363 100644 --- a/include/infinicore_infer/models/jiuge.h +++ b/include/infinicore_infer/models/jiuge.h @@ -15,6 +15,11 @@ typedef struct size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc, kvcache_block_size, dim_model_base; float epsilon, theta; uint32_t end_token; + // Longrope support + uint32_t rope_type; // 0 = standard, 1 = longrope + size_t original_max_position_embeddings; + const float *short_factor; // Array of dh/2 floats, nullptr if not longrope + const float *long_factor; // Array of dh/2 floats, nullptr if not longrope } JiugeMeta; typedef struct diff --git a/python/icinfer/engine/libinfinicore_infer.py b/python/icinfer/engine/libinfinicore_infer.py index d7e4dca7..2e44e198 100644 --- a/python/icinfer/engine/libinfinicore_infer.py +++ b/python/icinfer/engine/libinfinicore_infer.py @@ -52,6 +52,10 @@ class JiugeMetaCStruct(ctypes.Structure): ("epsilon", c_float), ("theta", c_float), ("end_token", c_uint), + ("rope_type", c_uint), + ("original_max_position_embeddings", c_size_t), + ("short_factor", POINTER(c_float)), + ("long_factor", POINTER(c_float)), ] diff --git a/python/icinfer/models/jiuge.py b/python/icinfer/models/jiuge.py index 277a2db3..e132c21c 100644 --- a/python/icinfer/models/jiuge.py +++ b/python/icinfer/models/jiuge.py @@ -3,6 +3,7 @@ from sympy import true from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref +import ctypes import os from pathlib import Path import safetensors @@ -125,6 +126,42 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): dim_model_base = ( config["dim_model_base"] if "dim_model_base" in config else config["hidden_size"] ) + + # Load longrope configuration + rope_type = 0 # 0 = standard, 1 = longrope + original_max_position_embeddings = 0 + short_factor_ptr = None + long_factor_ptr = None + self._short_factor_array = None # Keep reference to prevent GC + self._long_factor_array = None # Keep reference to prevent GC + + rope_scaling = config.get("rope_scaling", {}) + if isinstance(rope_scaling, dict): + rope_scaling_type = rope_scaling.get("rope_type") or rope_scaling.get("type", "") + if rope_scaling_type == "longrope": + rope_type = 1 + original_max_position_embeddings = rope_scaling.get( + "original_max_position_embeddings", + config.get("original_max_position_embeddings", 0) + ) + + short_factor_list = rope_scaling.get("short_factor", []) + long_factor_list = rope_scaling.get("long_factor", []) + + if short_factor_list and long_factor_list: + # Convert to ctypes arrays + half_dh = (config["hidden_size"] // config["num_attention_heads"]) // 2 + if len(short_factor_list) == half_dh and len(long_factor_list) == half_dh: + self._short_factor_array = (c_float * half_dh)(*short_factor_list) + self._long_factor_array = (c_float * half_dh)(*long_factor_list) + short_factor_ptr = ctypes.cast(self._short_factor_array, POINTER(c_float)) + long_factor_ptr = ctypes.cast(self._long_factor_array, POINTER(c_float)) + else: + logger.warning( + f"Longrope factor arrays have wrong length: " + f"short={len(short_factor_list)}, long={len(long_factor_list)}, expected={half_dh}" + ) + super().__init__( dt_logits=dt_, nlayer=config["num_hidden_layers"], @@ -146,6 +183,10 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): epsilon=config["rms_norm_eps"], theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), end_token=2, + rope_type=rope_type, + original_max_position_embeddings=original_max_position_embeddings, + short_factor=short_factor_ptr, + long_factor=long_factor_ptr, ) self.torch_dtype_logits = dtype diff --git a/python/icinfer/utils/jiuge_weights_loader.py b/python/icinfer/utils/jiuge_weights_loader.py index c872a3ca..0de730fb 100644 --- a/python/icinfer/utils/jiuge_weights_loader.py +++ b/python/icinfer/utils/jiuge_weights_loader.py @@ -7,6 +7,7 @@ from typing import Tuple import math from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref +import ctypes import os from pathlib import Path import safetensors @@ -122,6 +123,51 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): dim_model_base = ( config.dim_model_base if hasattr(config, "dim_model_base") else config.hidden_size ) + + # Load longrope configuration + rope_type = 0 # 0 = standard, 1 = longrope + original_max_position_embeddings = 0 + short_factor_ptr = None + long_factor_ptr = None + self._short_factor_array = None # Keep reference to prevent GC + self._long_factor_array = None # Keep reference to prevent GC + + # Handle both dict and object config + if hasattr(config, "rope_scaling"): + rope_scaling = config.rope_scaling + elif isinstance(config, dict) and "rope_scaling" in config: + rope_scaling = config["rope_scaling"] + else: + rope_scaling = {} + + if isinstance(rope_scaling, dict): + rope_scaling_type = rope_scaling.get("rope_type") or rope_scaling.get("type", "") + if rope_scaling_type == "longrope": + rope_type = 1 + original_max_position_embeddings = rope_scaling.get( + "original_max_position_embeddings", + getattr(config, "original_max_position_embeddings", 0) if not isinstance(config, dict) else config.get("original_max_position_embeddings", 0) + ) + + short_factor_list = rope_scaling.get("short_factor", []) + long_factor_list = rope_scaling.get("long_factor", []) + + if short_factor_list and long_factor_list: + # Convert to ctypes arrays + half_dh = (config.hidden_size // config.num_attention_heads) // 2 + if len(short_factor_list) == half_dh and len(long_factor_list) == half_dh: + self._short_factor_array = (c_float * half_dh)(*short_factor_list) + self._long_factor_array = (c_float * half_dh)(*long_factor_list) + short_factor_ptr = ctypes.cast(self._short_factor_array, POINTER(c_float)) + long_factor_ptr = ctypes.cast(self._long_factor_array, POINTER(c_float)) + else: + import logging + logger = logging.getLogger(__name__) + logger.warning( + f"Longrope factor arrays have wrong length: " + f"short={len(short_factor_list)}, long={len(long_factor_list)}, expected={half_dh}" + ) + super().__init__( dt_logits=dt_, nlayer=config.num_hidden_layers, @@ -141,6 +187,10 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): epsilon=config.rms_norm_eps, theta=(config.rope_theta if hasattr(config, "rope_theta") else 100000.0), end_token=2, + rope_type=rope_type, + original_max_position_embeddings=original_max_position_embeddings, + short_factor=short_factor_ptr, + long_factor=long_factor_ptr, ) self.torch_dtype_logits = dtype diff --git a/scripts/jiuge.py b/scripts/jiuge.py index 2cf3c433..2dcda072 100644 --- a/scripts/jiuge.py +++ b/scripts/jiuge.py @@ -117,6 +117,43 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): dim_model_base = ( config["dim_model_base"] if "dim_model_base" in config else config["hidden_size"] ) + + # Load longrope configuration + rope_type = 0 # 0 = standard, 1 = longrope + original_max_position_embeddings = 0 + short_factor_ptr = None + long_factor_ptr = None + self._short_factor_array = None # Keep reference to prevent GC + self._long_factor_array = None # Keep reference to prevent GC + + rope_scaling = config.get("rope_scaling", {}) + if isinstance(rope_scaling, dict): + rope_scaling_type = rope_scaling.get("rope_type") or rope_scaling.get("type", "") + if rope_scaling_type == "longrope": + rope_type = 1 + original_max_position_embeddings = rope_scaling.get( + "original_max_position_embeddings", + config.get("original_max_position_embeddings", 0) + ) + + short_factor_list = rope_scaling.get("short_factor", []) + long_factor_list = rope_scaling.get("long_factor", []) + + if short_factor_list and long_factor_list: + # Convert to ctypes arrays + dh = config["head_dim"] if "head_dim" in config else config["hidden_size"] // config["num_attention_heads"] + half_dh = dh // 2 + if len(short_factor_list) == half_dh and len(long_factor_list) == half_dh: + self._short_factor_array = (c_float * half_dh)(*short_factor_list) + self._long_factor_array = (c_float * half_dh)(*long_factor_list) + short_factor_ptr = ctypes.cast(self._short_factor_array, POINTER(c_float)) + long_factor_ptr = ctypes.cast(self._long_factor_array, POINTER(c_float)) + else: + print( + f"Warning: Longrope factor arrays have wrong length: " + f"short={len(short_factor_list)}, long={len(long_factor_list)}, expected={half_dh}" + ) + super().__init__( dt_logits=dt_, nlayer=config["num_hidden_layers"], @@ -138,6 +175,10 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): epsilon=config["rms_norm_eps"], theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), end_token=2, + rope_type=rope_type, + original_max_position_embeddings=original_max_position_embeddings, + short_factor=short_factor_ptr, + long_factor=long_factor_ptr, ) self.torch_dtype_logits = dtype diff --git a/scripts/libinfinicore_infer/jiuge.py b/scripts/libinfinicore_infer/jiuge.py index ec01b610..38cf536d 100644 --- a/scripts/libinfinicore_infer/jiuge.py +++ b/scripts/libinfinicore_infer/jiuge.py @@ -18,6 +18,10 @@ class JiugeMetaCStruct(Structure): ("epsilon", c_float), ("theta", c_float), ("end_token", c_uint), + ("rope_type", c_uint), + ("original_max_position_embeddings", c_size_t), + ("short_factor", POINTER(c_float)), + ("long_factor", POINTER(c_float)), ] diff --git a/src/models/jiuge/jiuge_weight.hpp b/src/models/jiuge/jiuge_weight.hpp index 7ee10155..b36a513b 100644 --- a/src/models/jiuge/jiuge_weight.hpp +++ b/src/models/jiuge/jiuge_weight.hpp @@ -152,10 +152,30 @@ inline std::shared_ptr getSinTable(JiugeMeta const *meta) { auto unit = dsize(meta->dt_logits); void *table = std::malloc(meta->dctx * half_dh * unit); + bool is_longrope = (meta->rope_type == 1) && (meta->short_factor != nullptr) && (meta->long_factor != nullptr); + const float *factors = nullptr; + if (is_longrope) { + // Use long_factor if max context exceeds original_max_position_embeddings, otherwise short_factor + // Since we generate a single table at model creation time, choose based on dctx + if (meta->original_max_position_embeddings > 0 && meta->dctx > meta->original_max_position_embeddings) { + factors = meta->long_factor; + } else { + factors = meta->short_factor; + } + } + for (size_t i = 0; i < meta->dctx; i++) { for (size_t j = 0; j < half_dh; j++) { - float _sin = std::sin( - static_cast(i) / std::pow(meta->theta, static_cast(j) / half_dh)); + float angle; + if (is_longrope && factors != nullptr) { + // Longrope: apply per-frequency scaling factor + float inv_freq = 1.0f / (factors[j] * std::pow(meta->theta, static_cast(j) / half_dh)); + angle = static_cast(i) * inv_freq; + } else { + // Standard RoPE + angle = static_cast(i) / std::pow(meta->theta, static_cast(j) / half_dh); + } + float _sin = std::sin(angle); if (meta->dt_logits == INFINI_DTYPE_F16) { ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_sin); } else if (meta->dt_logits == INFINI_DTYPE_BF16) { @@ -179,10 +199,30 @@ inline std::shared_ptr getCosTable(JiugeMeta const *meta) { auto unit = dsize(meta->dt_logits); void *table = std::malloc(meta->dctx * half_dh * unit); + bool is_longrope = (meta->rope_type == 1) && (meta->short_factor != nullptr) && (meta->long_factor != nullptr); + const float *factors = nullptr; + if (is_longrope) { + // Use long_factor if max context exceeds original_max_position_embeddings, otherwise short_factor + // Since we generate a single table at model creation time, choose based on dctx + if (meta->original_max_position_embeddings > 0 && meta->dctx > meta->original_max_position_embeddings) { + factors = meta->long_factor; + } else { + factors = meta->short_factor; + } + } + for (size_t i = 0; i < meta->dctx; i++) { for (size_t j = 0; j < half_dh; j++) { - float _cos = std::cos( - static_cast(i) / std::pow(meta->theta, static_cast(j) / half_dh)); + float angle; + if (is_longrope && factors != nullptr) { + // Longrope: apply per-frequency scaling factor + float inv_freq = 1.0f / (factors[j] * std::pow(meta->theta, static_cast(j) / half_dh)); + angle = static_cast(i) * inv_freq; + } else { + // Standard RoPE + angle = static_cast(i) / std::pow(meta->theta, static_cast(j) / half_dh); + } + float _cos = std::cos(angle); if (meta->dt_logits == INFINI_DTYPE_F16) { ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_cos); } else if (meta->dt_logits == INFINI_DTYPE_BF16) { From 9a37b3864ce325949c9ec720a4d84f696baa2837 Mon Sep 17 00:00:00 2001 From: Ceng23333 <441651826@qq.com> Date: Tue, 6 Jan 2026 16:37:24 +0800 Subject: [PATCH 04/12] add timeout checker Signed-off-by: Ceng23333 <441651826@qq.com> --- scripts/launch_server.py | 280 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 266 insertions(+), 14 deletions(-) diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 4b918b6d..71970f23 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -7,7 +7,7 @@ import argparse import queue -from fastapi import FastAPI, Request +from fastapi import FastAPI, Request, HTTPException from fastapi.responses import StreamingResponse, JSONResponse import contextlib import uvicorn @@ -16,6 +16,8 @@ import json import threading import janus +import os +import signal DEVICE_TYPE_MAP = { @@ -81,6 +83,12 @@ def parse_args(): default="0.0.0.0", help="Host to bind the server to (default: 0.0.0.0)", ) + parser.add_argument( + "--request-timeout", + type=int, + default=30, + help="Request timeout in seconds. Process will exit if a request hangs longer than this (default: 30)", + ) return parser.parse_args() @@ -93,9 +101,13 @@ def parse_args(): MAX_BATCH = args.max_batch SERVER_PORT = args.port SERVER_HOST = args.host +REQUEST_TIMEOUT = args.request_timeout print( f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs." ) +print( + f"Request timeout: {REQUEST_TIMEOUT}s. Process will exit if a request hangs longer than this." +) def chunk_json(id_, content=None, role=None, finish_reason=None): @@ -124,15 +136,29 @@ def chunk_json(id_, content=None, role=None, finish_reason=None): # A wrapper for InferTask that supports async output queue class AsyncInferTask(InferTask): - def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens, repetition_penalty=1.0): + def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens, repetition_penalty=1.0, test_hang_seconds=0): super().__init__(id, tokens, max_tokens, temperature, topk, topp, end_tokens, repetition_penalty) self.output_queue = janus.Queue() - print(f"[INFO] Create InferTask {self.id}") + self.last_activity_time = time.time() # Track when task was last active + self.test_hang_seconds = test_hang_seconds # Test parameter: sleep for this many seconds to simulate hang (set to 0 after first use) + self.timed_out = False # Flag to mark if task has timed out + print(f"[INFO] Create InferTask {self.id}" + (f" (TEST: will hang for {test_hang_seconds}s once)" if test_hang_seconds > 0 else "")) def output(self, out_token): self.next(out_token) + self.last_activity_time = time.time() # Update activity time when output is generated self.output_queue.sync_q.put(out_token) + def signal_timeout(self): + """Signal that this task has timed out""" + self.timed_out = True + self.finish_reason = "timeout" + + def signal_internal_error(self): + """Signal that an internal error occurred (process will be killed)""" + self.timed_out = True # Reuse timed_out flag to trigger error response + self.finish_reason = "internal_error" + @contextlib.asynccontextmanager async def lifespan(app: FastAPI): @@ -147,9 +173,15 @@ async def lifespan(app: FastAPI): ) app.state.kv_cache_pool = KVCachePool(app.state.model, MAX_BATCH) app.state.request_queue = janus.Queue() + app.state.active_tasks = {} # Track active tasks: task_id -> task object + app.state.task_lock = threading.Lock() # Lock for accessing active_tasks worker_thread = threading.Thread(target=worker_loop, args=(app,), daemon=True) worker_thread.start() + # Start timeout checker thread + timeout_checker_thread = threading.Thread(target=timeout_checker_loop, args=(app,), daemon=True) + timeout_checker_thread.start() + try: yield # The app runs here finally: @@ -165,6 +197,61 @@ async def lifespan(app: FastAPI): App = FastAPI(lifespan=lifespan) +# Timeout checker: monitors active tasks and kills process if any task hangs +def timeout_checker_loop(app): + """Monitor active tasks and kill the process if any task hangs beyond timeout""" + while True: + try: + time.sleep(5) # Check every 5 seconds + + current_time = time.time() + hung_tasks = [] + + with app.state.task_lock: + # Check all active tasks for timeout + for task_id, task in list(app.state.active_tasks.items()): + time_since_activity = current_time - task.last_activity_time + if time_since_activity > REQUEST_TIMEOUT: + hung_tasks.append((task_id, time_since_activity)) + + # If we found hung tasks, signal all active tasks and then kill the process + if hung_tasks: + print(f"[ERROR] Detected {len(hung_tasks)} hung task(s) exceeding timeout of {REQUEST_TIMEOUT}s:") + for task_id, hang_time in hung_tasks: + print(f" - Task {task_id}: hung for {hang_time:.1f}s") + + # Signal all active tasks (not just hung ones) to send error responses to clients + # This ensures all processing requests get error responses before process is killed + with app.state.task_lock: + all_active_tasks = list(app.state.active_tasks.items()) + print(f"[ERROR] Signaling {len(all_active_tasks)} active task(s) to send error responses...") + for task_id, task in all_active_tasks: + if task_id in [tid for tid, _ in hung_tasks]: + # Hung tasks get timeout error + task.signal_timeout() + print(f"[ERROR] Signaled timeout to hung task {task_id}") + else: + # Other active tasks get internal error (process will be killed) + task.signal_internal_error() + print(f"[ERROR] Signaled internal error to active task {task_id}") + + # Give a short time for error responses to be sent to clients + print(f"[ERROR] Waiting 2 seconds for error responses to be sent to clients...") + time.sleep(2) + + print(f"[ERROR] Killing process to trigger recovery mechanism...") + # Kill the process - this will be detected by the babysitter and trigger restart + os.kill(os.getpid(), signal.SIGTERM) + # If SIGTERM doesn't work, use SIGKILL as fallback after a delay + time.sleep(2) + os.kill(os.getpid(), signal.SIGKILL) + break + + except Exception as e: + print(f"[ERROR] Exception in timeout checker: {e}") + time.sleep(5) + + # App loop: take requests from the queue, do inference, and put unfinished requests back into the queue. def worker_loop(app): while True: @@ -176,22 +263,65 @@ def worker_loop(app): if task is None: return + # Register task as active + with app.state.task_lock: + app.state.active_tasks[task.id] = task + task.last_activity_time = time.time() + batch = [task] while len(batch) < MAX_BATCH: try: req = app.state.request_queue.sync_q.get_nowait() if req is not None: batch.append(req) + # Register additional tasks as active + with app.state.task_lock: + app.state.active_tasks[req.id] = req + req.last_activity_time = time.time() except queue.Empty: break + + # Update activity time before inference + batch_start_time = time.time() + with app.state.task_lock: + for t in batch: + t.last_activity_time = batch_start_time + + # Test hang simulation: if any task has test_hang_seconds > 0, sleep to simulate hang + # Only apply once per task by setting test_hang_seconds to 0 after use + tasks_needing_hang = [t for t in batch if t.test_hang_seconds > 0] + if tasks_needing_hang: + max_hang_time = max(t.test_hang_seconds for t in tasks_needing_hang) + print(f"[TEST] Simulating hang for {max_hang_time}s (task will exceed timeout if timeout < {max_hang_time}s)") + time.sleep(max_hang_time) + print(f"[TEST] Hang simulation complete, continuing with inference...") + # Reset test_hang_seconds to 0 for all tasks that used it (so it won't hang again) + for t in tasks_needing_hang: + t.test_hang_seconds = 0 + output_tokens = app.state.model.batch_infer_one_round(batch) + + # Update activity time after inference (critical: if batch_infer_one_round hangs, + # this won't execute, and timeout checker will detect it) + batch_end_time = time.time() + with app.state.task_lock: + for task, token in zip(batch, output_tokens): + task.last_activity_time = batch_end_time + task.output(token) + if task.finish_reason is None: + # Task continues, keep it tracked but update activity time + # It will be put back in queue and processed again + pass + else: + print(f"[INFO] Task {task.id} finished infer.") + app.state.kv_cache_pool.release_sync(task) + # Remove task from active tracking when finished + app.state.active_tasks.pop(task.id, None) + + # Put unfinished tasks back in queue (outside lock to avoid deadlock) for task, token in zip(batch, output_tokens): - task.output(token) if task.finish_reason is None: app.state.request_queue.sync_q.put(task) - else: - print(f"[INFO] Task {task.id} finished infer.") - app.state.kv_cache_pool.release_sync(task) def build_task(id_, request_data, request: Request): @@ -217,6 +347,12 @@ def build_task(id_, request_data, request: Request): tokens = request.app.state.model.tokenizer.encode(prompt) max_tokens = request_data.get("max_tokens", 0) + # Test parameter: test_hang_seconds - sleep for this many seconds to simulate hang + # This is useful for testing the timeout checker mechanism. + # Example: Set "test_hang_seconds": 350 in request to test timeout (if timeout is 300s) + # The sleep happens in the worker loop before batch_infer_one_round, simulating a hang + test_hang_seconds = request_data.get("test_hang_seconds", 0) + return AsyncInferTask( id_, tokens, @@ -226,14 +362,31 @@ def build_task(id_, request_data, request: Request): request_data.get("top_p", 1.0), request.app.state.model.eos_token_id, request_data.get("repetition_penalty", 1.0), + test_hang_seconds=test_hang_seconds, ) async def chat_stream(id_, request_data, request: Request): try: infer_task = build_task(id_, request_data, request) + # Track task from creation + with request.app.state.task_lock: + request.app.state.active_tasks[infer_task.id] = infer_task + infer_task.last_activity_time = time.time() + await request.app.state.kv_cache_pool.acquire(infer_task) + # Check if task already timed out before starting stream + if infer_task.timed_out: + raise HTTPException( + status_code=504, + detail={ + "message": f"Request timeout: task exceeded {REQUEST_TIMEOUT}s timeout", + "type": "timeout_error", + "code": "timeout" + } + ) + # Initial empty content chunk = json.dumps( chunk_json(id_, content="", role="assistant"), ensure_ascii=False @@ -250,29 +403,75 @@ async def chat_stream(id_, request_data, request: Request): infer_task.finish_reason is not None and infer_task.output_queue.async_q.empty() ): - chunk = json.dumps( - chunk_json(id_, finish_reason=infer_task.finish_reason), - ensure_ascii=False, - ) - yield f"data: {chunk}\n\n" + # Check if timed out or internal error - raise HTTPException for proper error code + if infer_task.timed_out: + # Both timeout and internal_error result in internal error response + # because the process will be killed and restarted + raise HTTPException( + status_code=500, + detail={ + "message": "Internal server error: process will be restarted", + "type": "internal_error", + "code": "internal_error" + } + ) + else: + chunk = json.dumps( + chunk_json(id_, finish_reason=infer_task.finish_reason), + ensure_ascii=False, + ) + yield f"data: {chunk}\n\n" break + # Check for timeout or internal error before getting next token + if infer_task.timed_out: + # Both timeout and internal_error result in internal error response + # because the process will be killed and restarted + raise HTTPException( + status_code=500, + detail={ + "message": "Internal server error: process will be restarted", + "type": "internal_error", + "code": "internal_error" + } + ) + token = await infer_task.output_queue.async_q.get() content = request.app.state.model.tokenizer.decode(token) chunk = json.dumps(chunk_json(id_, content=content), ensure_ascii=False) yield f"data: {chunk}\n\n" + except HTTPException: + # Re-raise HTTPException to propagate error status code + raise except Exception as e: print(f"[Error] ID : {id_} Exception: {e}") + raise HTTPException( + status_code=500, + detail={ + "message": str(e), + "type": "internal_error", + "code": "internal_error" + } + ) finally: - if infer_task.finish_reason is None: - infer_task.finish_reason = "cancel" + if infer_task: + if infer_task.finish_reason is None: + infer_task.finish_reason = "cancel" + # Clean up task from active tracking + with request.app.state.task_lock: + request.app.state.active_tasks.pop(infer_task.id, None) async def chat(id_, request_data, request: Request): try: infer_task = build_task(id_, request_data, request) + # Track task from creation + with request.app.state.task_lock: + request.app.state.active_tasks[infer_task.id] = infer_task + infer_task.last_activity_time = time.time() + await request.app.state.kv_cache_pool.acquire(infer_task) request.app.state.request_queue.sync_q.put(infer_task) output = [] @@ -283,10 +482,40 @@ async def chat(id_, request_data, request: Request): ): break + # Check for timeout or internal error before getting next token + if infer_task.timed_out: + # Both timeout and internal_error result in internal error response + # because the process will be killed and restarted + return JSONResponse( + content={ + "error": { + "message": "Internal server error: process will be restarted", + "type": "internal_error", + "code": "internal_error" + } + }, + status_code=500 # Internal Server Error + ) + token = await infer_task.output_queue.async_q.get() content = request.app.state.model.tokenizer.decode(token) output.append(content) + # Check if timed out or internal error before returning response + if infer_task.timed_out: + # Both timeout and internal_error result in internal error response + # because the process will be killed and restarted + return JSONResponse( + content={ + "error": { + "message": "Internal server error: process will be restarted", + "type": "internal_error", + "code": "internal_error" + } + }, + status_code=500 # Internal Server Error + ) + output_text = "".join(output).strip() response = chunk_json( id_, @@ -302,6 +531,9 @@ async def chat(id_, request_data, request: Request): finally: if infer_task.finish_reason is None: infer_task.finish_reason = "cancel" + # Clean up task from active tracking + with request.app.state.task_lock: + request.app.state.active_tasks.pop(infer_task.id, None) @App.post("/chat/completions") @@ -320,11 +552,15 @@ async def chat_completions(request: Request): stream = data.get("stream", False) id_ = f"cmpl-{uuid.uuid4().hex}" if stream: + # FastAPI's exception handler will catch HTTPException raised from the generator return StreamingResponse( chat_stream(id_, data, request), media_type="text/event-stream" ) else: response = await chat(id_, data, request) + # If response is already a JSONResponse (error case), return it directly + if isinstance(response, JSONResponse): + return response return JSONResponse(content=response) @@ -543,5 +779,21 @@ async def list_models(request: Request): "chat_template_kwargs": {"enable_thinking": false} }' +# Test timeout checker: simulate a hang that exceeds the timeout +# This will cause the process to be killed by the timeout checker +# (assuming --request-timeout is set to a value less than test_hang_seconds) +# Example: if --request-timeout=300, use test_hang_seconds=350 to trigger timeout +curl -N -H "Content-Type: application/json" \ + -X POST http://127.0.0.1:8000/chat/completions \ + -d '{ + "model": "jiuge", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "temperature": 0.7, + "test_hang_seconds": 350, + "stream": false + }' + """ From 3d39ebcd2b1050e06822cce5a3e83323933beffc Mon Sep 17 00:00:00 2001 From: Ceng23333 <441651826@qq.com> Date: Wed, 7 Jan 2026 11:38:01 +0800 Subject: [PATCH 05/12] fix streaming exception throwing Signed-off-by: Ceng23333 <441651826@qq.com> --- scripts/launch_server.py | 44 ++++++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 71970f23..211ebaad 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -403,18 +403,31 @@ async def chat_stream(id_, request_data, request: Request): infer_task.finish_reason is not None and infer_task.output_queue.async_q.empty() ): - # Check if timed out or internal error - raise HTTPException for proper error code + # Check if timed out or internal error - yield error chunk instead of raising HTTPException + # (can't raise HTTPException after streaming has started) if infer_task.timed_out: # Both timeout and internal_error result in internal error response # because the process will be killed and restarted - raise HTTPException( - status_code=500, - detail={ + # Yield error chunk in SSE format + error_chunk = { + "id": id_, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "unknown", + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": None + }], + "error": { "message": "Internal server error: process will be restarted", "type": "internal_error", "code": "internal_error" } - ) + } + chunk = json.dumps(error_chunk, ensure_ascii=False) + yield f"data: {chunk}\n\n" + yield "data: [DONE]\n\n" else: chunk = json.dumps( chunk_json(id_, finish_reason=infer_task.finish_reason), @@ -427,14 +440,27 @@ async def chat_stream(id_, request_data, request: Request): if infer_task.timed_out: # Both timeout and internal_error result in internal error response # because the process will be killed and restarted - raise HTTPException( - status_code=500, - detail={ + # Yield error chunk in SSE format (can't raise HTTPException after streaming started) + error_chunk = { + "id": id_, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "unknown", + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": None + }], + "error": { "message": "Internal server error: process will be restarted", "type": "internal_error", "code": "internal_error" } - ) + } + chunk = json.dumps(error_chunk, ensure_ascii=False) + yield f"data: {chunk}\n\n" + yield "data: [DONE]\n\n" + break token = await infer_task.output_queue.async_q.get() content = request.app.state.model.tokenizer.decode(token) From e89073d57d59d00c51808f325232d8acb3d79b88 Mon Sep 17 00:00:00 2001 From: Ceng23333 <441651826@qq.com> Date: Fri, 9 Jan 2026 13:30:40 +0800 Subject: [PATCH 06/12] update /models Signed-off-by: Ceng23333 <441651826@qq.com> --- scripts/launch_server.py | 74 ++++++++++++++++++++++++++++++++-------- 1 file changed, 60 insertions(+), 14 deletions(-) diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 211ebaad..a05c69a8 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -18,6 +18,7 @@ import janus import os import signal +from pathlib import Path DEVICE_TYPE_MAP = { @@ -89,6 +90,12 @@ def parse_args(): default=30, help="Request timeout in seconds. Process will exit if a request hangs longer than this (default: 30)", ) + parser.add_argument( + "--model-name", + type=str, + default=None, + help="Model name to return in /models endpoint. If not specified, will use the directory name from --model-path (like vLLM/llama.cpp)", + ) return parser.parse_args() @@ -102,6 +109,26 @@ def parse_args(): SERVER_PORT = args.port SERVER_HOST = args.host REQUEST_TIMEOUT = args.request_timeout + +# Derive model name from model path directory name (like vLLM and llama.cpp) +# Use --model-name if explicitly provided, otherwise use directory name +if args.model_name: + MODEL_NAME = args.model_name +elif model_path: + # Extract directory name from model path + # This follows the same convention as vLLM and llama.cpp + model_path_obj = Path(model_path).resolve() # Resolve to absolute path + if model_path_obj.is_dir(): + MODEL_NAME = model_path_obj.name + elif model_path_obj.is_file(): + # If it's a file, use the parent directory name + MODEL_NAME = model_path_obj.parent.name + else: + # Path doesn't exist yet, but extract name from path string + # Use the last component of the path + MODEL_NAME = model_path_obj.name or model_path_obj.parent.name +else: + MODEL_NAME = None print( f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs." ) @@ -734,22 +761,34 @@ async def completions(request: Request): async def list_models(request: Request): """ OpenAI-compatible /models endpoint. - Returns a list of available models. + Returns a list of available models (single model specified by --model-name argument). """ try: - # Get model information from app state - model = request.app.state.model - model_id = "jiuge" # Default model ID - - # Try to get model name from config if available - if hasattr(model, 'config') and model.config: - # Try model_type first - model_id = model.config.get("model_type", "jiuge") - # If model_type is not informative, try architectures - if model_id == "jiuge" and "architectures" in model.config: - architectures = model.config.get("architectures", []) - if architectures: - model_id = architectures[0].lower() + # Check if model is loaded (server is ready) + if not hasattr(request.app.state, 'model') or request.app.state.model is None: + # Server not ready yet - return 503 Service Unavailable + return JSONResponse( + content={"error": "Service not ready yet, model still loading"}, + status_code=503 + ) + + # Use model name from argument/directory name, otherwise try to detect from config + model_id = MODEL_NAME + + if not model_id: + # Get model information from app state + model = request.app.state.model + model_id = "unknown" # Default model ID + + # Try to get model name from config if available + if hasattr(model, 'config') and model.config: + # Try model_type first + model_id = model.config.get("model_type", "unknown") + # If model_type is not informative, try architectures + if model_id == "unknown" and "architectures" in model.config: + architectures = model.config.get("architectures", []) + if architectures: + model_id = architectures[0].lower() return JSONResponse(content={ "object": "list", @@ -765,6 +804,13 @@ async def list_models(request: Request): } ] }) + except AttributeError as e: + # Model not loaded yet + print(f"[Error] Model not loaded in /models: {e}") + return JSONResponse( + content={"error": "Service not ready yet, model still loading"}, + status_code=503 + ) except Exception as e: print(f"[Error] Exception in /models: {e}") return JSONResponse(content={"error": str(e)}, status_code=500) From f52dd5922878780ec69d6c34f2fe6702e0dc7981 Mon Sep 17 00:00:00 2001 From: Ceng23333 <441651826@qq.com> Date: Fri, 9 Jan 2026 16:57:07 +0800 Subject: [PATCH 07/12] fix format of response Signed-off-by: Ceng23333 <441651826@qq.com> --- scripts/launch_server.py | 118 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 109 insertions(+), 9 deletions(-) diff --git a/scripts/launch_server.py b/scripts/launch_server.py index a05c69a8..dfbba96a 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -137,7 +137,11 @@ def parse_args(): ) -def chunk_json(id_, content=None, role=None, finish_reason=None): +def chunk_json(id_, content=None, role=None, finish_reason=None, model="jiuge"): + """ + Generate SSE chunk format for streaming responses. + Used for Server-Sent Events (SSE) streaming mode. + """ delta = {} if content: delta["content"] = content @@ -147,12 +151,11 @@ def chunk_json(id_, content=None, role=None, finish_reason=None): "id": id_, "object": "chat.completion.chunk", "created": int(time.time()), - "model": "jiuge", + "model": model, "system_fingerprint": None, "choices": [ { "index": 0, - "text": content, "delta": delta, "logprobs": None, "finish_reason": finish_reason, @@ -161,6 +164,36 @@ def chunk_json(id_, content=None, role=None, finish_reason=None): } +def chat_completion_json(id_, content, role="assistant", finish_reason=None, prompt_tokens=0, completion_tokens=0, model="jiuge"): + """ + Generate OpenAI-compatible non-streaming chat completion response. + Used for non-streaming (stream=False) mode. + """ + return { + "id": id_, + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "system_fingerprint": None, + "choices": [ + { + "index": 0, + "message": { + "role": role, + "content": content + }, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens + } + } + + # A wrapper for InferTask that supports async output queue class AsyncInferTask(InferTask): def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens, repetition_penalty=1.0, test_hang_seconds=0): @@ -169,11 +202,15 @@ def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens, self.last_activity_time = time.time() # Track when task was last active self.test_hang_seconds = test_hang_seconds # Test parameter: sleep for this many seconds to simulate hang (set to 0 after first use) self.timed_out = False # Flag to mark if task has timed out + self.initial_prompt_tokens = len(tokens) # Track initial prompt token count for usage statistics + self.generated_tokens = [] # Track generated token IDs for counting completion tokens print(f"[INFO] Create InferTask {self.id}" + (f" (TEST: will hang for {test_hang_seconds}s once)" if test_hang_seconds > 0 else "")) def output(self, out_token): self.next(out_token) self.last_activity_time = time.time() # Update activity time when output is generated + if out_token is not None: # Track non-None tokens for completion count + self.generated_tokens.append(out_token) self.output_queue.sync_q.put(out_token) def signal_timeout(self): @@ -414,9 +451,12 @@ async def chat_stream(id_, request_data, request: Request): } ) + # Get model name from request or use global MODEL_NAME + model_name = request_data.get("model", MODEL_NAME or "jiuge") + # Initial empty content chunk = json.dumps( - chunk_json(id_, content="", role="assistant"), ensure_ascii=False + chunk_json(id_, content="", role="assistant", model=model_name), ensure_ascii=False ) yield f"data: {chunk}\n\n" @@ -440,7 +480,7 @@ async def chat_stream(id_, request_data, request: Request): "id": id_, "object": "chat.completion.chunk", "created": int(time.time()), - "model": "unknown", + "model": model_name, "choices": [{ "index": 0, "delta": {}, @@ -456,11 +496,14 @@ async def chat_stream(id_, request_data, request: Request): yield f"data: {chunk}\n\n" yield "data: [DONE]\n\n" else: + # Final chunk: empty delta with finish_reason (OpenAI API spec) chunk = json.dumps( - chunk_json(id_, finish_reason=infer_task.finish_reason), + chunk_json(id_, finish_reason=infer_task.finish_reason, model=model_name), ensure_ascii=False, ) yield f"data: {chunk}\n\n" + # Send [DONE] marker to indicate stream completion (OpenAI API spec) + yield "data: [DONE]\n\n" break # Check for timeout or internal error before getting next token @@ -472,7 +515,7 @@ async def chat_stream(id_, request_data, request: Request): "id": id_, "object": "chat.completion.chunk", "created": int(time.time()), - "model": "unknown", + "model": model_name, "choices": [{ "index": 0, "delta": {}, @@ -490,9 +533,21 @@ async def chat_stream(id_, request_data, request: Request): break token = await infer_task.output_queue.async_q.get() + + # Skip EOS tokens - don't include them in the stream + # The finish_reason will be set in the final chunk + if token is None: + continue + # Handle end_tokens as list, tuple, or single value + if isinstance(infer_task.end_tokens, (list, tuple)): + if token in infer_task.end_tokens: + continue + elif token == infer_task.end_tokens: + continue + content = request.app.state.model.tokenizer.decode(token) - chunk = json.dumps(chunk_json(id_, content=content), ensure_ascii=False) + chunk = json.dumps(chunk_json(id_, content=content, model=model_name), ensure_ascii=False) yield f"data: {chunk}\n\n" except HTTPException: @@ -551,6 +606,17 @@ async def chat(id_, request_data, request: Request): ) token = await infer_task.output_queue.async_q.get() + + # Skip EOS tokens - don't include them in the output + if token is None: + continue + # Handle end_tokens as list, tuple, or single value + if isinstance(infer_task.end_tokens, (list, tuple)): + if token in infer_task.end_tokens: + continue + elif token == infer_task.end_tokens: + continue + content = request.app.state.model.tokenizer.decode(token) output.append(content) @@ -570,11 +636,45 @@ async def chat(id_, request_data, request: Request): ) output_text = "".join(output).strip() - response = chunk_json( + + # Strip EOS token strings from the end of output (defensive check) + # Decode EOS tokens to get their string representations + eos_token_strings = [] + end_tokens_list = [] + if isinstance(infer_task.end_tokens, (list, tuple)): + end_tokens_list = list(infer_task.end_tokens) + else: + end_tokens_list = [infer_task.end_tokens] + + for eos_token_id in end_tokens_list: + try: + eos_str = request.app.state.model.tokenizer.decode([eos_token_id]) + if eos_str: + eos_token_strings.append(eos_str) + except Exception: + pass + + # Remove EOS token strings from the end of output + for eos_str in eos_token_strings: + if output_text.endswith(eos_str): + output_text = output_text[:-len(eos_str)].rstrip() + + # Calculate token usage + prompt_tokens = infer_task.initial_prompt_tokens + completion_tokens = len(infer_task.generated_tokens) + + # Get model name from request data or use global MODEL_NAME + model_name = request_data.get("model", MODEL_NAME or "jiuge") + + # Use correct OpenAI-compatible format for non-streaming response + response = chat_completion_json( id_, content=output_text, role="assistant", finish_reason=infer_task.finish_reason or "stop", + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + model=model_name ) return response From 1f52f246c2dcca82f2baf22b26182c54c444edd3 Mon Sep 17 00:00:00 2001 From: Ceng23333 <441651826@qq.com> Date: Sat, 10 Jan 2026 12:06:56 +0800 Subject: [PATCH 08/12] fix utf-8 decode issue Signed-off-by: Ceng23333 <441651826@qq.com> --- scripts/launch_server.py | 67 +++++++++++++++++++++++++++++++++++----- 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/scripts/launch_server.py b/scripts/launch_server.py index dfbba96a..8ab44368 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -462,6 +462,12 @@ async def chat_stream(id_, request_data, request: Request): request.app.state.request_queue.sync_q.put(infer_task) + # For streaming: accumulate tokens and decode incrementally to handle UTF-8 properly + # We maintain a buffer and decode the full buffer each time, only yielding new characters + # This ensures multi-byte UTF-8 sequences (emojis, etc.) are decoded correctly + token_buffer = [] + last_yielded_length = 0 + while True: if await request.is_disconnected(): print("Client disconnected. Aborting stream.") @@ -470,6 +476,17 @@ async def chat_stream(id_, request_data, request: Request): infer_task.finish_reason is not None and infer_task.output_queue.async_q.empty() ): + # Decode any remaining tokens in buffer before finishing + if token_buffer: + try: + decoded_text = request.app.state.model.tokenizer.decode(token_buffer, skip_special_tokens=False) + remaining_content = decoded_text[last_yielded_length:] + if remaining_content: + chunk = json.dumps(chunk_json(id_, content=remaining_content, model=model_name), ensure_ascii=False) + yield f"data: {chunk}\n\n" + except Exception: + pass + # Check if timed out or internal error - yield error chunk instead of raising HTTPException # (can't raise HTTPException after streaming has started) if infer_task.timed_out: @@ -545,10 +562,40 @@ async def chat_stream(id_, request_data, request: Request): elif token == infer_task.end_tokens: continue - content = request.app.state.model.tokenizer.decode(token) + # Accumulate token in buffer + token_buffer.append(token) - chunk = json.dumps(chunk_json(id_, content=content, model=model_name), ensure_ascii=False) - yield f"data: {chunk}\n\n" + # Decode the entire buffer each time to ensure proper UTF-8 handling + # The tokenizer handles multi-byte sequences correctly when decoding token lists + try: + decoded_text = request.app.state.model.tokenizer.decode(token_buffer, skip_special_tokens=False) + + # Calculate new content by comparing current decode with what we've already yielded + if len(decoded_text) > last_yielded_length: + new_content = decoded_text[last_yielded_length:] + + # Only yield if we have new content + if new_content: + chunk = json.dumps(chunk_json(id_, content=new_content, model=model_name), ensure_ascii=False) + yield f"data: {chunk}\n\n" + last_yielded_length = len(decoded_text) + + # Prevent buffer from growing too large by periodically flushing + # Keep last 5 tokens for multi-token character sequences + if len(token_buffer) > 20: + # Keep last 5 tokens and adjust last_yielded_length accordingly + tokens_to_keep = token_buffer[-5:] + decoded_kept = request.app.state.model.tokenizer.decode(tokens_to_keep, skip_special_tokens=False) + # Adjust: subtract the length of removed tokens' decoded text + removed_text = request.app.state.model.tokenizer.decode(token_buffer[:-5], skip_special_tokens=False) + last_yielded_length = len(decoded_text) - len(removed_text) + token_buffer = tokens_to_keep + + except Exception as e: + # If decoding fails, skip this token and log (don't break the stream) + print(f"[Warning] Failed to decode token {token}: {e}") + if len(token_buffer) > 0: + token_buffer.pop() # Remove the problematic token except HTTPException: # Re-raise HTTPException to propagate error status code @@ -582,7 +629,8 @@ async def chat(id_, request_data, request: Request): await request.app.state.kv_cache_pool.acquire(infer_task) request.app.state.request_queue.sync_q.put(infer_task) - output = [] + # Collect all tokens first, then decode at once to preserve UTF-8 sequences + tokens = [] while True: if ( infer_task.finish_reason is not None @@ -617,8 +665,8 @@ async def chat(id_, request_data, request: Request): elif token == infer_task.end_tokens: continue - content = request.app.state.model.tokenizer.decode(token) - output.append(content) + # Collect tokens - decode all at once to preserve multi-byte UTF-8 characters + tokens.append(token) # Check if timed out or internal error before returning response if infer_task.timed_out: @@ -635,7 +683,12 @@ async def chat(id_, request_data, request: Request): status_code=500 # Internal Server Error ) - output_text = "".join(output).strip() + # Decode all tokens at once to preserve multi-byte UTF-8 sequences (emojis, etc.) + # This is critical for proper handling of characters that span multiple tokens + if tokens: + output_text = request.app.state.model.tokenizer.decode(tokens, skip_special_tokens=False).strip() + else: + output_text = "" # Strip EOS token strings from the end of output (defensive check) # Decode EOS tokens to get their string representations From a1702820414a513d16f1d938ccff99647432d137 Mon Sep 17 00:00:00 2001 From: Ceng23333 <441651826@qq.com> Date: Sat, 10 Jan 2026 20:06:09 +0800 Subject: [PATCH 09/12] add max_cocurrency and replacement_char remove Signed-off-by: Ceng23333 <441651826@qq.com> --- scripts/launch_server.py | 250 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 242 insertions(+), 8 deletions(-) diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 8ab44368..43b61e4b 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -18,6 +18,7 @@ import janus import os import signal +import asyncio from pathlib import Path @@ -96,6 +97,19 @@ def parse_args(): default=None, help="Model name to return in /models endpoint. If not specified, will use the directory name from --model-path (like vLLM/llama.cpp)", ) + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. Requests exceeding this limit will wait in queue. Default: unlimited", + ) + parser.add_argument( + "--fix-replacement-chars", + type=lambda x: x.lower() in ('true', '1', 'yes', 'on'), + default=None, + help="Whether to automatically remove replacement characters (U+FFFD) from responses. " + "Default: True. Set to false to only log detections without removing them.", + ) return parser.parse_args() @@ -109,6 +123,15 @@ def parse_args(): SERVER_PORT = args.port SERVER_HOST = args.host REQUEST_TIMEOUT = args.request_timeout +MAX_CONCURRENCY = args.max_concurrency + +# Fix replacement characters config: CLI arg takes precedence, then env var, default to False +if args.fix_replacement_chars is not None: + FIX_REPLACEMENT_CHARS = args.fix_replacement_chars +else: + # Check environment variable, default to False if not set + env_value = os.environ.get("FIX_REPLACEMENT_CHARS", "false").lower() + FIX_REPLACEMENT_CHARS = env_value in ('true', '1', 'yes', 'on') # Derive model name from model path directory name (like vLLM and llama.cpp) # Use --model-name if explicitly provided, otherwise use directory name @@ -135,6 +158,11 @@ def parse_args(): print( f"Request timeout: {REQUEST_TIMEOUT}s. Process will exit if a request hangs longer than this." ) +if MAX_CONCURRENCY is not None and MAX_CONCURRENCY > 0: + print(f"Max concurrency: {MAX_CONCURRENCY}. Requests exceeding this limit will wait in queue.") +else: + print("Max concurrency: unlimited") +print(f"Fix replacement characters: {FIX_REPLACEMENT_CHARS} (U+FFFD will be {'removed' if FIX_REPLACEMENT_CHARS else 'logged only'})") def chunk_json(id_, content=None, role=None, finish_reason=None, model="jiuge"): @@ -239,6 +267,15 @@ async def lifespan(app: FastAPI): app.state.request_queue = janus.Queue() app.state.active_tasks = {} # Track active tasks: task_id -> task object app.state.task_lock = threading.Lock() # Lock for accessing active_tasks + + # Initialize concurrency control semaphore + if MAX_CONCURRENCY is not None and MAX_CONCURRENCY > 0: + app.state.concurrency_semaphore = asyncio.Semaphore(MAX_CONCURRENCY) + print(f"Max concurrency: {MAX_CONCURRENCY}. Requests exceeding this limit will wait in queue.") + else: + app.state.concurrency_semaphore = None + print("Max concurrency: unlimited") + worker_thread = threading.Thread(target=worker_loop, args=(app,), daemon=True) worker_thread.start() @@ -317,6 +354,66 @@ def timeout_checker_loop(app): # App loop: take requests from the queue, do inference, and put unfinished requests back into the queue. +def detect_and_fix_replacement_character(text, context="", request_id=None, fix=True): + """ + Detect and optionally fix replacement character (U+FFFD, ) in text. + + Args: + text: Text to check for replacement characters + context: Context string for logging (e.g., "streaming response", "non-streaming response") + request_id: Optional request ID for logging + fix: If True, remove replacement characters from the text. If False, only detect and log. + + Returns: + tuple: (cleaned_text: str, has_replacement: bool) + """ + replacement_char = '\ufffd' # U+FFFD + positions = [] + + for i, char in enumerate(text): + if ord(char) == 0xFFFD: # Unicode replacement character + positions.append(i) + + if positions: + # Log detailed information about replacement characters + log_prefix = f"[REPLACEMENT_CHAR_DETECTED]" + if request_id: + log_prefix += f" Request ID: {request_id}" + if context: + log_prefix += f" Context: {context}" + + print(f"{log_prefix} Found {len(positions)} replacement character(s) (U+FFFD)") + print(f" Positions: {positions[:10]}{'...' if len(positions) > 10 else ''}") + + # Collect and log context samples around first few replacement characters + for pos in positions[:3]: # Log first 3 occurrences + start = max(0, pos - 30) + end = min(len(text), pos + 30) + context_sample = text[start:end] + before = text[max(0, pos-10):pos] if pos > 0 else '' + after = text[pos+1:min(len(text), pos+11)] if pos < len(text)-1 else '' + print(f" Position {pos}:") + print(f" Before: {repr(before)}") + print(f" After: {repr(after)}") + print(f" Full context: {repr(context_sample)}") + + # Also print a snippet of the text for debugging + snippet_start = max(0, positions[0] - 100) + snippet_end = min(len(text), positions[0] + 100) + print(f" Text snippet (around first occurrence): {repr(text[snippet_start:snippet_end])}") + + # Fix: Remove all replacement characters + if fix: + # Remove all U+FFFD characters + cleaned_text = ''.join(char for char in text if ord(char) != 0xFFFD) + print(f" [FIXED] Removed {len(positions)} replacement character(s) from output") + return cleaned_text, True + else: + return text, True + + return text, False + + def worker_loop(app): while True: try: @@ -431,6 +528,21 @@ def build_task(id_, request_data, request: Request): async def chat_stream(id_, request_data, request: Request): + # Acquire concurrency semaphore if configured (waits if max concurrency reached) + semaphore = request.app.state.concurrency_semaphore + if semaphore: + await semaphore.acquire() + try: + async for item in _chat_stream_impl(id_, request_data, request): + yield item + finally: + semaphore.release() + else: + async for item in _chat_stream_impl(id_, request_data, request): + yield item + + +async def _chat_stream_impl(id_, request_data, request: Request): try: infer_task = build_task(id_, request_data, request) # Track task from creation @@ -480,6 +592,12 @@ async def chat_stream(id_, request_data, request: Request): if token_buffer: try: decoded_text = request.app.state.model.tokenizer.decode(token_buffer, skip_special_tokens=False) + + # Detect and fix replacement characters in the decoded text + decoded_text, has_replacement = detect_and_fix_replacement_character( + decoded_text, context="streaming response (final)", request_id=id_, fix=FIX_REPLACEMENT_CHARS + ) + remaining_content = decoded_text[last_yielded_length:] if remaining_content: chunk = json.dumps(chunk_json(id_, content=remaining_content, model=model_name), ensure_ascii=False) @@ -570,6 +688,12 @@ async def chat_stream(id_, request_data, request: Request): try: decoded_text = request.app.state.model.tokenizer.decode(token_buffer, skip_special_tokens=False) + # Detect and fix replacement characters in the decoded text + # We fix the entire decoded_text to maintain consistency with last_yielded_length tracking + decoded_text, has_replacement = detect_and_fix_replacement_character( + decoded_text, context="streaming response chunk", request_id=id_, fix=FIX_REPLACEMENT_CHARS + ) + # Calculate new content by comparing current decode with what we've already yielded if len(decoded_text) > last_yielded_length: new_content = decoded_text[last_yielded_length:] @@ -582,14 +706,43 @@ async def chat_stream(id_, request_data, request: Request): # Prevent buffer from growing too large by periodically flushing # Keep last 5 tokens for multi-token character sequences + # CRITICAL: When trimming, we need to calculate how many characters from the kept tokens + # have already been yielded. We do this by: + # 1. Calculating what portion of the full decoded text corresponds to removed tokens + # 2. The remaining portion (kept tokens) may have already been partially yielded + # 3. When we decode only the kept tokens, we need to figure out which portion was already sent if len(token_buffer) > 20: - # Keep last 5 tokens and adjust last_yielded_length accordingly - tokens_to_keep = token_buffer[-5:] - decoded_kept = request.app.state.model.tokenizer.decode(tokens_to_keep, skip_special_tokens=False) - # Adjust: subtract the length of removed tokens' decoded text - removed_text = request.app.state.model.tokenizer.decode(token_buffer[:-5], skip_special_tokens=False) - last_yielded_length = len(decoded_text) - len(removed_text) - token_buffer = tokens_to_keep + # Calculate what portion of decoded_text corresponds to removed tokens + num_removed = len(token_buffer) - 5 + removed_tokens = token_buffer[:num_removed] + kept_tokens = token_buffer[-5:] + + # Decode removed and kept portions separately + try: + removed_text = request.app.state.model.tokenizer.decode(removed_tokens, skip_special_tokens=False) + # The removed portion starts from the beginning, so its length is what we've already fully processed + # Calculate how many characters from kept tokens were already yielded + # We know: decoded_text = removed_text + kept_portion_from_full_decode + # And we've yielded up to last_yielded_length characters + if last_yielded_length > len(removed_text): + # Some portion of the kept tokens was already yielded + # Calculate the offset into the kept portion that was already sent + chars_yielded_from_kept = last_yielded_length - len(removed_text) + else: + # Nothing from kept tokens was yielded yet + chars_yielded_from_kept = 0 + + # Decode the kept tokens to get their full text + decoded_kept = request.app.state.model.tokenizer.decode(kept_tokens, skip_special_tokens=False) + + # Update last_yielded_length to point to the character position in the new buffer + # that corresponds to what we've already sent + last_yielded_length = min(chars_yielded_from_kept, len(decoded_kept)) + token_buffer = kept_tokens + except Exception: + # If decoding fails during trimming, don't trim (safer to keep buffer) + # Log but continue with current buffer + print(f"[Warning] Failed to decode during buffer trim, keeping full buffer") except Exception as e: # If decoding fails, skip this token and log (don't break the stream) @@ -620,6 +773,19 @@ async def chat_stream(id_, request_data, request: Request): async def chat(id_, request_data, request: Request): + # Acquire concurrency semaphore if configured (waits if max concurrency reached) + semaphore = request.app.state.concurrency_semaphore + if semaphore: + await semaphore.acquire() + try: + return await _chat_impl(id_, request_data, request) + finally: + semaphore.release() + else: + return await _chat_impl(id_, request_data, request) + + +async def _chat_impl(id_, request_data, request: Request): try: infer_task = build_task(id_, request_data, request) # Track task from creation @@ -685,8 +851,39 @@ async def chat(id_, request_data, request: Request): # Decode all tokens at once to preserve multi-byte UTF-8 sequences (emojis, etc.) # This is critical for proper handling of characters that span multiple tokens + # CRITICAL: Ensure proper UTF-8 handling to avoid replacement characters in markdown references if tokens: - output_text = request.app.state.model.tokenizer.decode(tokens, skip_special_tokens=False).strip() + try: + # Decode tokens to string - tokenizer should handle UTF-8 correctly + decoded_output = request.app.state.model.tokenizer.decode(tokens, skip_special_tokens=False) + # Handle both bytes and str return types from tokenizer + if isinstance(decoded_output, bytes): + # If tokenizer returns bytes, decode with UTF-8 + output_text = decoded_output.decode('utf-8', errors='strict').strip() + else: + # If tokenizer returns str, verify it's valid UTF-8 by re-encoding/decoding + # This catches any invalid surrogate pairs or other UTF-8 issues + # Use 'surrogatepass' to preserve surrogates, then 'strict' for final decode + try: + # Validate UTF-8 by encoding and decoding - if it fails, there's an issue + validated = decoded_output.encode('utf-8', errors='strict').decode('utf-8', errors='strict') + output_text = validated.strip() + except UnicodeError: + # If validation fails, try with error handling to recover + print(f"[Warning] UTF-8 validation failed, using replacement for invalid sequences") + output_text = decoded_output.encode('utf-8', errors='replace').decode('utf-8', errors='replace').strip() + except Exception as e: + print(f"[Error] Token decoding error: {e}") + # Last resort: try with error replacement + try: + decoded_output = request.app.state.model.tokenizer.decode(tokens, skip_special_tokens=False) + if isinstance(decoded_output, bytes): + output_text = decoded_output.decode('utf-8', errors='replace').strip() + else: + output_text = decoded_output.encode('utf-8', errors='replace').decode('utf-8', errors='replace').strip() + except Exception as e2: + print(f"[Error] Failed to decode tokens even with error handling: {e2}") + output_text = "" else: output_text = "" @@ -712,6 +909,11 @@ async def chat(id_, request_data, request: Request): if output_text.endswith(eos_str): output_text = output_text[:-len(eos_str)].rstrip() + # Detect and fix replacement characters (U+FFFD) in output + output_text, has_replacement = detect_and_fix_replacement_character( + output_text, context="non-streaming response", request_id=id_, fix=FIX_REPLACEMENT_CHARS + ) + # Calculate token usage prompt_tokens = infer_task.initial_prompt_tokens completion_tokens = len(infer_task.generated_tokens) @@ -842,6 +1044,11 @@ async def completion(id_, request_data, request: Request): output_text = "".join(output).strip() + # Detect and fix replacement characters in completion output + output_text, has_replacement = detect_and_fix_replacement_character( + output_text, context="completion response", request_id=id_, fix=FIX_REPLACEMENT_CHARS + ) + # Prepare tokens list for logprobs tokens_list = [] text_offset_list = [] @@ -1020,5 +1227,32 @@ async def list_models(request: Request): "stream": false }' +# Test UTF-8 decoding fix: Markdown reference with special characters +# This validates that characters don't appear in markdown references +curl -s -X POST http://127.0.0.1:8000/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen3-32B", + "messages": [ + {"role": "user", "content": "请写一段包含markdown链接的文本,例如 [链接文本](https://example.com) 和引用 [^1]"} + ], + "temperature": 0.7, + "max_tokens": 2000, + "stream": false + }' | python3 -c "import sys, json; data = json.load(sys.stdin); content = data['choices'][0]['message']['content'] if 'choices' in data else ''; print('PASS' if '' not in content and '' not in content else 'FAIL: Found replacement characters'); print(content[:200])" + +# Test UTF-8 decoding: Chinese with emojis +curl -s -X POST http://127.0.0.1:8000/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "jiuge", + "messages": [ + {"role": "user", "content": "请用中文回答:什么是人工智能?使用一些表情符号。"} + ], + "temperature": 0.7, + "max_tokens": 150, + "stream": false + }' | python3 -c "import sys, json; data = json.load(sys.stdin); content = data['choices'][0]['message']['content'] if 'choices' in data else ''; print('PASS' if '' not in content and '' not in content else 'FAIL: Found replacement characters'); print(content[:200])" + """ From 2d778e54cfd708fe1f6a10748490b92fe7130b04 Mon Sep 17 00:00:00 2001 From: Ceng23333 <441651826@qq.com> Date: Sun, 11 Jan 2026 09:59:43 +0800 Subject: [PATCH 10/12] optimize uffd checker Signed-off-by: Ceng23333 <441651826@qq.com> --- scripts/launch_server.py | 91 ++++++++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 35 deletions(-) diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 43b61e4b..4c467e2c 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -354,59 +354,78 @@ def timeout_checker_loop(app): # App loop: take requests from the queue, do inference, and put unfinished requests back into the queue. -def detect_and_fix_replacement_character(text, context="", request_id=None, fix=True): +def detect_and_fix_replacement_character(text, context="", request_id=None, fix=True, verbose_log=True): """ Detect and optionally fix replacement character (U+FFFD, ) in text. + Optimized for performance - fast path when no replacement chars found. Args: text: Text to check for replacement characters context: Context string for logging (e.g., "streaming response", "non-streaming response") request_id: Optional request ID for logging fix: If True, remove replacement characters from the text. If False, only detect and log. + verbose_log: If False, only log a brief message (for streaming chunks) Returns: tuple: (cleaned_text: str, has_replacement: bool) """ + # Fast path: Use in operator first (faster than character-by-character scan for most cases) + if '\ufffd' not in text: + return text, False + + # Found potential replacement chars, now count and get positions replacement_char = '\ufffd' # U+FFFD positions = [] - for i, char in enumerate(text): - if ord(char) == 0xFFFD: # Unicode replacement character - positions.append(i) + # Optimize: Use faster method to find all positions + # Using str.find in a loop is faster than enumerate for long strings + start = 0 + while True: + pos = text.find(replacement_char, start) + if pos == -1: + break + positions.append(pos) + start = pos + 1 if positions: - # Log detailed information about replacement characters - log_prefix = f"[REPLACEMENT_CHAR_DETECTED]" - if request_id: - log_prefix += f" Request ID: {request_id}" - if context: - log_prefix += f" Context: {context}" - - print(f"{log_prefix} Found {len(positions)} replacement character(s) (U+FFFD)") - print(f" Positions: {positions[:10]}{'...' if len(positions) > 10 else ''}") - - # Collect and log context samples around first few replacement characters - for pos in positions[:3]: # Log first 3 occurrences - start = max(0, pos - 30) - end = min(len(text), pos + 30) - context_sample = text[start:end] - before = text[max(0, pos-10):pos] if pos > 0 else '' - after = text[pos+1:min(len(text), pos+11)] if pos < len(text)-1 else '' - print(f" Position {pos}:") - print(f" Before: {repr(before)}") - print(f" After: {repr(after)}") - print(f" Full context: {repr(context_sample)}") - - # Also print a snippet of the text for debugging - snippet_start = max(0, positions[0] - 100) - snippet_end = min(len(text), positions[0] + 100) - print(f" Text snippet (around first occurrence): {repr(text[snippet_start:snippet_end])}") + # Logging (only if verbose_log is True, or for non-streaming responses) + if verbose_log: + log_prefix = f"[REPLACEMENT_CHAR_DETECTED]" + if request_id: + log_prefix += f" Request ID: {request_id}" + if context: + log_prefix += f" Context: {context}" + + print(f"{log_prefix} Found {len(positions)} replacement character(s) (U+FFFD)") + print(f" Positions: {positions[:10]}{'...' if len(positions) > 10 else ''}") + + # Collect and log context samples around first few replacement characters + for pos in positions[:3]: # Log first 3 occurrences + start = max(0, pos - 30) + end = min(len(text), pos + 30) + context_sample = text[start:end] + before = text[max(0, pos-10):pos] if pos > 0 else '' + after = text[pos+1:min(len(text), pos+11)] if pos < len(text)-1 else '' + print(f" Position {pos}:") + print(f" Before: {repr(before)}") + print(f" After: {repr(after)}") + print(f" Full context: {repr(context_sample)}") + + # Also print a snippet of the text for debugging + snippet_start = max(0, positions[0] - 100) + snippet_end = min(len(text), positions[0] + 100) + print(f" Text snippet (around first occurrence): {repr(text[snippet_start:snippet_end])}") + else: + # # Brief log for streaming chunks + # print(f"[REPLACEMENT_CHAR_DETECTED] {len(positions)} replacement char(s) in {context} (Request: {request_id or 'N/A'})") + pass # Fix: Remove all replacement characters if fix: - # Remove all U+FFFD characters - cleaned_text = ''.join(char for char in text if ord(char) != 0xFFFD) - print(f" [FIXED] Removed {len(positions)} replacement character(s) from output") + # Optimize: Use str.replace which is faster than list comprehension for this case + cleaned_text = text.replace(replacement_char, '') + if verbose_log: + print(f" [FIXED] Removed {len(positions)} replacement character(s) from output") return cleaned_text, True else: return text, True @@ -594,8 +613,9 @@ async def _chat_stream_impl(id_, request_data, request: Request): decoded_text = request.app.state.model.tokenizer.decode(token_buffer, skip_special_tokens=False) # Detect and fix replacement characters in the decoded text + # Use verbose_log=False for streaming to reduce logging overhead decoded_text, has_replacement = detect_and_fix_replacement_character( - decoded_text, context="streaming response (final)", request_id=id_, fix=FIX_REPLACEMENT_CHARS + decoded_text, context="streaming response (final)", request_id=id_, fix=FIX_REPLACEMENT_CHARS, verbose_log=False ) remaining_content = decoded_text[last_yielded_length:] @@ -690,8 +710,9 @@ async def _chat_stream_impl(id_, request_data, request: Request): # Detect and fix replacement characters in the decoded text # We fix the entire decoded_text to maintain consistency with last_yielded_length tracking + # Use verbose_log=False for streaming chunks to reduce logging overhead decoded_text, has_replacement = detect_and_fix_replacement_character( - decoded_text, context="streaming response chunk", request_id=id_, fix=FIX_REPLACEMENT_CHARS + decoded_text, context="streaming response chunk", request_id=id_, fix=FIX_REPLACEMENT_CHARS, verbose_log=False ) # Calculate new content by comparing current decode with what we've already yielded From 7da9d5b75e0bbf5784c6e8e95df9b156dedce62d Mon Sep 17 00:00:00 2001 From: Ceng23333 <441651826@qq.com> Date: Sun, 11 Jan 2026 10:37:27 +0800 Subject: [PATCH 11/12] revert fix_replacement and apply vllm-like handling Signed-off-by: Ceng23333 <441651826@qq.com> --- scripts/launch_server.py | 159 ++++++++------------------------------- 1 file changed, 33 insertions(+), 126 deletions(-) diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 4c467e2c..43098b60 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -103,13 +103,6 @@ def parse_args(): default=None, help="Maximum number of concurrent requests. Requests exceeding this limit will wait in queue. Default: unlimited", ) - parser.add_argument( - "--fix-replacement-chars", - type=lambda x: x.lower() in ('true', '1', 'yes', 'on'), - default=None, - help="Whether to automatically remove replacement characters (U+FFFD) from responses. " - "Default: True. Set to false to only log detections without removing them.", - ) return parser.parse_args() @@ -125,14 +118,6 @@ def parse_args(): REQUEST_TIMEOUT = args.request_timeout MAX_CONCURRENCY = args.max_concurrency -# Fix replacement characters config: CLI arg takes precedence, then env var, default to False -if args.fix_replacement_chars is not None: - FIX_REPLACEMENT_CHARS = args.fix_replacement_chars -else: - # Check environment variable, default to False if not set - env_value = os.environ.get("FIX_REPLACEMENT_CHARS", "false").lower() - FIX_REPLACEMENT_CHARS = env_value in ('true', '1', 'yes', 'on') - # Derive model name from model path directory name (like vLLM and llama.cpp) # Use --model-name if explicitly provided, otherwise use directory name if args.model_name: @@ -162,7 +147,6 @@ def parse_args(): print(f"Max concurrency: {MAX_CONCURRENCY}. Requests exceeding this limit will wait in queue.") else: print("Max concurrency: unlimited") -print(f"Fix replacement characters: {FIX_REPLACEMENT_CHARS} (U+FFFD will be {'removed' if FIX_REPLACEMENT_CHARS else 'logged only'})") def chunk_json(id_, content=None, role=None, finish_reason=None, model="jiuge"): @@ -354,85 +338,6 @@ def timeout_checker_loop(app): # App loop: take requests from the queue, do inference, and put unfinished requests back into the queue. -def detect_and_fix_replacement_character(text, context="", request_id=None, fix=True, verbose_log=True): - """ - Detect and optionally fix replacement character (U+FFFD, ) in text. - Optimized for performance - fast path when no replacement chars found. - - Args: - text: Text to check for replacement characters - context: Context string for logging (e.g., "streaming response", "non-streaming response") - request_id: Optional request ID for logging - fix: If True, remove replacement characters from the text. If False, only detect and log. - verbose_log: If False, only log a brief message (for streaming chunks) - - Returns: - tuple: (cleaned_text: str, has_replacement: bool) - """ - # Fast path: Use in operator first (faster than character-by-character scan for most cases) - if '\ufffd' not in text: - return text, False - - # Found potential replacement chars, now count and get positions - replacement_char = '\ufffd' # U+FFFD - positions = [] - - # Optimize: Use faster method to find all positions - # Using str.find in a loop is faster than enumerate for long strings - start = 0 - while True: - pos = text.find(replacement_char, start) - if pos == -1: - break - positions.append(pos) - start = pos + 1 - - if positions: - # Logging (only if verbose_log is True, or for non-streaming responses) - if verbose_log: - log_prefix = f"[REPLACEMENT_CHAR_DETECTED]" - if request_id: - log_prefix += f" Request ID: {request_id}" - if context: - log_prefix += f" Context: {context}" - - print(f"{log_prefix} Found {len(positions)} replacement character(s) (U+FFFD)") - print(f" Positions: {positions[:10]}{'...' if len(positions) > 10 else ''}") - - # Collect and log context samples around first few replacement characters - for pos in positions[:3]: # Log first 3 occurrences - start = max(0, pos - 30) - end = min(len(text), pos + 30) - context_sample = text[start:end] - before = text[max(0, pos-10):pos] if pos > 0 else '' - after = text[pos+1:min(len(text), pos+11)] if pos < len(text)-1 else '' - print(f" Position {pos}:") - print(f" Before: {repr(before)}") - print(f" After: {repr(after)}") - print(f" Full context: {repr(context_sample)}") - - # Also print a snippet of the text for debugging - snippet_start = max(0, positions[0] - 100) - snippet_end = min(len(text), positions[0] + 100) - print(f" Text snippet (around first occurrence): {repr(text[snippet_start:snippet_end])}") - else: - # # Brief log for streaming chunks - # print(f"[REPLACEMENT_CHAR_DETECTED] {len(positions)} replacement char(s) in {context} (Request: {request_id or 'N/A'})") - pass - - # Fix: Remove all replacement characters - if fix: - # Optimize: Use str.replace which is faster than list comprehension for this case - cleaned_text = text.replace(replacement_char, '') - if verbose_log: - print(f" [FIXED] Removed {len(positions)} replacement character(s) from output") - return cleaned_text, True - else: - return text, True - - return text, False - - def worker_loop(app): while True: try: @@ -612,14 +517,13 @@ async def _chat_stream_impl(id_, request_data, request: Request): try: decoded_text = request.app.state.model.tokenizer.decode(token_buffer, skip_special_tokens=False) - # Detect and fix replacement characters in the decoded text - # Use verbose_log=False for streaming to reduce logging overhead - decoded_text, has_replacement = detect_and_fix_replacement_character( - decoded_text, context="streaming response (final)", request_id=id_, fix=FIX_REPLACEMENT_CHARS, verbose_log=False - ) - + # On final chunk, yield everything including any trailing replacement chars + # (they're no longer incomplete sequences since this is the end) remaining_content = decoded_text[last_yielded_length:] if remaining_content: + # Log if replacement chars present (for monitoring) + if '\ufffd' in remaining_content: + print(f"[REPLACEMENT_CHAR_DETECTED] Found replacement char(s) in final streaming chunk (Request: {id_})") chunk = json.dumps(chunk_json(id_, content=remaining_content, model=model_name), ensure_ascii=False) yield f"data: {chunk}\n\n" except Exception: @@ -708,31 +612,33 @@ async def _chat_stream_impl(id_, request_data, request: Request): try: decoded_text = request.app.state.model.tokenizer.decode(token_buffer, skip_special_tokens=False) - # Detect and fix replacement characters in the decoded text - # We fix the entire decoded_text to maintain consistency with last_yielded_length tracking - # Use verbose_log=False for streaming chunks to reduce logging overhead - decoded_text, has_replacement = detect_and_fix_replacement_character( - decoded_text, context="streaming response chunk", request_id=id_, fix=FIX_REPLACEMENT_CHARS, verbose_log=False - ) - - # Calculate new content by comparing current decode with what we've already yielded - if len(decoded_text) > last_yielded_length: - new_content = decoded_text[last_yielded_length:] + # vLLM-style UTF-8 buffering: if text ends with replacement char (), + # it's likely an incomplete UTF-8 byte sequence - hold it back until more tokens arrive + # Only check the end, not the middle (middle replacement chars are real invalid tokens) + holds_back_incomplete_utf8 = False + if decoded_text and decoded_text.endswith('\ufffd'): + # Incomplete UTF-8 sequence - hold back this chunk + holds_back_incomplete_utf8 = True + else: + # Calculate new content by comparing current decode with what we've already yielded + if len(decoded_text) > last_yielded_length: + new_content = decoded_text[last_yielded_length:] - # Only yield if we have new content - if new_content: - chunk = json.dumps(chunk_json(id_, content=new_content, model=model_name), ensure_ascii=False) - yield f"data: {chunk}\n\n" - last_yielded_length = len(decoded_text) + # Only yield if we have new content + if new_content: + chunk = json.dumps(chunk_json(id_, content=new_content, model=model_name), ensure_ascii=False) + yield f"data: {chunk}\n\n" + last_yielded_length = len(decoded_text) # Prevent buffer from growing too large by periodically flushing # Keep last 5 tokens for multi-token character sequences - # CRITICAL: When trimming, we need to calculate how many characters from the kept tokens + # CRITICAL: Don't trim if we're holding back incomplete UTF-8 (could break sequence) + # When trimming, we need to calculate how many characters from the kept tokens # have already been yielded. We do this by: # 1. Calculating what portion of the full decoded text corresponds to removed tokens # 2. The remaining portion (kept tokens) may have already been partially yielded # 3. When we decode only the kept tokens, we need to figure out which portion was already sent - if len(token_buffer) > 20: + if not holds_back_incomplete_utf8 and len(token_buffer) > 20: # Calculate what portion of decoded_text corresponds to removed tokens num_removed = len(token_buffer) - 5 removed_tokens = token_buffer[:num_removed] @@ -930,10 +836,12 @@ async def _chat_impl(id_, request_data, request: Request): if output_text.endswith(eos_str): output_text = output_text[:-len(eos_str)].rstrip() - # Detect and fix replacement characters (U+FFFD) in output - output_text, has_replacement = detect_and_fix_replacement_character( - output_text, context="non-streaming response", request_id=id_, fix=FIX_REPLACEMENT_CHARS - ) + # vLLM-style: Log replacement characters for monitoring + # Replacement chars at the end would have been incomplete UTF-8, but for non-streaming + # we decode the full sequence, so any replacement chars are real invalid tokens + if '\ufffd' in output_text: + # Log for monitoring/debugging + print(f"[REPLACEMENT_CHAR_DETECTED] Found replacement char(s) in non-streaming response (Request: {id_})") # Calculate token usage prompt_tokens = infer_task.initial_prompt_tokens @@ -1065,10 +973,9 @@ async def completion(id_, request_data, request: Request): output_text = "".join(output).strip() - # Detect and fix replacement characters in completion output - output_text, has_replacement = detect_and_fix_replacement_character( - output_text, context="completion response", request_id=id_, fix=FIX_REPLACEMENT_CHARS - ) + # vLLM-style: Log replacement characters for monitoring + if '\ufffd' in output_text: + print(f"[REPLACEMENT_CHAR_DETECTED] Found replacement char(s) in completion response (Request: {id_})") # Prepare tokens list for logprobs tokens_list = [] From 3a570a816fdd95c9b1408b3177f822db43147ea4 Mon Sep 17 00:00:00 2001 From: Ceng23333 <441651826@qq.com> Date: Tue, 13 Jan 2026 10:25:18 +0800 Subject: [PATCH 12/12] update ceval script Signed-off-by: Ceng23333 <441651826@qq.com> --- scripts/test_ceval.py | 233 +++++++++++++++++++++++++++++++----------- 1 file changed, 174 insertions(+), 59 deletions(-) diff --git a/scripts/test_ceval.py b/scripts/test_ceval.py index 83365f4a..b040e4a4 100644 --- a/scripts/test_ceval.py +++ b/scripts/test_ceval.py @@ -10,13 +10,16 @@ def __init__( super().__init__(model_dir_path, device, ndev, max_tokens) pass - def generate(self, conversation, max_steps, topp_=1.0, topk_=1, temperature_=1.0): + def generate(self, conversation, max_steps, topp_=1.0, topk_=0, temperature_=1.0, repetition_penalty_=1.03): + # Align with launch_server.py: use apply_chat_template with enable_thinking=False + template_params = { + "conversation": conversation, + "add_generation_prompt": True, + "tokenize": False, + "enable_thinking": False # Disable thinking mode + } input_content = ( - self.tokenizer.apply_chat_template( - conversation=conversation, - add_generation_prompt=True, - tokenize=False, - ) + self.tokenizer.apply_chat_template(**template_params) + "正确答案是" ) @@ -31,30 +34,42 @@ def generate(self, conversation, max_steps, topp_=1.0, topk_=1, temperature_=1.0 topk_, topp_, self.eos_token_id, + repetition_penalty_, ) infer_task.bind_kvcache(KVCache(self)) steps = 0 total_time = 0 - output_content = "" + # Collect all tokens first, then decode at once to preserve UTF-8 sequences (aligned with launch_server) + output_tokens = [] for step_i in range(max_steps): start_time = time.time() - output_tokens = self.batch_infer_one_round([infer_task]) + step_output_tokens = self.batch_infer_one_round([infer_task]) end_time = time.time() steps += 1 - output_str = self.tokenizer.decode(output_tokens[0]) - output_content += output_str - print(output_str, end="", flush=True) - if output_tokens[0] in self.eos_token_id: + + token = step_output_tokens[0] + + # Check for EOS before adding to buffer + if token is None or token in self.eos_token_id: break - infer_task.next(output_tokens[0]) + + output_tokens.append(token) + infer_task.next(token) if step_i > 0: total_time += end_time - start_time + # Decode all tokens at once to preserve multi-byte UTF-8 sequences (aligned with launch_server non-streaming) + if output_tokens: + output_content = self.tokenizer.decode(output_tokens, skip_special_tokens=False).strip() + print(output_content, end="", flush=True) + else: + output_content = "" + print("\n") - avg_time = total_time * 1000 / (steps - 1 + 1e-9) + avg_time = total_time * 1000 / (steps - 1 + 1e-9) if steps > 1 else 0 print(f"Time per step: {avg_time:.3f}ms") infer_task._kv_cache.drop(self) @@ -94,60 +109,160 @@ def test(): ) sys.exit(1) - # https://huggingface.co/datasets/ceval/ceval-exam/tree/main/middle_school_geography + # Full list of subjects from https://huggingface.co/datasets/ceval/ceval-exam/tree/main + ALL_SUBJECTS = [ + "accountant", + "advanced_mathematics", + "art_studies", + "basic_medicine", + "business_administration", + "chinese_language_and_literature", + "civil_servant", + "clinical_medicine", + "college_chemistry", + "college_economics", + "college_physics", + "college_programming", + "computer_architecture", + "computer_network", + "discrete_mathematics", + "education_science", + "electrical_engineer", + "environmental_impact_assessment_engineer", + "fire_engineer", + "high_school_biology", + "high_school_chemistry", + "high_school_chinese", + "high_school_geography", + "high_school_history", + "high_school_mathematics", + "high_school_physics", + "high_school_politics", + "ideological_and_moral_cultivation", + "law", + "legal_professional", + "logic", + "mao_zedong_thought", + "marxism", + "metrology_engineer", + "middle_school_biology", + "middle_school_chemistry", + "middle_school_geography", + "middle_school_history", + "middle_school_mathematics", + "middle_school_physics", + "middle_school_politics", + "modern_chinese_history", + "operating_system", + "physician", + "plant_protection", + "probability_and_statistics", + "professional_tour_guide", + "sports_science", + "tax_accountant", + "teacher_qualification", + "urban_and_rural_planner", + "veterinary_medicine" + ] - dataset = load_dataset(r"ceval/ceval-exam", name="middle_school_mathematics") - # dataset = load_dataset(r"ceval/ceval-exam", name="high_school_history") - # dataset = load_dataset(r"ceval/ceval-exam", name="high_school_chinese") - # dataset = load_dataset(r"ceval/ceval-exam", name="high_school_physics") - # dataset = load_dataset(r"ceval/ceval-exam", name="middle_school_geography") - # dataset = load_dataset(r"ceval/ceval-exam", name="middle_school_physics") - - samples = dataset["val"] ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 model = JiugeForCeval(model_path, device_type, ndev) - answers_list = [] - for sample in samples: - input_content = f"'question':{sample['question']},'A': {sample['A']}, 'B':{sample['B']}, 'C': {sample['C']},'D': {sample['D']}。" - conversation = [ - { - "role": "system", - "content": "请从question的A,B,C,D四个选项中选择正确的选项。例如,标准答案:A。", - }, - {"role": "user", "content": input_content}, - ] - - answer = sample["answer"] - output_content, avg_time = model.generate( - conversation, 500, topp_=1.0, topk_=1, temperature_=1.0 - ) - print("标准答案:", answer) - answers_list.append( - {"id": sample["id"], "output_content": output_content, "answer": answer} - ) + # Overall statistics across all subjects + overall_true_num = 0 + overall_all_num = 0 + subject_results = {} + + # Test each subject + for subject in ALL_SUBJECTS: + print("=" * 80) + print(f"Testing subject: {subject}") + print("=" * 80) + try: + # Load dataset for this subject + dataset = load_dataset(r"ceval/ceval-exam", name=subject) + samples = dataset["test"] + + answers_list = [] + for sample in samples: + input_content = f"'question':{sample['question']},'A': {sample['A']}, 'B':{sample['B']}, 'C': {sample['C']},'D': {sample['D']}。" + conversation = [ + { + "role": "system", + "content": "请从question的A,B,C,D四个选项中选择正确的选项。例如,标准答案:A。", + }, + {"role": "user", "content": input_content}, + ] + + answer = sample["answer"] + output_content, avg_time = model.generate( + conversation, 1000, topp_=0.1, topk_=0, temperature_=0.9 + ) + print("标准答案:", answer) + answers_list.append( + {"id": sample["id"], "output_content": output_content, "answer": answer} + ) + + # Calculate accuracy for this subject + true_num = 0 + all_num = 0 + for cont in answers_list: + id = cont["id"] + output = cont["output_content"] + answer = cont["answer"] + + all_num = all_num + 1 + position = 0 + ABCD = output[position : position + 2] + if answer in ABCD: + true_num = true_num + 1 + print(f"id {id} : ", "正确") + else: + print(f"id {id}: ", "错误") + + accuracy = true_num / all_num if all_num > 0 else 0.0 + print(f"\nSubject: {subject}") + print(f"成绩: {true_num}/{all_num} = {accuracy:.4f} ({accuracy*100:.2f}%)") + + # Store results + subject_results[subject] = { + "true_num": true_num, + "all_num": all_num, + "accuracy": accuracy + } + overall_true_num += true_num + overall_all_num += all_num + + except Exception as e: + print(f"Error testing subject {subject}: {e}") + subject_results[subject] = { + "true_num": 0, + "all_num": 0, + "accuracy": 0.0, + "error": str(e) + } + + # Destroy model instance model.destroy_model_instance() - print("-------------------------------------------------------------") - - true_num = 0 - all_num = 0 - for cont in answers_list: - id = cont["id"] - output = cont["output_content"] - answer = cont["answer"] - - all_num = all_num + 1 - position = 0 - ABCD = output[position : position + 2] - if answer in ABCD: - true_num = true_num + 1 - print(f"id {id} : ", "正确") + # Print summary + print("\n" + "=" * 80) + print("SUMMARY - All Subjects") + print("=" * 80) + + # Print results for each subject + for subject, result in subject_results.items(): + if "error" in result: + print(f"{subject:45s}: ERROR - {result['error']}") else: - print(f"id {id}: ", "错误") + print(f"{subject:45s}: {result['true_num']:4d}/{result['all_num']:4d} = {result['accuracy']*100:6.2f}%") - print(f"成绩: {true_num}/{all_num}", true_num / all_num) + # Print overall statistics + overall_accuracy = overall_true_num / overall_all_num if overall_all_num > 0 else 0.0 + print("=" * 80) + print(f"Overall: {overall_true_num}/{overall_all_num} = {overall_accuracy:.4f} ({overall_accuracy*100:.2f}%)") + print("=" * 80) if __name__ == "__main__":