|
129 | 129 | ) |
130 | 130 | if model_type == "llama" and args.vision_text_model: |
131 | 131 | model_type = "mllama" |
| 132 | +if model_type == "maira-2": |
| 133 | + model_type = "maira2" |
132 | 134 | model_class = MODEL_CLASSES[model_type] |
133 | 135 | if args.config_file is None: |
134 | 136 | if model_type == "chatglm": |
|
161 | 163 |
|
162 | 164 | if not hasattr(config, "lm_head_generation"): |
163 | 165 | config.lm_head_generation = True |
| 166 | +if model_type == "maira2" and not hasattr(config.text_config, "lm_head_generation"): |
| 167 | + config.text_config.lm_head_generation = True |
164 | 168 |
|
165 | 169 | if model_type != "llava": |
166 | 170 | model = model_class[0].from_pretrained( |
167 | 171 | args.model_id, |
168 | 172 | torch_dtype=amp_dtype, |
169 | 173 | config=config, |
170 | | - low_cpu_mem_usage=True, |
| 174 | + low_cpu_mem_usage=True if model_type != "maira2" else False, |
171 | 175 | trust_remote_code=True, |
172 | 176 | ) |
173 | 177 | tokenizer = model_class[1].from_pretrained(args.model_id, trust_remote_code=True) |
@@ -228,6 +232,14 @@ def load_image(image_file): |
228 | 232 | raw_image = Image.open(image_file) |
229 | 233 | return raw_image |
230 | 234 |
|
| 235 | +elif re.search("maira2", model.config.architectures[0], re.IGNORECASE): |
| 236 | + from PIL import Image |
| 237 | + import requests |
| 238 | + |
| 239 | + def download_and_open(url: str) -> Image.Image: |
| 240 | + response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True) |
| 241 | + return Image.open(response.raw) |
| 242 | + |
231 | 243 |
|
232 | 244 | if re.search("llava", model.config.architectures[0], re.IGNORECASE): |
233 | 245 | model_name = get_model_name_from_path(args.model_id) |
@@ -305,6 +317,14 @@ def trace_handler(prof): |
305 | 317 | elif model_type == "whisper": |
306 | 318 | prompt = sample[0] |
307 | 319 | generate_kwargs.pop("min_new_tokens", None) |
| 320 | + elif model_type == "maira2": |
| 321 | + prompt = args.prompt |
| 322 | + sample = download_and_open(args.image_url) |
| 323 | + process_input_func = ( |
| 324 | + tokenizer.process_reporting_input |
| 325 | + if hasattr(tokenizer, "process_reporting_input") |
| 326 | + else tokenizer.format_and_preprocess_reporting_input |
| 327 | + ) |
308 | 328 | else: |
309 | 329 | # input prompt |
310 | 330 | current_path = pathlib.Path(__file__).parent.resolve() |
@@ -375,12 +395,30 @@ def trace_handler(prof): |
375 | 395 | inputs = tokenizer(raw_image, prompt, return_tensors="pt") |
376 | 396 | input_ids = inputs["input_ids"] |
377 | 397 | output = model.generate(**inputs, **generate_kwargs) |
| 398 | + elif model_type == "maira2": |
| 399 | + processed_inputs = process_input_func( |
| 400 | + current_frontal=sample, |
| 401 | + current_lateral=None, |
| 402 | + prior_frontal=None, |
| 403 | + indication=None, |
| 404 | + technique=None, |
| 405 | + comparison=None, |
| 406 | + prior_report=None, |
| 407 | + return_tensors="pt", |
| 408 | + get_grounding=False, |
| 409 | + ) |
| 410 | + input_ids = processed_inputs["input_ids"] |
| 411 | + output = model.generate(**processed_inputs, **generate_kwargs) |
378 | 412 | else: |
379 | 413 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids |
380 | 414 | output = model.generate(input_ids, **generate_kwargs) |
381 | 415 | gen_ids = output[0] if args.token_latency else output |
382 | 416 | gen_text = tokenizer.batch_decode( |
383 | | - gen_ids[:, input_ids.shape[1] :] if model_type == "llava" else gen_ids, |
| 417 | + ( |
| 418 | + gen_ids[:, input_ids.shape[1] :] |
| 419 | + if model_type in ["llava", "maira2"] |
| 420 | + else gen_ids |
| 421 | + ), |
384 | 422 | skip_special_tokens=True, |
385 | 423 | ) |
386 | 424 | toc = time.time() |
@@ -441,6 +479,19 @@ def trace_handler(prof): |
441 | 479 | raw_image = [load_image(args.image_url)] * args.batch_size |
442 | 480 | inputs = tokenizer(raw_image, prompt, return_tensors="pt") |
443 | 481 | output = model.generate(**inputs, **generate_kwargs) |
| 482 | + elif model_type == "maira2": |
| 483 | + processed_inputs = process_input_func( |
| 484 | + current_frontal=sample, |
| 485 | + current_lateral=None, |
| 486 | + prior_frontal=None, |
| 487 | + indication=None, |
| 488 | + technique=None, |
| 489 | + comparison=None, |
| 490 | + prior_report=None, |
| 491 | + return_tensors="pt", |
| 492 | + get_grounding=False, |
| 493 | + ) |
| 494 | + output = model.generate(**processed_inputs, **generate_kwargs) |
444 | 495 | else: |
445 | 496 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids |
446 | 497 | output = model.generate(input_ids, **generate_kwargs) |
|
0 commit comments