Skip to content

Commit 1db6831

Browse files
committed
Implement graph_construct using arrow v2
* bringing terminationFlag to GdsArrowClient (V2) to interrupt upload * job client also support waiting for a given status
1 parent 5e359d6 commit 1db6831

File tree

12 files changed

+369
-29
lines changed

12 files changed

+369
-29
lines changed

graphdatascience/tests/unit/procedure_surface/arrow/__init__.py

Whitespace-only changes.
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from contextlib import ExitStack
2+
from unittest import mock
3+
4+
from pandas import DataFrame
5+
from pytest_mock import MockerFixture
6+
7+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
8+
from graphdatascience.arrow_client.v2.api_types import JobStatus
9+
from graphdatascience.procedure_surface.arrow.catalog.catalog_arrow_endpoints import CatalogArrowEndpoints
10+
from graphdatascience.tests.unit.arrow_client.arrow_test_utils import ArrowTestResult
11+
12+
13+
def test_construct_with_no_rels(mocker: MockerFixture) -> None:
14+
arrow_client = mocker.Mock(spec=AuthenticatedArrowClient)
15+
job_id = "job-123"
16+
17+
relationship_loading_done_status = JobStatus(
18+
jobId=job_id,
19+
status="RELATIONSHIP_LOADING",
20+
progress=-1,
21+
description="",
22+
)
23+
construct_done_status = JobStatus(
24+
jobId=job_id,
25+
status="Done",
26+
progress=-1,
27+
description="",
28+
)
29+
30+
do_action_with_retry = mocker.Mock()
31+
do_action_with_retry.side_effect = [
32+
iter([ArrowTestResult(relationship_loading_done_status.dump_camel())]),
33+
iter([ArrowTestResult(construct_done_status.dump_camel())]),
34+
]
35+
36+
arrow_client.do_action_with_retry = do_action_with_retry
37+
38+
endpoints = CatalogArrowEndpoints(arrow_client=arrow_client)
39+
40+
nodes = DataFrame(
41+
{
42+
"nodeId": [0, 1],
43+
"labels": [["A"], ["B"]],
44+
"propA": [1337, 42.1],
45+
}
46+
)
47+
with patch_gds_arrow_client(job_id):
48+
G = endpoints.construct(graph_name="g", nodes=nodes, relationships=[])
49+
assert G.name() == "g"
50+
51+
52+
def test_construct_with_df_lists(mocker: MockerFixture) -> None:
53+
arrow_client = mocker.Mock(spec=AuthenticatedArrowClient)
54+
job_id = "foo"
55+
relationship_loading_done_status = JobStatus(
56+
jobId=job_id,
57+
status="RELATIONSHIP_LOADING",
58+
progress=-1,
59+
description="",
60+
)
61+
construct_done_status = JobStatus(
62+
jobId=job_id,
63+
status="Done",
64+
progress=-1,
65+
description="",
66+
)
67+
68+
do_action_with_retry = mocker.Mock()
69+
do_action_with_retry.side_effect = [
70+
iter([ArrowTestResult(relationship_loading_done_status.dump_camel())]),
71+
iter([ArrowTestResult(construct_done_status.dump_camel())]),
72+
]
73+
arrow_client.do_action_with_retry = do_action_with_retry
74+
75+
endpoints = CatalogArrowEndpoints(arrow_client=arrow_client)
76+
77+
nodes = [
78+
DataFrame({"nodeId": [0, 1], "labels": ["a", "a"], "property": [6.0, 7.0]}),
79+
DataFrame({"nodeId": [2, 3], "labels": ["b", "b"], "q": [-500, -400]}),
80+
]
81+
relationships = [
82+
DataFrame(
83+
{"sourceNodeId": [0, 1], "targetNodeId": [1, 2], "relationshipType": ["A", "A"], "weights": [0.2, 0.3]}
84+
),
85+
DataFrame({"sourceNodeId": [2, 3], "targetNodeId": [3, 0], "relationshipType": ["B", "B"]}),
86+
]
87+
with patch_gds_arrow_client(job_id):
88+
G = endpoints.construct(graph_name="g", nodes=nodes, relationships=relationships)
89+
assert G.name() == "g"
90+
91+
92+
def patch_gds_arrow_client(create_graph_job_id: str) -> ExitStack:
93+
exit_stack = ExitStack()
94+
patches = [
95+
mock.patch(
96+
"graphdatascience.arrow_client.v2.gds_arrow_client.GdsArrowClient.create_graph",
97+
return_value=create_graph_job_id,
98+
),
99+
mock.patch(
100+
"graphdatascience.arrow_client.v2.gds_arrow_client.GdsArrowClient.upload_nodes",
101+
return_value=None,
102+
),
103+
mock.patch(
104+
"graphdatascience.arrow_client.v2.gds_arrow_client.GdsArrowClient.upload_relationships",
105+
return_value=None,
106+
),
107+
mock.patch(
108+
"graphdatascience.arrow_client.v2.gds_arrow_client.GdsArrowClient.node_load_done",
109+
return_value=None,
110+
),
111+
mock.patch(
112+
"graphdatascience.arrow_client.v2.gds_arrow_client.GdsArrowClient.relationship_load_done",
113+
return_value=None,
114+
),
115+
]
116+
117+
for p in patches:
118+
exit_stack.enter_context(p)
119+
120+
return exit_stack

