|
32 | 32 | from transformers import Trainer as TransformersTrainer |
33 | 33 | from transformers import TrainerCallback, TrainerControl, TrainingArguments |
34 | 34 | from transformers.file_utils import WEIGHTS_NAME |
| 35 | +from transformers.integrations import TensorBoardCallback |
35 | 36 | from transformers.trainer_callback import TrainerState |
36 | 37 | from transformers.trainer_utils import get_last_checkpoint |
37 | 38 |
|
|
40 | 41 | GradSampler, |
41 | 42 | LoggerManager, |
42 | 43 | ModuleSparsificationInfo, |
| 44 | + TensorBoardLogger, |
43 | 45 | WANDBLogger, |
44 | 46 | ) |
45 | 47 | from sparseml.transformers.utils import SparseAutoModel |
@@ -154,7 +156,7 @@ def __init__( |
154 | 156 | self.criterion = torch.nn.CrossEntropyLoss() |
155 | 157 | self.callback_disable_fp16 = DisableHalfPrecisionCallback(self) |
156 | 158 | self.callback_handler.add_callback(self.callback_disable_fp16) |
157 | | - |
| 159 | + self._add_tensorboard_logger_if_available() |
158 | 160 | self.grad_sampler = GradSampler( |
159 | 161 | self._mfac_data_loader(), self._mfac_loss_function |
160 | 162 | ) |
@@ -263,7 +265,6 @@ def create_optimizer(self): |
263 | 265 | self.manager_steps_per_epoch = math.ceil( |
264 | 266 | len(self.train_dataset) / total_batch_size |
265 | 267 | ) |
266 | | - |
267 | 268 | if hasattr(self, "scaler"): |
268 | 269 | wrap_optim_key = "scaler" |
269 | 270 | self.scaler = self.manager.modify( |
@@ -702,6 +703,24 @@ def _mfac_loss_function(self, model_outputs, loss_target): |
702 | 703 | ) |
703 | 704 | return loss |
704 | 705 |
|
| 706 | + def _add_tensorboard_logger_if_available(self): |
| 707 | + tensorboard_callback = None |
| 708 | + for callback in self.callback_handler.callbacks: |
| 709 | + if isinstance(callback, TensorBoardCallback): |
| 710 | + tensorboard_callback = callback |
| 711 | + break |
| 712 | + if tensorboard_callback is None: |
| 713 | + return |
| 714 | + |
| 715 | + if tensorboard_callback.tb_writer is None: |
| 716 | + tensorboard_callback._init_summary_writer( |
| 717 | + self.args, log_dir=self.args.logging_dir |
| 718 | + ) |
| 719 | + |
| 720 | + self.logger_manager.add_logger( |
| 721 | + TensorBoardLogger(writer=tensorboard_callback.tb_writer) |
| 722 | + ) |
| 723 | + |
705 | 724 |
|
706 | 725 | class TrainerInterface(RecipeManagerTrainerInterface): |
707 | 726 | """ |
|
0 commit comments