Skip to content

Commit d925a89

Browse files
authored
Add_parameters (#141)
* added early_stopping kwargs * added data aware batch size * added device to prediction * added checkpoint kwargs * pre-commit fix * pushing a linting error correction
1 parent 5d962d5 commit d925a89

File tree

5 files changed

+60
-9
lines changed

5 files changed

+60
-9
lines changed

examples/__only_for_dev__/adhoc_scaffold.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ def print_metrics(y_true, y_pred, tag):
8585
tabular_model.fit(train=train, validation=val)
8686
test.drop(columns=["target"], inplace=True)
8787
pred_df = tabular_model.predict(test)
88+
pred_df = tabular_model.predict(test, device="cpu")
89+
pred_df = tabular_model.predict(test, device="cuda")
90+
import torch # noqa: E402
8891

92+
pred_df = tabular_model.predict(test, device=torch.device("cuda"))
8993
# tabular_model.fit(train=train, validation=val)
9094
# tabular_model.fit(train=train, validation=val, max_epochs=5)
9195
# tabular_model.fit(train=train, validation=val, max_epochs=5, reset=True)

src/pytorch_tabular/config/config.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ class TrainerConfig:
237237
Args:
238238
batch_size (int): Number of samples in each batch of training
239239
240+
data_aware_init_batch_size (int): Number of samples in each batch of training for the data-aware initialization, when applicable. Defaults to 2000
241+
240242
fast_dev_run (bool): runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train, val
241243
and test to find any bugs (ie: a sort of unit test).
242244
@@ -296,6 +298,9 @@ class TrainerConfig:
296298
early_stopping_patience (int): The number of epochs to wait until there is no further improvements
297299
in loss/metric
298300
301+
early_stopping_kwargs (Optional[Dict]): Additional keyword arguments for the early stopping callback.
302+
See the documentation for the PyTorch Lightning EarlyStopping callback for more details.
303+
299304
checkpoints (Optional[str]): The loss/metric that needed to be monitored for checkpoints. If None,
300305
there will be no checkpoints
301306
@@ -311,6 +316,9 @@ class TrainerConfig:
311316
312317
checkpoints_save_top_k (int): The number of best models to save
313318
319+
checkpoints_kwargs (Optional[Dict]): Additional keyword arguments for the checkpoints callback.
320+
See the documentation for the PyTorch Lightning ModelCheckpoint callback for more details.
321+
314322
load_best (bool): Flag to load the best model saved during training
315323
316324
track_grad_norm (int): Track and Log Gradient Norms in the logger. -1 by default means no tracking.
@@ -328,6 +336,12 @@ class TrainerConfig:
328336
"""
329337

330338
batch_size: int = field(default=64, metadata={"help": "Number of samples in each batch of training"})
339+
data_aware_init_batch_size: int = field(
340+
default=2000,
341+
metadata={
342+
"help": "Number of samples in each batch of training for the data-aware initialization, when applicable. Defaults to 2000"
343+
},
344+
)
331345
fast_dev_run: bool = field(
332346
default=False,
333347
metadata={
@@ -429,6 +443,12 @@ class TrainerConfig:
429443
default=3,
430444
metadata={"help": "The number of epochs to wait until there is no further improvements in loss/metric"},
431445
)
446+
early_stopping_kwargs: Optional[Dict[str, Any]] = field(
447+
default_factory=lambda: dict(),
448+
metadata={
449+
"help": "Additional keyword arguments for the early stopping callback. See the documentation for the PyTorch Lightning EarlyStopping callback for more details."
450+
},
451+
)
432452
checkpoints: Optional[str] = field(
433453
default="valid_loss",
434454
metadata={
@@ -457,6 +477,12 @@ class TrainerConfig:
457477
default=1,
458478
metadata={"help": "The number of best models to save"},
459479
)
480+
checkpoints_kwargs: Optional[Dict[str, Any]] = field(
481+
default_factory=lambda: dict(),
482+
metadata={
483+
"help": "Additional keyword arguments for the checkpoints callback. See the documentation for the PyTorch Lightning ModelCheckpoint callback for more details."
484+
},
485+
)
460486
load_best: bool = field(
461487
default=True,
462488
metadata={"help": "Flag to load the best model saved during training"},
@@ -508,6 +534,16 @@ def __post_init__(self):
508534
warnings.warn("Ignoring devices in favor of devices_list")
509535
self.devices = self.devices_list
510536
delattr(self, "devices_list")
537+
for key in self.early_stopping_kwargs.keys():
538+
if key in ["min_delta", "mode", "patience"]:
539+
raise ValueError(
540+
f"Cannot override {key} in early_stopping_kwargs. Please use the appropriate argument in `TrainerConfig`"
541+
)
542+
for key in self.checkpoints_kwargs.keys():
543+
if key in ["dirpath", "filename", "monitor", "save_top_k", "mode", "every_n_epochs"]:
544+
raise ValueError(
545+
f"Cannot override {key} in checkpoints_kwargs. Please use the appropriate argument in `TrainerConfig`"
546+
)
511547

512548

513549
@dataclass

src/pytorch_tabular/models/gate/gate_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,6 @@ def data_aware_initialization(self, datamodule):
220220
if self.hparams.task == "regression":
221221
logger.info("Data Aware Initialization of T0")
222222
# Need a big batch to initialize properly
223-
alt_loader = datamodule.train_dataloader(batch_size=2000)
223+
alt_loader = datamodule.train_dataloader(batch_size=self.hparams.data_aware_init_batch_size)
224224
batch = next(iter(alt_loader))
225225
self.head.T0.data = torch.mean(batch["target"], dim=0)

src/pytorch_tabular/models/node/node_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def data_aware_initialization(self, datamodule):
8181
"""Performs data-aware initialization for NODE"""
8282
logger.info("Data Aware Initialization of NODE using a forward pass with 2000 batch size....")
8383
# Need a big batch to initialize properly
84-
alt_loader = datamodule.train_dataloader(batch_size=2000)
84+
alt_loader = datamodule.train_dataloader(batch_size=self.hparams.data_aware_init_batch_size)
8585
batch = next(iter(alt_loader))
8686
for k, v in batch.items():
8787
if isinstance(v, list) and (len(v) == 0):

src/pytorch_tabular/tabular_model.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ def _prepare_callbacks(self, callbacks=None) -> List:
225225
monitor=self.config.early_stopping,
226226
min_delta=self.config.early_stopping_min_delta,
227227
patience=self.config.early_stopping_patience,
228-
verbose=False,
229228
mode=self.config.early_stopping_mode,
229+
**self.config.early_stopping_kwargs,
230230
)
231231
callbacks.append(early_stop_callback)
232232
if self.config.checkpoints:
@@ -239,6 +239,7 @@ def _prepare_callbacks(self, callbacks=None) -> List:
239239
save_top_k=self.config.checkpoints_save_top_k,
240240
mode=self.config.checkpoints_mode,
241241
every_n_epochs=self.config.checkpoints_every_n_epochs,
242+
**self.config.checkpoints_kwargs,
242243
)
243244
callbacks.append(model_checkpoint)
244245
self.config.enable_checkpointing = True
@@ -1061,6 +1062,7 @@ def predict(
10611062
n_samples: Optional[int] = 100,
10621063
ret_logits=False,
10631064
include_input_features: bool = True,
1065+
device: Optional[torch.device] = None,
10641066
) -> pd.DataFrame:
10651067
"""Uses the trained model to predict on new data and return as a dataframe
10661068
@@ -1085,26 +1087,35 @@ def predict(
10851087
DeprecationWarning,
10861088
)
10871089
assert all([q <= 1 and q >= 0 for q in quantiles]), "Quantiles should be a decimal between 0 and 1"
1088-
self.model.eval()
1090+
if device is not None:
1091+
if isinstance(device, str):
1092+
device = torch.device(device)
1093+
if self.model.device != device:
1094+
model = self.model.to(device)
1095+
else:
1096+
model = self.model
1097+
else:
1098+
model = self.model
1099+
model.eval()
10891100
inference_dataloader = self.datamodule.prepare_inference_dataloader(test)
10901101
point_predictions = []
10911102
quantile_predictions = []
10921103
logits_predictions = defaultdict(list)
1093-
is_probabilistic = hasattr(self.model.hparams, "_probabilistic") and self.model.hparams._probabilistic
1104+
is_probabilistic = hasattr(model.hparams, "_probabilistic") and model.hparams._probabilistic
10941105
for batch in track(inference_dataloader, description="Generating Predictions..."):
10951106
for k, v in batch.items():
10961107
if isinstance(v, list) and (len(v) == 0):
10971108
# Skipping empty list
10981109
continue
1099-
batch[k] = v.to(self.model.device)
1110+
batch[k] = v.to(model.device)
11001111
if is_probabilistic:
1101-
samples, ret_value = self.model.sample(batch, n_samples, ret_model_output=True)
1112+
samples, ret_value = model.sample(batch, n_samples, ret_model_output=True)
11021113
y_hat = torch.mean(samples, dim=-1)
11031114
quantile_preds = []
11041115
for q in quantiles:
11051116
quantile_preds.append(torch.quantile(samples, q=q, dim=-1).unsqueeze(1))
11061117
else:
1107-
y_hat, ret_value = self.model.predict(batch, ret_model_output=True)
1118+
y_hat, ret_value = model.predict(batch, ret_model_output=True)
11081119
if ret_logits:
11091120
for k, v in ret_value.items():
11101121
# if k == "backbone_features":
@@ -1121,7 +1132,7 @@ def predict(
11211132
if quantile_predictions.ndim == 2:
11221133
quantile_predictions = quantile_predictions.unsqueeze(-1)
11231134
if include_input_features:
1124-
pred_df = test.copy() # TODO Add option to switch between including the entire input DF or not.
1135+
pred_df = test.copy()
11251136
else:
11261137
pred_df = pd.DataFrame(index=test.index)
11271138
if self.config.task == "regression":

0 commit comments

Comments
 (0)