Skip to content

[Feature]: Generalize Prediction pipeline for Lightning CLI models#148

Open
aditya0by0 wants to merge 18 commits intodevfrom
feature/general_pred_pipeline
Open

[Feature]: Generalize Prediction pipeline for Lightning CLI models#148
aditya0by0 wants to merge 18 commits intodevfrom
feature/general_pred_pipeline

Conversation

@aditya0by0
Copy link
Member

@aditya0by0 aditya0by0 commented Jan 30, 2026

Generalize prediction logic

Please merge below PRs after this PR:

Related Discussion

Related bugs rectified in Lightning for the pipeline

Additional changes

@aditya0by0 aditya0by0 added the enhancement New feature or request label Jan 30, 2026
@aditya0by0 aditya0by0 requested a review from sfluegel05 February 3, 2026 10:12
@aditya0by0 aditya0by0 marked this pull request as ready for review February 3, 2026 10:12
@aditya0by0
Copy link
Member Author

@sfluegel05,

Could you confirm our agreed approach for handling old_checkpoint files which don't classification labels stored in them?

  1. Update code to handle legacy checkpoints: This requires adding logic to prediction.py and the chebifier repo to ingest external class files.

    • Concerns: Adds boilerplate and permanent complexity to handle a temporary issue.
  2. Patch old checkpoints (Preferred): Use the below one-time script to inject labels into the existing files.

    • Benefits: Keeps the codebase clean and ensures all checkpoints follow a standardized schema.

I’m in favor of Option 2 to avoid carrying technical debt in the prediction logic. Does this match your understanding?

I'm willing to add this script to the repo, and small readme note for old checkpoints for option 2.

import sys

import torch


def add_class_labels_to_checkpoint(input_path, classes_file_path):
    with open(classes_file_path, "r") as f:
        class_labels = [line.strip() for line in f.readlines()]

    assert len(class_labels) > 0, "The classes file is empty."

    # 1. Load the checkpoint
    checkpoint = torch.load(
        input_path, map_location=torch.device("cpu"), weights_only=False
    )

    if "classification_labels" in checkpoint:
        print(
            "Warning: 'classification_labels' key already exists in the checkpoint and will be overwritten."
        )


    # 2. Add your custom key/value pair
    checkpoint["classification_labels"] = class_labels

    # 3. Save the modified checkpoint
    output_path = input_path.replace(".ckpt", "_modified.ckpt")
    torch.save(checkpoint, output_path)
    print(f"Successfully added classification_labels and saved to {output_path}")


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Usage: python modify_checkpoints.py <input_checkpoint> <classes_file>")
        sys.exit(1)

    input_ckpt = sys.argv[1]
    classes_file = sys.argv[2]

    add_class_labels_to_checkpoint(
        input_path=input_ckpt, classes_file_path=classes_file
    )

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a new generalized prediction pipeline intended to work with LightningCLI-saved models/checkpoints, including persisting classification label names into checkpoints for consistent prediction output formatting.

Changes:

  • Add checkpoint persistence of classification_labels (derived from a dataset classes.txt) and wire the dataset path into model init via LightningCLI argument linking.
  • Introduce a new SMILES prediction entrypoint (chebai/result/prediction.py) that reconstructs model/datamodule from checkpoint hyperparameters.
  • Refactor XYBaseDataModule.predict_dataloader to build a prediction dataloader from an in-memory SMILES list, plus update docs/tests and add VS Code workspace files.

Reviewed changes

Copilot reviewed 11 out of 12 changed files in this pull request and generated 11 comments.

Show a summary per file
File Description
tests/unit/cli/testCLI.py Adjusts CLI unit test model args (smaller hidden layer).
tests/unit/cli/mock_dm.py Adds classes_txt_file_path for CLI linking in tests.
tests/unit/cli/classification_labels.txt Adds sample classification labels used by CLI unit tests.
chebai/trainer/CustomTrainer.py Removes prior bespoke prediction logic and overrides predict().
chebai/result/prediction.py Adds new prediction script/class for SMILES/file inference from checkpoint.
chebai/preprocessing/datasets/base.py Refactors prediction dataloader flow and adds classes_txt_file_path.
chebai/models/base.py Adds label-file loading + saving classification_labels into checkpoints.
chebai/cli.py Links data.classes_txt_file_path into model.init_args.classes_txt_file_path.
README.md Updates prediction instructions to use the new prediction script.
.vscode/settings.json Adds VS Code project settings (currently invalid JSON).
.vscode/extensions.json Adds recommended VS Code extensions.
.gitignore Stops ignoring the entire .vscode directory (only ignores launch.json).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

aditya0by0 and others added 3 commits February 10, 2026 00:32
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@aditya0by0 aditya0by0 marked this pull request as draft February 17, 2026 12:52
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 12 out of 13 changed files in this pull request and generated 8 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@aditya0by0
Copy link
Member Author

@copilot open a new pull request to apply changes based on the comments in this thread

Copy link

Copilot AI commented Feb 17, 2026

@aditya0by0 I've opened a new pull request, #152, to work on those changes. Once the pull request is ready, I'll request review from you.

Copilot AI added a commit that referenced this pull request Feb 17, 2026
Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com>
Copilot AI and others added 2 commits February 17, 2026 20:44
)

* Initial plan

* Address review comments from PR #148

Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com>
@aditya0by0
Copy link
Member Author

It is possible to add test for prediction pipeline for electra.
The pipeline will train the electra for 1 epoch, save and load the model and make a prediction.

we can limit the vocab size and labels for such mock model-pipeline.

@sfluegel05, Do you think is there need for such test case OR the existing test from this PR is sufficient.
If yes, do we need to treat it in an another issue in a another PR.

import os
import tempfile

import torch

from chebai.models.electra import Electra

# Smallest viable config
model = Electra(
    model_type="classification",
    config={
        "vocab_size": 10,
        "max_position_embeddings": 1,
        "num_attention_heads": 1,
        "num_hidden_layers": 1,
        "type_vocab_size": 1,
        "hidden_size": 1,
        "intermediate_size": 1,
    },
    out_dim=10,
    input_dim=10,
)

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {num_params}")

# Save checkpoint and measure size
with tempfile.TemporaryDirectory() as tmpdir:
    ckpt_path = os.path.join(tmpdir, "electra_small.ckpt")
    torch.save({"state_dict": model.state_dict()}, ckpt_path)
    size_bytes = os.path.getsize(ckpt_path)
    print(f"Checkpoint size: {size_bytes} bytes")
(gnn) sh-4.4$ /home/staff/a/akhedekar/miniconda3/envs/gnn/bin/python /home/staff/a/akhedekar/python-chebai/test.py
Input dimension for the model: 10 Output dimension for the model: 10
Total parameters: 1959
Checkpoint size: 18367 bytes (0.018 MB)

@aditya0by0 aditya0by0 marked this pull request as ready for review February 17, 2026 20:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request priority: high

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants