Skip to content

Commit a11f736

Browse files
authored
support chat_template_kwargs in v1/chat/completions (#4201)
* support chat_template_kwargs in v1/chat/completions * fix * fix * fix * fix ut
1 parent 32f1f0c commit a11f736

File tree

11 files changed

+74
-61
lines changed

11 files changed

+74
-61
lines changed

lmdeploy/model.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -768,18 +768,16 @@ def messages2prompt(self, messages, sequence_start=True, **kwargs):
768768
'Each message should be a dict with "role" and "content" keys.'
769769

770770
if 'enable_thinking' in kwargs and kwargs['enable_thinking'] is None:
771-
# Workaround for internlm/Intern-S1: the chat template expects a <think> tag appended,
772-
# but when enable_thinking=None is specified, the <think> tag is omitted.
771+
# Workaround for internlm/Intern-S1: when enable_thinking=None passed apply_chat_template,
772+
# the <think> tag is not generated.
773773
kwargs.pop('enable_thinking')
774-
if 'reasoning_effort' in kwargs and kwargs.get('reasoning_effort', None) is None:
774+
if 'reasoning_effort' in kwargs and kwargs['reasoning_effort'] is None:
775775
kwargs.pop('reasoning_effort')
776-
add_vision_id = kwargs.pop('add_vision_id', False)
777776
add_generation_prompt = messages[-1]['role'] != 'assistant'
778777
if sequence_start:
779778
prompt = self.tokenizer.apply_chat_template(messages,
780779
tokenize=False,
781780
add_generation_prompt=add_generation_prompt,
782-
add_vision_id=add_vision_id,
783781
**kwargs)
784782
else:
785783
# Use a sentinel position to avoid the influence of default system role in the tokenizer's chat template
@@ -790,7 +788,6 @@ def messages2prompt(self, messages, sequence_start=True, **kwargs):
790788
prompt = self.tokenizer.apply_chat_template(sentinel_messages + messages,
791789
tokenize=False,
792790
add_generation_prompt=add_generation_prompt,
793-
add_vision_id=add_vision_id,
794791
**kwargs)
795792
# remove the sentinel part
796793
prompt = prompt[len(sentinel_prompt):]

lmdeploy/serve/async_engine.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,15 @@ def __call__(self,
196196
gen_config: Optional[GenerationConfig] = None,
197197
stream_response: bool = True,
198198
do_preprocess: bool = True,
199-
adapter_name: str = None) -> Union[Response, Iterator[Response]]:
199+
adapter_name: str = None,
200+
**kwargs) -> Union[Response, Iterator[Response]]:
200201
self._engine.chat(prompt,
201202
gen_config=gen_config or self._gen_config,
202203
stream_response=stream_response,
203204
do_preprocess=do_preprocess,
204205
session=self,
205-
adapter_name=adapter_name)
206+
adapter_name=adapter_name,
207+
**kwargs)
206208
if stream_response:
207209
return self.generator
208210
else:
@@ -691,7 +693,7 @@ async def _get_prompt_input(self,
691693
adapter_name: str,
692694
tools: Optional[List[object]] = None,
693695
reasoning_effort: Optional[Literal['low', 'medium', 'high']] = None,
694-
enable_thinking: Optional[bool] = None,
696+
chat_template_kwargs: Optional[Dict] = None,
695697
**kwargs):
696698
# Change multimodal data to openai text messages, i.e.,
697699
# [{'role': 'user', 'content': [{'type': 'text', 'text': 'hi'}]}] ->
@@ -706,12 +708,12 @@ async def _get_prompt_input(self,
706708
chat_template = MODELS.module_dict[adapter_name]()
707709
else:
708710
chat_template = BaseChatTemplate()
711+
chat_template_kwargs = chat_template_kwargs or {}
709712
prompt = chat_template.messages2prompt(prompt,
710713
sequence_start,
711714
tools=tools,
712-
enable_thinking=enable_thinking,
713715
reasoning_effort=reasoning_effort,
714-
**kwargs)
716+
**chat_template_kwargs)
715717
if prompt is None:
716718
raise ValueError(
717719
f'You are using base template to handle chat task. Please specify a `--chat-template` name chosen from `lmdeploy list` if you want to use OpenAI messages input.' # noqa
@@ -768,7 +770,7 @@ async def generate(
768770
rewind_stop_tokens: bool = False,
769771
input_ids: Optional[List] = None,
770772
enable_thinking: Optional[bool] = None,
771-
add_vision_id: Optional[bool] = False,
773+
chat_template_kwargs: Optional[Dict] = None,
772774
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
773775
**kwargs):
774776
"""Generate responses.
@@ -811,6 +813,14 @@ async def generate(
811813
if gen_config.n > 1:
812814
logger.warning(f'n({gen_config.n}) > 1 hasn\'t been supported yet. Fallback to 1')
813815
gen_config.n = 1
816+
chat_template_kwargs = chat_template_kwargs or {}
817+
if enable_thinking is not None:
818+
logger.warning('enable_thinking is deprecated, use chat_template_kwargs["enable_thinking"] instead')
819+
if chat_template_kwargs.get('enable_thinking') is None:
820+
chat_template_kwargs['enable_thinking'] = enable_thinking
821+
else:
822+
logger.warning('chat_template_kwargs["enable_thinking"] is already set, '
823+
'the value will not be overwritten by enable_thinking')
814824
if messages:
815825
prompt = messages
816826
self.request_logger.log_prompt(session_id=session_id, prompt=prompt)
@@ -820,9 +830,8 @@ async def generate(
820830
adapter_name,
821831
tools=tools,
822832
reasoning_effort=reasoning_effort,
823-
enable_thinking=enable_thinking,
824-
add_vision_id=add_vision_id,
825833
mm_processor_kwargs=mm_processor_kwargs,
834+
chat_template_kwargs=chat_template_kwargs,
826835
**kwargs)
827836
prompt = prompt_input['prompt']
828837
input_ids = prompt_input['input_ids']

lmdeploy/serve/openai/api_server.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,15 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque
457457
tools = [item.function.model_dump() for item in request.tools]
458458
# text completion for string input
459459
do_preprocess = False if isinstance(request.messages, str) else request.do_preprocess
460+
chat_template_kwargs = request.chat_template_kwargs or {}
461+
if request.enable_thinking is not None:
462+
logger.warning('`enable_thinking` will be deprecated in the future, '
463+
'please use `chat_template_kwargs` instead.')
464+
if chat_template_kwargs.get('enable_thinking') is None:
465+
chat_template_kwargs['enable_thinking'] = request.enable_thinking
466+
else:
467+
logger.warning('`enable_thinking` in `chat_template_kwargs` will override the value in request.')
468+
enable_thinking = chat_template_kwargs.get('enable_thinking', None)
460469
result_generator = VariableInterface.async_engine.generate(
461470
request.messages,
462471
request.session_id,
@@ -468,8 +477,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque
468477
sequence_end=True,
469478
do_preprocess=do_preprocess,
470479
adapter_name=adapter_name,
471-
enable_thinking=request.enable_thinking,
472-
add_vision_id=request.add_vision_id,
480+
chat_template_kwargs=chat_template_kwargs or None,
473481
mm_processor_kwargs=request.mm_processor_kwargs)
474482

475483
def create_stream_response_json(index: int,
@@ -543,8 +551,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
543551
elif (request.tool_choice != 'none' and request.tools is not None
544552
and VariableInterface.tool_parser is None):
545553
logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.')
546-
547-
if VariableInterface.reasoning_parser is not None and request.enable_thinking is not False:
554+
if VariableInterface.reasoning_parser is not None and enable_thinking is not False:
548555
reasoning_delta = VariableInterface.reasoning_parser.extract_reasoning_content_streaming(
549556
previous_text=previous_text,
550557
current_text=current_text,
@@ -617,7 +624,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
617624
elif request.tool_choice != 'none' and request.tools is not None and VariableInterface.tool_parser is None:
618625
logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.')
619626

620-
if VariableInterface.reasoning_parser is not None and request.enable_thinking is not False:
627+
if VariableInterface.reasoning_parser is not None and enable_thinking is not False:
621628
reasoning_content, text = VariableInterface.reasoning_parser.extract_reasoning_content(text, request)
622629

623630
message = ChatMessage(role='assistant',

lmdeploy/serve/openai/protocol.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,14 @@ class ChatCompletionRequest(BaseModel):
149149
seed: Optional[int] = None
150150
min_new_tokens: Optional[int] = Field(default=None, examples=[None])
151151
min_p: float = 0.0
152-
enable_thinking: Optional[bool] = None
153-
add_vision_id: Optional[bool] = False
152+
enable_thinking: Optional[bool] = None # will be deprecated in the future
154153
return_token_ids: Optional[bool] = False
155154
include_stop_str_in_output: Optional[bool] = False
155+
chat_template_kwargs: dict[str, Any] | None = Field(
156+
default=None,
157+
description=('Additional keyword args to pass to the template renderer. '
158+
'Will be accessible by the chat template.'),
159+
)
156160
# kwargs for hf processor
157161
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
158162
default=None,

lmdeploy/serve/vl_async_engine.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ async def _get_prompt_input(self,
5555
sequence_start: bool,
5656
adapter_name: str,
5757
tools: Optional[List[object]] = None,
58-
enable_thinking: Optional[bool] = None,
59-
add_vision_id: Optional[bool] = False,
58+
chat_template_kwargs: Optional[Dict] = None,
6059
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
6160
**kwargs):
6261
"""Process messages and return the required data for the inference
@@ -71,8 +70,7 @@ async def _get_prompt_input(self,
7170
sequence_start,
7271
adapter_name,
7372
tools=tools,
74-
enable_thinking=enable_thinking,
75-
add_vision_id=add_vision_id,
73+
chat_template_kwargs=chat_template_kwargs,
7674
**kwargs)
7775
elif isinstance(messages, List):
7876
has_multimodal_input = any(
@@ -84,8 +82,7 @@ async def _get_prompt_input(self,
8482
sequence_start,
8583
adapter_name,
8684
tools,
87-
enable_thinking=enable_thinking,
88-
add_vision_id=add_vision_id,
85+
chat_template_kwargs=chat_template_kwargs,
8986
**kwargs)
9087
else:
9188
raise RuntimeError(f'unsupported messages {messages}')
@@ -105,8 +102,7 @@ async def _get_prompt_input(self,
105102
self.tokenizer,
106103
sequence_start,
107104
tools=tools,
108-
enable_thinking=enable_thinking,
109-
add_vision_id=add_vision_id)
105+
chat_template_kwargs=chat_template_kwargs)
110106
elif self.backend == 'pytorch':
111107
# for pt engine, this module only conduct the image preprocessing
112108
# It leaves the vision embedding to the pt engine
@@ -115,8 +111,7 @@ async def _get_prompt_input(self,
115111
self.tokenizer,
116112
sequence_start,
117113
tools=tools,
118-
enable_thinking=enable_thinking,
119-
add_vision_id=add_vision_id)
114+
chat_template_kwargs=chat_template_kwargs)
120115
return results
121116

