Skip to content

Commit fbd0c03

Browse files
authored
Enable optimized Maira2 and upgrade transformers to v4.46.2 (#3376)
1 parent b376a2c commit fbd0c03

File tree

29 files changed

+1375
-23
lines changed

29 files changed

+1375
-23
lines changed

docs/tutorials/features/fast_bert.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Currently `ipex.fast_bert` API is only well optimized for training. For inferenc
99

1010
### Prerequisite
1111

12-
- Transformers 4.6.0 ~ 4.45.0
12+
- Transformers 4.6.0 ~ 4.46.2
1313

1414
### Usage Example
1515

examples/cpu/features/fast_bert/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Currently `ipex.fast_bert` API is only well optimized for training. For inference, it ensures functionality, while to get peak perf, please use `ipex.optimize` API + torchscript.
66

77
# Prerequisite:
8-
Transformers 4.6.0 ~ 4.45.0
8+
Transformers 4.6.0 ~ 4.46.2
99

1010
# Usage Example:
1111
Training:

examples/cpu/llm/fine-tuning/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ black[jupyter]
66
datasets
77
fire
88
peft
9-
transformers==4.45.0
9+
transformers==4.46.2
1010
gradio
1111
sentencepiece

examples/cpu/llm/inference/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
|Phi| microsoft/Phi-3-medium-4k-instruct | 🟩 | 🟩 | 🟩 | 🟩 | 🟩 |
4545
|Phi| microsoft/Phi-3-medium-128k-instruct | 🟩 | 🟩 | 🟩 | 🟩 | 🟩 |
4646
|Whisper| openai/whisper-large-v2 | 🟩 | 🟩 | 🟩 | 🟩 | |
47+
|Maira| microsoft/maira-2 | 🟩 | 🟩 | | 🟩 | |
4748

4849
## 1.2 Verified for distributed inference mode via DeepSpeed
4950

examples/cpu/llm/inference/distributed/run_generation_with_deepspeed.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,8 @@ def get_checkpoint_files(model_name_or_path):
298298
model_type = next((x for x in MODEL_CLASSES.keys() if x in model_name.lower()), "auto")
299299
if model_type == "llama" and args.vision_text_model:
300300
model_type = "mllama"
301+
if model_type == "maira-2":
302+
model_type = "maira2"
301303
model_class = MODEL_CLASSES[model_type]
302304
tokenizer = model_class[1].from_pretrained(model_name, trust_remote_code=True)
303305

@@ -350,6 +352,8 @@ def get_checkpoint_files(model_name_or_path):
350352

351353
if not hasattr(config, "lm_head_generation"):
352354
config.lm_head_generation = True
355+
if model_type == "maira2" and not hasattr(config.text_config, "lm_head_generation"):
356+
config.text_config.lm_head_generation = True
353357
num_beams = 1 if args.greedy else 4
354358
if model_type in ["git", "llava"]:
355359
config.batch_size = int(args.batch_size) * num_beams
@@ -389,7 +393,13 @@ def get_checkpoint_files(model_name_or_path):
389393
model = model_class[0].from_pretrained(
390394
model_name,
391395
config=config,
392-
low_cpu_mem_usage=True,
396+
low_cpu_mem_usage=True if model_type != "maira2" else False,
397+
torch_dtype=load_dtype,
398+
trust_remote_code=True,
399+
)
400+
elif model_type == "maira2":
401+
model = model_class[0].from_pretrained(
402+
model_name,
393403
torch_dtype=load_dtype,
394404
trust_remote_code=True,
395405
)
@@ -653,6 +663,22 @@ def load_image(image_file):
653663
input_size = inputs["input_ids"].size(dim=1)
654664
print("---- Prompt size:", input_size)
655665
inputs = [prompt] * args.batch_size
666+
elif model_type == "maira2":
667+
from PIL import Image
668+
import requests
669+
670+
def download_and_open(url: str) -> Image.Image:
671+
response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True)
672+
return Image.open(response.raw)
673+
674+
prompt = args.prompt
675+
sample = download_and_open(args.image_url)
676+
process_input_func = (
677+
tokenizer.process_reporting_input
678+
if hasattr(tokenizer, "process_reporting_input")
679+
else tokenizer.format_and_preprocess_reporting_input
680+
)
681+
inputs = [prompt] * args.batch_size
656682
else:
657683
# input tokens
658684
input_sentences = []
@@ -719,6 +745,19 @@ def generate():
719745
raw_image = [raw_image] * args.batch_size
720746
input_tokens = tokenizer(raw_image, prompt, return_tensors="pt")
721747
input_ids = input_tokens["input_ids"]
748+
elif model_type == "maira2":
749+
input_tokens = process_input_func(
750+
current_frontal=sample,
751+
current_lateral=None,
752+
prior_frontal=None,
753+
indication=None,
754+
technique=None,
755+
comparison=None,
756+
prior_report=None,
757+
return_tensors="pt",
758+
get_grounding=False,
759+
)
760+
input_ids = input_tokens["input_ids"]
722761
else:
723762
input_tokens = tokenizer.batch_encode_plus(
724763
inputs, return_token_type_ids=False, return_tensors="pt"
@@ -743,7 +782,11 @@ def generate():
743782
for i, o in zip(input_tokens_lengths, output_tokens_lengths)
744783
]
745784
gen_text = tokenizer.batch_decode(
746-
gen_ids[:, input_ids.shape[1] :] if model_type == "llava" else gen_ids,
785+
(
786+
gen_ids[:, input_ids.shape[1] :]
787+
if model_type in ["llava", "maira2"]
788+
else gen_ids
789+
),
747790
skip_special_tokens=True,
748791
)
749792

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
transformers==4.45.0
1+
transformers==4.46.2
22
neural-compressor==2.4.1

