Skip to content

Commit 10e84dd

Browse files
authored
Merge pull request #317 from sfu-db/feat/progress-bar
feat(eda): add progress bar for dask local scheduler
2 parents 2735787 + e13257c commit 10e84dd

File tree

7 files changed

+225
-54
lines changed

7 files changed

+225
-54
lines changed

dataprep/eda/__init__.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,24 @@
22
dataprep.eda
33
============
44
"""
5-
import tempfile
5+
from bokeh.io import output_notebook
66

7-
from bokeh.io import output_file, output_notebook
8-
from .distribution import compute, plot, render
97
from .correlation import compute_correlation, plot_correlation, render_correlation
10-
from .missing import compute_missing, plot_missing, render_missing
118
from .create_report import create_report
12-
from .utils import is_notebook
9+
from .distribution import compute, plot, render
1310
from .dtypes import (
14-
DType,
1511
Categorical,
16-
Nominal,
17-
Ordinal,
18-
Numerical,
1912
Continuous,
20-
Discrete,
2113
DateTime,
14+
Discrete,
15+
DType,
16+
Nominal,
17+
Numerical,
18+
Ordinal,
2219
Text,
2320
)
21+
from .missing import compute_missing, plot_missing, render_missing
22+
from .utils import is_notebook
2423

2524
__all__ = [
2625
"plot_correlation",

dataprep/eda/correlation/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
This module implements the plot_correlation(df) function.
33
"""
44

5-
from typing import Any, List, Optional, Tuple, Union
5+
from typing import Optional, Tuple, Union
66

77
import dask.dataframe as dd
88
import pandas as pd
9-
from bokeh.io import show
109

10+
from ..progress_bar import ProgressBar
11+
from ..report import Report
1112
from .compute import compute_correlation
1213
from .render import render_correlation
13-
from ..report import Report
1414

1515
__all__ = ["render_correlation", "compute_correlation", "plot_correlation"]
1616

