[Feature]: Generalize Prediction pipeline for Lightning CLI models#148
[Feature]: Generalize Prediction pipeline for Lightning CLI models#148aditya0by0 wants to merge 18 commits intodevfrom
Conversation
|
Could you confirm our agreed approach for handling
I’m in favor of Option 2 to avoid carrying technical debt in the prediction logic. Does this match your understanding? I'm willing to add this script to the repo, and small readme note for old checkpoints for option 2. import sys
import torch
def add_class_labels_to_checkpoint(input_path, classes_file_path):
with open(classes_file_path, "r") as f:
class_labels = [line.strip() for line in f.readlines()]
assert len(class_labels) > 0, "The classes file is empty."
# 1. Load the checkpoint
checkpoint = torch.load(
input_path, map_location=torch.device("cpu"), weights_only=False
)
if "classification_labels" in checkpoint:
print(
"Warning: 'classification_labels' key already exists in the checkpoint and will be overwritten."
)
# 2. Add your custom key/value pair
checkpoint["classification_labels"] = class_labels
# 3. Save the modified checkpoint
output_path = input_path.replace(".ckpt", "_modified.ckpt")
torch.save(checkpoint, output_path)
print(f"Successfully added classification_labels and saved to {output_path}")
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python modify_checkpoints.py <input_checkpoint> <classes_file>")
sys.exit(1)
input_ckpt = sys.argv[1]
classes_file = sys.argv[2]
add_class_labels_to_checkpoint(
input_path=input_ckpt, classes_file_path=classes_file
)
|
There was a problem hiding this comment.
Pull request overview
This PR introduces a new generalized prediction pipeline intended to work with LightningCLI-saved models/checkpoints, including persisting classification label names into checkpoints for consistent prediction output formatting.
Changes:
- Add checkpoint persistence of
classification_labels(derived from a datasetclasses.txt) and wire the dataset path into model init via LightningCLI argument linking. - Introduce a new SMILES prediction entrypoint (
chebai/result/prediction.py) that reconstructs model/datamodule from checkpoint hyperparameters. - Refactor
XYBaseDataModule.predict_dataloaderto build a prediction dataloader from an in-memory SMILES list, plus update docs/tests and add VS Code workspace files.
Reviewed changes
Copilot reviewed 11 out of 12 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/unit/cli/testCLI.py | Adjusts CLI unit test model args (smaller hidden layer). |
| tests/unit/cli/mock_dm.py | Adds classes_txt_file_path for CLI linking in tests. |
| tests/unit/cli/classification_labels.txt | Adds sample classification labels used by CLI unit tests. |
| chebai/trainer/CustomTrainer.py | Removes prior bespoke prediction logic and overrides predict(). |
| chebai/result/prediction.py | Adds new prediction script/class for SMILES/file inference from checkpoint. |
| chebai/preprocessing/datasets/base.py | Refactors prediction dataloader flow and adds classes_txt_file_path. |
| chebai/models/base.py | Adds label-file loading + saving classification_labels into checkpoints. |
| chebai/cli.py | Links data.classes_txt_file_path into model.init_args.classes_txt_file_path. |
| README.md | Updates prediction instructions to use the new prediction script. |
| .vscode/settings.json | Adds VS Code project settings (currently invalid JSON). |
| .vscode/extensions.json | Adds recommended VS Code extensions. |
| .gitignore | Stops ignoring the entire .vscode directory (only ignores launch.json). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 12 out of 13 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
@copilot open a new pull request to apply changes based on the comments in this thread |
|
@aditya0by0 I've opened a new pull request, #152, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com>
) * Initial plan * Address review comments from PR #148 Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com>
|
It is possible to add test for prediction pipeline for electra. we can limit the vocab size and labels for such mock model-pipeline. @sfluegel05, Do you think is there need for such test case OR the existing test from this PR is sufficient. import os
import tempfile
import torch
from chebai.models.electra import Electra
# Smallest viable config
model = Electra(
model_type="classification",
config={
"vocab_size": 10,
"max_position_embeddings": 1,
"num_attention_heads": 1,
"num_hidden_layers": 1,
"type_vocab_size": 1,
"hidden_size": 1,
"intermediate_size": 1,
},
out_dim=10,
input_dim=10,
)
# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {num_params}")
# Save checkpoint and measure size
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_path = os.path.join(tmpdir, "electra_small.ckpt")
torch.save({"state_dict": model.state_dict()}, ckpt_path)
size_bytes = os.path.getsize(ckpt_path)
print(f"Checkpoint size: {size_bytes} bytes")(gnn) sh-4.4$ /home/staff/a/akhedekar/miniconda3/envs/gnn/bin/python /home/staff/a/akhedekar/python-chebai/test.py
Input dimension for the model: 10 Output dimension for the model: 10
Total parameters: 1959
Checkpoint size: 18367 bytes (0.018 MB) |
Generalize prediction logic
Please merge below PRs after this PR:
Related Discussion
Related bugs rectified in Lightning for the pipeline
LightningDataModule.load_from_checkpointdoes not restore subclass fromdatamodule_hyper_parametersLightning-AI/pytorch-lightning#21477save_hyperparameters(ignore=...)is not persistent across inheritance; ignored params reappear when base class also callssave_hyperparametersLightning-AI/pytorch-lightning#21488Additional changes
Save class labels in checkpoint under the key "classification_labels"
Wrap inference with
torch.inference_mode()to avoid gradient tracking (see Avoid gradient tracking python-chebifier#21)model.eval()in PyTorchtorch.no_grad()andtorch.inference_mode()Use
torch.compilefor faster inference