Skip to content

Commit 49b51cd

Browse files
Fail fast on polling callback contract violations
Co-authored-by: Shri Sukhani <shrisukhani@users.noreply.github.com>
1 parent 2af7dad commit 49b51cd

2 files changed

Lines changed: 80 additions & 25 deletions

File tree

hyperbrowser/client/polling.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
_MAX_OPERATION_NAME_LENGTH = 200
1616

1717

18+
class _NonRetryablePollingError(HyperbrowserError):
19+
pass
20+
21+
1822
def _validate_non_negative_real(value: float, *, field_name: str) -> None:
1923
if isinstance(value, bool) or not isinstance(value, Real):
2024
raise HyperbrowserError(f"{field_name} must be a number")
@@ -52,7 +56,7 @@ def _ensure_boolean_terminal_result(result: object, *, operation_name: str) -> b
5256
result, callback_name="is_terminal_status", operation_name=operation_name
5357
)
5458
if not isinstance(result, bool):
55-
raise HyperbrowserError(
59+
raise _NonRetryablePollingError(
5660
f"is_terminal_status must return a boolean for {operation_name}"
5761
)
5862
return result
@@ -63,15 +67,17 @@ def _ensure_status_string(status: object, *, operation_name: str) -> str:
6367
status, callback_name="get_status", operation_name=operation_name
6468
)
6569
if not isinstance(status, str):
66-
raise HyperbrowserError(f"get_status must return a string for {operation_name}")
70+
raise _NonRetryablePollingError(
71+
f"get_status must return a string for {operation_name}"
72+
)
6773
return status
6874

6975

