Skip to content

Commit aa337e2

Browse files
committed
Add safeguard to raise exception on termination
CI showed the closing the put-stream does not always throw an exception
1 parent 0f61e16 commit aa337e2

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

src/graphdatascience/arrow_client/v2/gds_arrow_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,9 @@ def upload_batch(p: RecordBatch) -> None:
506506
with put_stream:
507507
for partition in batches:
508508
if termination_flag is not None and termination_flag.is_set():
509-
self.abort_job(job_id) # closing the put_stream will raise an error
510-
break
509+
self.abort_job(job_id)
510+
# closing the put_stream should raise an error. this is a safeguard to always signal the termination to the user.
511+
raise RuntimeError(f"Upload for job '{job_id}' was aborted via termination flag.")
511512

512513
upload_batch(partition)
513514
ack_stream.read()

tests/integrationV2/arrow_client/v2/test_gds_arrow_client_v2.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numpy as np
55
import pandas as pd
66
import pytest
7-
from pyarrow.flight import FlightCancelledError
87
from testcontainers.core.network import Network
98

109
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
@@ -107,9 +106,7 @@ def test_project_from_triplets(arrow_client: AuthenticatedArrowClient, gds_arrow
107106
assert G.name() == graph_name
108107

109108

110-
def test_project_from_triplets_interrupted(
111-
arrow_client: AuthenticatedArrowClient, gds_arrow_client: GdsArrowClient
112-
) -> None:
109+
def test_project_from_triplets_interrupted(gds_arrow_client: GdsArrowClient) -> None:
113110
df = pd.DataFrame(
114111
{"sourceNode": np.array([1, 2, 3], dtype=np.int64), "targetNode": np.array([4, 5, 6], dtype=np.int64)}
115112
)
@@ -118,7 +115,7 @@ def test_project_from_triplets_interrupted(
118115
termination_flag.set()
119116

120117
job_id = gds_arrow_client.create_graph_from_triplets("triplets")
121-
with pytest.raises(FlightCancelledError, match=".*Arrow process 'triplets' was aborted.*"):
118+
with pytest.raises(Exception, match=".*was aborted.*"):
122119
gds_arrow_client.upload_triplets(job_id, df, termination_flag=termination_flag)
123120

124121

0 commit comments

Comments
 (0)