diff --git a/tests/test_query.py b/tests/test_query.py index 16c088b1..59c1b90e 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -547,7 +547,7 @@ async def _test(): class TestQueryCrossTaskCleanup: - """Tests for cross-task cleanup of Query task groups (issue #454). + """Tests for cross-task cleanup of Query task groups (issues #454, #983). When a user breaks out of an async for loop over process_query(), Python finalizes the async generator in a different task than the one that called @@ -598,6 +598,86 @@ async def _test(): anyio.run(_test) + def test_close_during_generator_teardown_asyncio(self): + """Regression test for #983: close() called from an async generator's + finally block (triggered by GeneratorExit when the consumer breaks the + loop) must not raise RuntimeError about cross-task cancel scope exit. + + This simulates the exact pattern in _process_query_inner where + query.close() is called in a finally block during generator teardown + on a different task than Query.start(). + """ + import asyncio + + async def _test(): + mock_transport = _make_mock_transport(messages=[]) + q = Query(transport=mock_transport, is_streaming_mode=True) + + await q.start() + + errors: list[BaseException] = [] + + # Wrapping generator that simulates _process_query_inner's + # try/finally pattern: close() runs in the finally block when + # GeneratorExit is thrown by the consumer breaking the loop. + async def wrapping_gen(): + try: + async for msg in q.receive_messages(): + yield msg + finally: + await q.close() + + async def consumer(): + try: + async for _ in wrapping_gen(): + break # GeneratorExit -> wrapping_gen.finally -> q.close() + except Exception as e: + errors.append(e) + + # Run the consumer on a separate task so close() runs on a + # different task than start() -- the exact scenario from #983. + task = asyncio.create_task(consumer()) + await task + + assert errors == [], ( + f"close() during generator teardown raised: {errors}" + ) + + asyncio.run(_test()) + + def test_close_during_generator_teardown_trio(self): + """Trio parity for the #983 regression test above.""" + async def _test(): + mock_transport = _make_mock_transport(messages=[]) + q = Query(transport=mock_transport, is_streaming_mode=True) + + await q.start() + + errors: list[BaseException] = [] + + async def wrapping_gen(): + try: + async for msg in q.receive_messages(): + yield msg + finally: + await q.close() + + async def consumer(): + try: + async for _ in wrapping_gen(): + break + except Exception as e: + errors.append(e) + + async with anyio.create_task_group() as tg: + tg.start_soon(consumer) + + assert errors == [], ( + f"close() during generator teardown raised: {errors}" + ) + + anyio.run(_test, backend="trio") + @pytest.mark.filterwarnings( "ignore:Unclosed