Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions chebai/callbacks/save_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import logging
import os

from lightning import LightningModule, Trainer
from lightning.pytorch.cli import SaveConfigCallback
from lightning.pytorch.loggers import WandbLogger

logger = logging.getLogger(__name__)


class CustomSaveConfigCallback(SaveConfigCallback):
"""
Custom SaveConfigCallback that uploads the Lightning config file to W&B.

This callback extends the default SaveConfigCallback to automatically upload
the lightning_config.yaml file to Weights & Biases online run logs when using
WandbLogger. This ensures better traceability and reproducibility of experiments.

The config file is uploaded using wandb.save(), which makes it available in the
W&B web interface under the "Files" tab of the run.
"""

def save_config(
self, trainer: Trainer, pl_module: LightningModule, stage: str
) -> None:
"""
Save the config to W&B if a WandbLogger is being used.

This method is called after the config file has been saved to the log directory.
It checks if the trainer is using a WandbLogger and, if so, uploads the config
file to W&B using wandb.save().

Note:
We don't call super().save_config() because the parent class implementation
is empty. The actual config file saving to disk happens in the setup() method
before this method is called.

This method uses the following attributes from the parent SaveConfigCallback:
- self.save_to_log_dir: Whether to save config to the log directory
- self.config_filename: Name of the config file to upload

Args:
trainer: The PyTorch Lightning Trainer instance.
pl_module: The LightningModule being trained.
stage: The current stage of training (e.g., 'fit', 'validate', 'test').
"""
# Only proceed if we're saving to log_dir and have a valid trainer
if not self.save_to_log_dir or trainer.log_dir is None:
return

# Check if we're using WandbLogger
wandb_logger = None
for logger in trainer.loggers if hasattr(trainer, "loggers") else []:
if isinstance(logger, WandbLogger):
wandb_logger = logger
break

# If WandbLogger is not found, skip uploading
if wandb_logger is None:
return

# Check if the config file exists
config_path = os.path.join(trainer.log_dir, self.config_filename)
if not os.path.exists(config_path):
return

# Upload the config file to W&B
try:
import wandb

# Upload the config file to W&B
# This will make it available in the W&B web interface
wandb.save(config_path, base_path=trainer.log_dir, policy="now")
except ImportError:
# wandb is not installed, skip uploading
pass
except Exception as e:
# Log the error but don't fail the training run
logger.warning(f"Failed to upload {self.config_filename} to W&B: {e}")
2 changes: 2 additions & 0 deletions chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from lightning.pytorch.cli import LightningArgumentParser, LightningCLI

from chebai.callbacks.save_config import CustomSaveConfigCallback
from chebai.preprocessing.datasets.base import XYBaseDataModule
from chebai.trainer.CustomTrainer import CustomTrainer

