Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --model=conf
```
A command with additional options may look like this:
```
python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce_weighted.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000
python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce_weighted.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_weighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000
```

### Fine-tuning for classification tasks, e.g. Toxicity prediction
Expand Down
4 changes: 2 additions & 2 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def _process_input_for_prediction(

Args:
smiles_list (List[str]): List of SMILES strings.
model_hparams (Optional[dict]): Model hyperparameters.
model_hparams (dict): Model hyperparameters.
Some prediction pre-processing pipelines may require these.

Returns:
Expand All @@ -467,7 +467,7 @@ def _process_input_for_prediction(
return data, valid_indices

def _preprocess_smiles_for_pred(
self, idx, smiles: str, model_hparams: Optional[dict] = None
self, idx: int, smiles: str, model_hparams: Optional[dict] = None
) -> dict:
"""Preprocess prediction data."""
# Add dummy labels because the collate function requires them.
Expand Down
4 changes: 2 additions & 2 deletions chebai/preprocessing/migration/migrate_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def add_class_labels_to_checkpoint(input_path, classes_file_path):


if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python modify_checkpoints.py <input_checkpoint> <classes_file>")
if len(sys.argv) < 3:
print("Usage: python migrate_checkpoints.py <input_checkpoint> <classes_file>")
sys.exit(1)

input_ckpt = sys.argv[1]
Expand Down
7 changes: 3 additions & 4 deletions chebai/result/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ def __init__(
self._model.to(self.device)
print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}")

try:
self._classification_labels: list = ckpt_file.get("classification_labels")
except KeyError:
self._classification_labels: list = ckpt_file.get("classification_labels")
if self._classification_labels is None:
raise KeyError(
"The checkpoint does not contain 'classification_labels'. "
"Make sure the checkpoint is compatible with python-chebai version 1.2.1 or later."
Expand Down Expand Up @@ -140,7 +139,7 @@ def predict_smiles(
Returns:
A tensor containing the predictions.
"""
# For certain data prediction piplines, we may need model hyperparameters
# For certain data prediction pipelines, we may need model hyperparameters
pred_dl, valid_indices = self._dm.predict_dataloader(
smiles_list=smiles, model_hparams=self._model_hparams
)
Expand Down