122117
@classmethod

lmdeploy/vl/engine.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ async def wrap_for_pytorch(
8080
tokenizer,
8181
sequence_start,
8282
tools: Optional[List[object]] = None,
83-
enable_thinking: Optional[bool] = None,
84-
add_vision_id: Optional[bool] = False,
83+
chat_template_kwargs: Optional[Dict] = None,
8584
) -> List[Dict]:
8685
"""
8786
Args:
@@ -106,8 +105,7 @@ async def wrap_for_pytorch(
106105
tokenizer,
107106
sequence_start,
108107
tools=tools,
109-
enable_thinking=enable_thinking,
110-
add_vision_id=add_vision_id)
108+
chat_template_kwargs=chat_template_kwargs)
111109
else:
112110
result = self.model.to_pytorch_with_input_ids(messages)
113111
# clear data
@@ -123,8 +121,7 @@ async def wrap_for_turbomind(
123121
tokenizer,
124122
sequence_start,
125123
tools: Optional[List[object]] = None,
126-
enable_thinking: Optional[bool] = None,
127-
add_vision_id: Optional[bool] = False,
124+
chat_template_kwargs: Optional[Dict] = None,
128125
) -> Dict:
129126
"""
130127
Args:
@@ -145,8 +142,7 @@ async def wrap_for_turbomind(
145142
tokenizer,
146143
sequence_start,
147144
tools=tools,
148-
enable_thinking=enable_thinking,
149-
add_vision_id=add_vision_id)
145+
chat_template_kwargs=chat_template_kwargs)
150146
# clear data
151147
for i, message in enumerate(messages):
152148
if isinstance(message['content'], List):