src/graphdatascience/arrow_client/v2/gds_arrow_client.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from graphdatascience.arrow_client.arrow_endpoint_version import ArrowEndpointVersion
1313
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient, ConnectionInfo
14+
from graphdatascience.query_runner.termination_flag import TerminationFlag
1415

1516
from ...procedure_surface.api.default_values import ALL_TYPES
1617
from ...procedure_surface.utils.config_converter import ConfigConverter
@@ -328,6 +329,7 @@ def upload_nodes(
328329
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
329330
batch_size: int = 10000,
330331
progress_callback: Callable[[int], None] = lambda x: None,
332+
termination_flag: TerminationFlag | None = None,
331333
) -> None:
332334
"""
333335
Uploads node data to the server for a given job.
@@ -342,15 +344,20 @@ def upload_nodes(
342344
The number of rows per batch
343345
progress_callback
344346
A callback function that is called with the number of rows uploaded after each batch
347+
termination_flag
348+
A termination flag to cancel the upload if requested
345349
"""
346-
self._upload_data("graph.project.fromTables.nodes", job_id, data, batch_size, progress_callback)
350+
self._upload_data(
351+
"graph.project.fromTables.nodes", job_id, data, batch_size, progress_callback, termination_flag
352+
)
347353

348354
def upload_relationships(
349355
self,
350356
job_id: str,
351357
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
352358
batch_size: int = 10000,
353359
progress_callback: Callable[[int], None] = lambda x: None,
360+
termination_flag: TerminationFlag | None = None,
354361
) -> None:
355362
"""
356363
Uploads relationship data to the server for a given job.
@@ -365,15 +372,20 @@ def upload_relationships(
365372
The number of rows per batch
366373
progress_callback
367374
A callback function that is called with the number of rows uploaded after each batch
375+
termination_flag
376+
A termination flag to cancel the upload if requested
368377
"""
369-
self._upload_data("graph.project.fromTables.relationships", job_id, data, batch_size, progress_callback)
378+
self._upload_data(
379+
"graph.project.fromTables.relationships", job_id, data, batch_size, progress_callback, termination_flag
380+
)
370381

371382
def upload_triplets(
372383
self,
373384
job_id: str,
374385
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
375386
batch_size: int = 10000,
376387
progress_callback: Callable[[int], None] = lambda x: None,
388+
termination_flag: TerminationFlag | None = None,
377389
) -> None:
378390
"""
379391
Uploads triplet data to the server for a given job.
@@ -388,8 +400,10 @@ def upload_triplets(
388400
The number of rows per batch
389401
progress_callback
390402
A callback function that is called with the number of rows uploaded after each batch
403+
termination_flag
404+
A termination flag to cancel the upload if requested
391405
"""
392-
self._upload_data("graph.project.fromTriplets", job_id, data, batch_size, progress_callback)
406+
self._upload_data("graph.project.fromTriplets", job_id, data, batch_size, progress_callback, termination_flag)
393407

394408
def abort_job(self, job_id: str) -> None:
395409
"""
@@ -464,6 +478,7 @@ def _upload_data(
464478
data: pyarrow.Table | list[pyarrow.RecordBatch] | pandas.DataFrame,
465479
batch_size: int = 10000,
466480
progress_callback: Callable[[int], None] = lambda x: None,
481+
termination_flag: TerminationFlag | None = None,
467482
) -> None:
468483
match data:
469484
case pyarrow.Table():
@@ -490,6 +505,10 @@ def upload_batch(p: RecordBatch) -> None:
490505

491506
with put_stream:
492507
for partition in batches:
508+
if termination_flag is not None and termination_flag.is_set():
509+
self.abort_job(job_id) # closing the put_stream will raise an error
510+
break
511+
493512
upload_batch(partition)
494513
ack_stream.read()
495514
progress_callback(partition.num_rows)

src/graphdatascience/arrow_client/v2/job_client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,26 @@ def wait_for_job(
3939
client: AuthenticatedArrowClient,
4040
job_id: str,
4141
show_progress: bool,
42+
expected_status: str | None = None,
4243
termination_flag: TerminationFlag | None = None,
4344
) -> None:
4445
progress_bar: TqdmProgressBar | None = None
4546

47+
def check_expected_status(status: JobStatus) -> bool:
48+
return job_status.succeeded() if expected_status is None else status.status == expected_status
49+
4650
if termination_flag is None:
4751
termination_flag = TerminationFlag.create()
4852

49-
for attempt in Retrying(retry=retry_if_result(lambda _: True), wait=wait_exponential(min=0.1, max=5)):
53+
for attempt in Retrying(
54+
retry=retry_if_result(lambda _: True), wait=wait_exponential(min=0.1, max=5), reraise=True
55+
):
5056
with attempt:
5157
termination_flag.assert_running()
5258

5359
job_status = self.get_job_status(client, job_id)
5460

55-
if job_status.succeeded() or job_status.aborted():
61+
if check_expected_status(job_status) or job_status.aborted():
5662
if progress_bar:
5763
progress_bar.finish(success=job_status.succeeded())
5864
return

src/graphdatascience/procedure_surface/api/catalog/catalog_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def construct(
2222
graph_name: str,
2323
nodes: DataFrame | list[DataFrame],
2424
relationships: DataFrame | list[DataFrame] | None = None,
25-
concurrency: int = 4,
25+
concurrency: int | None = None,
2626
undirected_relationship_types: list[str] | None = None,
2727
) -> GraphV2:
2828
"""Construct a graph from a list of node and relationship dataframes.

