diff --git a/chebai/callbacks/save_config.py b/chebai/callbacks/save_config.py new file mode 100644 index 00000000..92dfadbd --- /dev/null +++ b/chebai/callbacks/save_config.py @@ -0,0 +1,79 @@ +import logging +import os + +from lightning import LightningModule, Trainer +from lightning.pytorch.cli import SaveConfigCallback +from lightning.pytorch.loggers import WandbLogger + +logger = logging.getLogger(__name__) + + +class CustomSaveConfigCallback(SaveConfigCallback): + """ + Custom SaveConfigCallback that uploads the Lightning config file to W&B. + + This callback extends the default SaveConfigCallback to automatically upload + the lightning_config.yaml file to Weights & Biases online run logs when using + WandbLogger. This ensures better traceability and reproducibility of experiments. + + The config file is uploaded using wandb.save(), which makes it available in the + W&B web interface under the "Files" tab of the run. + """ + + def save_config( + self, trainer: Trainer, pl_module: LightningModule, stage: str + ) -> None: + """ + Save the config to W&B if a WandbLogger is being used. + + This method is called after the config file has been saved to the log directory. + It checks if the trainer is using a WandbLogger and, if so, uploads the config + file to W&B using wandb.save(). + + Note: + We don't call super().save_config() because the parent class implementation + is empty. The actual config file saving to disk happens in the setup() method + before this method is called. + + This method uses the following attributes from the parent SaveConfigCallback: + - self.save_to_log_dir: Whether to save config to the log directory + - self.config_filename: Name of the config file to upload + + Args: + trainer: The PyTorch Lightning Trainer instance. + pl_module: The LightningModule being trained. + stage: The current stage of training (e.g., 'fit', 'validate', 'test'). + """ + # Only proceed if we're saving to log_dir and have a valid trainer + if not self.save_to_log_dir or trainer.log_dir is None: + return + + # Check if we're using WandbLogger + wandb_logger = None + for logger in trainer.loggers if hasattr(trainer, "loggers") else []: + if isinstance(logger, WandbLogger): + wandb_logger = logger + break + + # If WandbLogger is not found, skip uploading + if wandb_logger is None: + return + + # Check if the config file exists + config_path = os.path.join(trainer.log_dir, self.config_filename) + if not os.path.exists(config_path): + return + + # Upload the config file to W&B + try: + import wandb + + # Upload the config file to W&B + # This will make it available in the W&B web interface + wandb.save(config_path, base_path=trainer.log_dir, policy="now") + except ImportError: + # wandb is not installed, skip uploading + pass + except Exception as e: + # Log the error but don't fail the training run + logger.warning(f"Failed to upload {self.config_filename} to W&B: {e}") diff --git a/chebai/cli.py b/chebai/cli.py index 1aaba53c..8ca5cd58 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -2,6 +2,7 @@ from lightning.pytorch.cli import LightningArgumentParser, LightningCLI +from chebai.callbacks.save_config import CustomSaveConfigCallback from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.trainer.CustomTrainer import CustomTrainer @@ -121,6 +122,7 @@ def cli(): Main function to instantiate and run the ChebaiCLI. """ ChebaiCLI( + save_config_callback=CustomSaveConfigCallback, save_config_kwargs={"config_filename": "lightning_config.yaml"}, parser_kwargs={"parser_mode": "omegaconf"}, ) diff --git a/tests/unit/callbacks/__init__.py b/tests/unit/callbacks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/callbacks/test_save_config.py b/tests/unit/callbacks/test_save_config.py new file mode 100644 index 00000000..f53bbe0d --- /dev/null +++ b/tests/unit/callbacks/test_save_config.py @@ -0,0 +1,189 @@ +import os +import tempfile +import unittest +from unittest.mock import MagicMock, patch + +from lightning import LightningModule, Trainer +from lightning.pytorch.loggers import WandbLogger + +from chebai.callbacks.save_config import CustomSaveConfigCallback + + +class DummyModule(LightningModule): + """Dummy module for testing.""" + + def __init__(self): + super().__init__() + self.layer = None + + def forward(self, x): + return x + + +class TestCustomSaveConfigCallback(unittest.TestCase): + """Test CustomSaveConfigCallback functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + + def test_callback_uploads_config_with_wandb_logger(self): + """Test that the callback uploads config when WandbLogger is present.""" + # Create a mock parser and config + mock_parser = MagicMock() + mock_config = MagicMock() + + # Create the callback + callback = CustomSaveConfigCallback( + parser=mock_parser, + config=mock_config, + config_filename="lightning_config.yaml", + overwrite=True, + ) + + # Create a config file in the temp directory + config_path = os.path.join(self.temp_dir, "lightning_config.yaml") + with open(config_path, "w") as f: + f.write("test: config\n") + + # Create a mock WandbLogger + mock_wandb_logger = MagicMock(spec=WandbLogger) + + # Create a mock trainer with the WandbLogger + mock_trainer = MagicMock(spec=Trainer) + mock_trainer.log_dir = self.temp_dir + mock_trainer.loggers = [mock_wandb_logger] + mock_trainer.is_global_zero = True + + # Create a dummy module + pl_module = DummyModule() + + # Mock wandb module + with patch("wandb.save") as mock_wandb_save: + # Call save_config + callback.save_config(mock_trainer, pl_module, "fit") + + # Verify wandb.save was called with the correct arguments + mock_wandb_save.assert_called_once_with( + config_path, base_path=self.temp_dir, policy="now" + ) + + def test_callback_skips_upload_without_wandb_logger(self): + """Test that the callback skips upload when no WandbLogger is present.""" + # Create a mock parser and config + mock_parser = MagicMock() + mock_config = MagicMock() + + # Create the callback + callback = CustomSaveConfigCallback( + parser=mock_parser, + config=mock_config, + config_filename="lightning_config.yaml", + overwrite=True, + ) + + # Create a config file in the temp directory + config_path = os.path.join(self.temp_dir, "lightning_config.yaml") + with open(config_path, "w") as f: + f.write("test: config\n") + + # Create a mock trainer WITHOUT WandbLogger + mock_trainer = MagicMock(spec=Trainer) + mock_trainer.log_dir = self.temp_dir + mock_trainer.loggers = [] # No loggers + mock_trainer.is_global_zero = True + + # Create a dummy module + pl_module = DummyModule() + + # Mock wandb module + with patch("wandb.save") as mock_wandb_save: + # Call save_config + callback.save_config(mock_trainer, pl_module, "fit") + + # Verify wandb.save was NOT called + mock_wandb_save.assert_not_called() + + def test_callback_handles_missing_config_file(self): + """Test that the callback handles missing config file gracefully.""" + # Create a mock parser and config + mock_parser = MagicMock() + mock_config = MagicMock() + + # Create the callback + callback = CustomSaveConfigCallback( + parser=mock_parser, + config=mock_config, + config_filename="nonexistent_config.yaml", + overwrite=True, + ) + + # Create a mock WandbLogger + mock_wandb_logger = MagicMock(spec=WandbLogger) + + # Create a mock trainer with the WandbLogger + mock_trainer = MagicMock(spec=Trainer) + mock_trainer.log_dir = self.temp_dir + mock_trainer.loggers = [mock_wandb_logger] + mock_trainer.is_global_zero = True + + # Create a dummy module + pl_module = DummyModule() + + # Mock wandb module + with patch("wandb.save") as mock_wandb_save: + # Call save_config - should not raise an error + callback.save_config(mock_trainer, pl_module, "fit") + + # Verify wandb.save was NOT called (because file doesn't exist) + mock_wandb_save.assert_not_called() + + def test_callback_handles_wandb_not_installed(self): + """Test that the callback handles missing wandb package gracefully.""" + # Create a mock parser and config + mock_parser = MagicMock() + mock_config = MagicMock() + + # Create the callback + callback = CustomSaveConfigCallback( + parser=mock_parser, + config=mock_config, + config_filename="lightning_config.yaml", + overwrite=True, + ) + + # Create a config file in the temp directory + config_path = os.path.join(self.temp_dir, "lightning_config.yaml") + with open(config_path, "w") as f: + f.write("test: config\n") + + # Create a mock WandbLogger + mock_wandb_logger = MagicMock(spec=WandbLogger) + + # Create a mock trainer with the WandbLogger + mock_trainer = MagicMock(spec=Trainer) + mock_trainer.log_dir = self.temp_dir + mock_trainer.loggers = [mock_wandb_logger] + mock_trainer.is_global_zero = True + + # Create a dummy module + pl_module = DummyModule() + + # Mock wandb import to raise ImportError + # This simulates wandb not being installed + with patch("builtins.__import__") as mock_import: + + def import_side_effect(name, *args, **kwargs): + if name == "wandb": + raise ImportError("No module named 'wandb'") + return __import__(name, *args, **kwargs) + + mock_import.side_effect = import_side_effect + + # Call save_config - should not raise an error + # The callback should catch the ImportError and continue gracefully + callback.save_config(mock_trainer, pl_module, "fit") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/cli/testCLI.py b/tests/unit/cli/testCLI.py index 863a6df3..4e221a9b 100644 --- a/tests/unit/cli/testCLI.py +++ b/tests/unit/cli/testCLI.py @@ -1,5 +1,6 @@ import unittest +from chebai.callbacks.save_config import CustomSaveConfigCallback from chebai.cli import ChebaiCLI @@ -23,6 +24,7 @@ def test_mlp_on_chebai_cli(self): try: ChebaiCLI( args=self.cli_args, + save_config_callback=CustomSaveConfigCallback, save_config_kwargs={"config_filename": "lightning_config.yaml"}, parser_kwargs={"parser_mode": "omegaconf"}, )