lmdeploy/vl/model/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
132132
if self.backend == 'turbomind':
133133
raise NotImplementedError()
134134

135-
def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
135+
def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, chat_template_kwargs=None, **kwargs):
136136
"""Pack the preprocessing results in a format compatible with what is
137137
required by pytorch engine. ONLY implement it when the backend is
138138
pytorch engine.
@@ -142,11 +142,13 @@ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwarg
142142
chat_template: the chat template defined in `lmdeploy/model.py`
143143
tokenzer: the tokenizer model
144144
sequence_start: starting flag of a sequence
145+
chat_template_kwargs: additional arguments for chat template
146+
processing, such as `add_vision_id` and `enable_thinking`
145147
"""
146148
if self.backend == 'pytorch':
147149
raise NotImplementedError()
148150

149-
def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
151+
def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, chat_template_kwargs=None, **kwargs):
150152
"""Pack the forwarding results in a format compatible with what is
151153
required by turbomind engine. ONLY implement it when the backend is
152154
turbomind engine.
@@ -156,6 +158,8 @@ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwa
156158
chat_template: the chat template defined in `lmdeploy/model.py`
157159
tokenzer: the tokenizer model
158160
sequence_start: starting flag of a sequence
161+
chat_template_kwargs: additional arguments for chat template
162+
processing, such as `add_vision_id` and `enable_thinking`
159163
"""
160164
if self.backend == 'turbomind':
161165
raise NotImplementedError()

0 commit comments

Comments
 (0)