3535 add_start_docstrings ,
3636)
3737from transformers .configuration_utils import PretrainedConfig
38+ from transformers .tokenization_utils import PreTrainedTokenizer
3839from transformers .utils import is_offline_mode
3940
4041from executorch .extension .pybindings .portable_lib import ExecuTorchModule , _load_for_executorch
4142from executorch .kernels import quantized # noqa
4243
4344from ..exporters import TasksManager
4445from ..exporters .executorch import main_export
45- from ..exporters .executorch .utils import verify_eos_tokens_in_tokenizer
46+ from ..exporters .executorch .utils import verify_eos_tokens_in_pretrained_tokenizer
4647from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING , OptimizedModel
4748from ..utils .file_utils import find_files_matching_pattern
4849from .stats import Stats
@@ -525,7 +526,7 @@ def generate(
525526
526527 def text_generation (
527528 self ,
528- tokenizer : " PreTrainedTokenizer" ,
529+ tokenizer : PreTrainedTokenizer ,
529530 prompt : str ,
530531 echo : bool = True ,
531532 max_seq_len : Optional [int ] = None ,
@@ -745,7 +746,7 @@ def generate(
745746
746747 def text_generation (
747748 self ,
748- tokenizer : " PreTrainedTokenizer" ,
749+ tokenizer : PreTrainedTokenizer ,
749750 prompt : str ,
750751 echo : bool = True ,
751752 max_seq_len : Optional [int ] = None ,
@@ -772,7 +773,7 @@ def text_generation(
772773 raise ValueError (
773774 f"The tokenizer's bos_token_id={ self .tokenizer .bos_token_id } must be the same as the model's bos_token_id={ self .bos_token_id } ."
774775 )
775- if not verify_eos_tokens_in_tokenizer (self .eos_token_ids , self .tokenizer ):
776+ if not verify_eos_tokens_in_pretrained_tokenizer (self .eos_token_ids , self .tokenizer ):
776777 raise ValueError (
777778 f"The tokenizer's eos_token_id does not match with the model's eos_token_ids={ self .eos_token_ids } ."
778779 )
@@ -1066,7 +1067,7 @@ def generate(
10661067
10671068 def transcribe (
10681069 self ,
1069- tokenizer : " PreTrainedTokenizer" ,
1070+ tokenizer : PreTrainedTokenizer ,
10701071 input_features : torch .Tensor ,
10711072 echo : bool = True ,
10721073 max_seq_len : Optional [int ] = None ,
@@ -1197,7 +1198,7 @@ def forward(
11971198
11981199 def generate (
11991200 self ,
1200- tokenizer : "PretrainedTokenizer" ,
1201+ tokenizer : PreTrainedTokenizer ,
12011202 input_ids : torch .LongTensor ,
12021203 pixel_values : Optional [torch .FloatTensor ] = None ,
12031204 max_new_tokens : int = 100 ,
@@ -1237,31 +1238,37 @@ class ExecuTorchModelForMultiModalToText(ExecuTorchModelBase):
12371238
12381239 def __init__ (self , models : Dict [str , "ExecuTorchModule" ], config : "PretrainedConfig" ):
12391240 super ().__init__ (models = models , config = config )
1240- if not hasattr (self , "decoder" ):
1241- raise AttributeError ("Expected attribute 'decoder' not found in the instance." )
1242- if not hasattr (self , "token_embeddings" ):
1243- raise AttributeError ("Expected attribute 'token_embeddings' not found in the instance." )
1244- if not hasattr (self , "audio_encoder" ):
1245- raise AttributeError ("Expected attribute 'audio_encoder' not found in the instance." )
1246- metadata = self .decoder .method_names ()
1241+ # if not hasattr(self, "decoder"):
1242+ # raise AttributeError("Expected attribute 'decoder' not found in the instance.")
1243+ # if not hasattr(self, "token_embeddings"):
1244+ # raise AttributeError("Expected attribute 'token_embeddings' not found in the instance.")
1245+ # if not hasattr(self, "audio_encoder"):
1246+ # raise AttributeError("Expected attribute 'audio_encoder' not found in the instance.")
1247+
1248+ # required_methods = ["decoder", "token_embeddings", "audio_encoder"]
1249+ # for required_method in required_methods:
1250+ # if required_method not in self.model.method_names():
1251+ # raise ValueError("Exported .pte file needs to containt 'decoder', 'token_embeddings', and 'audio_encoder' methods.")
1252+
1253+ metadata = self .model .method_names ()
12471254 if "use_kv_cache" in metadata :
1248- self .use_kv_cache = self .decoder .run_method ("use_kv_cache" )[0 ]
1255+ self .use_kv_cache = self .model .run_method ("use_kv_cache" )[0 ]
12491256 if "get_max_seq_len" in metadata :
1250- self .max_cache_size = self .decoder .run_method ("get_max_seq_len" )[0 ]
1257+ self .max_cache_size = self .model .run_method ("get_max_seq_len" )[0 ]
12511258 if "get_max_batch_size" in metadata :
1252- self .max_batch_size = self .decoder .run_method ("get_max_batch_size" )[0 ]
1259+ self .max_batch_size = self .model .run_method ("get_max_batch_size" )[0 ]
12531260 if "get_dtype" in metadata :
1254- self .dtype = self .decoder .run_method ("get_dtype" )[0 ]
1261+ self .dtype = self .model .run_method ("get_dtype" )[0 ]
12551262 if "get_bos_id" in metadata :
1256- self .bos_token_id = self .decoder .run_method ("get_bos_id" )[0 ]
1263+ self .bos_token_id = self .model .run_method ("get_bos_id" )[0 ]
12571264 if "get_eos_id" in metadata :
1258- self .eos_token_id = self .decoder .run_method ("get_eos_id" )[0 ]
1265+ self .eos_token_id = self .model .run_method ("get_eos_id" )[0 ]
12591266 if "get_vocab_size" in metadata :
1260- self .vocab_size = self .decoder .run_method ("get_vocab_size" )[0 ]
1267+ self .vocab_size = self .model .run_method ("get_vocab_size" )[0 ]
12611268 if "max_hidden_seq_length" in metadata :
1262- self .max_hidden_seq_length = self .decoder .run_method ("max_hidden_seq_length" )[0 ]
1269+ self .max_hidden_seq_length = self .model .run_method ("max_hidden_seq_length" )[0 ]
12631270 if "decoder_start_token_id" in metadata :
1264- self .decoder_start_token_id = self .decoder .run_method ("decoder_start_token_id" )[0 ]
1271+ self .decoder_start_token_id = self .model .run_method ("decoder_start_token_id" )[0 ]
12651272
12661273 def forward (
12671274 self ,
@@ -1300,25 +1307,28 @@ def generate(
13001307 )
13011308 max_seq_len = self .max_cache_size
13021309
1310+ # Prefill.
13031311 self .stats .on_sampling_begin ()
13041312 logits = self .forward (
1305- input_ids = torch .tensor (prompt_tokens , dtype = torch .long , device = self .device ). unsqueeze ( 0 ) ,
1306- cache_position = torch .arange (len (prompt_tokens ), dtype = torch .long , device = self .device ),
1313+ input_ids = torch .tensor (prompt_tokens , dtype = torch .long , device = self .device ),
1314+ cache_position = torch .arange (len (prompt_tokens [ 0 ] ), dtype = torch .long , device = self .device ),
13071315 input_features = input_features ,
13081316 )
13091317 self .stats .on_sampling_end ()
1310- next_token = torch .argmax (logits , dim = - 1 )[0 , - 1 ].item ()
13111318 self .stats .on_prompt_eval_end ()
1312- first_token_generated = False
13131319
1314- generated_tokens = prompt_tokens + [next_token ]
1320+ next_token = torch .argmax (logits [:, - 1 , :], dim = - 1 ).item ()
1321+ generated_tokens = [next_token ]
1322+ print (self .tokenizer .decode ([next_token ]), end = "" )
13151323
1316- while len (generated_tokens ) < max_seq_len :
1324+ # Token-by-token generation.
1325+ first_token_generated = False
1326+ while len (generated_tokens ) + len (prompt_tokens ) < max_seq_len :
13171327 self .stats .on_sampling_begin ()
13181328 logits = self .forward (
13191329 input_ids = torch .tensor ([next_token ], dtype = torch .long , device = self .device ).unsqueeze (0 ),
13201330 cache_position = torch .tensor (
1321- [pos_base + len (generated_tokens ) - 1 ],
1331+ [pos_base + len (generated_tokens ) + len ( prompt_tokens ) - 1 ],
13221332 dtype = torch .long ,
13231333 device = self .device ,
13241334 ),
@@ -1328,20 +1338,20 @@ def generate(
13281338 self .stats .on_first_token ()
13291339 first_token_generated = True
13301340
1331- next_token = torch .argmax (logits , dim = - 1 ).item ()
1341+ next_token = torch .argmax (logits [:, - 1 , :] , dim = - 1 ).item ()
13321342 generated_tokens .append (next_token )
1343+ print (self .tokenizer .decode ([next_token ]), end = "" )
13331344
1334- if next_token in self .eos_token_ids :
1345+ if next_token == self .eos_token_id :
13351346 break
13361347
13371348 self .stats .set_num_generated_tokens (len (generated_tokens ) - len (prompt_tokens ))
1338-
13391349 return generated_tokens if echo else generated_tokens [len (prompt_tokens ) :]
13401350
13411351 def text_generation (
13421352 self ,
13431353 processor : "ProcessorMixin" ,
1344- tokenizer : " PreTrainedTokenizer" ,
1354+ tokenizer : PreTrainedTokenizer ,
13451355 input_conversation : List [Dict ],
13461356 echo : bool = True ,
13471357 max_seq_len : Optional [int ] = None ,
@@ -1368,22 +1378,21 @@ def text_generation(
13681378 raise ValueError (
13691379 f"The tokenizer's bos_token_id={ self .tokenizer .bos_token_id } must be the same as the model's bos_token_id={ self .bos_token_id } ."
13701380 )
1371- if not verify_eos_tokens_in_tokenizer (self .eos_token_ids , self .tokenizer ):
1381+ if isinstance ( self . tokenizer , PreTrainedTokenizer ) and verify_eos_tokens_in_pretrained_tokenizer (self .eos_token_id , self .tokenizer ):
13721382 raise ValueError (
1373- f"The tokenizer's eos_token_id does not match with the model's eos_token_ids ={ self .eos_token_ids } ."
1383+ f"The tokenizer's eos_token_id does not match with the model's eos_token_id ={ self .eos_token_id } ."
13741384 )
13751385
13761386 # Reset stats for a new generation
13771387 self .stats .reset ()
13781388 self .stats .on_inference_start ()
13791389
13801390 inputs = processor .apply_chat_template (input_conversation )
1381- prompt_tokens = self .tokenizer .encode (inputs ["input_ids" ])
13821391 self .stats .on_token_encode_end ()
1383- self .stats .set_num_prompt_tokens (len (prompt_tokens ))
1392+ self .stats .set_num_prompt_tokens (len (inputs [ "input_ids" ][ 0 ] ))
13841393
13851394 generated_tokens = self .generate (
1386- prompt_tokens = prompt_tokens ,
1395+ prompt_tokens = inputs [ "input_ids" ] ,
13871396 input_features = inputs ["input_features" ],
13881397 echo = echo ,
13891398 max_seq_len = max_seq_len ,
0 commit comments