|
16 | 16 | SparseML transformers trainer classes and interfaces to be plugged in with |
17 | 17 | existing or similiar HF trainer flows |
18 | 18 | """ |
| 19 | +import collections |
19 | 20 | import inspect |
20 | 21 | import logging |
21 | 22 | import math |
22 | 23 | import os |
23 | 24 | import warnings |
| 25 | +from contextlib import suppress |
24 | 26 | from dataclasses import asdict |
25 | | -from typing import Any, Dict, List, Optional, Tuple, Union |
| 27 | +from typing import Any, Dict, Final, List, Optional, Tuple, Union |
26 | 28 |
|
27 | 29 | import datasets |
28 | 30 | import numpy |
|
36 | 38 | from transformers.trainer_callback import TrainerState |
37 | 39 | from transformers.trainer_pt_utils import reissue_pt_warnings |
38 | 40 | from transformers.trainer_utils import ShardedDDPOption, get_last_checkpoint |
| 41 | +from transformers.utils import PaddingStrategy |
39 | 42 |
|
40 | 43 | from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer |
41 | 44 | from sparseml.pytorch.utils import ( |
|
56 | 59 | "TransformersTrainer", |
57 | 60 | ] |
58 | 61 |
|
59 | | - |
60 | 62 | _LOGGER = logging.getLogger(__name__) |
61 | 63 | TRAINER_STATE_NAME = "trainer_state.json" |
62 | 64 | OPTIMIZER_NAME = "optimizer.pt" |
@@ -489,31 +491,56 @@ def log_model_sparsification(self): |
489 | 491 | ) |
490 | 492 |
|
491 | 493 | def save_sample_inputs_outputs( |
492 | | - self, num_samples_to_export: int = 100, output_dir: Optional[str] = None |
| 494 | + self, |
| 495 | + num_samples_to_export: int = 100, |
| 496 | + output_dir: Optional[str] = None, |
| 497 | + tokenizer: Optional[Any] = None, |
493 | 498 | ): |
494 | 499 | """ |
495 | 500 | Save sample inputs/outputs/labels in save_dir as .npz arrays |
496 | 501 |
|
497 | 502 | :param num_samples_to_export: Number of samples to export. |
498 | 503 | Defaults to 100 |
499 | 504 | :param output_dir: The directory to store sample inputs and outputs in |
| 505 | + :param tokenizer: if eval and train dataset cannot be generated, then |
| 506 | + the tokenizer is used to generate fake inputs |
500 | 507 | """ |
501 | 508 | num_samples = 0 |
502 | | - output_dir = output_dir or self.args.output_dir or "" |
503 | 509 |
|
504 | | - sample_in_dir = os.path.join(output_dir, "sample_inputs") |
505 | | - sample_out_dir = os.path.join(output_dir, "sample_outputs") |
| 510 | + if output_dir is None: |
| 511 | + output_dir = ( |
| 512 | + self.args.output_dir if hasattr(self.args, "output_dir") else "" |
| 513 | + ) |
| 514 | + |
| 515 | + sample_in_dir = os.path.join(output_dir, "sample-inputs") |
| 516 | + sample_out_dir = os.path.join(output_dir, "sample-outputs") |
506 | 517 |
|
507 | 518 | os.makedirs(sample_in_dir, exist_ok=True) |
508 | 519 | os.makedirs(sample_out_dir, exist_ok=True) |
509 | 520 | device = self.model.device |
510 | 521 |
|
| 522 | + dataloader = None |
511 | 523 | try: |
512 | 524 | dataloader = self.get_eval_dataloader() |
513 | 525 | except Exception: |
514 | | - dataloader = self.get_train_dataloader() |
| 526 | + with suppress(ValueError): |
| 527 | + dataloader = self.get_train_dataloader() |
| 528 | + |
| 529 | + if not dataloader and not tokenizer: |
| 530 | + raise ValueError( |
| 531 | + "tokenizer is needed to generate fake sample inputs when Trainer is " |
| 532 | + "not initialized with a train or eval dataset" |
| 533 | + ) |
| 534 | + if dataloader is None: |
| 535 | + # we have the tokenizer so use it |
| 536 | + dataloader = self._get_fake_dataloader( |
| 537 | + num_samples=num_samples_to_export, tokenizer=tokenizer |
| 538 | + ) |
515 | 539 |
|
516 | | - _LOGGER.info(f"Exporting {num_samples_to_export} samples to {output_dir}") |
| 540 | + _LOGGER.info( |
| 541 | + f"Exporting {num_samples_to_export} samples to " |
| 542 | + f"{os.path.abspath(output_dir)}" |
| 543 | + ) |
517 | 544 | for _, sample_batch in enumerate(dataloader): |
518 | 545 | sample_batch.pop("labels", None) |
519 | 546 | input_names = list(sample_batch.keys()) |
@@ -725,6 +752,34 @@ def _add_tensorboard_logger_if_available(self): |
725 | 752 | TensorBoardLogger(writer=tensorboard_callback.tb_writer) |
726 | 753 | ) |
727 | 754 |
|
| 755 | + def _get_fake_dataloader( |
| 756 | + self, |
| 757 | + num_samples: int, |
| 758 | + tokenizer: "PreTrainedTokenizerBase", # noqa: F821 |
| 759 | + ): |
| 760 | + |
| 761 | + # Rearrange inputs' keys to match those defined by model foward func, which |
| 762 | + # seem to define how the order of inputs is determined in the exported model |
| 763 | + forward_args_spec = inspect.getfullargspec(self.model.__class__.forward) |
| 764 | + synthetic_input: Final = self._get_fake_input( |
| 765 | + forward_func_input_keys=forward_args_spec.args, |
| 766 | + tokenizer=tokenizer, |
| 767 | + ) |
| 768 | + return (synthetic_input for _ in range(num_samples)) |
| 769 | + |
| 770 | + def _get_fake_input(self, forward_func_input_keys, tokenizer): |
| 771 | + inputs = tokenizer( |
| 772 | + "", return_tensors="pt", padding=PaddingStrategy.MAX_LENGTH.value |
| 773 | + ).data # Dict[Tensor] |
| 774 | + inputs = collections.OrderedDict( |
| 775 | + [ |
| 776 | + (input_key, inputs[input_key][0].reshape(1, -1)) |
| 777 | + for input_key in forward_func_input_keys |
| 778 | + if input_key in inputs |
| 779 | + ] |
| 780 | + ) |
| 781 | + return inputs |
| 782 | + |
728 | 783 |
|
729 | 784 | class TrainerInterface(RecipeManagerTrainerInterface): |
730 | 785 | """ |
|
0 commit comments