Skip to content

Commit 6152a24

Browse files
committed
Add progress bar to job client
1 parent 6fa2a74 commit 6152a24

File tree

15 files changed

+266
-34
lines changed

15 files changed

+266
-34
lines changed

graphdatascience/arrow_client/v2/api_types.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,42 @@ class JobIdConfig(ArrowBaseModel):
55
job_id: str
66

77

8+
UNKNOWN_PROGRESS = -1
9+
10+
811
class JobStatus(ArrowBaseModel):
912
job_id: str
1013
status: str
1114
progress: float
15+
description: str
16+
17+
def progress_known(self) -> bool:
18+
if self.progress == UNKNOWN_PROGRESS:
19+
return False
20+
return True
21+
22+
def progress_percent(self) -> float | None:
23+
if self.progress_known():
24+
return self.progress * 100
25+
return None
26+
27+
def base_task(self) -> str:
28+
return self.description.split("::")[0].strip()
29+
30+
def sub_tasks(self) -> str | None:
31+
task_split = self.description.split("::", maxsplit=1)
32+
if len(task_split) > 1:
33+
return task_split[1].strip()
34+
return None
35+
36+
def aborted(self) -> bool:
37+
return self.status == "Aborted"
38+
39+
def succeeded(self) -> bool:
40+
return self.status == "Done"
41+
42+
def running(self) -> bool:
43+
return self.status == "Running"
1244

1345

1446
class MutateResult(ArrowBaseModel):

graphdatascience/arrow_client/v2/job_client.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,28 @@
11
import json
2-
from typing import Any
2+
from typing import Any, Optional
33

44
from pandas import ArrowDtype, DataFrame
55
from pyarrow._flight import Ticket
66

77
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
88
from graphdatascience.arrow_client.v2.api_types import JobIdConfig, JobStatus
99
from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize_single
10+
from graphdatascience.query_runner.progress.progress_bar import TqdmProgressBar
1011

1112
JOB_STATUS_ENDPOINT = "v2/jobs.status"
1213
RESULTS_SUMMARY_ENDPOINT = "v2/results.summary"
1314

1415

1516
class JobClient:
17+
def __init__(self, progress_bar_options: dict[str, Any] | None = None):
18+
self._progress_bar_options = progress_bar_options or {}
19+
1620
@staticmethod
17-
def run_job_and_wait(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, Any]) -> str:
21+
def run_job_and_wait(
22+
client: AuthenticatedArrowClient, endpoint: str, config: dict[str, Any], show_progress: bool
23+
) -> str:
1824
job_id = JobClient.run_job(client, endpoint, config)
19-
JobClient.wait_for_job(client, job_id)
25+
JobClient().wait_for_job(client, job_id, show_progress=show_progress)
2026
return job_id
2127

2228
@staticmethod
@@ -26,14 +32,30 @@ def run_job(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, A
2632
single = deserialize_single(res)
2733
return JobIdConfig(**single).job_id
2834

29-
@staticmethod
30-
def wait_for_job(client: AuthenticatedArrowClient, job_id: str) -> None:
35+
def wait_for_job(self, client: AuthenticatedArrowClient, job_id: str, show_progress: bool) -> None:
36+
progress_bar: Optional[TqdmProgressBar] = None
3137
while True:
3238
arrow_res = client.do_action_with_retry(JOB_STATUS_ENDPOINT, JobIdConfig(jobId=job_id).dump_camel())
3339
job_status = JobStatus(**deserialize_single(arrow_res))
34-
if job_status.status == "Done":
40+
41+
if job_status.succeeded() or job_status.aborted():
42+
if progress_bar:
43+
progress_bar.finish(success=job_status.succeeded())
3544
break
3645

46+
if show_progress:
47+
if progress_bar is None:
48+
base_task = job_status.base_task()
49+
if base_task:
50+
progress_bar = TqdmProgressBar(
51+
task_name=base_task,
52+
relative_progress=job_status.progress_percent(),
53+
54+
bar_options=self._progress_bar_options,
55+
)
56+
if progress_bar:
57+
progress_bar.update(job_status.status, job_status.progress_percent(), job_status.sub_tasks())
58+
3759
@staticmethod
3860
def get_summary(client: AuthenticatedArrowClient, job_id: str) -> dict[str, Any]:
3961
res = client.do_action_with_retry(RESULTS_SUMMARY_ENDPOINT, JobIdConfig(jobId=job_id).dump_camel())

graphdatascience/procedure_surface/arrow/catalog/node_label_arrow_endpoints.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
):
2020
self._arrow_client = arrow_client
2121
self._node_property_endpoints = NodePropertyEndpoints(arrow_client, write_back_client)
22+
self._show_progress = False # TODO add option to show progress
2223

