Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 401af9f

Browse files
corey-nmrahul-tuli
authored andcommitted
Adding config to wandb logger and calling LoggerManager.save(recipe) (#1318)
Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>
1 parent fa3a376 commit 401af9f

File tree

1 file changed

+7
-1
lines changed
  • src/sparseml/pytorch/torchvision

1 file changed

+7
-1
lines changed

src/sparseml/pytorch/torchvision/train.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,13 +545,19 @@ def collate_fn(batch):
545545
TensorBoardLogger(log_path=args.output_dir),
546546
]
547547
try:
548-
loggers.append(WANDBLogger())
548+
config = vars(args)
549+
if manager is not None:
550+
config["manager"] = str(manager)
551+
loggers.append(WANDBLogger(init_kwargs=dict(config=config)))
549552
except ImportError:
550553
warnings.warn("Unable to import wandb for logging")
551554
logger = LoggerManager(loggers)
552555
else:
553556
logger = LoggerManager(log_python=False)
554557

558+
if args.recipe is not None:
559+
logger.save(args.recipe)
560+
555561
steps_per_epoch = len(data_loader) / args.gradient_accum_steps
556562

557563
def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: int):

0 commit comments

Comments
 (0)