Skip to content

Commit 4679764

Browse files
authored
Fix compatibility issues with Transformers v4.46.2 (#3389)
1 parent c1851ee commit 4679764

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

intel_extension_for_pytorch/transformers/models/reference/models.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5768,6 +5768,81 @@ def prepare_inputs_for_generation_chatglm(
57685768
}
57695769

57705770

5771+
def prepare_inputs_for_generation_opt_mpt(
5772+
self,
5773+
input_ids,
5774+
past_key_values=None,
5775+
attention_mask=None,
5776+
inputs_embeds=None,
5777+
**kwargs,
5778+
):
5779+
if past_key_values is not None:
5780+
past_length = past_key_values[0][0].shape[2]
5781+
5782+
# Some generation methods already pass only the last input ID
5783+
if input_ids.shape[1] > past_length:
5784+
remove_prefix_length = past_length
5785+
else:
5786+
# Default to old behavior: keep only final ID
5787+
remove_prefix_length = input_ids.shape[1] - 1
5788+
5789+
input_ids = input_ids[:, remove_prefix_length:]
5790+
5791+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
5792+
if inputs_embeds is not None and past_key_values is None:
5793+
model_inputs = {"inputs_embeds": inputs_embeds}
5794+
else:
5795+
model_inputs = {"input_ids": input_ids}
5796+
5797+
model_inputs.update(
5798+
{
5799+
"past_key_values": past_key_values,
5800+
"use_cache": kwargs.get("use_cache"),
5801+
"attention_mask": attention_mask,
5802+
}
5803+
)
5804+
return model_inputs
5805+
5806+
5807+
def prepare_inputs_for_generation_t5(
5808+
self,
5809+
input_ids,
5810+
past_key_values=None,
5811+
attention_mask=None,
5812+
head_mask=None,
5813+
decoder_head_mask=None,
5814+
decoder_attention_mask=None,
5815+
cross_attn_head_mask=None,
5816+
use_cache=None,
5817+
encoder_outputs=None,
5818+
**kwargs,
5819+
):
5820+
# cut decoder_input_ids if past_key_values is used
5821+
if past_key_values is not None:
5822+
past_length = past_key_values[0][0].shape[2]
5823+
5824+
# Some generation methods already pass only the last input ID
5825+
if input_ids.shape[1] > past_length:
5826+
remove_prefix_length = past_length
5827+
else:
5828+
# Default to old behavior: keep only final ID
5829+
remove_prefix_length = input_ids.shape[1] - 1
5830+
5831+
input_ids = input_ids[:, remove_prefix_length:]
5832+
5833+
return {
5834+
"decoder_input_ids": input_ids,
5835+
"past_key_values": past_key_values,
5836+
"encoder_outputs": encoder_outputs,
5837+
"attention_mask": attention_mask,
5838+
"head_mask": head_mask,
5839+
"decoder_head_mask": decoder_head_mask,
5840+
"decoder_attention_mask": decoder_attention_mask,
5841+
"cross_attn_head_mask": cross_attn_head_mask,
5842+
"use_cache": use_cache,
5843+
}
5844+
5845+
57715846
def prepare_inputs_for_generation_llama(
57725847
self,
57735848
input_ids,

intel_extension_for_pytorch/transformers/optimize.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@ def model_convert_reference(_model):
219219
prepare_inputs_for_generation_gptneox,
220220
prepare_inputs_for_generation_git,
221221
prepare_inputs_for_generation_llava,
222+
prepare_inputs_for_generation_opt_mpt,
223+
prepare_inputs_for_generation_t5,
222224
detect_language,
223225
_postprocess_outputs_whisper,
224226
_prepare_encoder_decoder_kwargs_for_generation,
@@ -362,6 +364,11 @@ def model_convert_reference(_model):
362364
"forward",
363365
OPTForCausalLM_forward,
364366
)
367+
convert_function(
368+
_model,
369+
"prepare_inputs_for_generation",
370+
prepare_inputs_for_generation_opt_mpt,
371+
)
365372
elif (
366373
hasattr(_model, "__class__")
367374
and _model.__class__
@@ -439,6 +446,9 @@ def model_convert_reference(_model):
439446
"forward",
440447
T5DenseGatedActDense_forward,
441448
)
449+
convert_function(
450+
_model, "prepare_inputs_for_generation", prepare_inputs_for_generation_t5
451+
)
442452

443453
# checking if model has been wrapped by deepspeed (distributed or not)
444454
try:
@@ -742,6 +752,11 @@ def model_convert_reference(_model):
742752
_model.config,
743753
distributed=distributed,
744754
)
755+
convert_function(
756+
_model,
757+
"prepare_inputs_for_generation",
758+
prepare_inputs_for_generation_opt_mpt,
759+
)
745760
elif _model.config.architectures[0] == "MixtralForCausalLM":
746761
convert_function(_model, "forward", MixtralForCausalLM_forward)
747762
convert_function(_model.model, "forward", MixtralModel_forward)

0 commit comments

Comments
 (0)