2324
def mutate(
2425
self,
@@ -45,7 +46,10 @@ def mutate(
4546
job_id=job_id,
4647
)
4748

48-
job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.nodeLabel.mutate", config)
49+
show_progress = self._show_progress and log_progress if log_progress is not None else self._show_progress
50+
job_id = JobClient.run_job_and_wait(
51+
self._arrow_client, "v2/graph.nodeLabel.mutate", config, show_progress=show_progress
52+
)
4953
return NodeLabelMutateResult(**JobClient.get_summary(self._arrow_client, job_id))
5054

5155
def write(

graphdatascience/procedure_surface/arrow/catalog/relationship_arrow_endpoints.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,15 @@
1919

2020

2121
class RelationshipArrowEndpoints(RelationshipsEndpoints):
22-
def __init__(self, arrow_client: AuthenticatedArrowClient, write_back_client: Optional[RemoteWriteBackClient]):
22+
def __init__(
23+
self,
24+
arrow_client: AuthenticatedArrowClient,
25+
write_back_client: Optional[RemoteWriteBackClient],
26+
show_progress: bool = False,
27+
):
2328
self._arrow_client = arrow_client
2429
self._write_back_client = write_back_client
30+
self._show_progress = show_progress
2531

2632
def stream(
2733
self,
@@ -154,7 +160,10 @@ def index_inverse(
154160
job_id=job_id,
155161
)
156162

157-
job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.relationships.indexInverse", config)
163+
show_progress = self._show_progress and log_progress if log_progress is not None else self._show_progress
164+
job_id = JobClient.run_job_and_wait(
165+
self._arrow_client, "v2/graph.relationships.indexInverse", config, show_progress=show_progress
166+
)
158167
result = JobClient.get_summary(self._arrow_client, job_id)
159168
return RelationshipsInverseIndexResult(**result)
160169

@@ -182,7 +191,10 @@ def to_undirected(
182191
username=username,
183192
job_id=job_id,
184193
)
194+
show_progress = self._show_progress and log_progress if log_progress is not None else self._show_progress
185195

186-
job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.relationships.toUndirected", config)
196+
job_id = JobClient.run_job_and_wait(
197+
self._arrow_client, "v2/graph.relationships.toUndirected", config, show_progress=show_progress
198+
)
187199
result = JobClient.get_summary(self._arrow_client, job_id)
188200
return RelationshipsToUndirectedResult(**result)

graphdatascience/procedure_surface/arrow/catalog_arrow_endpoints.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,16 @@
3737
class CatalogArrowEndpoints(CatalogEndpoints):
3838
GDS_REMOTE_PROJECTION_PROC_NAME = "gds.arrow.project"
3939

40-
def __init__(self, arrow_client: AuthenticatedArrowClient, query_runner: Optional[QueryRunner] = None):
40+
def __init__(
41+
self,
42+
arrow_client: AuthenticatedArrowClient,
43+
query_runner: Optional[QueryRunner] = None,
44+
show_progress: bool = False,
45+
):
4146
self._arrow_client = arrow_client
4247
self._query_runner = query_runner
4348
self._graph_backend = GraphOpsArrow(arrow_client)
49+
self._show_progress = show_progress
4450
if query_runner is not None:
4551
protocol_version = ProtocolVersionResolver(query_runner).resolve()
4652
self._project_protocol = ProjectProtocol.select(protocol_version)
@@ -152,7 +158,9 @@ def filter(
152158
job_id=job_id,
153159
)
154160

155-
job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.project.filter", config)
161+
job_id = JobClient.run_job_and_wait(
162+
self._arrow_client, "v2/graph.project.filter", config, show_progress=self._show_progress
163+
)
156164

157165
return GraphWithFilterResult(
158166
get_graph(graph_name, self._arrow_client),
@@ -192,7 +200,10 @@ def generate(
192200
username=username,
193201
)
194202

195-
job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.generate", config)
203+
show_progress = self._show_progress and log_progress if log_progress is not None else self._show_progress
204+
job_id = JobClient.run_job_and_wait(
205+
self._arrow_client, "v2/graph.generate", config, show_progress=show_progress
206+
)
196207

197208
return GraphWithGenerationStats(
198209
get_graph(graph_name, self._arrow_client),
@@ -201,7 +212,7 @@ def generate(
201212

202213
@property
203214
def sample(self) -> GraphSamplingEndpoints:
204-
return GraphSamplingArrowEndpoints(self._arrow_client)
215+
return GraphSamplingArrowEndpoints(self._arrow_client, show_progress=self._show_progress)
205216

206217
@property
207218
def node_labels(self) -> NodeLabelArrowEndpoints:
@@ -218,6 +229,7 @@ def relationships(self) -> RelationshipArrowEndpoints:
218229
return RelationshipArrowEndpoints(
219230
self._arrow_client,
220231
RemoteWriteBackClient(self._arrow_client, self._query_runner) if self._query_runner else None,
232+
show_progress=self._show_progress,
221233
)
222234

223235
def _arrow_config(self) -> dict[str, Any]:

graphdatascience/procedure_surface/arrow/graph_sampling_arrow_endpoints.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515

1616

1717
class GraphSamplingArrowEndpoints(GraphSamplingEndpoints):
18-
def __init__(self, arrow_client: AuthenticatedArrowClient):
18+
def __init__(self, arrow_client: AuthenticatedArrowClient, show_progress: bool = False):
1919
self._arrow_client = arrow_client
20+
self._show_progress = show_progress
2021

2122
def rwr(
2223
self,
@@ -52,7 +53,10 @@ def rwr(
5253
job_id=job_id,
5354
)
5455

55-
job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.sample.rwr", config)
56+
show_progress = self._show_progress and log_progress if log_progress is not None else self._show_progress
57+
job_id = JobClient.run_job_and_wait(
58+
self._arrow_client, "v2/graph.sample.rwr", config, show_progress=show_progress
59+
)
5660

5761
return GraphWithSamplingResult(
5862
get_graph(graph_name, self._arrow_client),
@@ -93,7 +97,10 @@ def cnarw(
9397
job_id=job_id,
9498
)
9599

96-
job_id = JobClient.run_job_and_wait(self._arrow_client, "v2/graph.sample.cnarw", config)
100+
show_progress = self._show_progress and log_progress if log_progress is not None else self._show_progress
101+
job_id = JobClient.run_job_and_wait(
102+
self._arrow_client, "v2/graph.sample.cnarw", config, show_progress=show_progress
103+
)
97104

98105
return GraphWithSamplingResult(
99106
get_graph(graph_name, self._arrow_client),

graphdatascience/procedure_surface/arrow/node_property_endpoints.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,26 @@ class NodePropertyEndpoints:
2020
"""
2121

2222
def __init__(
23-
self, arrow_client: AuthenticatedArrowClient, write_back_client: Optional[RemoteWriteBackClient] = None
23+
self,
24+
arrow_client: AuthenticatedArrowClient,
25+
write_back_client: Optional[RemoteWriteBackClient] = None,
26+
show_progress: bool = True,
2427
):
2528
self._arrow_client = arrow_client
2629
self._write_back_client = write_back_client
30+
self._show_progress = show_progress
2731

2832
def run_job_and_get_summary(self, endpoint: str, G: GraphV2, config: Dict[str, Any]) -> Dict[str, Any]:
2933
"""Run a job and return the computation summary."""
30-
job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config)
34+
job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, self._show_progress)
3135
return JobClient.get_summary(self._arrow_client, job_id)
3236

37+
# TODO expose show_progress option to endpoints
3338
def run_job_and_mutate(
3439
self, endpoint: str, G: GraphV2, config: Dict[str, Any], mutate_property: str
3540
) -> Dict[str, Any]:
3641
"""Run a job, mutate node properties, and return summary with mutation result."""
37-
job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config)
42+
job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, self._show_progress)
3843
mutate_result = MutationClient.mutate_node_property(self._arrow_client, job_id, mutate_property)
3944
computation_result = JobClient.get_summary(self._arrow_client, job_id)
4045

@@ -51,11 +56,13 @@ def run_job_and_mutate(
5156

5257
return computation_result
5358

59+
# TODO expose show_progress option to endpoints
5460
def run_job_and_stream(self, endpoint: str, G: GraphV2, config: Dict[str, Any]) -> DataFrame:
5561
"""Run a job and return streamed results."""
56-
job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config)
62+
job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, show_progress=self._show_progress)
5763
return JobClient.stream_results(self._arrow_client, G.name(), job_id)
5864

65+
# TODO expose show_progress option to endpoints
5966
def run_job_and_write(
6067
self,
6168
endpoint: str,
@@ -66,7 +73,7 @@ def run_job_and_write(
6673
property_overwrites: Optional[Union[str, dict[str, str]]] = None,
6774
) -> Dict[str, Any]:
6875
"""Run a job, write results, and return summary with write time."""
69-
job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config)
76+
job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, show_progress=self._show_progress)
7077
computation_result = JobClient.get_summary(self._arrow_client, job_id)
7178

7279
if self._write_back_client is None:

graphdatascience/procedure_surface/arrow/pagerank_arrow_endpoints.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,14 @@
1818

1919
class PageRankArrowEndpoints(PageRankEndpoints):
2020
def __init__(
21-
self, arrow_client: AuthenticatedArrowClient, write_back_client: Optional[RemoteWriteBackClient] = None
21+
self,
22+
arrow_client: AuthenticatedArrowClient,
23+
write_back_client: Optional[RemoteWriteBackClient] = None,
24+
show_progress: bool = False,
2225
):
23-
self._node_property_endpoints = NodePropertyEndpoints(arrow_client, write_back_client)
26+
self._node_property_endpoints = NodePropertyEndpoints(
27+
arrow_client, write_back_client, show_progress=show_progress
28+
)
2429

2530
def mutate(
2631
self,

graphdatascience/query_runner/progress/progress_bar.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
class TqdmProgressBar:
9+
# TODO helper method for creating for a test with obserable progress
910
def __init__(self, task_name: str, relative_progress: Optional[float], bar_options: dict[str, Any] = {}):
1011
root_task_name = task_name
1112
if relative_progress is None: # Qualitative progress report
@@ -21,6 +22,7 @@ def __init__(self, task_name: str, relative_progress: Optional[float], bar_optio
2122
total=100,
2223
unit="%",
2324
desc=root_task_name,
25+
initial=relative_progress,
2426
**bar_options,
2527
)
2628

graphdatascience/session/aura_graph_data_science.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,17 @@ def create(
9696
query_runner=session_query_runner,
9797
delete_fn=delete_fn,
9898
gds_version=gds_version,
99-
v2_endpoints=SessionV2Endpoints(session_auth_arrow_client, db_bolt_query_runner),
99+
v2_endpoints=SessionV2Endpoints(
100+
session_auth_arrow_client, db_bolt_query_runner, show_progress=show_progress
101+
),
100102
)
101103
else:
102104
standalone_query_runner = StandaloneSessionQueryRunner(session_arrow_query_runner)
103105
return cls(
104106
query_runner=standalone_query_runner,
105107
delete_fn=delete_fn,
106108
gds_version=gds_version,
107-
v2_endpoints=SessionV2Endpoints(session_auth_arrow_client, None),
109+
v2_endpoints=SessionV2Endpoints(session_auth_arrow_client, None, show_progress=show_progress),
108110
)
109111

110112
def __init__(

0 commit comments

Comments
 (0)