@@ -61,8 +61,8 @@ def plot_correlation(
6161
This function only supports numerical or categorical data,
6262
and it is better to drop None, Nan and Null value before using it
6363
"""
64-
65-
intermediate = compute_correlation(df, x=x, y=y, value_range=value_range, k=k)
64+
with ProgressBar(minimum=1):
65+
intermediate = compute_correlation(df, x=x, y=y, value_range=value_range, k=k)
6666
figure = render_correlation(intermediate)
6767

6868
return Report(figure)

dataprep/eda/correlation/compute/common.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def rankdata(data: np.ndarray, axis: int = 0) -> np.ndarray:
2828
name="rankdata-bottleneck", pure=True
2929
)
3030
def nanrankdata(data: np.ndarray, axis: int = 0) -> np.ndarray:
31-
"""delayed version of rankdata"""
31+
"""delayed version of rankdata."""
3232
return nanrankdata_(data, axis=axis)
3333

3434

@@ -38,6 +38,13 @@ def nanrankdata(data: np.ndarray, axis: int = 0) -> np.ndarray:
3838
def kendalltau( # pylint: disable=invalid-name
3939
a: np.ndarray, b: np.ndarray
4040
) -> np.ndarray:
41-
"""delayed version of kendalltau"""
41+
"""delayed version of kendalltau."""
4242
corr = kendalltau_(a, b).correlation
4343
return np.float64(corr) # Sometimes corr is a float, causes dask error
44+
45+
46+
@dask.delayed
47+
def corrcoef(arr: np.ndarray) -> np.ndarray:
48+
"""delayed version of np.corrcoef."""
49+
_, (corr, _) = np.corrcoef(arr, rowvar=False)
50+
return corr

dataprep/eda/correlation/compute/univariate.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
import numpy as np
1010
import pandas as pd
1111

12-
from ...intermediate import Intermediate
1312
from ...data_array import DataArray
14-
from .common import CorrelationMethod, kendalltau, nanrankdata
13+
from ...intermediate import Intermediate
14+
from .common import CorrelationMethod, kendalltau, nanrankdata, corrcoef
1515

1616

1717
def _calc_univariate(
@@ -74,17 +74,17 @@ def _calc_univariate(
7474
def _pearson_1xn(x: da.Array, data: da.Array) -> da.Array:
7575
_, ncols = data.shape
7676

77-
datamask = da.isnan(data)
78-
xmask = da.isnan(x)[:, 0]
77+
fused = da.concatenate([data, x], axis=1)
78+
mask = ~da.isnan(data)
7979

8080
corrs = []
8181
for j in range(ncols):
82-
y = data[:, [j]]
83-
84-
mask = ~(xmask | datamask[:, j])
85-
xy = np.concatenate([x, y], axis=1)[mask]
86-
xy.compute_chunk_sizes() # Not optimal here
87-
_, (corr, _) = da.corrcoef(xy, rowvar=False)
82+
xy = fused[:, [-1, j]]
83+
mask_ = mask[:, -1] & mask[:, j]
84+
xy = xy[mask_]
85+
corr = da.from_delayed(corrcoef(xy), dtype=np.float, shape=())
86+
# not usable because xy has unknown rows due to the null filter
87+
# _, (corr, _) = da.corrcoef(xy, rowvar=False)
8888
corrs.append(corr)
8989

9090
return da.stack(corrs)

dataprep/eda/distribution/__init__.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22
This module implements the plot(df) function.
33
"""
44

5-
from typing import Optional, Tuple, Union, Dict
5+
from typing import Optional, Tuple, Union
66

77
import dask.dataframe as dd
88
import pandas as pd
99

10+
from ..container import Container
11+
from ..dtypes import DTypeDef
12+
from ..progress_bar import ProgressBar
13+
from ..report import Report
1014
from .compute import compute
1115
from .render import render
12-
from ..report import Report
13-
from ..dtypes import DTypeDef
14-
from ..container import Container
1516

1617
__all__ = ["plot", "compute", "render"]
1718

@@ -143,25 +144,26 @@ def plot(
143144
"""
144145
# pylint: disable=too-many-locals,line-too-long
145146

146-
intermediate = compute(
147-
df,
148-
x=x,
149-
y=y,
150-
z=z,
151-
bins=bins,
152-
ngroups=ngroups,
153-
largest=largest,
154-
nsubgroups=nsubgroups,
155-
timeunit=timeunit.lower(),
156-
agg=agg,
157-
sample_size=sample_size,
158-
top_words=top_words,
159-
stopword=stopword,
160-
lemmatize=lemmatize,
161-
stem=stem,
162-
value_range=value_range,
163-
dtype=dtype,
164-
)
147+
with ProgressBar(minimum=1):
148+
intermediate = compute(
149+
df,
150+
x=x,
151+
y=y,
152+
z=z,
153+
bins=bins,
154+
ngroups=ngroups,
155+
largest=largest,
156+
nsubgroups=nsubgroups,
157+
timeunit=timeunit.lower(),
158+
agg=agg,
159+
sample_size=sample_size,
160+
top_words=top_words,
161+
stopword=stopword,
162+
lemmatize=lemmatize,
163+
stem=stem,
164+
value_range=value_range,
165+
dtype=dtype,
166+
)
165167
figure = render(intermediate, yscale=yscale, tile_size=tile_size)
166168
if intermediate.visual_type == "distribution_grid":
167169
return Container(figure)

dataprep/eda/missing/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66

77
import dask.dataframe as dd
88
import pandas as pd
9-
from bokeh.io import show
109

10+
from ..dtypes import DTypeDef
11+
from ..progress_bar import ProgressBar
12+
from ..report import Report
1113
from .compute import compute_missing
1214
from .render import render_missing
13-
from ..report import Report
14-
from ..dtypes import DTypeDef
1515

1616
__all__ = ["render_missing", "compute_missing", "plot_missing"]
1717

@@ -56,6 +56,10 @@ def plot_missing(
5656
>>> plot_missing(df, "HDI_for_year")
5757
>>> plot_missing(df, "HDI_for_year", "population")
5858
"""
59-
itmdt = compute_missing(df, x, y, dtype=dtype, bins=bins, ndist_sample=ndist_sample)
59+
60+
with ProgressBar(minimum=1):
61+
itmdt = compute_missing(
62+
df, x, y, dtype=dtype, bins=bins, ndist_sample=ndist_sample
63+
)
6064
fig = render_missing(itmdt)
6165
return Report(fig)

dataprep/eda/progress_bar.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""ProgressBar shows the how many dask tasks finished/remains using tqdm."""
2+
3+
from typing import Any, Optional, Dict, Tuple, Union
4+
from time import time
5+
6+
from dask.callbacks import Callback
7+
8+
from .utils import is_notebook
9+
10+
if is_notebook():
11+
from tqdm.notebook import tqdm
12+
else:
13+
from tqdm import tqdm
14+
15+
# pylint: disable=method-hidden,too-many-instance-attributes
16+
class ProgressBar(Callback): # type: ignore
17+
"""A progress bar for DataPrep.EDA.
18+
19+
Parameters
20+
----------
21+
minimum : int, optional
22+
Minimum time threshold in seconds before displaying a progress bar.
23+
Default is 0 (always display)
24+
_min_tasks : int, optional
25+
Minimum graph size to show a progress bar, default is 5
26+
width : int, optional
27+
Width of the bar. None means auto width.
28+
interval : float, optional
29+
Update resolution in seconds, default is 0.1 seconds
30+
"""
31+
32+
_minimum: float = 0
33+
_min_tasks: int = 5
34+
_width: Optional[int] = None
35+
_interval: float = 0.1
36+
_last_duration: float = 0
37+
_pbar: Optional[tqdm] = None
38+
_state: Optional[Dict[str, Any]] = None
39+
_started: Optional[float] = None
40+
_last_task: Optional[str] = None # in case we initialize the pbar in _finish
41+
42+
def __init__(
43+
self,
44+
minimum: float = 0,
45+
min_tasks: int = 5,
46+
width: Optional[int] = None,
47+
interval: float = 0.1,
48+
) -> None:
49+
super().__init__()
50+
self._minimum = minimum
51+
self._min_tasks = min_tasks
52+
self._width = width
53+
self._interval = interval
54+
55+
def _start(self, _dsk: Any) -> None:
56+
"""A hook to start this callback."""
57+
58+
def _start_state(self, _dsk: Any, state: Dict[str, Any]) -> None:
59+
"""A hook called before every task gets executed."""
60+
self._started = time()
61+
self._state = state
62+
_, ntasks = self._count_tasks()
63+
64+
if ntasks > self._min_tasks:
65+
self._init_bar()
66+
67+
def _pretask(
68+
self, key: Union[str, Tuple[str, ...]], _dsk: Any, _state: Dict[str, Any]
69+
) -> None:
70+
"""A hook called before one task gets executed."""
71+
if self._started is None:
72+
raise ValueError("ProgressBar not started properly")
73+
74+
if self._pbar is None and time() - self._started > self._minimum:
75+
self._init_bar()
76+
77+
if isinstance(key, tuple):
78+
key = key[0]
79+
80+
if self._pbar is not None:
81+
self._pbar.set_description(f"Computing {key}")
82+
else:
83+
self._last_task = key
84+
85+
def _posttask(
86+
self,
87+
_key: str,
88+
_result: Any,
89+
_dsk: Any,
90+
_state: Dict[str, Any],
91+
_worker_id: Any,
92+
) -> None:
93+
"""A hook called after one task gets executed."""
94+
95+
if self._pbar is not None:
96+
self._update_bar()
97+
98+
def _finish(self, _dsk: Any, _state: Dict[str, Any], _errored: bool) -> None:
99+
"""A hook called after all tasks get executed."""
100+
if self._started is None:
101+
raise ValueError("ProgressBar not started properly")
102+
103+
if self._pbar is None and time() - self._started > self._minimum:
104+
self._init_bar()
105+
106+
if self._pbar is not None:
107+
self._update_bar()
108+
self._pbar.close()
109+
110+
self._state = None
111+
self._started = None
112+
self._pbar = None
113+
114+
def _update_bar(self) -> None:
115+
if self._pbar is None:
116+
return
117+
ndone, _ = self._count_tasks()
118+
119+
self._pbar.update(max(0, ndone - self._pbar.n))
120+
121+
def _init_bar(self) -> None:
122+
if self._pbar is not None:
123+
raise ValueError("ProgressBar already initialized.")
124+
ndone, ntasks = self._count_tasks()
125+
126+
if self._last_task is not None:
127+
desc = f"Computing {self._last_task}"
128+
else:
129+
desc = ""
130+
131+
if self._width is None:
132+
self._pbar = tqdm(
133+
total=ntasks,
134+
dynamic_ncols=True,
135+
mininterval=self._interval,
136+
initial=ndone,
137+
desc=desc,
138+
)
139+
else:
140+
self._pbar = tqdm(
141+
total=ntasks,
142+
ncols=self._width,
143+
mininterval=self._interval,
144+
initial=ndone,
145+
desc=desc,
146+
)
147+
148+
self._pbar.start_t = self._started
149+
self._pbar.refresh()
150+
151+
def _count_tasks(self) -> Tuple[int, int]:
152+
if self._state is None:
153+
raise ValueError("ProgressBar not started properly")
154+
155+
state = self._state
156+
ndone = len(state["finished"])
157+
ntasks = sum(len(state[k]) for k in ["ready", "waiting", "running"]) + ndone
158+
159+
return ndone, ntasks

0 commit comments

Comments
 (0)