Skip to content

Commit 0165fb7

Browse files
committed
Working E2E pybind runner
1 parent a681d18 commit 0165fb7

File tree

6 files changed

+74
-58
lines changed

6 files changed

+74
-58
lines changed

optimum/executorch/modeling.py

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@
3535
add_start_docstrings,
3636
)
3737
from transformers.configuration_utils import PretrainedConfig
38+
from transformers.tokenization_utils import PreTrainedTokenizer
3839
from transformers.utils import is_offline_mode
3940

4041
from executorch.extension.pybindings.portable_lib import ExecuTorchModule, _load_for_executorch
4142
from executorch.kernels import quantized # noqa
4243

4344
from ..exporters import TasksManager
4445
from ..exporters.executorch import main_export
45-
from ..exporters.executorch.utils import verify_eos_tokens_in_tokenizer
46+
from ..exporters.executorch.utils import verify_eos_tokens_in_pretrained_tokenizer
4647
from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel
4748
from ..utils.file_utils import find_files_matching_pattern
4849
from .stats import Stats
@@ -525,7 +526,7 @@ def generate(
525526

526527
def text_generation(
527528
self,
528-
tokenizer: "PreTrainedTokenizer",
529+
tokenizer: PreTrainedTokenizer,
529530
prompt: str,
530531
echo: bool = True,
531532
max_seq_len: Optional[int] = None,
@@ -745,7 +746,7 @@ def generate(
745746

746747
def text_generation(
747748
self,
748-
tokenizer: "PreTrainedTokenizer",
749+
tokenizer: PreTrainedTokenizer,
749750
prompt: str,
750751
echo: bool = True,
751752
max_seq_len: Optional[int] = None,
@@ -772,7 +773,7 @@ def text_generation(
772773
raise ValueError(
773774
f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}."
774775
)
775-
if not verify_eos_tokens_in_tokenizer(self.eos_token_ids, self.tokenizer):
776+
if not verify_eos_tokens_in_pretrained_tokenizer(self.eos_token_ids, self.tokenizer):
776777
raise ValueError(
777778
f"The tokenizer's eos_token_id does not match with the model's eos_token_ids={self.eos_token_ids}."
778779
)
@@ -1066,7 +1067,7 @@ def generate(
10661067

10671068
def transcribe(
10681069
self,
1069-
tokenizer: "PreTrainedTokenizer",
1070+
tokenizer: PreTrainedTokenizer,
10701071
input_features: torch.Tensor,
10711072
echo: bool = True,
10721073
max_seq_len: Optional[int] = None,
@@ -1197,7 +1198,7 @@ def forward(
11971198

11981199
def generate(
11991200
self,
1200-
tokenizer: "PretrainedTokenizer",
1201+
tokenizer: PreTrainedTokenizer,
12011202
input_ids: torch.LongTensor,
12021203
pixel_values: Optional[torch.FloatTensor] = None,
12031204
max_new_tokens: int = 100,
@@ -1237,31 +1238,37 @@ class ExecuTorchModelForMultiModalToText(ExecuTorchModelBase):
12371238

12381239
def __init__(self, models: Dict[str, "ExecuTorchModule"], config: "PretrainedConfig"):
12391240
super().__init__(models=models, config=config)
1240-
if not hasattr(self, "decoder"):
1241-
raise AttributeError("Expected attribute 'decoder' not found in the instance.")
1242-
if not hasattr(self, "token_embeddings"):
1243-
raise AttributeError("Expected attribute 'token_embeddings' not found in the instance.")
1244-
if not hasattr(self, "audio_encoder"):
1245-
raise AttributeError("Expected attribute 'audio_encoder' not found in the instance.")
1246-
metadata = self.decoder.method_names()
1241+
# if not hasattr(self, "decoder"):
1242+
# raise AttributeError("Expected attribute 'decoder' not found in the instance.")
1243+
# if not hasattr(self, "token_embeddings"):
1244+
# raise AttributeError("Expected attribute 'token_embeddings' not found in the instance.")
1245+
# if not hasattr(self, "audio_encoder"):
1246+
# raise AttributeError("Expected attribute 'audio_encoder' not found in the instance.")
1247+
1248+
# required_methods = ["decoder", "token_embeddings", "audio_encoder"]
1249+
# for required_method in required_methods:
1250+
# if required_method not in self.model.method_names():
1251+
# raise ValueError("Exported .pte file needs to containt 'decoder', 'token_embeddings', and 'audio_encoder' methods.")
1252+
1253+
metadata = self.model.method_names()
12471254
if "use_kv_cache" in metadata:
1248-
self.use_kv_cache = self.decoder.run_method("use_kv_cache")[0]
1255+
self.use_kv_cache = self.model.run_method("use_kv_cache")[0]
12491256
if "get_max_seq_len" in metadata:
1250-
self.max_cache_size = self.decoder.run_method("get_max_seq_len")[0]
1257+
self.max_cache_size = self.model.run_method("get_max_seq_len")[0]
12511258
if "get_max_batch_size" in metadata:
1252-
self.max_batch_size = self.decoder.run_method("get_max_batch_size")[0]
1259+
self.max_batch_size = self.model.run_method("get_max_batch_size")[0]
12531260
if "get_dtype" in metadata:
1254-
self.dtype = self.decoder.run_method("get_dtype")[0]
1261+
self.dtype = self.model.run_method("get_dtype")[0]
12551262
if "get_bos_id" in metadata:
1256-
self.bos_token_id = self.decoder.run_method("get_bos_id")[0]
1263+
self.bos_token_id = self.model.run_method("get_bos_id")[0]
12571264
if "get_eos_id" in metadata:
1258-
self.eos_token_id = self.decoder.run_method("get_eos_id")[0]
1265+
self.eos_token_id = self.model.run_method("get_eos_id")[0]
12591266
if "get_vocab_size" in metadata:
1260-
self.vocab_size = self.decoder.run_method("get_vocab_size")[0]
1267+
self.vocab_size = self.model.run_method("get_vocab_size")[0]
12611268
if "max_hidden_seq_length" in metadata:
1262-
self.max_hidden_seq_length = self.decoder.run_method("max_hidden_seq_length")[0]
1269+
self.max_hidden_seq_length = self.model.run_method("max_hidden_seq_length")[0]
12631270
if "decoder_start_token_id" in metadata:
1264-
self.decoder_start_token_id = self.decoder.run_method("decoder_start_token_id")[0]
1271+
self.decoder_start_token_id = self.model.run_method("decoder_start_token_id")[0]
12651272

12661273
def forward(
12671274
self,
@@ -1300,25 +1307,28 @@ def generate(
13001307
)
13011308
max_seq_len = self.max_cache_size
13021309

1310+
# Prefill.
13031311
self.stats.on_sampling_begin()
13041312
logits = self.forward(
1305-
input_ids=torch.tensor(prompt_tokens, dtype=torch.long, device=self.device).unsqueeze(0),
1306-
cache_position=torch.arange(len(prompt_tokens), dtype=torch.long, device=self.device),
1313+
input_ids=torch.tensor(prompt_tokens, dtype=torch.long, device=self.device),
1314+
cache_position=torch.arange(len(prompt_tokens[0]), dtype=torch.long, device=self.device),
13071315
input_features=input_features,
13081316
)
13091317
self.stats.on_sampling_end()
1310-
next_token = torch.argmax(logits, dim=-1)[0, -1].item()
13111318
self.stats.on_prompt_eval_end()
1312-
first_token_generated = False
13131319

1314-
generated_tokens = prompt_tokens + [next_token]
1320+
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
1321+
generated_tokens = [next_token]
1322+
print(self.tokenizer.decode([next_token]), end="")
13151323

1316-
while len(generated_tokens) < max_seq_len:
1324+
# Token-by-token generation.
1325+
first_token_generated = False
1326+
while len(generated_tokens) + len(prompt_tokens) < max_seq_len:
13171327
self.stats.on_sampling_begin()
13181328
logits = self.forward(
13191329
input_ids=torch.tensor([next_token], dtype=torch.long, device=self.device).unsqueeze(0),
13201330
cache_position=torch.tensor(
1321-
[pos_base + len(generated_tokens) - 1],
1331+
[pos_base + len(generated_tokens) + len(prompt_tokens) - 1],
13221332
dtype=torch.long,
13231333
device=self.device,
13241334
),
@@ -1328,20 +1338,20 @@ def generate(
13281338
self.stats.on_first_token()
13291339
first_token_generated = True
13301340

1331-
next_token = torch.argmax(logits, dim=-1).item()
1341+
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
13321342
generated_tokens.append(next_token)
1343+
print(self.tokenizer.decode([next_token]), end="")
13331344

1334-
if next_token in self.eos_token_ids:
1345+
if next_token == self.eos_token_id:
13351346
break
13361347

13371348
self.stats.set_num_generated_tokens(len(generated_tokens) - len(prompt_tokens))
1338-
13391349
return generated_tokens if echo else generated_tokens[len(prompt_tokens) :]
13401350

13411351
def text_generation(
13421352
self,
13431353
processor: "ProcessorMixin",
1344-
tokenizer: "PreTrainedTokenizer",
1354+
tokenizer: PreTrainedTokenizer,
13451355
input_conversation: List[Dict],
13461356
echo: bool = True,
13471357
max_seq_len: Optional[int] = None,
@@ -1368,22 +1378,21 @@ def text_generation(
13681378
raise ValueError(
13691379
f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}."
13701380
)
1371-
if not verify_eos_tokens_in_tokenizer(self.eos_token_ids, self.tokenizer):
1381+
if isinstance(self.tokenizer, PreTrainedTokenizer) and verify_eos_tokens_in_pretrained_tokenizer(self.eos_token_id, self.tokenizer):
13721382
raise ValueError(
1373-
f"The tokenizer's eos_token_id does not match with the model's eos_token_ids={self.eos_token_ids}."
1383+
f"The tokenizer's eos_token_id does not match with the model's eos_token_id={self.eos_token_id}."
13741384
)
13751385

13761386
# Reset stats for a new generation
13771387
self.stats.reset()
13781388
self.stats.on_inference_start()
13791389

13801390
inputs = processor.apply_chat_template(input_conversation)
1381-
prompt_tokens = self.tokenizer.encode(inputs["input_ids"])
13821391
self.stats.on_token_encode_end()
1383-
self.stats.set_num_prompt_tokens(len(prompt_tokens))
1392+
self.stats.set_num_prompt_tokens(len(inputs["input_ids"][0]))
13841393

13851394
generated_tokens = self.generate(
1386-
prompt_tokens=prompt_tokens,
1395+
prompt_tokens=inputs["input_ids"],
13871396
input_features=inputs["input_features"],
13881397
echo=echo,
13891398
max_seq_len=max_seq_len,

optimum/exporters/executorch/integrations.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,17 +671,19 @@ def export(
671671
exported_programs["token_embeddings"] = token_embeddings_exported_program
672672

673673
# 3. Export encoder.
674+
input_ids = torch.zeros_like(inputs_embeds[:, :, 0], dtype=torch.long)
675+
input_ids[0, 1] = self.config.audio_token_id # Make sure we don't have an all-false mask for the imput_embeds.
674676
if isinstance(self.model, VoxtralForConditionalGeneration):
675677
# TODO(JZ): specific to Voxtral, should generalize.
676678
chunk_length = self.model.audio_tower.config.max_source_positions * self.model.audio_tower.conv1.stride[0] * self.model.audio_tower.conv2.stride[0]
677679
encoder_input_kwargs = {
678680
"input_features": torch.rand(3, 128, chunk_length), # (bsz, features, seq_len)
679681
"inputs_embeds": inputs_embeds,
680-
"input_ids": inputs_embeds[:, :, 0],
682+
"input_ids": input_ids,
681683
}
682684

683685
max_audio_len = 150 # In s, should be a multiple of 30. TODO(JZ): make this configurable top-level.
684-
max_seq_len = self.metadata.get("get_max_seq_len") - 1 # TODO(JZ): why - 1? Copied from Gemma3 draft PR.
686+
max_seq_len = self.metadata.get("get_max_seq_len")
685687
dynamic_shapes = {
686688
"input_features": {
687689
0: torch.export.Dim("enc_batch_size_dim", min=1, max=max_audio_len//30),

optimum/exporters/executorch/recipes/xnnpack.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _lower_to_executorch(
7070
) -> Dict[str, ExecutorchProgram]:
7171
backend_config_dict = {
7272
"extract_delegate_segments": True,
73-
# "memory_planning_pass": MemoryPlanningPass(alloc_graph_input=False),
73+
"memory_planning_pass": MemoryPlanningPass(alloc_graph_input=False),
7474
}
7575
if parse(executorch_version.__version__).base_version > "0.6.0":
7676
backend_config_dict["do_quant_fusion_and_const_prop"] = True
@@ -89,14 +89,16 @@ def _lower_to_executorch(
8989
et_prog = et_prog.to_executorch(
9090
config=ExecutorchBackendConfig(**backend_config_dict),
9191
)
92-
logging.debug(
93-
f"\nExecuTorch program for {pte_name}.pte: {et_prog.exported_program().graph_module}"
94-
)
95-
delegation_info = get_delegation_info(et_prog.exported_program().graph_module)
96-
logging.debug(f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}")
97-
logging.debug(
98-
f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}"
99-
)
92+
for method in et_prog.methods:
93+
logging.debug(f"---------------------- Method: {method} ----------------------")
94+
logging.debug(
95+
f"\nExecuTorch program for {pte_name}.pte: {et_prog.exported_program(method).graph_module}"
96+
)
97+
delegation_info = get_delegation_info(et_prog.exported_program(method).graph_module)
98+
logging.debug(f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}")
99+
logging.debug(
100+
f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}"
101+
)
100102
return {pte_name: et_prog}
101103

102104
exported_progs = model.export()

optimum/exporters/executorch/tasks/multimodal_text_to_text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
6161
attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa")
6262
cache_implementation = kwargs.get("cache_implementation", "static")
6363
use_custom_sdpa = use_custom_sdpa or attn_implementation == "custom_sdpa"
64+
qlinear_config = kwargs.get("qlinear", None)
65+
qembedding_config = kwargs.get("qembedding", None)
6466
max_length = kwargs.get("max_length", 2048)
6567
config = kwargs.get("config") or AutoConfig.from_pretrained(model_name_or_path)
6668

@@ -111,8 +113,6 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
111113

112114
# TODO: Move quantization recipe out for better composability.
113115
# TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed.
114-
qlinear_config = kwargs.get("qlinear", None)
115-
qembedding_config = kwargs.get("qembedding", None)
116116
if qlinear_config or qembedding_config:
117117
# TODO: Update torchao to use 0.11.0 once released
118118
if parse(torchao.__version__) < parse("0.11.0.dev0"):

optimum/exporters/executorch/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818
from transformers import GenerationConfig, PretrainedConfig
19+
from transformers.tokenization_utils import PreTrainedTokenizer
1920

2021

2122
def save_config_to_constant_methods(
@@ -65,7 +66,7 @@ def save_config_to_constant_methods(
6566
return {k: v for k, v in {**metadata, **kwargs}.items() if v is not None}
6667

6768

68-
def verify_eos_tokens_in_tokenizer(model_eos_ids: List[int], tokenizer) -> bool:
69+
def verify_eos_tokens_in_pretrained_tokenizer(model_eos_ids: List[int], tokenizer: PreTrainedTokenizer) -> bool:
6970
"""
7071
Verifies that the model's EOS token IDs are present in the tokenizer's
7172
set of potential end-of-sequence tokens.

tests/models/test_modeling_voxtral.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import transformers
2828
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
2929
from packaging.version import parse
30-
from transformers import AutoTokenizer, AutoProcessor
30+
from transformers import AutoConfig, AutoTokenizer, AutoProcessor
3131
from transformers.testing_utils import slow
3232

3333
from optimum.utils.import_utils import is_transformers_version
@@ -42,7 +42,7 @@
4242

4343
os.environ["TOKENIZERS_PARALLELISM"] = "false"
4444

45-
logging.basicConfig(level=logging.INFO)
45+
logging.basicConfig(level=logging.DEBUG)
4646

4747

4848
@pytest.mark.skipif(
@@ -71,15 +71,16 @@ def __init__(self, *args, **kwargs):
7171
# reason="Only available on transformers >= 4.53.0.dev0 and torchao >= 0.11.0",
7272
# )
7373
# @pytest.mark.skipif(is_linux_ci, reason="OOM on linux runner")
74-
@pytest.mark.skip()
74+
# @pytest.mark.skip()
7575
def test_voxtral_audio_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8we_exported_program(self):
7676
model_id = "mistralai/Voxtral-Mini-3B-2507"
77+
config = AutoConfig.from_pretrained(model_id)
7778
module = load_multimodal_text_to_text_model(
7879
model_id,
7980
use_custom_sdpa=True,
8081
use_custom_kv_cache=True,
8182
qlinear=True,
82-
qembedding_config=True,
83+
qembedding=True,
8384
)
8485

8586
res = module.export()
@@ -166,11 +167,12 @@ def test_voxtral_audio_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8
166167
]
167168

168169
model = ExecuTorchModelForMultiModalToText.from_pretrained(
169-
model_id,
170+
# model_id,
171+
"/Users/jackzhxng/Documents/voxtral", # Load already exported model in local file path.
170172
recipe="xnnpack",
171173
attn_implementation="custom_sdpa",
172174
use_custom_kv_cache=True,
173-
**{"qlinear": True, "qembeeding": True, "task": "multimodal-text-to-text"},
175+
**{"qlinear": True, "qembedding": True, "task": "multimodal-text-to-text"},
174176
)
175177
self.assertIsInstance(model, ExecuTorchModelForMultiModalToText)
176178
self.assertIsInstance(model.model, ExecuTorchModule)

0 commit comments

Comments
 (0)