Skip to content

Commit a34db18

Browse files
authored
Add mm_processor_args for Qwen3-VL (#4196)
* add mm processor args * move args passing pos * fix None case * remove deprecated env * add for /generate * fix arg no
1 parent 1426ea4 commit a34db18

File tree

7 files changed

+38
-13
lines changed

7 files changed

+38
-13
lines changed

lmdeploy/pytorch/envs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def _patched_get_env(
125125
# dlblas
126126
# we don't need to read this, it would be passed to ray workers
127127
# If Ray is launched from outside, it may fail to access the environment variables.
128-
os.getenv('DEEPEP_MAX_BATCH_SIZE', None)
129128
os.getenv('DEEPEP_MAX_TOKENS_PER_RANK', None)
130129
os.getenv('DEEPEP_ENABLE_MNNVL', None)
131130
os.getenv('DEEPEP_MODE', 'auto')

lmdeploy/serve/async_engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,7 @@ async def generate(
769769
input_ids: Optional[List] = None,
770770
enable_thinking: Optional[bool] = None,
771771
add_vision_id: Optional[bool] = False,
772+
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
772773
**kwargs):
773774
"""Generate responses.
774775
@@ -821,6 +822,7 @@ async def generate(
821822
reasoning_effort=reasoning_effort,
822823
enable_thinking=enable_thinking,
823824
add_vision_id=add_vision_id,
825+
mm_processor_kwargs=mm_processor_kwargs,
824826
**kwargs)
825827
prompt = prompt_input['prompt']
826828
input_ids = prompt_input['input_ids']

lmdeploy/serve/openai/api_server.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque
470470
adapter_name=adapter_name,
471471
enable_thinking=request.enable_thinking,
472472
add_vision_id=request.add_vision_id,
473-
)
473+
mm_processor_kwargs=request.mm_processor_kwargs)
474474

475475
def create_stream_response_json(index: int,
476476
delta_message: DeltaMessage,
@@ -911,7 +911,6 @@ async def _inner_call(i, generator):
911911

912912
@router.post('/generate', dependencies=[Depends(check_api_key)])
913913
async def generate(request: GenerateReqInput, raw_request: Request = None):
914-
915914
if request.session_id == -1:
916915
VariableInterface.session_id += 1
917916
request.session_id = VariableInterface.session_id
@@ -965,7 +964,7 @@ async def generate(request: GenerateReqInput, raw_request: Request = None):
965964
sequence_start=True,
966965
sequence_end=True,
967966
do_preprocess=False,
968-
)
967+
mm_processor_kwargs=request.mm_processor_kwargs)
969968

970969
def create_generate_response_json(res, text, output_ids, logprobs, finish_reason, routed_experts=None):
971970
# only output router experts in last chunk

lmdeploy/serve/openai/protocol.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,11 @@ class ChatCompletionRequest(BaseModel):
153153
add_vision_id: Optional[bool] = False
154154
return_token_ids: Optional[bool] = False
155155
include_stop_str_in_output: Optional[bool] = False
156+
# kwargs for hf processor
157+
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
158+
default=None,
159+
description=('Additional kwargs to pass to the HF processor'),
160+
)
156161

157162

158163
class FunctionCall(BaseModel):

lmdeploy/serve/vl_async_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22

33
import asyncio
4-
from typing import Dict, List, Literal, Optional, Tuple, Union
4+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
55

66
import PIL
77

@@ -57,6 +57,7 @@ async def _get_prompt_input(self,
5757
tools: Optional[List[object]] = None,
5858
enable_thinking: Optional[bool] = None,
5959
add_vision_id: Optional[bool] = False,
60+
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
6061
**kwargs):
6162
"""Process messages and return the required data for the inference
6263
engines.
@@ -91,7 +92,7 @@ async def _get_prompt_input(self,
9192

9293
chat_template = self.chat_template if do_preprocess else BaseChatTemplate()
9394
messages = await self.async_convert_to_pil_images(messages)
94-
results = await self.vl_encoder.preprocess(messages)
95+
results = await self.vl_encoder.preprocess(messages, mm_processor_kwargs)
9596
if self.backend == 'turbomind':
9697
# for tm engine, this module perform vision embedding after image
9798
# preprocessing. It utilizes the hf model's vision embeddings

lmdeploy/vl/engine.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22

33
import asyncio
4+
import inspect
45
from concurrent.futures import ThreadPoolExecutor
5-
from typing import Dict, List, Optional, Union
6+
from typing import Any, Dict, List, Optional, Union
67

78
import torch
89

@@ -23,6 +24,11 @@ def _raise_exception_on_finish(task: asyncio.Task) -> None:
2324
raise e
2425

2526

27+
def _accepts_arg(func, arg_name: str) -> bool:
28+
"""Check if a function accepts a specific keyword argument."""
29+
return arg_name in inspect.signature(func).parameters
30+
31+
2632
class ImageEncoder:
2733
"""Image encoder."""
2834

@@ -41,9 +47,15 @@ def __init__(
4147
self.executor = ThreadPoolExecutor(max_workers=1)
4248
torch.cuda.empty_cache()
4349

44-
async def preprocess(self, messages: List[Dict]) -> List[Dict]:
50+
async def preprocess(self,
51+
messages: List[Dict],
52+
mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> List[Dict]:
4553
"""Preprocess multimodal data in the messages."""
46-
future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.preprocess, messages)
54+
if _accepts_arg(self.model.preprocess, 'mm_processor_kwargs'):
55+
future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.preprocess, messages,
56+
mm_processor_kwargs)
57+
else:
58+
future = asyncio.get_event_loop().run_in_executor(self.executor, self.model.preprocess, messages)
4759
future.add_done_callback(_raise_exception_on_finish)
4860
outputs = await future
4961
return outputs

lmdeploy/vl/model/qwen3.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from typing import Dict, List, Optional
2+
from typing import Any, Dict, List, Optional
33

44
import torch
5+
from transformers import AutoProcessor
56

67
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel
78

@@ -22,14 +23,17 @@ class Qwen3VLModel(VisionModel):
2223

2324
def build_preprocessor(self):
2425
check_transformers()
25-
from transformers import AutoProcessor
2626
self.processor = AutoProcessor.from_pretrained(self.model_path)
2727
tokenizer = self.processor.tokenizer
2828
self.image_token = self.processor.image_token
2929
self.image_token_id = tokenizer.encode(self.image_token)[-1]
30+
self.mm_processor_kwargs = None
3031

31-
def preprocess(self, messages: List[Dict]) -> List[Dict]:
32+
def preprocess(self, messages: List[Dict], mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> List[Dict]:
3233
"""Refer to `super().preprocess()` for spec."""
34+
if mm_processor_kwargs is None:
35+
mm_processor_kwargs = {}
36+
3337
images = self.collect_images(messages)
3438
optional_keys = {'resized_height', 'resized_width', 'min_pixels', 'max_pixels'}
3539
outputs = []
@@ -38,7 +42,10 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]:
3842

3943
item = dict(type='image', image=image)
4044
item.update({key: params[key] for key in params.keys() if key in optional_keys})
41-
result = self.processor.image_processor(images=image, videos=None, return_tensors='pt')
45+
result = self.processor.image_processor(images=image,
46+
videos=None,
47+
return_tensors='pt',
48+
**mm_processor_kwargs)
4249
merge_length = self.processor.image_processor.merge_size**2
4350
image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length
4451
result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id))

0 commit comments

Comments
 (0)