From be773870d4fbd325e14ccacb0d52a4afbb611581 Mon Sep 17 00:00:00 2001 From: gyx09212214-prog Date: Thu, 18 Jun 2026 15:05:39 +0800 Subject: [PATCH 1/4] fix(realtime): await background tasks during cleanup --- src/agents/realtime/session.py | 32 ++++++++++++++++-------- tests/realtime/test_session.py | 45 ++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 10 deletions(-) diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index b8eec22a37..c899c901f5 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -1246,11 +1246,26 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None: ) ) - def _cleanup_guardrail_tasks(self) -> None: - for task in self._guardrail_tasks: + async def _cancel_and_await_tasks(self, tasks: set[asyncio.Task[Any]]) -> None: + if not tasks: + return + + current_task = asyncio.current_task() + tasks_to_await: list[asyncio.Task[Any]] = [] + for task in list(tasks): + if task is current_task: + continue if not task.done(): task.cancel() - self._guardrail_tasks.clear() + tasks_to_await.append(task) + + if tasks_to_await: + await asyncio.gather(*tasks_to_await, return_exceptions=True) + + tasks.clear() + + async def _cleanup_guardrail_tasks(self) -> None: + await self._cancel_and_await_tasks(self._guardrail_tasks) def _enqueue_tool_call_task( self, @@ -1316,11 +1331,8 @@ def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None: ) ) - def _cleanup_tool_call_tasks(self) -> None: - for task in self._tool_call_tasks: - if not task.done(): - task.cancel() - self._tool_call_tasks.clear() + async def _cleanup_tool_call_tasks(self) -> None: + await self._cancel_and_await_tasks(self._tool_call_tasks) def _wake_event_iterators(self) -> None: for _ in range(self._event_iterator_waiters): @@ -1333,8 +1345,8 @@ async def _cleanup(self) -> None: return # Cancel and cleanup guardrail tasks - self._cleanup_guardrail_tasks() - self._cleanup_tool_call_tasks() + await self._cleanup_guardrail_tasks() + await self._cleanup_tool_call_tasks() # Remove ourselves as a listener self._model.remove_listener(self) diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 0e4f88bee7..f004477d8d 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -157,6 +157,51 @@ async def test_aiter_exits_waiting_iterators_when_session_closes(): task.result() +@pytest.mark.asyncio +async def test_cleanup_awaits_cancelled_background_tasks(): + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + + guardrail_started = asyncio.Event() + guardrail_finished = asyncio.Event() + tool_started = asyncio.Event() + tool_finished = asyncio.Event() + + async def guardrail_task(): + guardrail_started.set() + try: + await asyncio.Event().wait() + finally: + await asyncio.sleep(0) + guardrail_finished.set() + + async def tool_call_task(): + tool_started.set() + try: + await asyncio.Event().wait() + finally: + await asyncio.sleep(0) + tool_finished.set() + + guardrail = asyncio.create_task(guardrail_task()) + tool_call = asyncio.create_task(tool_call_task()) + session._guardrail_tasks.add(guardrail) + session._tool_call_tasks.add(tool_call) + + await guardrail_started.wait() + await tool_started.wait() + + await session._cleanup() + + assert guardrail.done() + assert tool_call.done() + assert guardrail_finished.is_set() + assert tool_finished.is_set() + assert session._guardrail_tasks == set() + assert session._tool_call_tasks == set() + + @pytest.mark.asyncio async def test_transcription_completed_adds_new_user_item(): model = _DummyModel() From 14d12e901f40d503153dcf500473ab970eb450b6 Mon Sep 17 00:00:00 2001 From: gyx09212214-prog Date: Thu, 18 Jun 2026 15:29:42 +0800 Subject: [PATCH 2/4] fix(realtime): drain cleanup task tracking --- src/agents/realtime/session.py | 34 +++++++++++++++-------------- tests/realtime/test_session.py | 39 ++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 16 deletions(-) diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index c899c901f5..1f918ba8b7 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -1247,22 +1247,24 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None: ) async def _cancel_and_await_tasks(self, tasks: set[asyncio.Task[Any]]) -> None: - if not tasks: - return - current_task = asyncio.current_task() - tasks_to_await: list[asyncio.Task[Any]] = [] - for task in list(tasks): - if task is current_task: - continue - if not task.done(): - task.cancel() - tasks_to_await.append(task) - if tasks_to_await: - await asyncio.gather(*tasks_to_await, return_exceptions=True) + while tasks: + tasks_to_await: list[asyncio.Task[Any]] = [] + for task in list(tasks): + if task is current_task: + tasks.discard(task) + continue + if not task.done(): + task.cancel() + tasks_to_await.append(task) + + if not tasks_to_await: + return - tasks.clear() + await asyncio.gather(*tasks_to_await, return_exceptions=True) + for task in tasks_to_await: + tasks.discard(task) async def _cleanup_guardrail_tasks(self) -> None: await self._cancel_and_await_tasks(self._guardrail_tasks) @@ -1344,13 +1346,13 @@ async def _cleanup(self) -> None: self._wake_event_iterators() return + # Remove ourselves as a listener + self._model.remove_listener(self) + # Cancel and cleanup guardrail tasks await self._cleanup_guardrail_tasks() await self._cleanup_tool_call_tasks() - # Remove ourselves as a listener - self._model.remove_listener(self) - # Close the model connection await self._model.close() diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index f004477d8d..c45d7079f1 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -202,6 +202,45 @@ async def tool_call_task(): assert session._tool_call_tasks == set() +@pytest.mark.asyncio +async def test_cleanup_awaits_background_tasks_added_during_cancellation(): + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + + first_started = asyncio.Event() + second_started = asyncio.Event() + second_finished = asyncio.Event() + + async def second_task(): + second_started.set() + try: + await asyncio.Event().wait() + finally: + await asyncio.sleep(0) + second_finished.set() + + async def first_task(): + first_started.set() + try: + await asyncio.Event().wait() + finally: + task = asyncio.create_task(second_task()) + session._guardrail_tasks.add(task) + await second_started.wait() + + first = asyncio.create_task(first_task()) + session._guardrail_tasks.add(first) + + await first_started.wait() + + await session._cleanup() + + assert first.done() + assert second_finished.is_set() + assert session._guardrail_tasks == set() + + @pytest.mark.asyncio async def test_transcription_completed_adds_new_user_item(): model = _DummyModel() From 1133fb50715a6322cef52a91f8328ae2f8fd6125 Mon Sep 17 00:00:00 2001 From: gyx09212214-prog Date: Thu, 18 Jun 2026 15:38:22 +0800 Subject: [PATCH 3/4] fix(realtime): block cleanup race enqueues --- src/agents/realtime/session.py | 18 +++++++++++++++ tests/realtime/test_session.py | 42 ++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 1f918ba8b7..5e0fbc9269 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -174,6 +174,7 @@ def __init__( asyncio.Queue() ) self._event_iterator_waiters = 0 + self._closing = False self._closed = False self._stored_exception: BaseException | None = None self._pending_tool_calls: dict[ @@ -291,8 +292,14 @@ async def update_agent(self, agent: RealtimeAgent) -> None: ) async def on_event(self, event: RealtimeModelEvent) -> None: + if self._closing or self._closed: + return + await self._put_event(RealtimeRawModelEvent(data=event, info=self._event_info)) + if self._closing or self._closed: + return + if event.type == "error": await self._put_event(RealtimeError(info=self._event_info, error=event.error)) elif event.type == "function_call": @@ -466,6 +473,8 @@ async def on_event(self, event: RealtimeModelEvent) -> None: async def _put_event(self, event: RealtimeSessionEvent) -> None: """Put an event into the queue.""" + if self._closed: + return await self._event_queue.put(event) async def _function_needs_approval( @@ -1220,6 +1229,8 @@ async def _run_output_guardrails(self, text: str, response_id: str) -> bool: def _enqueue_guardrail_task(self, text: str, response_id: str) -> None: # Runs the guardrails in a separate task to avoid blocking the main loop + if self._closing or self._closed: + return task = asyncio.create_task(self._run_output_guardrails(text, response_id)) self._guardrail_tasks.add(task) @@ -1278,6 +1289,11 @@ def _enqueue_tool_call_task( call_id_reserved: bool = False, ) -> None: """Run tool calls in the background to avoid blocking realtime transport.""" + if self._closing or self._closed: + if call_id_reserved: + self._finish_tool_call(event.call_id, mark_completed=False) + return + handle_kwargs: dict[str, Any] = {"agent_snapshot": agent_snapshot} if from_pending_approval: handle_kwargs["from_pending_approval"] = True @@ -1346,6 +1362,8 @@ async def _cleanup(self) -> None: self._wake_event_iterators() return + self._closing = True + # Remove ourselves as a listener self._model.remove_listener(self) diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index c45d7079f1..d807434758 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -769,6 +769,48 @@ async def test_function_call_event_runs_async_by_default(self, mock_model, mock_ assert isinstance(raw_event, RealtimeRawModelEvent) assert raw_event.data == function_call_event + @pytest.mark.asyncio + async def test_cleanup_prevents_in_flight_function_call_from_enqueuing_task( + self, mock_model, mock_agent + ): + session = RealtimeSession(mock_model, mock_agent, None) + function_call_event = RealtimeModelToolCallEvent( + name="test_function", + call_id="call_cleanup_race", + arguments="{}", + ) + + first_put_started = asyncio.Event() + release_first_put = asyncio.Event() + tool_task_started = asyncio.Event() + original_put_event = session._put_event + + async def blocked_put_event(event): + first_put_started.set() + await release_first_put.wait() + await original_put_event(event) + + async def blocked_handle_tool_call(*_args, **_kwargs): + tool_task_started.set() + await asyncio.Event().wait() + + with pytest.MonkeyPatch().context() as m: + handle_tool_call_mock = AsyncMock(side_effect=blocked_handle_tool_call) + m.setattr(session, "_put_event", blocked_put_event) + m.setattr(session, "_handle_tool_call", handle_tool_call_mock) + + on_event_task = asyncio.create_task(session.on_event(function_call_event)) + await first_put_started.wait() + + await session._cleanup() + release_first_put.set() + await on_event_task + await asyncio.sleep(0) + + assert session._tool_call_tasks == set() + assert not tool_task_started.is_set() + handle_tool_call_mock.assert_not_awaited() + class TestHistoryManagement: """Test suite for history management and audio transcription in From 598f16977ec883d664a1ab15568fa8351e00275a Mon Sep 17 00:00:00 2001 From: gyx09212214-prog Date: Tue, 23 Jun 2026 00:21:21 +0800 Subject: [PATCH 4/4] fix(realtime): bound background cleanup during close --- src/agents/realtime/session.py | 145 +++++++++++++++++++++++++++++---- tests/realtime/test_session.py | 131 +++++++++++++++++++++++++++-- 2 files changed, 250 insertions(+), 26 deletions(-) diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 5e0fbc9269..facc587172 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -80,6 +80,7 @@ class _RealtimeSessionClosedSentinel: _REALTIME_SESSION_CLOSED_SENTINEL = _RealtimeSessionClosedSentinel() +_CLEANUP_BACKGROUND_TASK_TIMEOUT = 5.0 def _serialize_tool_output(output: Any) -> str: @@ -473,7 +474,7 @@ async def on_event(self, event: RealtimeModelEvent) -> None: async def _put_event(self, event: RealtimeSessionEvent) -> None: """Put an event into the queue.""" - if self._closed: + if self._closing or self._closed: return await self._event_queue.put(event) @@ -541,6 +542,8 @@ async def _maybe_request_tool_approval( ) needs_approval = await self._function_needs_approval(function_tool, tool_call) + if self._closing or self._closed: + return None if not needs_approval: return True @@ -561,6 +564,8 @@ async def _maybe_request_tool_approval( tool_call=tool_call, agent=agent, ) + if self._closing or self._closed: + return None if rejected_message is not None: return self._build_realtime_tool_output( tool=function_tool, @@ -569,6 +574,8 @@ async def _maybe_request_tool_approval( output=rejected_message, ) + if self._closing or self._closed: + return None self._pending_tool_calls[tool_call.call_id] = ( tool_call, agent, @@ -675,17 +682,27 @@ async def _send_tool_rejection( ) async def _send_tool_output_completion(self, pending_output: _PendingToolOutput) -> None: + if self._closing or self._closed: + return + call_id = pending_output.tool_call.call_id self._pending_tool_outputs[call_id] = pending_output try: await self._send_pending_tool_output(pending_output) except Exception as exc: + if self._closing or self._closed: + self._pending_tool_outputs.pop(call_id, None) + return raise _PendingToolOutputSendError(call_id, exc) from exc self._pending_tool_outputs.pop(call_id, None) async def _send_pending_tool_output(self, pending_output: _PendingToolOutput) -> None: + if self._closing or self._closed: + return if pending_output.session_update is not None: await self._model.send_event(pending_output.session_update) + if self._closing or self._closed: + return await self._model.send_event( RealtimeModelSendToolOutput( tool_call=pending_output.tool_call, @@ -693,6 +710,8 @@ async def _send_pending_tool_output(self, pending_output: _PendingToolOutput) -> start_response=pending_output.start_response, ) ) + if self._closing or self._closed: + return if pending_output.tool_end_event is not None: await self._put_event(pending_output.tool_end_event) @@ -826,6 +845,9 @@ async def _handle_tool_call( agent.get_all_tools(self._context_wrapper), self._get_handoffs(agent, self._context_wrapper), ) + if self._closing or self._closed: + return + function_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)} handoff_map = {handoff.tool_name: handoff for handoff in handoffs} @@ -834,6 +856,8 @@ async def _handle_tool_call( approval_status = await self._maybe_request_tool_approval( event, function_tool=func_tool, agent=agent ) + if self._closing or self._closed: + return if isinstance(approval_status, _PendingToolOutput): await self._send_tool_output_completion(approval_status) mark_completed = True @@ -850,6 +874,8 @@ async def _handle_tool_call( tool_call=event, agent=agent, ) + if self._closing or self._closed: + return if rejected_message is not None: await self._send_tool_output_completion( self._build_realtime_tool_output( @@ -870,6 +896,8 @@ async def _handle_tool_call( arguments=event.arguments, ) ) + if self._closing or self._closed: + return tool_context = ToolContext( context=self._context_wrapper.context, @@ -884,6 +912,8 @@ async def _handle_tool_call( context=tool_context, arguments=event.arguments, ) + if self._closing or self._closed: + return await self._send_tool_output_completion( _PendingToolOutput( @@ -913,6 +943,8 @@ async def _handle_tool_call( # Execute the handoff to get the new agent result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments) + if self._closing or self._closed: + return if not isinstance(result, RealtimeAgent): raise UserError( f"Handoff {handoff.tool_name} returned invalid result: {type(result)}" @@ -929,6 +961,8 @@ async def _handle_tool_call( starting_settings=None, agent=self._current_agent, ) + if self._closing or self._closed: + return # Send handoff event await self._put_event( @@ -972,6 +1006,8 @@ async def _handle_tool_call( self._finish_tool_call(event.call_id, mark_completed=mark_completed) def _begin_tool_call(self, call_id: str, *, from_pending_approval: bool) -> bool: + if self._closing or self._closed: + return False if call_id in self._active_tool_call_ids or call_id in self._completed_tool_call_ids: return False if not from_pending_approval and call_id in self._pending_tool_calls: @@ -981,7 +1017,7 @@ def _begin_tool_call(self, call_id: str, *, from_pending_approval: bool) -> bool def _finish_tool_call(self, call_id: str, *, mark_completed: bool) -> None: self._active_tool_call_ids.discard(call_id) - if mark_completed: + if mark_completed and not self._closing and not self._closed: self._completed_tool_call_ids.add(call_id) @classmethod @@ -1197,7 +1233,7 @@ async def _run_output_guardrails(self, text: str, response_id: str) -> bool: if triggered_results: # Double-check: bail if already interrupted for this response - if response_id in self._interrupted_response_ids: + if response_id in self._interrupted_response_ids or self._closing or self._closed: return False # Mark as interrupted immediately (before any awaits) to minimize race window @@ -1213,9 +1249,13 @@ async def _run_output_guardrails(self, text: str, response_id: str) -> bool: ) # Interrupt the model + if self._closing or self._closed: + return False await self._model.send_event(RealtimeModelSendInterrupt(force_response_cancel=True)) # Send guardrail triggered message + if self._closing or self._closed: + return False guardrail_names = [result.guardrail.get_name() for result in triggered_results] await self._model.send_event( RealtimeModelSendUserInput( @@ -1243,6 +1283,10 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None: # Remove from tracking set self._guardrail_tasks.discard(task) + if self._closing or self._closed: + self._retrieve_task_exception(task) + return + # Check for exceptions and propagate as events if not task.cancelled(): exception = task.exception() @@ -1257,28 +1301,88 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None: ) ) - async def _cancel_and_await_tasks(self, tasks: set[asyncio.Task[Any]]) -> None: - current_task = asyncio.current_task() + def _retrieve_task_exception(self, task: asyncio.Task[Any]) -> None: + if task.cancelled(): + return + task.exception() - while tasks: - tasks_to_await: list[asyncio.Task[Any]] = [] - for task in list(tasks): + def _discard_task_from_sets( + self, + task: asyncio.Task[Any], + task_sets: Sequence[set[asyncio.Task[Any]]], + ) -> None: + for task_set in task_sets: + task_set.discard(task) + + def _cancel_and_detach_remaining_tasks( + self, + task_sets: Sequence[set[asyncio.Task[Any]]], + current_task: asyncio.Task[Any] | None, + ) -> int: + detached = 0 + for task_set in task_sets: + for task in list(task_set): if task is current_task: - tasks.discard(task) + task_set.discard(task) continue if not task.done(): task.cancel() - tasks_to_await.append(task) + task_set.discard(task) + detached += 1 + return detached + + async def _cancel_and_await_tasks( + self, + task_sets: Sequence[set[asyncio.Task[Any]]], + *, + timeout: float | None = None, + ) -> None: + timeout = _CLEANUP_BACKGROUND_TASK_TIMEOUT if timeout is None else timeout + current_task = asyncio.current_task() + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout + + while any(task_sets): + tasks_to_await: list[asyncio.Task[Any]] = [] + for task_set in task_sets: + for task in list(task_set): + if task is current_task: + task_set.discard(task) + continue + if not task.done(): + task.cancel() + tasks_to_await.append(task) if not tasks_to_await: return - await asyncio.gather(*tasks_to_await, return_exceptions=True) - for task in tasks_to_await: - tasks.discard(task) + remaining_timeout = deadline - loop.time() + if remaining_timeout <= 0: + detached = self._cancel_and_detach_remaining_tasks(task_sets, current_task) + if detached: + logger.warning( + "Timed out waiting for %d realtime background task(s) to cancel.", + detached, + ) + return + + done, pending = await asyncio.wait(tasks_to_await, timeout=remaining_timeout) + for task in done: + self._discard_task_from_sets(task, task_sets) + self._retrieve_task_exception(task) + + if pending: + for task in pending: + self._discard_task_from_sets(task, task_sets) + detached = self._cancel_and_detach_remaining_tasks(task_sets, current_task) + logger.warning( + "Timed out waiting for %d realtime background task(s) to cancel.", + len(pending) + detached, + ) + return async def _cleanup_guardrail_tasks(self) -> None: - await self._cancel_and_await_tasks(self._guardrail_tasks) + await self._cancel_and_await_tasks((self._guardrail_tasks,)) def _enqueue_tool_call_task( self, @@ -1307,6 +1411,10 @@ def _enqueue_tool_call_task( def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None: self._tool_call_tasks.discard(task) + if self._closing or self._closed: + self._retrieve_task_exception(task) + return + if task.cancelled(): return @@ -1350,7 +1458,7 @@ def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None: ) async def _cleanup_tool_call_tasks(self) -> None: - await self._cancel_and_await_tasks(self._tool_call_tasks) + await self._cancel_and_await_tasks((self._tool_call_tasks,)) def _wake_event_iterators(self) -> None: for _ in range(self._event_iterator_waiters): @@ -1367,9 +1475,8 @@ async def _cleanup(self) -> None: # Remove ourselves as a listener self._model.remove_listener(self) - # Cancel and cleanup guardrail tasks - await self._cleanup_guardrail_tasks() - await self._cleanup_tool_call_tasks() + # Cancel and cleanup guardrail/tool-call tasks together so one group cannot block the other. + await self._cancel_and_await_tasks((self._guardrail_tasks, self._tool_call_tasks)) # Close the model connection await self._model.close() @@ -1377,6 +1484,8 @@ async def _cleanup(self) -> None: # Clear pending approval tracking self._pending_tool_calls.clear() self._pending_tool_outputs.clear() + self._active_tool_call_ids.clear() + self._completed_tool_call_ids.clear() # Mark as closed self._closed = True diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index d807434758..cac7f7518e 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -158,22 +158,28 @@ async def test_aiter_exits_waiting_iterators_when_session_closes(): @pytest.mark.asyncio -async def test_cleanup_awaits_cancelled_background_tasks(): - model = _DummyModel() - agent = RealtimeAgent(name="agent") - session = RealtimeSession(model, agent, None) - +async def test_cleanup_cancels_task_groups_concurrently_and_awaits_finalizers(): guardrail_started = asyncio.Event() guardrail_finished = asyncio.Event() tool_started = asyncio.Event() + tool_cancelled = asyncio.Event() tool_finished = asyncio.Event() + class CloseAssertingModel(_DummyModel): + async def close(self): + assert guardrail_finished.is_set() + assert tool_finished.is_set() + + model = CloseAssertingModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + async def guardrail_task(): guardrail_started.set() try: await asyncio.Event().wait() finally: - await asyncio.sleep(0) + await tool_cancelled.wait() guardrail_finished.set() async def tool_call_task(): @@ -181,7 +187,7 @@ async def tool_call_task(): try: await asyncio.Event().wait() finally: - await asyncio.sleep(0) + tool_cancelled.set() tool_finished.set() guardrail = asyncio.create_task(guardrail_task()) @@ -192,7 +198,7 @@ async def tool_call_task(): await guardrail_started.wait() await tool_started.wait() - await session._cleanup() + await asyncio.wait_for(session._cleanup(), timeout=1) assert guardrail.done() assert tool_call.done() @@ -202,6 +208,115 @@ async def tool_call_task(): assert session._tool_call_tasks == set() +@pytest.mark.asyncio +async def test_cleanup_timeout_bounds_cancellation_resistant_task(): + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + + task_started = asyncio.Event() + cancel_seen = asyncio.Event() + release_task = asyncio.Event() + + async def cancellation_resistant_task(): + task_started.set() + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + cancel_seen.set() + await release_task.wait() + raise + + task = asyncio.create_task(cancellation_resistant_task()) + session._guardrail_tasks.add(task) + + await task_started.wait() + + with pytest.MonkeyPatch().context() as m: + m.setattr("agents.realtime.session._CLEANUP_BACKGROUND_TASK_TIMEOUT", 0.01) + await asyncio.wait_for(session._cleanup(), timeout=1) + + await asyncio.wait_for(cancel_seen.wait(), timeout=1) + assert session._closed + assert session._guardrail_tasks == set() + assert not task.done() + + release_task.set() + await asyncio.gather(task, return_exceptions=True) + + +@pytest.mark.asyncio +async def test_cleanup_skips_tracked_task_that_calls_close(): + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + + close_finished = asyncio.Event() + + async def tracked_task(): + await session.close() + close_finished.set() + + task = asyncio.create_task(tracked_task()) + session._tool_call_tasks.add(task) + task.add_done_callback(session._on_tool_call_task_done) + + await asyncio.wait_for(close_finished.wait(), timeout=1) + await task + + assert session._closed + assert session._tool_call_tasks == set() + + +@pytest.mark.asyncio +async def test_late_task_failures_after_cleanup_timeout_do_not_mutate_session(): + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + + guardrail_started = asyncio.Event() + guardrail_cancelled = asyncio.Event() + tool_started = asyncio.Event() + tool_cancelled = asyncio.Event() + release_tasks = asyncio.Event() + + async def fail_after_timeout(started: asyncio.Event, cancelled: asyncio.Event): + started.set() + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + cancelled.set() + await release_tasks.wait() + raise RuntimeError("late failure") from None + + guardrail = asyncio.create_task(fail_after_timeout(guardrail_started, guardrail_cancelled)) + tool_call = asyncio.create_task(fail_after_timeout(tool_started, tool_cancelled)) + session._guardrail_tasks.add(guardrail) + session._tool_call_tasks.add(tool_call) + guardrail.add_done_callback(session._on_guardrail_task_done) + tool_call.add_done_callback(session._on_tool_call_task_done) + + await guardrail_started.wait() + await tool_started.wait() + + with pytest.MonkeyPatch().context() as m: + m.setattr("agents.realtime.session._CLEANUP_BACKGROUND_TASK_TIMEOUT", 0.01) + await asyncio.wait_for(session._cleanup(), timeout=1) + + await asyncio.wait_for(guardrail_cancelled.wait(), timeout=1) + await asyncio.wait_for(tool_cancelled.wait(), timeout=1) + + release_tasks.set() + results = await asyncio.gather(guardrail, tool_call, return_exceptions=True) + await asyncio.sleep(0) + + assert all(isinstance(result, RuntimeError) for result in results) + assert session._stored_exception is None + assert session._event_queue.empty() + assert session._guardrail_tasks == set() + assert session._tool_call_tasks == set() + + @pytest.mark.asyncio async def test_cleanup_awaits_background_tasks_added_during_cancellation(): model = _DummyModel()