examples/cpu/llm/inference/run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
579579
"phi-3": ("/phi-3_local_shard"),
580580
"phi": ("/phi_local_shard"),
581581
"whisper": ("/whisper_local_shard"),
582+
"maira": ("/maira2_local_shard"),
582583
}
583584
model_type = next(
584585
(

examples/cpu/llm/inference/single_instance/run_generation.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@
129129
)
130130
if model_type == "llama" and args.vision_text_model:
131131
model_type = "mllama"
132+
if model_type == "maira-2":
133+
model_type = "maira2"
132134
model_class = MODEL_CLASSES[model_type]
133135
if args.config_file is None:
134136
if model_type == "chatglm":
@@ -161,13 +163,15 @@
161163

162164
if not hasattr(config, "lm_head_generation"):
163165
config.lm_head_generation = True
166+
if model_type == "maira2" and not hasattr(config.text_config, "lm_head_generation"):
167+
config.text_config.lm_head_generation = True
164168

165169
if model_type != "llava":
166170
model = model_class[0].from_pretrained(
167171
args.model_id,
168172
torch_dtype=amp_dtype,
169173
config=config,
170-
low_cpu_mem_usage=True,
174+
low_cpu_mem_usage=True if model_type != "maira2" else False,
171175
trust_remote_code=True,
172176
)
173177
tokenizer = model_class[1].from_pretrained(args.model_id, trust_remote_code=True)
@@ -228,6 +232,14 @@ def load_image(image_file):
228232
raw_image = Image.open(image_file)
229233
return raw_image
230234

235+
elif re.search("maira2", model.config.architectures[0], re.IGNORECASE):
236+
from PIL import Image
237+
import requests
238+
239+
def download_and_open(url: str) -> Image.Image:
240+
response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True)
241+
return Image.open(response.raw)
242+
231243

232244
if re.search("llava", model.config.architectures[0], re.IGNORECASE):
233245
model_name = get_model_name_from_path(args.model_id)
@@ -305,6 +317,14 @@ def trace_handler(prof):
305317
elif model_type == "whisper":
306318
prompt = sample[0]
307319
generate_kwargs.pop("min_new_tokens", None)
320+
elif model_type == "maira2":
321+
prompt = args.prompt
322+
sample = download_and_open(args.image_url)
323+
process_input_func = (
324+
tokenizer.process_reporting_input
325+
if hasattr(tokenizer, "process_reporting_input")
326+
else tokenizer.format_and_preprocess_reporting_input
327+
)
308328
else:
309329
# input prompt
310330
current_path = pathlib.Path(__file__).parent.resolve()
@@ -375,12 +395,30 @@ def trace_handler(prof):
375395
inputs = tokenizer(raw_image, prompt, return_tensors="pt")
376396
input_ids = inputs["input_ids"]
377397
output = model.generate(**inputs, **generate_kwargs)
398+
elif model_type == "maira2":
399+
processed_inputs = process_input_func(
400+
current_frontal=sample,
401+
current_lateral=None,
402+
prior_frontal=None,
403+
indication=None,
404+
technique=None,
405+
comparison=None,
406+
prior_report=None,
407+
return_tensors="pt",
408+
get_grounding=False,
409+
)
410+
input_ids = processed_inputs["input_ids"]
411+
output = model.generate(**processed_inputs, **generate_kwargs)
378412
else:
379413
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
380414
output = model.generate(input_ids, **generate_kwargs)
381415
gen_ids = output[0] if args.token_latency else output
382416
gen_text = tokenizer.batch_decode(
383-
gen_ids[:, input_ids.shape[1] :] if model_type == "llava" else gen_ids,
417+
(
418+
gen_ids[:, input_ids.shape[1] :]
419+
if model_type in ["llava", "maira2"]
420+
else gen_ids
421+
),
384422
skip_special_tokens=True,
385423
)
386424
toc = time.time()
@@ -441,6 +479,19 @@ def trace_handler(prof):
441479
raw_image = [load_image(args.image_url)] * args.batch_size
442480
inputs = tokenizer(raw_image, prompt, return_tensors="pt")
443481
output = model.generate(**inputs, **generate_kwargs)
482+
elif model_type == "maira2":
483+
processed_inputs = process_input_func(
484+
current_frontal=sample,
485+
current_lateral=None,
486+
prior_frontal=None,
487+
indication=None,
488+
technique=None,
489+
comparison=None,
490+
prior_report=None,
491+
return_tensors="pt",
492+
get_grounding=False,
493+
)
494+
output = model.generate(**processed_inputs, **generate_kwargs)
444495
else:
445496
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
446497
output = model.generate(input_ids, **generate_kwargs)

0 commit comments

Comments
 (0)