Skip to content

Commit bba779b

Browse files
Harden page-batch callback contract validation
Co-authored-by: Shri Sukhani <shrisukhani@users.noreply.github.com>
1 parent 49b51cd commit bba779b

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

hyperbrowser/client/polling.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,17 @@ def collect_paginated_results(
397397
operation_name=operation_name,
398398
)
399399
current_page_batch = get_current_page_batch(page_response)
400+
_ensure_non_awaitable(
401+
current_page_batch,
402+
callback_name="get_current_page_batch",
403+
operation_name=operation_name,
404+
)
400405
total_page_batches = get_total_page_batches(page_response)
406+
_ensure_non_awaitable(
407+
total_page_batches,
408+
callback_name="get_total_page_batches",
409+
operation_name=operation_name,
410+
)
401411
_validate_page_batch_values(
402412
operation_name=operation_name,
403413
current_page_batch=current_page_batch,
@@ -477,7 +487,17 @@ async def collect_paginated_results_async(
477487
operation_name=operation_name,
478488
)
479489
current_page_batch = get_current_page_batch(page_response)
490+
_ensure_non_awaitable(
491+
current_page_batch,
492+
callback_name="get_current_page_batch",
493+
operation_name=operation_name,
494+
)
480495
total_page_batches = get_total_page_batches(page_response)
496+
_ensure_non_awaitable(
497+
total_page_batches,
498+
callback_name="get_total_page_batches",
499+
operation_name=operation_name,
500+
)
481501
_validate_page_batch_values(
482502
operation_name=operation_name,
483503
current_page_batch=current_page_batch,

tests/test_polling.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,40 @@ def on_page_success(response: dict) -> object:
390390
assert callback_attempts["count"] == 1
391391

392392

393+
def test_collect_paginated_results_rejects_awaitable_current_page_callback_result():
394+
with pytest.raises(
395+
HyperbrowserError,
396+
match="get_current_page_batch must return a non-awaitable result",
397+
):
398+
collect_paginated_results(
399+
operation_name="sync paginated awaitable current page callback",
400+
get_next_page=lambda page: {"current": 1, "total": 1, "items": []},
401+
get_current_page_batch=lambda response: asyncio.sleep(0), # type: ignore[return-value]
402+
get_total_page_batches=lambda response: response["total"],
403+
on_page_success=lambda response: None,
404+
max_wait_seconds=1.0,
405+
max_attempts=5,
406+
retry_delay_seconds=0.0001,
407+
)
408+
409+
410+
def test_collect_paginated_results_rejects_awaitable_total_pages_callback_result():
411+
with pytest.raises(
412+
HyperbrowserError,
413+
match="get_total_page_batches must return a non-awaitable result",
414+
):
415+
collect_paginated_results(
416+
operation_name="sync paginated awaitable total pages callback",
417+
get_next_page=lambda page: {"current": 1, "total": 1, "items": []},
418+
get_current_page_batch=lambda response: response["current"],
419+
get_total_page_batches=lambda response: asyncio.sleep(0), # type: ignore[return-value]
420+
on_page_success=lambda response: None,
421+
max_wait_seconds=1.0,
422+
max_attempts=5,
423+
retry_delay_seconds=0.0001,
424+
)
425+
426+
393427
def test_collect_paginated_results_allows_single_page_on_zero_max_wait():
394428
collected = []
395429

@@ -486,6 +520,50 @@ def on_page_success(response: dict) -> object:
486520
asyncio.run(run())
487521

488522

523+
def test_collect_paginated_results_async_rejects_awaitable_current_page_callback_result():
524+
async def run() -> None:
525+
with pytest.raises(
526+
HyperbrowserError,
527+
match="get_current_page_batch must return a non-awaitable result",
528+
):
529+
await collect_paginated_results_async(
530+
operation_name="async paginated awaitable current page callback",
531+
get_next_page=lambda page: asyncio.sleep(
532+
0, result={"current": 1, "total": 1, "items": []}
533+
),
534+
get_current_page_batch=lambda response: asyncio.sleep(0), # type: ignore[return-value]
535+
get_total_page_batches=lambda response: response["total"],
536+
on_page_success=lambda response: None,
537+
max_wait_seconds=1.0,
538+
max_attempts=5,
539+
retry_delay_seconds=0.0001,
540+
)
541+
542+
asyncio.run(run())
543+
544+
545+
def test_collect_paginated_results_async_rejects_awaitable_total_pages_callback_result():
546+
async def run() -> None:
547+
with pytest.raises(
548+
HyperbrowserError,
549+
match="get_total_page_batches must return a non-awaitable result",
550+
):
551+
await collect_paginated_results_async(
552+
operation_name="async paginated awaitable total pages callback",
553+
get_next_page=lambda page: asyncio.sleep(
554+
0, result={"current": 1, "total": 1, "items": []}
555+
),
556+
get_current_page_batch=lambda response: response["current"],
557+
get_total_page_batches=lambda response: asyncio.sleep(0), # type: ignore[return-value]
558+
on_page_success=lambda response: None,
559+
max_wait_seconds=1.0,
560+
max_attempts=5,
561+
retry_delay_seconds=0.0001,
562+
)
563+
564+
asyncio.run(run())
565+
566+
489567
def test_collect_paginated_results_async_allows_single_page_on_zero_max_wait():
490568
async def run() -> None:
491569
collected = []

0 commit comments

Comments
 (0)