From 1f955bcb3f72351bd521b1cef1705257b3c2eed9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Feb 2026 20:42:43 +0000 Subject: [PATCH 1/5] Initial plan From d44e7845284b427b55521e3879e5da756878d725 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Feb 2026 20:52:17 +0000 Subject: [PATCH 2/5] Add CustomSaveConfigCallback to upload lightning_config.yaml to W&B Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com> --- chebai/callbacks/save_config.py | 67 +++++++++ chebai/cli.py | 2 + tests/unit/callbacks/__init__.py | 0 tests/unit/callbacks/test_save_config.py | 182 +++++++++++++++++++++++ tests/unit/cli/testCLI.py | 2 + 5 files changed, 253 insertions(+) create mode 100644 chebai/callbacks/save_config.py create mode 100644 tests/unit/callbacks/__init__.py create mode 100644 tests/unit/callbacks/test_save_config.py diff --git a/chebai/callbacks/save_config.py b/chebai/callbacks/save_config.py new file mode 100644 index 00000000..7318b548 --- /dev/null +++ b/chebai/callbacks/save_config.py @@ -0,0 +1,67 @@ +import os +from typing import TYPE_CHECKING + +from lightning import LightningModule, Trainer +from lightning.pytorch.cli import SaveConfigCallback +from lightning.pytorch.loggers import WandbLogger + +if TYPE_CHECKING: + pass + + +class CustomSaveConfigCallback(SaveConfigCallback): + """ + Custom SaveConfigCallback that uploads the Lightning config file to W&B. + + This callback extends the default SaveConfigCallback to automatically upload + the lightning_config.yaml file to Weights & Biases online run logs when using + WandbLogger. This ensures better traceability and reproducibility of experiments. + + The config file is uploaded using wandb.save(), which makes it available in the + W&B web interface under the "Files" tab of the run. + """ + + def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: + """ + Save the config to W&B if a WandbLogger is being used. + + This method is called after the config file has been saved to the log directory. + It checks if the trainer is using a WandbLogger and, if so, uploads the config + file to W&B using wandb.save(). + + Args: + trainer: The PyTorch Lightning Trainer instance. + pl_module: The LightningModule being trained. + stage: The current stage of training (e.g., 'fit', 'validate', 'test'). + """ + # Only proceed if we're saving to log_dir and have a valid trainer + if not self.save_to_log_dir or trainer.log_dir is None: + return + + # Check if we're using WandbLogger + wandb_logger = None + for logger in trainer.loggers if hasattr(trainer, "loggers") else []: + if isinstance(logger, WandbLogger): + wandb_logger = logger + break + + # If WandbLogger is found, upload the config file + if wandb_logger is not None: + config_path = os.path.join(trainer.log_dir, self.config_filename) + if os.path.exists(config_path): + try: + import wandb + + # Upload the config file to W&B + # This will make it available in the W&B web interface + wandb.save(config_path, base_path=trainer.log_dir, policy="now") + except ImportError: + # wandb is not installed, skip uploading + pass + except Exception as e: + # Log the error but don't fail the training run + import logging + + logging.warning( + f"Failed to upload {self.config_filename} to W&B: {e}" + ) diff --git a/chebai/cli.py b/chebai/cli.py index 1aaba53c..8ca5cd58 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -2,6 +2,7 @@ from lightning.pytorch.cli import LightningArgumentParser, LightningCLI +from chebai.callbacks.save_config import CustomSaveConfigCallback from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.trainer.CustomTrainer import CustomTrainer @@ -121,6 +122,7 @@ def cli(): Main function to instantiate and run the ChebaiCLI. """ ChebaiCLI( + save_config_callback=CustomSaveConfigCallback, save_config_kwargs={"config_filename": "lightning_config.yaml"}, parser_kwargs={"parser_mode": "omegaconf"}, ) diff --git a/tests/unit/callbacks/__init__.py b/tests/unit/callbacks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/callbacks/test_save_config.py b/tests/unit/callbacks/test_save_config.py new file mode 100644 index 00000000..4e3c3a0f --- /dev/null +++ b/tests/unit/callbacks/test_save_config.py @@ -0,0 +1,182 @@ +import os +import tempfile +import unittest +from unittest.mock import MagicMock, patch + +from lightning import LightningModule, Trainer +from lightning.pytorch.loggers import WandbLogger + +from chebai.callbacks.save_config import CustomSaveConfigCallback + + +class DummyModule(LightningModule): + """Dummy module for testing.""" + + def __init__(self): + super().__init__() + self.layer = None + + def forward(self, x): + return x + + +class TestCustomSaveConfigCallback(unittest.TestCase): + """Test CustomSaveConfigCallback functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + + def test_callback_uploads_config_with_wandb_logger(self): + """Test that the callback uploads config when WandbLogger is present.""" + # Create a mock parser and config + mock_parser = MagicMock() + mock_config = MagicMock() + + # Create the callback + callback = CustomSaveConfigCallback( + parser=mock_parser, + config=mock_config, + config_filename="lightning_config.yaml", + overwrite=True, + ) + + # Create a config file in the temp directory + config_path = os.path.join(self.temp_dir, "lightning_config.yaml") + with open(config_path, "w") as f: + f.write("test: config\n") + + # Create a mock WandbLogger + mock_wandb_logger = MagicMock(spec=WandbLogger) + + # Create a mock trainer with the WandbLogger + mock_trainer = MagicMock(spec=Trainer) + mock_trainer.log_dir = self.temp_dir + mock_trainer.loggers = [mock_wandb_logger] + mock_trainer.is_global_zero = True + + # Create a dummy module + pl_module = DummyModule() + + # Mock wandb module + with patch("wandb.save") as mock_wandb_save: + # Call save_config + callback.save_config(mock_trainer, pl_module, "fit") + + # Verify wandb.save was called with the correct arguments + mock_wandb_save.assert_called_once_with( + config_path, base_path=self.temp_dir, policy="now" + ) + + def test_callback_skips_upload_without_wandb_logger(self): + """Test that the callback skips upload when no WandbLogger is present.""" + # Create a mock parser and config + mock_parser = MagicMock() + mock_config = MagicMock() + + # Create the callback + callback = CustomSaveConfigCallback( + parser=mock_parser, + config=mock_config, + config_filename="lightning_config.yaml", + overwrite=True, + ) + + # Create a config file in the temp directory + config_path = os.path.join(self.temp_dir, "lightning_config.yaml") + with open(config_path, "w") as f: + f.write("test: config\n") + + # Create a mock trainer WITHOUT WandbLogger + mock_trainer = MagicMock(spec=Trainer) + mock_trainer.log_dir = self.temp_dir + mock_trainer.loggers = [] # No loggers + mock_trainer.is_global_zero = True + + # Create a dummy module + pl_module = DummyModule() + + # Mock wandb module + with patch("wandb.save") as mock_wandb_save: + # Call save_config + callback.save_config(mock_trainer, pl_module, "fit") + + # Verify wandb.save was NOT called + mock_wandb_save.assert_not_called() + + def test_callback_handles_missing_config_file(self): + """Test that the callback handles missing config file gracefully.""" + # Create a mock parser and config + mock_parser = MagicMock() + mock_config = MagicMock() + + # Create the callback + callback = CustomSaveConfigCallback( + parser=mock_parser, + config=mock_config, + config_filename="nonexistent_config.yaml", + overwrite=True, + ) + + # Create a mock WandbLogger + mock_wandb_logger = MagicMock(spec=WandbLogger) + + # Create a mock trainer with the WandbLogger + mock_trainer = MagicMock(spec=Trainer) + mock_trainer.log_dir = self.temp_dir + mock_trainer.loggers = [mock_wandb_logger] + mock_trainer.is_global_zero = True + + # Create a dummy module + pl_module = DummyModule() + + # Mock wandb module + with patch("wandb.save") as mock_wandb_save: + # Call save_config - should not raise an error + callback.save_config(mock_trainer, pl_module, "fit") + + # Verify wandb.save was NOT called (because file doesn't exist) + mock_wandb_save.assert_not_called() + + def test_callback_handles_wandb_not_installed(self): + """Test that the callback handles missing wandb package gracefully.""" + # Create a mock parser and config + mock_parser = MagicMock() + mock_config = MagicMock() + + # Create the callback + callback = CustomSaveConfigCallback( + parser=mock_parser, + config=mock_config, + config_filename="lightning_config.yaml", + overwrite=True, + ) + + # Create a config file in the temp directory + config_path = os.path.join(self.temp_dir, "lightning_config.yaml") + with open(config_path, "w") as f: + f.write("test: config\n") + + # Create a mock WandbLogger + mock_wandb_logger = MagicMock(spec=WandbLogger) + + # Create a mock trainer with the WandbLogger + mock_trainer = MagicMock(spec=Trainer) + mock_trainer.log_dir = self.temp_dir + mock_trainer.loggers = [mock_wandb_logger] + mock_trainer.is_global_zero = True + + # Create a dummy module + pl_module = DummyModule() + + # Mock wandb import to raise ImportError + with patch("builtins.__import__", side_effect=ImportError("No wandb")): + # Call save_config - should not raise an error + try: + callback.save_config(mock_trainer, pl_module, "fit") + except ImportError: + self.fail("save_config should handle ImportError gracefully") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/cli/testCLI.py b/tests/unit/cli/testCLI.py index 863a6df3..4e221a9b 100644 --- a/tests/unit/cli/testCLI.py +++ b/tests/unit/cli/testCLI.py @@ -1,5 +1,6 @@ import unittest +from chebai.callbacks.save_config import CustomSaveConfigCallback from chebai.cli import ChebaiCLI @@ -23,6 +24,7 @@ def test_mlp_on_chebai_cli(self): try: ChebaiCLI( args=self.cli_args, + save_config_callback=CustomSaveConfigCallback, save_config_kwargs={"config_filename": "lightning_config.yaml"}, parser_kwargs={"parser_mode": "omegaconf"}, ) From 9cdca2e9a57bf698bf5b70353438b6f5198de241 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Feb 2026 20:54:17 +0000 Subject: [PATCH 3/5] Address code review feedback - improve test and add clarifying comment Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com> --- chebai/callbacks/save_config.py | 9 ++++++++- tests/unit/callbacks/test_save_config.py | 8 +++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/chebai/callbacks/save_config.py b/chebai/callbacks/save_config.py index 7318b548..bc3e7861 100644 --- a/chebai/callbacks/save_config.py +++ b/chebai/callbacks/save_config.py @@ -21,7 +21,9 @@ class CustomSaveConfigCallback(SaveConfigCallback): W&B web interface under the "Files" tab of the run. """ - def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: + def save_config( + self, trainer: Trainer, pl_module: LightningModule, stage: str + ) -> None: """ Save the config to W&B if a WandbLogger is being used. @@ -29,6 +31,11 @@ def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) It checks if the trainer is using a WandbLogger and, if so, uploads the config file to W&B using wandb.save(). + Note: + We don't call super().save_config() because the parent class implementation + is empty. The actual config file saving to disk happens in the setup() method + before this method is called. + Args: trainer: The PyTorch Lightning Trainer instance. pl_module: The LightningModule being trained. diff --git a/tests/unit/callbacks/test_save_config.py b/tests/unit/callbacks/test_save_config.py index 4e3c3a0f..b19c21f5 100644 --- a/tests/unit/callbacks/test_save_config.py +++ b/tests/unit/callbacks/test_save_config.py @@ -169,12 +169,14 @@ def test_callback_handles_wandb_not_installed(self): # Create a dummy module pl_module = DummyModule() - # Mock wandb import to raise ImportError - with patch("builtins.__import__", side_effect=ImportError("No wandb")): + # Mock wandb import to raise ImportError by patching sys.modules + import sys + + with patch.dict(sys.modules, {"wandb": None}): # Call save_config - should not raise an error try: callback.save_config(mock_trainer, pl_module, "fit") - except ImportError: + except (ImportError, AttributeError): self.fail("save_config should handle ImportError gracefully") From bcf457d64c0d6a82607cffe9373c1cb2558ebded Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Feb 2026 20:56:08 +0000 Subject: [PATCH 4/5] Address additional code review feedback - improve code organization and clarity Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com> --- chebai/callbacks/save_config.py | 49 ++++++++++++++---------- tests/unit/callbacks/test_save_config.py | 8 +++- 2 files changed, 35 insertions(+), 22 deletions(-) diff --git a/chebai/callbacks/save_config.py b/chebai/callbacks/save_config.py index bc3e7861..1c3fcb8a 100644 --- a/chebai/callbacks/save_config.py +++ b/chebai/callbacks/save_config.py @@ -1,3 +1,4 @@ +import logging import os from typing import TYPE_CHECKING @@ -8,6 +9,8 @@ if TYPE_CHECKING: pass +logger = logging.getLogger(__name__) + class CustomSaveConfigCallback(SaveConfigCallback): """ @@ -36,6 +39,10 @@ def save_config( is empty. The actual config file saving to disk happens in the setup() method before this method is called. + This method uses the following attributes from the parent SaveConfigCallback: + - self.save_to_log_dir: Whether to save config to the log directory + - self.config_filename: Name of the config file to upload + Args: trainer: The PyTorch Lightning Trainer instance. pl_module: The LightningModule being trained. @@ -52,23 +59,25 @@ def save_config( wandb_logger = logger break - # If WandbLogger is found, upload the config file - if wandb_logger is not None: - config_path = os.path.join(trainer.log_dir, self.config_filename) - if os.path.exists(config_path): - try: - import wandb - - # Upload the config file to W&B - # This will make it available in the W&B web interface - wandb.save(config_path, base_path=trainer.log_dir, policy="now") - except ImportError: - # wandb is not installed, skip uploading - pass - except Exception as e: - # Log the error but don't fail the training run - import logging - - logging.warning( - f"Failed to upload {self.config_filename} to W&B: {e}" - ) + # If WandbLogger is not found, skip uploading + if wandb_logger is None: + return + + # Check if the config file exists + config_path = os.path.join(trainer.log_dir, self.config_filename) + if not os.path.exists(config_path): + return + + # Upload the config file to W&B + try: + import wandb + + # Upload the config file to W&B + # This will make it available in the W&B web interface + wandb.save(config_path, base_path=trainer.log_dir, policy="now") + except ImportError: + # wandb is not installed, skip uploading + pass + except Exception as e: + # Log the error but don't fail the training run + logger.warning(f"Failed to upload {self.config_filename} to W&B: {e}") diff --git a/tests/unit/callbacks/test_save_config.py b/tests/unit/callbacks/test_save_config.py index b19c21f5..44f5ceeb 100644 --- a/tests/unit/callbacks/test_save_config.py +++ b/tests/unit/callbacks/test_save_config.py @@ -174,10 +174,14 @@ def test_callback_handles_wandb_not_installed(self): with patch.dict(sys.modules, {"wandb": None}): # Call save_config - should not raise an error + # Note: This may raise AttributeError when trying to access wandb.save + # since wandb is None, which is expected behavior we want to handle gracefully try: callback.save_config(mock_trainer, pl_module, "fit") - except (ImportError, AttributeError): - self.fail("save_config should handle ImportError gracefully") + except (ImportError, AttributeError) as e: + self.fail( + f"save_config should handle missing wandb package gracefully, got: {e}" + ) if __name__ == "__main__": From 00c636644f1d9aeefacfd93bfdaefeaff12cb7b8 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Feb 2026 20:57:37 +0000 Subject: [PATCH 5/5] Remove unused TYPE_CHECKING and improve test for wandb import handling Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com> --- chebai/callbacks/save_config.py | 4 ---- tests/unit/callbacks/test_save_config.py | 23 ++++++++++++----------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/chebai/callbacks/save_config.py b/chebai/callbacks/save_config.py index 1c3fcb8a..92dfadbd 100644 --- a/chebai/callbacks/save_config.py +++ b/chebai/callbacks/save_config.py @@ -1,14 +1,10 @@ import logging import os -from typing import TYPE_CHECKING from lightning import LightningModule, Trainer from lightning.pytorch.cli import SaveConfigCallback from lightning.pytorch.loggers import WandbLogger -if TYPE_CHECKING: - pass - logger = logging.getLogger(__name__) diff --git a/tests/unit/callbacks/test_save_config.py b/tests/unit/callbacks/test_save_config.py index 44f5ceeb..f53bbe0d 100644 --- a/tests/unit/callbacks/test_save_config.py +++ b/tests/unit/callbacks/test_save_config.py @@ -169,19 +169,20 @@ def test_callback_handles_wandb_not_installed(self): # Create a dummy module pl_module = DummyModule() - # Mock wandb import to raise ImportError by patching sys.modules - import sys + # Mock wandb import to raise ImportError + # This simulates wandb not being installed + with patch("builtins.__import__") as mock_import: + + def import_side_effect(name, *args, **kwargs): + if name == "wandb": + raise ImportError("No module named 'wandb'") + return __import__(name, *args, **kwargs) + + mock_import.side_effect = import_side_effect - with patch.dict(sys.modules, {"wandb": None}): # Call save_config - should not raise an error - # Note: This may raise AttributeError when trying to access wandb.save - # since wandb is None, which is expected behavior we want to handle gracefully - try: - callback.save_config(mock_trainer, pl_module, "fit") - except (ImportError, AttributeError) as e: - self.fail( - f"save_config should handle missing wandb package gracefully, got: {e}" - ) + # The callback should catch the ImportError and continue gracefully + callback.save_config(mock_trainer, pl_module, "fit") if __name__ == "__main__":