Skip to content

Commit 7a1ac00

Browse files
committed
Add progress bar to remote write-back
1 parent 65ad5ff commit 7a1ac00

File tree

11 files changed

+120
-17
lines changed

11 files changed

+120
-17
lines changed

graphdatascience/arrow_client/v2/remote_write_back_client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ def __init__(self, arrow_client: AuthenticatedArrowClient, query_runner: QueryRu
2020
protocol_version = ProtocolVersionResolver(query_runner).resolve()
2121
self._write_protocol = WriteProtocol.select(protocol_version)
2222

23-
# TODO: Add progress logging
2423
def write(
2524
self,
2625
graph_name: str,
2726
job_id: str,
2827
concurrency: Optional[int] = None,
2928
property_overwrites: Optional[dict[str, str]] = None,
3029
relationship_type_overwrite: Optional[str] = None,
30+
log_progress: bool = True,
3131
) -> WriteBackResult:
3232
arrow_config = self._arrow_configuration()
3333

@@ -49,7 +49,11 @@ def write(
4949
start_time = time.time()
5050

5151
result = self._write_protocol.run_write_back(
52-
self._query_runner, write_back_params, None, TerminationFlagNoop()
52+
self._query_runner,
53+
write_back_params,
54+
None,
55+
log_progress=log_progress,
56+
terminationFlag=TerminationFlagNoop(),
5357
).squeeze()
5458
write_millis = int((time.time() - start_time) * 1000)
5559

graphdatascience/procedure_surface/arrow/catalog/node_properties_arrow_endpoints.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
self._node_property_endpoints = NodePropertyEndpoints(
3535
arrow_client, self._write_back_client, show_progress=show_progress
3636
)
37+
self._show_progress = show_progress
3738

3839
def stream(
3940
self,
@@ -112,6 +113,7 @@ def write(
112113
job_id,
113114
concurrency=write_concurrency if write_concurrency is not None else concurrency,
114115
property_overwrites=node_property_spec.to_dict(),
116+
log_progress=self._show_progress and log_progress,
115117
)
116118

117119
return NodePropertiesWriteResult(

graphdatascience/procedure_surface/arrow/catalog/relationship_arrow_endpoints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def write(
101101
job_id,
102102
concurrency=write_concurrency if write_concurrency is not None else concurrency,
103103
relationship_type_overwrite=relationship_type,
104+
log_progress=log_progress and self._show_progress,
104105
)
105106

106107
written_relationships = (

graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def project(
128128
TerminationFlag.create(),
129129
None,
130130
None,
131-
logging,
131+
self._show_progress and logging,
132132
)
133133

134134
job_result = ProjectionResult(**JobClient.get_summary(self._arrow_client, job_id))

graphdatascience/procedure_surface/arrow/node_property_endpoints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def run_job_and_write(
9090
job_id,
9191
concurrency=write_concurrency if write_concurrency is not None else concurrency,
9292
property_overwrites=property_overwrites,
93+
log_progress=show_progress,
9394
)
9495

9596
# modify computation result to include write details

graphdatascience/query_runner/progress/progress_bar.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from typing import Any, Optional
1+
from __future__ import annotations
2+
3+
from types import TracebackType
4+
from typing import Any, Optional, Type
25

36
from tqdm.auto import tqdm
47

58
from graphdatascience.query_runner.progress.progress_provider import TaskWithProgress
69

710

811
class TqdmProgressBar:
9-
# TODO helper method for creating for a test with obserable progress
1012
def __init__(self, task_name: str, relative_progress: Optional[float], bar_options: dict[str, Any] = {}):
1113
root_task_name = task_name
1214
if relative_progress is None: # Qualitative progress report
@@ -26,6 +28,17 @@ def __init__(self, task_name: str, relative_progress: Optional[float], bar_optio
2628
**bar_options,
2729
)
2830

31+
def __enter__(self: TqdmProgressBar) -> TqdmProgressBar:
32+
return self
33+
34+
def __exit__(
35+
self,
36+
exception_type: Optional[Type[BaseException]],
37+
exception_value: Optional[BaseException],
38+
traceback: Optional[TracebackType],
39+
) -> None:
40+
self.finish(success=exception_value is None)
41+
2942
def update(
3043
self,
3144
status: str,
@@ -43,11 +56,11 @@ def update(
4356
def finish(self, success: bool) -> None:
4457
if not success:
4558
self._tqdm_bar.set_postfix_str("status: FAILED", refresh=True)
46-
return
47-
48-
if self._tqdm_bar.total is not None:
49-
self._tqdm_bar.update(self._tqdm_bar.total - self._tqdm_bar.n)
50-
self._tqdm_bar.set_postfix_str("status: FINISHED", refresh=True)
59+
else:
60+
if self._tqdm_bar.total is not None:
61+
self._tqdm_bar.update(self._tqdm_bar.total - self._tqdm_bar.n)
62+
self._tqdm_bar.set_postfix_str("status: FINISHED", refresh=True)
63+
self._tqdm_bar.close()
5164

5265
@staticmethod
5366
def _relative_progress(task: TaskWithProgress) -> Optional[float]:

graphdatascience/query_runner/protocol/write_protocols.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from tenacity import retry, retry_if_result, wait_incrementing
77

88
from graphdatascience.call_parameters import CallParameters
9+
from graphdatascience.query_runner.progress.progress_bar import TqdmProgressBar
910
from graphdatascience.query_runner.protocol.status import Status
1011
from graphdatascience.query_runner.query_mode import QueryMode
1112
from graphdatascience.query_runner.query_runner import QueryRunner
@@ -33,6 +34,7 @@ def run_write_back(
3334
query_runner: QueryRunner,
3435
parameters: CallParameters,
3536
yields: Optional[list[str]],
37+
log_progress: bool,
3638
terminationFlag: TerminationFlag,
3739
) -> DataFrame:
3840
"""Executes the write-back procedure"""
@@ -68,6 +70,7 @@ def run_write_back(
6870
query_runner: QueryRunner,
6971
parameters: CallParameters,
7072
yields: Optional[list[str]],
73+
log_progress: bool,
7174
terminationFlag: TerminationFlag,
7275
) -> DataFrame:
7376
return query_runner.call_procedure(
@@ -108,6 +111,7 @@ def run_write_back(
108111
query_runner: QueryRunner,
109112
parameters: CallParameters,
110113
yields: Optional[list[str]],
114+
log_progress: bool,
111115
terminationFlag: TerminationFlag,
112116
) -> DataFrame:
113117
return query_runner.call_procedure(
@@ -123,6 +127,9 @@ def run_write_back(
123127

124128

125129
class RemoteWriteBackV3(WriteProtocol):
130+
def __init__(self, progress_bar_options: dict[str, Any] | None = None):
131+
self._progress_bar_options = progress_bar_options or {}
132+
126133
def write_back_params(
127134
self,
128135
graph_name: str,
@@ -138,6 +145,7 @@ def run_write_back(
138145
query_runner: QueryRunner,
139146
parameters: CallParameters,
140147
yields: Optional[list[str]],
148+
log_progress: bool,
141149
terminationFlag: TerminationFlag,
142150
) -> DataFrame:
143151
def is_not_completed(result: DataFrame) -> bool:
@@ -156,9 +164,9 @@ def is_not_completed(result: DataFrame) -> bool:
156164
logging.DEBUG,
157165
),
158166
)
159-
def write_fn() -> DataFrame:
167+
def write_fn(progress_bar: Optional[TqdmProgressBar]) -> DataFrame:
160168
terminationFlag.assert_running()
161-
return query_runner.call_procedure(
169+
result = query_runner.call_procedure(
162170
ProtocolVersion.V3.versioned_procedure_name("gds.arrow.write"),
163171
parameters,
164172
yields,
@@ -168,4 +176,17 @@ def write_fn() -> DataFrame:
168176
custom_error=False,
169177
)
170178

171-
return write_fn()
179+
if progress_bar:
180+
progress_bar.update(status=result.squeeze()["status"], progress=result.squeeze()["progress"] * 100)
181+
182+
return result
183+
184+
if log_progress:
185+
with TqdmProgressBar(
186+
task_name=f"Write-Back (graph: {parameters['graphName']})",
187+
relative_progress=0.0,
188+
bar_options=self._progress_bar_options,
189+
) as progress_bar:
190+
return write_fn(progress_bar)
191+
else:
192+
return write_fn(None)

graphdatascience/query_runner/session_query_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,9 @@ def _remote_write_back(
226226
write_back_start = time.time()
227227

228228
def run_write_back() -> DataFrame:
229-
return write_protocol.run_write_back(self._db_query_runner, write_back_params, yields, terminationFlag)
229+
return write_protocol.run_write_back(
230+
self._db_query_runner, write_back_params, yields, log_progress=logging, terminationFlag=terminationFlag
231+
)
230232

231233
try:
232234
# Skipping progress for now as export has a different jobId

graphdatascience/tests/integrationV2/procedure_surface/arrow/test_pagerank_arrow_endpoints.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ def test_pagerank_mutate(pagerank_endpoints: PageRankArrowEndpoints, sample_grap
9494
@pytest.mark.db_integration
9595
def test_pagerank_write(arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner, db_graph: GraphV2) -> None:
9696
"""Test PageRank write operation."""
97-
endpoints = PageRankArrowEndpoints(arrow_client, RemoteWriteBackClient(arrow_client, query_runner))
97+
endpoints = PageRankArrowEndpoints(
98+
arrow_client, RemoteWriteBackClient(arrow_client, query_runner), show_progress=True
99+
)
98100
result = endpoints.write(G=db_graph, write_property="pagerank")
99101

100102
assert isinstance(result, PageRankWriteResult)

graphdatascience/tests/unit/conftest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,11 @@ def create_graph_constructor(
169169
def cloneWithoutRouting(self, host: str, port: int) -> QueryRunner:
170170
return self
171171

172-
def set__mock_result(self, result: DataFrame) -> None:
172+
def set__mock_result(self, result: QueryResult) -> None:
173173
self._result_map.clear()
174174
self._result_map[""] = result
175175

176-
def add__mock_result(self, query_sub_string: str, result: DataFrame) -> None:
176+
def add__mock_result(self, query_sub_string: str, result: QueryResult) -> None:
177177
self._result_map[query_sub_string] = result
178178

179179
def get_mock_result(self, query: str) -> QueryResult:
@@ -190,6 +190,9 @@ def get_mock_result(self, query: str) -> QueryResult:
190190
)
191191
if len(matched_results) == 0:
192192
return DataFrame()
193+
194+
if isinstance(matched_results[0][1], Exception):
195+
raise matched_results[0][1]
193196
return matched_results[0][1]
194197

195198

0 commit comments

Comments
 (0)