Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit be28f31

Browse files
authored
add TensorBoardLogger to transformers integration (#912)
1 parent 76fe720 commit be28f31

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

src/sparseml/pytorch/utils/logger.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,16 @@ def __len__(self):
778778
def __iter__(self):
779779
return iter(self.loggers)
780780

781+
def add_logger(self, logger: BaseLogger):
782+
"""
783+
add a BaseLogger implementation to the loggers of this manager
784+
785+
:param logger: logger object to add
786+
"""
787+
if not isinstance(logger, BaseLogger):
788+
raise ValueError(f"logger {type(logger)} must be of type BaseLogger")
789+
self._loggers.append(logger)
790+
781791
def log_ready(self, epoch, last_log_epoch):
782792
"""
783793
Check if there is a logger that is ready to accept a log

src/sparseml/transformers/sparsification/trainer.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from transformers import Trainer as TransformersTrainer
3333
from transformers import TrainerCallback, TrainerControl, TrainingArguments
3434
from transformers.file_utils import WEIGHTS_NAME
35+
from transformers.integrations import TensorBoardCallback
3536
from transformers.trainer_callback import TrainerState
3637
from transformers.trainer_utils import get_last_checkpoint
3738

@@ -40,6 +41,7 @@
4041
GradSampler,
4142
LoggerManager,
4243
ModuleSparsificationInfo,
44+
TensorBoardLogger,
4345
WANDBLogger,
4446
)
4547
from sparseml.transformers.utils import SparseAutoModel
@@ -154,7 +156,7 @@ def __init__(
154156
self.criterion = torch.nn.CrossEntropyLoss()
155157
self.callback_disable_fp16 = DisableHalfPrecisionCallback(self)
156158
self.callback_handler.add_callback(self.callback_disable_fp16)
157-
159+
self._add_tensorboard_logger_if_available()
158160
self.grad_sampler = GradSampler(
159161
self._mfac_data_loader(), self._mfac_loss_function
160162
)
@@ -263,7 +265,6 @@ def create_optimizer(self):
263265
self.manager_steps_per_epoch = math.ceil(
264266
len(self.train_dataset) / total_batch_size
265267
)
266-
267268
if hasattr(self, "scaler"):
268269
wrap_optim_key = "scaler"
269270
self.scaler = self.manager.modify(
@@ -702,6 +703,24 @@ def _mfac_loss_function(self, model_outputs, loss_target):
702703
)
703704
return loss
704705

706+
def _add_tensorboard_logger_if_available(self):
707+
tensorboard_callback = None
708+
for callback in self.callback_handler.callbacks:
709+
if isinstance(callback, TensorBoardCallback):
710+
tensorboard_callback = callback
711+
break
712+
if tensorboard_callback is None:
713+
return
714+
715+
if tensorboard_callback.tb_writer is None:
716+
tensorboard_callback._init_summary_writer(
717+
self.args, log_dir=self.args.logging_dir
718+
)
719+
720+
self.logger_manager.add_logger(
721+
TensorBoardLogger(writer=tensorboard_callback.tb_writer)
722+
)
723+
705724

706725
class TrainerInterface(RecipeManagerTrainerInterface):
707726
"""

0 commit comments

Comments
 (0)