@@ -5768,6 +5768,81 @@ def prepare_inputs_for_generation_chatglm(
57685768 }
57695769
57705770
5771+ def prepare_inputs_for_generation_opt_mpt (
5772+ self ,
5773+ input_ids ,
5774+ past_key_values = None ,
5775+ attention_mask = None ,
5776+ inputs_embeds = None ,
5777+ ** kwargs ,
5778+ ):
5779+ if past_key_values is not None :
5780+ past_length = past_key_values [0 ][0 ].shape [2 ]
5781+
5782+ # Some generation methods already pass only the last input ID
5783+ if input_ids .shape [1 ] > past_length :
5784+ remove_prefix_length = past_length
5785+ else :
5786+ # Default to old behavior: keep only final ID
5787+ remove_prefix_length = input_ids .shape [1 ] - 1
5788+
5789+ input_ids = input_ids [:, remove_prefix_length :]
5790+
5791+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
5792+ if inputs_embeds is not None and past_key_values is None :
5793+ model_inputs = {"inputs_embeds" : inputs_embeds }
5794+ else :
5795+ model_inputs = {"input_ids" : input_ids }
5796+
5797+ model_inputs .update (
5798+ {
5799+ "past_key_values" : past_key_values ,
5800+ "use_cache" : kwargs .get ("use_cache" ),
5801+ "attention_mask" : attention_mask ,
5802+ }
5803+ )
5804+ return model_inputs
5805+
5806+
5807+ def prepare_inputs_for_generation_t5 (
5808+ self ,
5809+ input_ids ,
5810+ past_key_values = None ,
5811+ attention_mask = None ,
5812+ head_mask = None ,
5813+ decoder_head_mask = None ,
5814+ decoder_attention_mask = None ,
5815+ cross_attn_head_mask = None ,
5816+ use_cache = None ,
5817+ encoder_outputs = None ,
5818+ ** kwargs ,
5819+ ):
5820+ # cut decoder_input_ids if past_key_values is used
5821+ if past_key_values is not None :
5822+ past_length = past_key_values [0 ][0 ].shape [2 ]
5823+
5824+ # Some generation methods already pass only the last input ID
5825+ if input_ids .shape [1 ] > past_length :
5826+ remove_prefix_length = past_length
5827+ else :
5828+ # Default to old behavior: keep only final ID
5829+ remove_prefix_length = input_ids .shape [1 ] - 1
5830+
5831+ input_ids = input_ids [:, remove_prefix_length :]
5832+
5833+ return {
5834+ "decoder_input_ids" : input_ids ,
5835+ "past_key_values" : past_key_values ,
5836+ "encoder_outputs" : encoder_outputs ,
5837+ "attention_mask" : attention_mask ,
5838+ "head_mask" : head_mask ,
5839+ "decoder_head_mask" : decoder_head_mask ,
5840+ "decoder_attention_mask" : decoder_attention_mask ,
5841+ "cross_attn_head_mask" : cross_attn_head_mask ,
5842+ "use_cache" : use_cache ,
5843+ }
5844+
5845+
57715846def prepare_inputs_for_generation_llama (
57725847 self ,
57735848 input_ids ,
0 commit comments