Expand Down Expand Up @@ -121,6 +122,7 @@ def cli():
Main function to instantiate and run the ChebaiCLI.
"""
ChebaiCLI(
save_config_callback=CustomSaveConfigCallback,
save_config_kwargs={"config_filename": "lightning_config.yaml"},
parser_kwargs={"parser_mode": "omegaconf"},
)
Empty file.
189 changes: 189 additions & 0 deletions tests/unit/callbacks/test_save_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import os
import tempfile
import unittest
from unittest.mock import MagicMock, patch

from lightning import LightningModule, Trainer
from lightning.pytorch.loggers import WandbLogger

from chebai.callbacks.save_config import CustomSaveConfigCallback


class DummyModule(LightningModule):
"""Dummy module for testing."""

def __init__(self):
super().__init__()
self.layer = None

def forward(self, x):
return x


class TestCustomSaveConfigCallback(unittest.TestCase):
"""Test CustomSaveConfigCallback functionality."""

def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()

def test_callback_uploads_config_with_wandb_logger(self):
"""Test that the callback uploads config when WandbLogger is present."""
# Create a mock parser and config
mock_parser = MagicMock()
mock_config = MagicMock()

# Create the callback
callback = CustomSaveConfigCallback(
parser=mock_parser,
config=mock_config,
config_filename="lightning_config.yaml",
overwrite=True,
)

# Create a config file in the temp directory
config_path = os.path.join(self.temp_dir, "lightning_config.yaml")
with open(config_path, "w") as f:
f.write("test: config\n")

# Create a mock WandbLogger
mock_wandb_logger = MagicMock(spec=WandbLogger)

# Create a mock trainer with the WandbLogger
mock_trainer = MagicMock(spec=Trainer)
mock_trainer.log_dir = self.temp_dir
mock_trainer.loggers = [mock_wandb_logger]
mock_trainer.is_global_zero = True

# Create a dummy module
pl_module = DummyModule()

# Mock wandb module
with patch("wandb.save") as mock_wandb_save:
# Call save_config
callback.save_config(mock_trainer, pl_module, "fit")

# Verify wandb.save was called with the correct arguments
mock_wandb_save.assert_called_once_with(
config_path, base_path=self.temp_dir, policy="now"
)

def test_callback_skips_upload_without_wandb_logger(self):
"""Test that the callback skips upload when no WandbLogger is present."""
# Create a mock parser and config
mock_parser = MagicMock()
mock_config = MagicMock()

# Create the callback
callback = CustomSaveConfigCallback(
parser=mock_parser,
config=mock_config,
config_filename="lightning_config.yaml",
overwrite=True,
)

# Create a config file in the temp directory
config_path = os.path.join(self.temp_dir, "lightning_config.yaml")
with open(config_path, "w") as f:
f.write("test: config\n")

# Create a mock trainer WITHOUT WandbLogger
mock_trainer = MagicMock(spec=Trainer)
mock_trainer.log_dir = self.temp_dir
mock_trainer.loggers = [] # No loggers
mock_trainer.is_global_zero = True

# Create a dummy module
pl_module = DummyModule()

# Mock wandb module
with patch("wandb.save") as mock_wandb_save:
# Call save_config
callback.save_config(mock_trainer, pl_module, "fit")

# Verify wandb.save was NOT called
mock_wandb_save.assert_not_called()

def test_callback_handles_missing_config_file(self):
"""Test that the callback handles missing config file gracefully."""
# Create a mock parser and config
mock_parser = MagicMock()
mock_config = MagicMock()

# Create the callback
callback = CustomSaveConfigCallback(
parser=mock_parser,
config=mock_config,
config_filename="nonexistent_config.yaml",
overwrite=True,
)

# Create a mock WandbLogger
mock_wandb_logger = MagicMock(spec=WandbLogger)

# Create a mock trainer with the WandbLogger
mock_trainer = MagicMock(spec=Trainer)
mock_trainer.log_dir = self.temp_dir
mock_trainer.loggers = [mock_wandb_logger]
mock_trainer.is_global_zero = True

# Create a dummy module
pl_module = DummyModule()

# Mock wandb module
with patch("wandb.save") as mock_wandb_save:
# Call save_config - should not raise an error
callback.save_config(mock_trainer, pl_module, "fit")

# Verify wandb.save was NOT called (because file doesn't exist)
mock_wandb_save.assert_not_called()

def test_callback_handles_wandb_not_installed(self):
"""Test that the callback handles missing wandb package gracefully."""
# Create a mock parser and config
mock_parser = MagicMock()
mock_config = MagicMock()

# Create the callback
callback = CustomSaveConfigCallback(
parser=mock_parser,
config=mock_config,
config_filename="lightning_config.yaml",
overwrite=True,
)

# Create a config file in the temp directory
config_path = os.path.join(self.temp_dir, "lightning_config.yaml")
with open(config_path, "w") as f:
f.write("test: config\n")

# Create a mock WandbLogger
mock_wandb_logger = MagicMock(spec=WandbLogger)

# Create a mock trainer with the WandbLogger
mock_trainer = MagicMock(spec=Trainer)
mock_trainer.log_dir = self.temp_dir
mock_trainer.loggers = [mock_wandb_logger]
mock_trainer.is_global_zero = True

# Create a dummy module
pl_module = DummyModule()

# Mock wandb import to raise ImportError
# This simulates wandb not being installed
with patch("builtins.__import__") as mock_import:

def import_side_effect(name, *args, **kwargs):
if name == "wandb":
raise ImportError("No module named 'wandb'")
return __import__(name, *args, **kwargs)

mock_import.side_effect = import_side_effect

# Call save_config - should not raise an error
# The callback should catch the ImportError and continue gracefully
callback.save_config(mock_trainer, pl_module, "fit")


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions tests/unit/cli/testCLI.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest

from chebai.callbacks.save_config import CustomSaveConfigCallback
from chebai.cli import ChebaiCLI


Expand All @@ -23,6 +24,7 @@ def test_mlp_on_chebai_cli(self):
try:
ChebaiCLI(
args=self.cli_args,
save_config_callback=CustomSaveConfigCallback,
save_config_kwargs={"config_filename": "lightning_config.yaml"},
parser_kwargs={"parser_mode": "omegaconf"},
)
Expand Down
Loading