7076
def _ensure_awaitable(
7177
result: object, *, callback_name: str, operation_name: str
7278
) -> Awaitable[object]:
7379
if not inspect.isawaitable(result):
74-
raise HyperbrowserError(
80+
raise _NonRetryablePollingError(
7581
f"{callback_name} must return an awaitable for {operation_name}"
7682
)
7783
return result
@@ -83,7 +89,7 @@ def _ensure_non_awaitable(
8389
if inspect.isawaitable(result):
8490
if inspect.iscoroutine(result):
8591
result.close()
86-
raise HyperbrowserError(
92+
raise _NonRetryablePollingError(
8793
f"{callback_name} must return a non-awaitable result for {operation_name}"
8894
)
8995

@@ -182,9 +188,6 @@ def poll_until_terminal_status(
182188
while True:
183189
try:
184190
status = get_status()
185-
_ensure_non_awaitable(
186-
status, callback_name="get_status", operation_name=operation_name
187-
)
188191
failures = 0
189192
except Exception as exc:
190193
failures += 1
@@ -227,19 +230,21 @@ def retry_operation(
227230
while True:
228231
try:
229232
operation_result = operation()
230-
_ensure_non_awaitable(
231-
operation_result,
232-
callback_name="operation",
233-
operation_name=operation_name,
234-
)
235-
return operation_result
236233
except Exception as exc:
237234
failures += 1
238235
if failures >= max_attempts:
239236
raise HyperbrowserError(
240237
f"{operation_name} failed after {max_attempts} attempts: {exc}"
241238
) from exc
242239
time.sleep(retry_delay_seconds)
240+
continue
241+
242+
_ensure_non_awaitable(
243+
operation_result,
244+
callback_name="operation",
245+
operation_name=operation_name,
246+
)
247+
return operation_result
243248

244249

245250
async def poll_until_terminal_status_async(
@@ -412,6 +417,8 @@ def collect_paginated_results(
412417
else:
413418
stagnation_failures = 0
414419
should_sleep = current_page_batch < total_page_batches
420+
except _NonRetryablePollingError:
421+
raise
415422
except HyperbrowserPollingError:
416423
raise
417424
except Exception as exc:
@@ -490,6 +497,8 @@ async def collect_paginated_results_async(
490497
else:
491498
stagnation_failures = 0
492499
should_sleep = current_page_batch < total_page_batches
500+
except _NonRetryablePollingError:
501+
raise
493502
except HyperbrowserPollingError:
494503
raise
495504
except Exception as exc:

tests/test_polling.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -120,18 +120,26 @@ def test_poll_until_terminal_status_rejects_awaitable_status_callback_result():
120120
async def async_get_status() -> str:
121121
return "completed"
122122

123+
attempts = {"count": 0}
124+
125+
def get_status() -> object:
126+
attempts["count"] += 1
127+
return async_get_status()
128+
123129
with pytest.raises(
124130
HyperbrowserError, match="get_status must return a non-awaitable result"
125131
):
126132
poll_until_terminal_status(
127133
operation_name="sync poll awaitable callback",
128-
get_status=lambda: async_get_status(), # type: ignore[return-value]
134+
get_status=get_status, # type: ignore[arg-type]
129135
is_terminal_status=lambda value: value == "completed",
130136
poll_interval_seconds=0.0001,
131137
max_wait_seconds=1.0,
132-
max_status_failures=1,
138+
max_status_failures=5,
133139
)
134140

141+
assert attempts["count"] == 1
142+
135143

136144
def test_retry_operation_retries_and_returns_value():
137145
attempts = {"count": 0}
@@ -166,16 +174,24 @@ def test_retry_operation_rejects_awaitable_operation_result():
166174
async def async_operation() -> str:
167175
return "ok"
168176

177+
attempts = {"count": 0}
178+
179+
def operation() -> object:
180+
attempts["count"] += 1
181+
return async_operation()
182+
169183
with pytest.raises(
170184
HyperbrowserError, match="operation must return a non-awaitable result"
171185
):
172186
retry_operation(
173187
operation_name="sync retry awaitable callback",
174-
operation=lambda: async_operation(), # type: ignore[return-value]
175-
max_attempts=2,
188+
operation=operation, # type: ignore[arg-type]
189+
max_attempts=5,
176190
retry_delay_seconds=0.0001,
177191
)
178192

193+
assert attempts["count"] == 1
194+
179195

180196
def test_async_polling_and_retry_helpers():
181197
async def run() -> None:
@@ -327,22 +343,36 @@ def test_collect_paginated_results_rejects_awaitable_page_callback_result():
327343
async def async_get_page() -> dict:
328344
return {"current": 1, "total": 1, "items": []}
329345

346+
attempts = {"count": 0}
347+
348+
def get_next_page(page: int) -> object:
349+
attempts["count"] += 1
350+
return async_get_page()
351+
330352
with pytest.raises(
331353
HyperbrowserError, match="get_next_page must return a non-awaitable result"
332354
):
333355
collect_paginated_results(
334356
operation_name="sync paginated awaitable page callback",
335-
get_next_page=lambda page: async_get_page(), # type: ignore[return-value]
357+
get_next_page=get_next_page, # type: ignore[arg-type]
336358
get_current_page_batch=lambda response: response["current"],
337359
get_total_page_batches=lambda response: response["total"],
338360
on_page_success=lambda response: None,
339361
max_wait_seconds=1.0,
340-
max_attempts=2,
362+
max_attempts=5,
341363
retry_delay_seconds=0.0001,
342364
)
343365

366+
assert attempts["count"] == 1
367+
344368

345369
def test_collect_paginated_results_rejects_awaitable_on_page_success_result():
370+
callback_attempts = {"count": 0}
371+
372+
def on_page_success(response: dict) -> object:
373+
callback_attempts["count"] += 1
374+
return asyncio.sleep(0)
375+
346376
with pytest.raises(
347377
HyperbrowserError, match="on_page_success must return a non-awaitable result"
348378
):
@@ -351,12 +381,14 @@ def test_collect_paginated_results_rejects_awaitable_on_page_success_result():
351381
get_next_page=lambda page: {"current": 1, "total": 1, "items": []},
352382
get_current_page_batch=lambda response: response["current"],
353383
get_total_page_batches=lambda response: response["total"],
354-
on_page_success=lambda response: asyncio.sleep(0), # type: ignore[return-value]
384+
on_page_success=on_page_success, # type: ignore[arg-type]
355385
max_wait_seconds=1.0,
356-
max_attempts=2,
386+
max_attempts=5,
357387
retry_delay_seconds=0.0001,
358388
)
359389

390+
assert callback_attempts["count"] == 1
391+
360392

361393
def test_collect_paginated_results_allows_single_page_on_zero_max_wait():
362394
collected = []
@@ -401,25 +433,38 @@ async def run() -> None:
401433

402434
def test_collect_paginated_results_async_rejects_non_awaitable_page_callback_result():
403435
async def run() -> None:
436+
attempts = {"count": 0}
437+
438+
def get_next_page(page: int) -> object:
439+
attempts["count"] += 1
440+
return {"current": 1, "total": 1, "items": []}
441+
404442
with pytest.raises(
405443
HyperbrowserError, match="get_next_page must return an awaitable"
406444
):
407445
await collect_paginated_results_async(
408446
operation_name="async paginated awaitable validation",
409-
get_next_page=lambda page: {"current": 1, "total": 1, "items": []}, # type: ignore[return-value]
447+
get_next_page=get_next_page, # type: ignore[arg-type]
410448
get_current_page_batch=lambda response: response["current"],
411449
get_total_page_batches=lambda response: response["total"],
412450
on_page_success=lambda response: None,
413451
max_wait_seconds=1.0,
414-
max_attempts=2,
452+
max_attempts=5,
415453
retry_delay_seconds=0.0001,
416454
)
455+
assert attempts["count"] == 1
417456

418457
asyncio.run(run())
419458

420459

421460
def test_collect_paginated_results_async_rejects_awaitable_on_page_success_result():
422461
async def run() -> None:
462+
callback_attempts = {"count": 0}
463+
464+
def on_page_success(response: dict) -> object:
465+
callback_attempts["count"] += 1
466+
return asyncio.sleep(0)
467+
423468
with pytest.raises(
424469
HyperbrowserError,
425470
match="on_page_success must return a non-awaitable result",
@@ -431,11 +476,12 @@ async def run() -> None:
431476
),
432477
get_current_page_batch=lambda response: response["current"],
433478
get_total_page_batches=lambda response: response["total"],
434-
on_page_success=lambda response: asyncio.sleep(0), # type: ignore[return-value]
479+
on_page_success=on_page_success, # type: ignore[arg-type]
435480
max_wait_seconds=1.0,
436-
max_attempts=2,
481+
max_attempts=5,
437482
retry_delay_seconds=0.0001,
438483
)
484+
assert callback_attempts["count"] == 1
439485

440486
asyncio.run(run())
441487

0 commit comments

Comments
 (0)