Skip to content

Commit 6fa2a74

Browse files
committed
Introduce progress bar class for shared logic
1 parent 6202d68 commit 6fa2a74

File tree

3 files changed

+82
-52
lines changed

3 files changed

+82
-52
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from typing import Any, Optional
2+
3+
from tqdm.auto import tqdm
4+
5+
from graphdatascience.query_runner.progress.progress_provider import TaskWithProgress
6+
7+
8+
class TqdmProgressBar:
9+
def __init__(self, task_name: str, relative_progress: Optional[float], bar_options: dict[str, Any] = {}):
10+
root_task_name = task_name
11+
if relative_progress is None: # Qualitative progress report
12+
self._tqdm_bar = tqdm(
13+
total=None,
14+
unit="",
15+
desc=root_task_name,
16+
bar_format="{desc} [elapsed: {elapsed} {postfix}]",
17+
**bar_options,
18+
)
19+
else:
20+
self._tqdm_bar = tqdm(
21+
total=100,
22+
unit="%",
23+
desc=root_task_name,
24+
**bar_options,
25+
)
26+
27+
def update(
28+
self,
29+
status: str,
30+
progress: Optional[float],
31+
sub_tasks_description: Optional[str] = None,
32+
) -> None:
33+
postfix = f"status: {status}, task: {sub_tasks_description}" if sub_tasks_description else f"status: {status}"
34+
self._tqdm_bar.set_postfix_str(postfix, refresh=False)
35+
if progress is not None:
36+
new_progress = progress - self._tqdm_bar.n
37+
self._tqdm_bar.update(new_progress)
38+
else:
39+
self._tqdm_bar.refresh()
40+
41+
def finish(self, success: bool) -> None:
42+
if not success:
43+
self._tqdm_bar.set_postfix_str("status: FAILED", refresh=True)
44+
return
45+
46+
if self._tqdm_bar.total is not None:
47+
self._tqdm_bar.update(self._tqdm_bar.total - self._tqdm_bar.n)
48+
self._tqdm_bar.set_postfix_str("status: FINISHED", refresh=True)
49+
50+
@staticmethod
51+
def _relative_progress(task: TaskWithProgress) -> Optional[float]:
52+
try:
53+
return float(task.progress_percent.removesuffix("%"))
54+
except ValueError:
55+
return None

graphdatascience/query_runner/progress/progress_provider.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ class TaskWithProgress:
1010
status: str
1111
sub_tasks_description: Optional[str] = None
1212

13+
def relative_progress(self) -> Optional[float]:
14+
try:
15+
return float(self.progress_percent.removesuffix("%"))
16+
except ValueError:
17+
return None
18+
1319

1420
class ProgressProvider(ABC):
1521
@abstractmethod

graphdatascience/query_runner/progress/query_progress_logger.py

Lines changed: 21 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import warnings
22
from concurrent.futures import Future, ThreadPoolExecutor, wait
3-
from typing import Any, Callable, NoReturn, Optional
3+
from typing import Any, Callable, Optional
44

55
from pandas import DataFrame
6-
from tqdm.auto import tqdm
6+
7+
from graphdatascience.query_runner.progress.progress_bar import TqdmProgressBar
78

89
from ...server_version.server_version import ServerVersion
910
from .progress_provider import ProgressProvider, TaskWithProgress
@@ -28,6 +29,8 @@ def __init__(
2829
self._polling_interval = polling_interval
2930
self._progress_bar_options = progress_bar_options
3031

32+
self._progress_bar_options.setdefault("maxinterval", self._polling_interval)
33+
3134
def run_with_progress_logging(
3235
self, runnable: DataFrameProducer, job_id: str, database: Optional[str] = None
3336
) -> DataFrame:
@@ -58,7 +61,7 @@ def _select_progress_provider(self, job_id: str) -> ProgressProvider:
5861
def _log(
5962
self, future: Future[Any], job_id: str, progress_provider: ProgressProvider, database: Optional[str] = None
6063
) -> None:
61-
pbar: Optional[tqdm[NoReturn]] = None
64+
pbar: Optional[TqdmProgressBar] = None
6265
warn_if_failure = True
6366

6467
while wait([future], timeout=self._polling_interval).not_done:
@@ -83,53 +86,19 @@ def _log(
8386
if pbar is not None:
8487
self._finish_pbar(future, pbar)
8588

86-
def _init_pbar(self, task_with_progress: TaskWithProgress) -> tqdm: # type: ignore
87-
root_task_name = task_with_progress.task_name
88-
parsed_progress = QueryProgressLogger._relative_progress(task_with_progress)
89-
if parsed_progress is None: # Qualitative progress report
90-
return tqdm(
91-
total=None,
92-
unit="",
93-
desc=root_task_name,
94-
maxinterval=self._polling_interval,
95-
bar_format="{desc} [elapsed: {elapsed} {postfix}]",
96-
**self._progress_bar_options,
97-
)
98-
else:
99-
return tqdm(
100-
total=100,
101-
unit="%",
102-
desc=root_task_name,
103-
maxinterval=self._polling_interval,
104-
**self._progress_bar_options,
105-
)
106-
107-
def _update_pbar(self, pbar: tqdm, task_with_progress: TaskWithProgress) -> None: # type: ignore
108-
parsed_progress = QueryProgressLogger._relative_progress(task_with_progress)
109-
postfix = (
110-
f"status: {task_with_progress.status}, task: {task_with_progress.sub_tasks_description}"
111-
if task_with_progress.sub_tasks_description
112-
else f"status: {task_with_progress.status}"
89+
def _update_pbar(self, pbar: TqdmProgressBar, task: TaskWithProgress) -> None:
90+
pbar.update(
91+
task.status,
92+
task.relative_progress(),
93+
task.sub_tasks_description,
11394
)
114-
pbar.set_postfix_str(postfix, refresh=False)
115-
if parsed_progress is not None:
116-
new_progress = parsed_progress - pbar.n
117-
pbar.update(new_progress)
118-
else:
119-
pbar.refresh()
120-
121-
def _finish_pbar(self, future: Future[Any], pbar: tqdm) -> None: # type: ignore
122-
if future.exception():
123-
pbar.set_postfix_str("status: FAILED", refresh=True)
124-
return
125-
126-
if pbar.total is not None:
127-
pbar.update(pbar.total - pbar.n)
128-
pbar.set_postfix_str("status: FINISHED", refresh=True)
129-
130-
@staticmethod
131-
def _relative_progress(task: TaskWithProgress) -> Optional[float]:
132-
try:
133-
return float(task.progress_percent.removesuffix("%"))
134-
except ValueError:
135-
return None
95+
96+
def _init_pbar(self, task: TaskWithProgress) -> TqdmProgressBar:
97+
return TqdmProgressBar(
98+
task.task_name,
99+
task.relative_progress(),
100+
bar_options=self._progress_bar_options,
101+
)
102+
103+
def _finish_pbar(self, future: Future[Any], pbar: TqdmProgressBar) -> None:
104+
pbar.finish(future.exception() is None)

0 commit comments

Comments
 (0)