Skip to content

Commit 0a5ee6d

Browse files
authored
removed tqdm and switched over to rich (#137)
* removed tqdm and switched over to rich * fixed issue with rich progressbar
1 parent 4ab4545 commit 0a5ee6d

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

src/pytorch_tabular/categorical_encoders.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313
import numpy as np
1414
import pandas as pd
15+
from rich.progress import track
1516
from sklearn.base import BaseEstimator, TransformerMixin
16-
from tqdm.autonotebook import tqdm
1717

1818
from pytorch_tabular.utils import get_logger
1919

@@ -198,9 +198,9 @@ def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
198198
assert all(c in X.columns for c in self.cols)
199199

200200
X_encoded = X.copy(deep=True)
201-
for col, mapping in tqdm(
201+
for col, mapping in track(
202202
self._mapping.items(),
203-
desc="Encoding the data...",
203+
description="Encoding the data...",
204204
total=len(self._mapping.values()),
205205
):
206206
for dim in range(mapping[self.NAN_CATEGORY].shape[0]):

src/pytorch_tabular/feature_extractor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from collections import defaultdict
55

66
import pandas as pd
7+
from rich.progress import track
78
from sklearn.base import BaseEstimator, TransformerMixin
8-
from tqdm.autonotebook import tqdm
99

1010
from pytorch_tabular.models import NODEModel, TabNetModel
1111
from pytorch_tabular.models.mixture_density import MDNModel
@@ -57,7 +57,7 @@ def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
5757
self.tabular_model.model.eval()
5858
inference_dataloader = self.tabular_model.datamodule.prepare_inference_dataloader(X_encoded)
5959
logits_predictions = defaultdict(list)
60-
for batch in tqdm(inference_dataloader, desc="Generating Features..."):
60+
for batch in track(inference_dataloader, description="Generating Features..."):
6161
for k, v in batch.items():
6262
if isinstance(v, list) and (len(v) == 0):
6363
# Skipping empty list

src/pytorch_tabular/tabular_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
from pytorch_lightning.callbacks import RichProgressBar
2323
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
2424
from pytorch_lightning.utilities.model_summary import summarize
25+
from rich.progress import track
2526
from sklearn.base import TransformerMixin
2627
from sklearn.preprocessing import LabelEncoder
2728
from torch import nn
28-
from tqdm.autonotebook import tqdm
2929

3030
from pytorch_tabular.config import (
3131
DataConfig,
@@ -1084,7 +1084,7 @@ def predict(
10841084
quantile_predictions = []
10851085
logits_predictions = defaultdict(list)
10861086
is_probabilistic = hasattr(self.model.hparams, "_probabilistic") and self.model.hparams._probabilistic
1087-
for batch in tqdm(inference_dataloader, desc="Generating Predictions..."):
1087+
for batch in track(inference_dataloader, description="Generating Predictions..."):
10881088
for k, v in batch.items():
10891089
if isinstance(v, list) and (len(v) == 0):
10901090
# Skipping empty list

0 commit comments

Comments
 (0)