From 7dbd707097038f350c72c4ce98d051c2409acff1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Feb 2026 19:39:07 +0000 Subject: [PATCH 1/2] Initial plan From 969cf53c8b36b2db461224c34a0d77a0a33d48d3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Feb 2026 19:40:47 +0000 Subject: [PATCH 2/2] Address review comments from PR #148 Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com> --- README.md | 2 +- chebai/preprocessing/datasets/base.py | 4 ++-- chebai/preprocessing/migration/migrate_checkpoints.py | 4 ++-- chebai/result/prediction.py | 7 +++---- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index e650ad91..401c7324 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --model=conf ``` A command with additional options may look like this: ``` -python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce_weighted.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000 +python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce_weighted.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_weighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000 ``` ### Fine-tuning for classification tasks, e.g. Toxicity prediction diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index f8d5653e..e295a3ed 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -445,7 +445,7 @@ def _process_input_for_prediction( Args: smiles_list (List[str]): List of SMILES strings. - model_hparams (Optional[dict]): Model hyperparameters. + model_hparams (dict): Model hyperparameters. Some prediction pre-processing pipelines may require these. Returns: @@ -467,7 +467,7 @@ def _process_input_for_prediction( return data, valid_indices def _preprocess_smiles_for_pred( - self, idx, smiles: str, model_hparams: Optional[dict] = None + self, idx: int, smiles: str, model_hparams: Optional[dict] = None ) -> dict: """Preprocess prediction data.""" # Add dummy labels because the collate function requires them. diff --git a/chebai/preprocessing/migration/migrate_checkpoints.py b/chebai/preprocessing/migration/migrate_checkpoints.py index dae1d9a6..00fbde38 100644 --- a/chebai/preprocessing/migration/migrate_checkpoints.py +++ b/chebai/preprocessing/migration/migrate_checkpoints.py @@ -41,8 +41,8 @@ def add_class_labels_to_checkpoint(input_path, classes_file_path): if __name__ == "__main__": - if len(sys.argv) < 2: - print("Usage: python modify_checkpoints.py ") + if len(sys.argv) < 3: + print("Usage: python migrate_checkpoints.py ") sys.exit(1) input_ckpt = sys.argv[1] diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index cad8e067..60548c5e 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -71,9 +71,8 @@ def __init__( self._model.to(self.device) print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}") - try: - self._classification_labels: list = ckpt_file.get("classification_labels") - except KeyError: + self._classification_labels: list = ckpt_file.get("classification_labels") + if self._classification_labels is None: raise KeyError( "The checkpoint does not contain 'classification_labels'. " "Make sure the checkpoint is compatible with python-chebai version 1.2.1 or later." @@ -140,7 +139,7 @@ def predict_smiles( Returns: A tensor containing the predictions. """ - # For certain data prediction piplines, we may need model hyperparameters + # For certain data prediction pipelines, we may need model hyperparameters pred_dl, valid_indices = self._dm.predict_dataloader( smiles_list=smiles, model_hparams=self._model_hparams )