src/graphdatascience/procedure_surface/arrow/catalog/catalog_arrow_endpoints.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pandas import DataFrame
99

1010
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
11+
from graphdatascience.arrow_client.v2.gds_arrow_client import GdsArrowClient
1112
from graphdatascience.arrow_client.v2.job_client import JobClient
1213
from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
1314
from graphdatascience.procedure_surface.api.base_result import BaseResult
@@ -31,6 +32,7 @@
3132
)
3233
from graphdatascience.procedure_surface.arrow.catalog.relationship_arrow_endpoints import RelationshipArrowEndpoints
3334
from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter
35+
from graphdatascience.query_runner.progress.progress_bar import NoOpProgressBar, ProgressBar, TqdmProgressBar
3436
from graphdatascience.query_runner.protocol.project_protocols import ProjectProtocol
3537
from graphdatascience.query_runner.query_runner import QueryRunner
3638
from graphdatascience.query_runner.termination_flag import TerminationFlag
@@ -135,10 +137,73 @@ def construct(
135137
graph_name: str,
136138
nodes: DataFrame | list[DataFrame],
137139
relationships: DataFrame | list[DataFrame] | None = None,
138-
concurrency: int = 4,
140+
concurrency: int | None = None,
139141
undirected_relationship_types: list[str] | None = None,
140142
) -> GraphV2:
141-
raise NotImplementedError("Graph construction is not yet supported via V2 endpoints.")
143+
gds_arrow_client = GdsArrowClient(self._arrow_client)
144+
job_client = JobClient()
145+
termination_flag = TerminationFlag.create()
146+
147+
if self._show_progress:
148+
progress_bar: ProgressBar = TqdmProgressBar(task_name="Constructing graph", relative_progress=0.0)
149+
else:
150+
progress_bar = NoOpProgressBar()
151+
152+
with progress_bar:
153+
create_job_id: str = gds_arrow_client.create_graph(
154+
graph_name=graph_name,
155+
undirected_relationship_types=undirected_relationship_types or [],
156+
concurrency=concurrency,
157+
)
158+
node_count = nodes.shape[0] if isinstance(nodes, DataFrame) else sum(df.shape[0] for df in nodes)
159+
if isinstance(relationships, DataFrame):
160+
rel_count = relationships.shape[0]
161+
elif relationships is None:
162+
rel_count = 0
163+
relationships = []
164+
else:
165+
rel_count = sum(df.shape[0] for df in relationships)
166+
total_count = node_count + rel_count
167+
168+
gds_arrow_client.upload_nodes(
169+
create_job_id,
170+
nodes,
171+
progress_callback=lambda rows_imported: progress_bar.update(
172+
sub_tasks_description="Uploading nodes", progress=rows_imported / total_count, status="Running"
173+
),
174+
termination_flag=termination_flag,
175+
)
176+
177+
gds_arrow_client.node_load_done(create_job_id)
178+
179+
# skipping progress bar here as we have our own for the overall process
180+
job_client.wait_for_job(
181+
self._arrow_client,
182+
create_job_id,
183+
expected_status="RELATIONSHIP_LOADING",
184+
termination_flag=termination_flag,
185+
show_progress=False,
186+
)
187+
188+
if rel_count > 0:
189+
gds_arrow_client.upload_relationships(
190+
create_job_id,
191+
relationships,
192+
progress_callback=lambda rows_imported: progress_bar.update(
193+
sub_tasks_description="Uploading relationships",
194+
progress=rows_imported / total_count,
195+
status="Running",
196+
),
197+
termination_flag=termination_flag,
198+
)
199+
200+
gds_arrow_client.relationship_load_done(create_job_id)
201+
202+
# will produce a second progress bar to show graph construction on the server side
203+
job_client.wait_for_job(
204+
self._arrow_client, create_job_id, termination_flag=termination_flag, show_progress=True
205+
)
206+
return get_graph(graph_name, self._arrow_client)
142207

143208
def drop(self, G: GraphV2 | str, fail_if_missing: bool = True) -> GraphInfo | None:
144209
graph_name = G.name() if isinstance(G, GraphV2) else G

0 commit comments

Comments
 (0)