Skip to content

Commit beb29d8

Browse files
Fix sweep to keep the best model and add best_score of the first model (#371)
* Fix sweep to keep the best model and add best_score of the first model * Remove datamodule before deepcopy best_model * Add sanity check --------- Co-authored-by: Manu Joseph V <manujosephv@gmail.com>
1 parent 2ca211d commit beb29d8

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

src/pytorch_tabular/tabular_model_sweep.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -359,27 +359,31 @@ def _init_tabular_model(m):
359359
res_dict["time_taken"] = time.time() - start_time
360360
res_dict["time_taken_per_epoch"] = res_dict["time_taken"] / res_dict["epochs"]
361361

362-
if verbose:
363-
logger.info(f"Finished Training {name}")
364-
logger.info("Results:" f" {', '.join([f'{k}: {v}' for k,v in res_dict.items()])}")
365-
res_dict["params"] = params
366-
results.append(res_dict)
367-
if best_model is None:
368-
best_model = tabular_model
369-
else:
370-
if is_lower_better:
371-
if res_dict[f"test_{rank_metric[0]}"] < best_score:
372-
best_model = tabular_model
373-
best_score = res_dict[f"test_{rank_metric[0]}"]
374-
else:
375-
if res_dict[f"test_{rank_metric[0]}"] > best_score:
376-
best_model = tabular_model
362+
if verbose:
363+
logger.info(f"Finished Training {name}")
364+
logger.info("Results:" f" {', '.join([f'{k}: {v}' for k,v in res_dict.items()])}")
365+
res_dict["params"] = params
366+
results.append(res_dict)
367+
if return_best_model:
368+
tabular_model.datamodule = None
369+
if best_model is None:
370+
best_model = copy.deepcopy(tabular_model)
377371
best_score = res_dict[f"test_{rank_metric[0]}"]
372+
else:
373+
if is_lower_better:
374+
if res_dict[f"test_{rank_metric[0]}"] < best_score:
375+
best_model = copy.deepcopy(tabular_model)
376+
best_score = res_dict[f"test_{rank_metric[0]}"]
377+
else:
378+
if res_dict[f"test_{rank_metric[0]}"] > best_score:
379+
best_model = copy.deepcopy(tabular_model)
380+
best_score = res_dict[f"test_{rank_metric[0]}"]
378381
if verbose:
379382
logger.info("Model Sweep Finished")
380383
logger.info(f"Best Model: {best_model.name}")
381384
results = pd.DataFrame(results).sort_values(by=f"test_{rank_metric[0]}", ascending=is_lower_better)
382-
if return_best_model:
385+
if return_best_model and best_model is not None:
386+
best_model.datamodule = datamodule
383387
return results, best_model
384388
else:
385389
return results

0 commit comments

Comments
 (0)