Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 68f2a2f

Browse files
authored
Support to Generate Fake Sample/Inputs (#1180) (#1197)
* Support to Generate Fake Sample/Inputs and outputs if no `--data_args` supplied in export script * Address all review comments Simplify call to `_get_fake_inputs` Save inputs/outputs to `sample-inputs`/`sample-outputs`
1 parent 238b6ef commit 68f2a2f

File tree

2 files changed

+76
-15
lines changed

2 files changed

+76
-15
lines changed

src/sparseml/transformers/export.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,10 @@ def export_transformer_to_onnx(
246246
)
247247

248248
if num_export_samples > 0 and data_args is None:
249-
raise ValueError(
249+
_LOGGER.info(
250250
f"--data_args is needed for exporting {num_export_samples} "
251-
f"samples but got {data_args}"
251+
"real samples but got None, synthetic data samples will be "
252+
"generated based on model input/output shapes"
252253
)
253254
data_args: Dict[str, Any] = _parse_data_args(data_args)
254255

@@ -265,7 +266,7 @@ def export_transformer_to_onnx(
265266
_LOGGER.info(f"loaded model, config, and tokenizer from {model_path}")
266267

267268
eval_dataset = None
268-
if num_export_samples > 0:
269+
if num_export_samples > 0 and data_args:
269270
tokenized_dataset = load_task_dataset(
270271
task=task,
271272
tokenizer=tokenizer,
@@ -316,12 +317,16 @@ def export_transformer_to_onnx(
316317
# Rearrange inputs' keys to match those defined by model foward func, which
317318
# seem to define how the order of inputs is determined in the exported model
318319
forward_args_spec = inspect.getfullargspec(model.__class__.forward)
319-
dropped = [f for f in inputs.keys() if f not in forward_args_spec.args]
320+
dropped = [
321+
input_key
322+
for input_key in inputs.keys()
323+
if input_key not in forward_args_spec.args
324+
]
320325
inputs = collections.OrderedDict(
321326
[
322-
(f, inputs[f][0].reshape(1, -1))
323-
for f in forward_args_spec.args
324-
if f in inputs
327+
(func_input_arg_name, inputs[func_input_arg_name][0].reshape(1, -1))
328+
for func_input_arg_name in forward_args_spec.args
329+
if func_input_arg_name in inputs
325330
]
326331
)
327332
if dropped:
@@ -362,6 +367,7 @@ def export_transformer_to_onnx(
362367
_LOGGER.info(f"Exporting {num_export_samples} sample inputs/outputs")
363368
trainer.save_sample_inputs_outputs(
364369
num_samples_to_export=num_export_samples,
370+
tokenizer=tokenizer,
365371
)
366372

367373
_LOGGER.info(f"{num_export_samples} sample inputs/outputs exported")

src/sparseml/transformers/sparsification/trainer.py

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
SparseML transformers trainer classes and interfaces to be plugged in with
1717
existing or similiar HF trainer flows
1818
"""
19+
import collections
1920
import inspect
2021
import logging
2122
import math
2223
import os
2324
import warnings
25+
from contextlib import suppress
2426
from dataclasses import asdict
25-
from typing import Any, Dict, List, Optional, Tuple, Union
27+
from typing import Any, Dict, Final, List, Optional, Tuple, Union
2628

2729
import datasets
2830
import numpy
@@ -36,6 +38,7 @@
3638
from transformers.trainer_callback import TrainerState
3739
from transformers.trainer_pt_utils import reissue_pt_warnings
3840
from transformers.trainer_utils import ShardedDDPOption, get_last_checkpoint
41+
from transformers.utils import PaddingStrategy
3942

4043
from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer
4144
from sparseml.pytorch.utils import (
@@ -56,7 +59,6 @@
5659
"TransformersTrainer",
5760
]
5861

59-
6062
_LOGGER = logging.getLogger(__name__)
6163
TRAINER_STATE_NAME = "trainer_state.json"
6264
OPTIMIZER_NAME = "optimizer.pt"
@@ -489,31 +491,56 @@ def log_model_sparsification(self):
489491
)
490492

491493
def save_sample_inputs_outputs(
492-
self, num_samples_to_export: int = 100, output_dir: Optional[str] = None
494+
self,
495+
num_samples_to_export: int = 100,
496+
output_dir: Optional[str] = None,
497+
tokenizer: Optional[Any] = None,
493498
):
494499
"""
495500
Save sample inputs/outputs/labels in save_dir as .npz arrays
496501
497502
:param num_samples_to_export: Number of samples to export.
498503
Defaults to 100
499504
:param output_dir: The directory to store sample inputs and outputs in
505+
:param tokenizer: if eval and train dataset cannot be generated, then
506+
the tokenizer is used to generate fake inputs
500507
"""
501508
num_samples = 0
502-
output_dir = output_dir or self.args.output_dir or ""
503509

504-
sample_in_dir = os.path.join(output_dir, "sample_inputs")
505-
sample_out_dir = os.path.join(output_dir, "sample_outputs")
510+
if output_dir is None:
511+
output_dir = (
512+
self.args.output_dir if hasattr(self.args, "output_dir") else ""
513+
)
514+
515+
sample_in_dir = os.path.join(output_dir, "sample-inputs")
516+
sample_out_dir = os.path.join(output_dir, "sample-outputs")
506517

507518
os.makedirs(sample_in_dir, exist_ok=True)
508519
os.makedirs(sample_out_dir, exist_ok=True)
509520
device = self.model.device
510521

522+
dataloader = None
511523
try:
512524
dataloader = self.get_eval_dataloader()
513525
except Exception:
514-
dataloader = self.get_train_dataloader()
526+
with suppress(ValueError):
527+
dataloader = self.get_train_dataloader()
528+
529+
if not dataloader and not tokenizer:
530+
raise ValueError(
531+
"tokenizer is needed to generate fake sample inputs when Trainer is "
532+
"not initialized with a train or eval dataset"
533+
)
534+
if dataloader is None:
535+
# we have the tokenizer so use it
536+
dataloader = self._get_fake_dataloader(
537+
num_samples=num_samples_to_export, tokenizer=tokenizer
538+
)
515539

516-
_LOGGER.info(f"Exporting {num_samples_to_export} samples to {output_dir}")
540+
_LOGGER.info(
541+
f"Exporting {num_samples_to_export} samples to "
542+
f"{os.path.abspath(output_dir)}"
543+
)
517544
for _, sample_batch in enumerate(dataloader):
518545
sample_batch.pop("labels", None)
519546
input_names = list(sample_batch.keys())
@@ -725,6 +752,34 @@ def _add_tensorboard_logger_if_available(self):
725752
TensorBoardLogger(writer=tensorboard_callback.tb_writer)
726753
)
727754

755+
def _get_fake_dataloader(
756+
self,
757+
num_samples: int,
758+
tokenizer: "PreTrainedTokenizerBase", # noqa: F821
759+
):
760+
761+
# Rearrange inputs' keys to match those defined by model foward func, which
762+
# seem to define how the order of inputs is determined in the exported model
763+
forward_args_spec = inspect.getfullargspec(self.model.__class__.forward)
764+
synthetic_input: Final = self._get_fake_input(
765+
forward_func_input_keys=forward_args_spec.args,
766+
tokenizer=tokenizer,
767+
)
768+
return (synthetic_input for _ in range(num_samples))
769+
770+
def _get_fake_input(self, forward_func_input_keys, tokenizer):
771+
inputs = tokenizer(
772+
"", return_tensors="pt", padding=PaddingStrategy.MAX_LENGTH.value
773+
).data # Dict[Tensor]
774+
inputs = collections.OrderedDict(
775+
[
776+
(input_key, inputs[input_key][0].reshape(1, -1))
777+
for input_key in forward_func_input_keys
778+
if input_key in inputs
779+
]
780+
)
781+
return inputs
782+
728783

729784
class TrainerInterface(RecipeManagerTrainerInterface):
730785
"""

0 commit comments

Comments
 (0)