Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 101 additions & 37 deletions chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,112 @@ def __init__(self, *args, **kwargs):
"""
super().__init__(trainer_class=CustomTrainer, *args, **kwargs)

def before_instantiate_classes(self) -> None:
"""
Hook called before instantiating classes (Lightning 2.6+ compatible).
Instantiate the datamodule early to compute num_labels and feature_vector_size.
"""
# Get the current subcommand config (fit, test, validate, predict, etc.)
subcommand = self.config.get(self.config["subcommand"])

if not (subcommand and "data" in subcommand):
return

data_config = subcommand["data"]
if "class_path" not in data_config:
return

# Import and instantiate the datamodule class
module_path, class_name = data_config["class_path"].rsplit(".", 1)
import importlib
module = importlib.import_module(module_path)
data_class = getattr(module, class_name)

# Instantiate with init_args
init_args = data_config.get("init_args", {})
data_instance = data_class(**init_args)

# Call prepare_data and setup to initialize dynamic properties
# We need to check the private attribute to avoid calling the property which has an assert
if hasattr(data_instance, "_num_of_labels") and data_instance._num_of_labels is None:
data_instance.prepare_data()
data_instance.setup()

num_labels = data_instance.num_of_labels
feature_vector_size = data_instance.feature_vector_size

# Update model init args
self._update_model_args(subcommand, num_labels, feature_vector_size)

# Update trainer callbacks
self._update_trainer_callbacks(subcommand, num_labels)

def _update_model_args(self, subcommand: dict, num_labels: int, feature_vector_size: int) -> None:
"""Helper method to update model initialization arguments."""
if "model" not in subcommand or "init_args" not in subcommand["model"]:
return

model_init_args = subcommand["model"]["init_args"]

# Set out_dim and input_dim if not already set
if model_init_args.get("out_dim") is None:
model_init_args["out_dim"] = num_labels
if model_init_args.get("input_dim") is None:
model_init_args["input_dim"] = feature_vector_size

# Update metrics num_labels in all metrics configurations
for kind in ("train", "val", "test"):
metrics_key = f"{kind}_metrics"
metrics_config = model_init_args.get(metrics_key)
if metrics_config:
self._update_metrics_num_labels(metrics_config, num_labels)

def _update_metrics_num_labels(self, metrics_config: dict, num_labels: int) -> None:
"""Helper method to update num_labels in metrics configuration."""
init_args = metrics_config.get("init_args", {})
metrics_dict = init_args.get("metrics", {})

for metric_name, metric_config in metrics_dict.items():
metric_init_args = metric_config.get("init_args", {})
if "num_labels" in metric_init_args and metric_init_args["num_labels"] is None:
metric_init_args["num_labels"] = num_labels

def _update_trainer_callbacks(self, subcommand: dict, num_labels: int) -> None:
"""Helper method to update num_labels in trainer callbacks."""
if "trainer" not in subcommand or "callbacks" not in subcommand["trainer"]:
return

callbacks = subcommand["trainer"]["callbacks"]

if isinstance(callbacks, list):
for callback in callbacks:
self._set_callback_num_labels(callback, num_labels)
else:
self._set_callback_num_labels(callbacks, num_labels)

def _set_callback_num_labels(self, callback: dict, num_labels: int) -> None:
"""Helper method to set num_labels in a single callback configuration."""
init_args = callback.get("init_args", {})
if "num_labels" in init_args and init_args["num_labels"] is None:
init_args["num_labels"] = num_labels

def add_arguments_to_parser(self, parser: LightningArgumentParser):
"""
Link input parameters that are used by different classes (e.g. number of labels)
see https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_expert.html#argument-linking

Args:
parser (LightningArgumentParser): Argument parser instance.

Note:
In Lightning 2.6+, we use model.init_args.out_dim as the source for linking
because it's set during before_instantiate_classes() from the computed num_labels.
This avoids issues with linking from data.num_of_labels which is a property
that requires the datamodule to be instantiated.
"""

def call_data_methods(data: Type[XYBaseDataModule]):
if data._num_of_labels is None:
data.prepare_data()
data.setup()
return data.num_of_labels

parser.link_arguments(
"data",
"model.init_args.out_dim",
apply_on="instantiate",
compute_fn=call_data_methods,
)

parser.link_arguments(
"data.feature_vector_size",
"model.init_args.input_dim",
apply_on="instantiate",
)

# Link num_labels (via out_dim) to metrics configurations
# out_dim is set in before_instantiate_classes() from data.num_of_labels
for kind in ("train", "val", "test"):
for average in (
"micro-f1",
Expand All @@ -70,31 +148,17 @@ def call_data_methods(data: Type[XYBaseDataModule]):
"rmse",
"r2",
):
# When using lightning > 2.5.1 then need to uncomment all metrics that are not used
# for average in ("mse", "rmse","r2"): # for regression
# for average in ("f1", "roc-auc"): # for binary classification
# for average in ("micro-f1", "macro-f1", "roc-auc"): # for multilabel classification
# for average in ("micro-f1", "macro-f1", "balanced-accuracy", "roc-auc"): # for multilabel classification using balanced-accuracy
parser.link_arguments(
"data.num_of_labels",
"model.init_args.out_dim",
f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
apply_on="instantiate",
)

# Link out_dim to trainer callbacks
parser.link_arguments(
"data.num_of_labels", "trainer.callbacks.init_args.num_labels"
"model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
)
# parser.link_arguments(
# "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
# )
# parser.link_arguments(
# "data", "model.init_args.criterion.init_args.data_extractor"
# )
# parser.link_arguments(
# "data.init_args.chebi_version",
# "model.init_args.criterion.init_args.data_extractor.init_args.chebi_version",
# )

# Link datamodule to criterion's data extractor
parser.link_arguments(
"data", "model.init_args.criterion.init_args.data_extractor"
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"transformers",
"pysmiles==1.1.2",
"rdkit==2024.3.6",
"lightning==2.5.1",
"lightning==2.6.1",
]

[project.optional-dependencies]
Expand Down
Loading