Skip to content

Commit 04f3448

Browse files
[pre-commit.ci] pre-commit suggestions (#338)
* [pre-commit.ci] pre-commit suggestions updates: - [github.com/psf/black: 23.11.0 → 23.12.1](psf/black@23.11.0...23.12.1) - [github.com/astral-sh/ruff-pre-commit: v0.1.6 → v0.1.11](astral-sh/ruff-pre-commit@v0.1.6...v0.1.11) - [github.com/pre-commit/mirrors-prettier: v3.1.0 → v4.0.0-alpha.8](pre-commit/mirrors-prettier@v3.1.0...v4.0.0-alpha.8) * Apply suggestions from code review * long lines --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka <jirka.borovec@seznam.cz>
1 parent 8956c48 commit 04f3448

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ repos:
4040
args: [--in-place, --wrap-summaries=120, --wrap-descriptions=120]
4141

4242
- repo: https://github.com/psf/black
43-
rev: 23.11.0
43+
rev: 23.12.1
4444
hooks:
4545
- id: black
4646
name: Black code
@@ -64,7 +64,7 @@ repos:
6464
- id: yesqa
6565

6666
- repo: https://github.com/astral-sh/ruff-pre-commit
67-
rev: v0.1.6
67+
rev: v0.1.11
6868
hooks:
6969
- id: ruff
7070
args: ["--fix"]

src/pytorch_tabular/tabular_model_sweep.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,10 @@ def _validate_args(
7676
assert all(
7777
isinstance(m, (str, ModelConfig)) for m in model_list
7878
), f"models must be a list of strings or ModelConfigs, but got {model_list}"
79-
assert all(
80-
task == m.task for m in model_list if isinstance(m, ModelConfig)
81-
), f"task must be the same as the task in ModelConfig, but got {task} and {[m.task for m in model_list if isinstance(m, ModelConfig)]}"
79+
assert all(task == m.task for m in model_list if isinstance(m, ModelConfig)), (
80+
"task must be the same as the task in ModelConfig,"
81+
f" but got {task} and {[m.task for m in model_list if isinstance(m, ModelConfig)]}"
82+
)
8283
if metrics is not None:
8384
assert isinstance(metrics, list), f"metrics must be a list of strings or callables, but got {type(metrics)}"
8485
assert all(
@@ -154,9 +155,9 @@ def model_sweep(
154155
155156
trainer_config (Union[TrainerConfig, str]): TrainerConfig object or path to the yaml file.
156157
157-
model_list (Union[str, List[Union[ModelConfig, str]]], optional): The list of models to compare. This can be one of
158-
the presets defined in ``pytorch_tabular.tabular_model_sweep.MODEL_SWEEP_PRESETS`` or a list of ``ModelConfig`` objects.
159-
Defaults to "lite".
158+
model_list (Union[str, List[Union[ModelConfig, str]]], optional): The list of models to compare.
159+
This can be one of the presets defined in ``pytorch_tabular.tabular_model_sweep.MODEL_SWEEP_PRESETS``
160+
or a list of ``ModelConfig`` objects. Defaults to "lite".
160161
161162
metrics (Optional[List[str]]): the list of metrics you need to track during training. The metrics
162163
should be one of the functional metrics implemented in ``torchmetrics``. By default, it is

0 commit comments

Comments
 (0)