44import torch
55from transformers import AutoProcessor
66
7+ from lmdeploy .utils import get_logger
78from lmdeploy .vl .model .base import VISION_MODELS , VisionModel
89
10+ logger = get_logger ('lmdeploy' )
11+
912
1013def check_transformers ():
1114 try :
@@ -29,23 +32,59 @@ def build_preprocessor(self):
2932 self .image_token_id = tokenizer .encode (self .image_token )[- 1 ]
3033 self .mm_processor_kwargs = None
3134
35+ def get_processor_args (self , mm_processor_kwargs : Optional [Dict [str , Any ]] = None ):
36+ min_pixels = self .processor .image_processor .size ['shortest_edge' ]
37+ max_pixels = self .processor .image_processor .size ['longest_edge' ]
38+
39+ if mm_processor_kwargs is None :
40+ return min_pixels , max_pixels
41+
42+ input_min_pixels = mm_processor_kwargs .get ('min_pixels' , None )
43+ input_max_pixels = mm_processor_kwargs .get ('max_pixels' , None )
44+
45+ # boundary check for min_pixels and max_pixels
46+ if input_min_pixels is None :
47+ if input_max_pixels is not None :
48+ # only max_pixels is given in the input
49+ if input_max_pixels < min_pixels :
50+ logger .warning (
51+ f'input max_pixels { input_max_pixels } < default min_pixels { min_pixels } , fall back to default.' )
52+ return min_pixels , max_pixels
53+ max_pixels = input_max_pixels
54+ else :
55+ if input_max_pixels is None :
56+ # only min_pixels is given in the input
57+ if input_min_pixels > max_pixels :
58+ logger .warning (
59+ f'input min_pixels { input_min_pixels } > default max_pixels { max_pixels } , fall back to default.' )
60+ return min_pixels , max_pixels
61+ else :
62+ if input_min_pixels > input_max_pixels :
63+ logger .warning (
64+ f'input min_pixels { input_min_pixels } > max_pixels { input_max_pixels } , fall back to default.' )
65+ return min_pixels , max_pixels
66+ max_pixels = input_max_pixels
67+ min_pixels = input_min_pixels
68+
69+ return min_pixels , max_pixels
70+
3271 def preprocess (self , messages : List [Dict ], mm_processor_kwargs : Optional [Dict [str , Any ]] = None ) -> List [Dict ]:
3372 """Refer to `super().preprocess()` for spec."""
34- if mm_processor_kwargs is None :
35- mm_processor_kwargs = {}
73+
74+ min_pixels , max_pixels = self . get_processor_args ( mm_processor_kwargs )
3675
3776 images = self .collect_images (messages )
38- optional_keys = {'resized_height' , 'resized_width' , 'min_pixels' , 'max_pixels' }
3977 outputs = []
4078 for image , params in images :
4179 image = image .convert ('RGB' )
4280
43- item = dict (type = 'image' , image = image )
44- item .update ({key : params [key ] for key in params .keys () if key in optional_keys })
4581 result = self .processor .image_processor (images = image ,
4682 videos = None ,
47- return_tensors = 'pt' ,
48- ** mm_processor_kwargs )
83+ size = {
84+ 'shortest_edge' : min_pixels ,
85+ 'longest_edge' : max_pixels
86+ },
87+ return_tensors = 'pt' )
4988 merge_length = self .processor .image_processor .merge_size ** 2
5089 image_tokens = result ['image_grid_thw' ].prod (dim = 1 ) // merge_length
5190 result .update (dict (image_size = image .size , image_tokens = image_tokens , image_token_id = self .image_token_id ))
0 commit comments