Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 156 additions & 15 deletions src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class _RealtimeSessionClosedSentinel:


_REALTIME_SESSION_CLOSED_SENTINEL = _RealtimeSessionClosedSentinel()
_CLEANUP_BACKGROUND_TASK_TIMEOUT = 5.0


def _serialize_tool_output(output: Any) -> str:
Expand Down Expand Up @@ -174,6 +175,7 @@ def __init__(
asyncio.Queue()
)
self._event_iterator_waiters = 0
self._closing = False
self._closed = False
self._stored_exception: BaseException | None = None
self._pending_tool_calls: dict[
Expand Down Expand Up @@ -291,8 +293,14 @@ async def update_agent(self, agent: RealtimeAgent) -> None:
)

async def on_event(self, event: RealtimeModelEvent) -> None:
if self._closing or self._closed:
return

await self._put_event(RealtimeRawModelEvent(data=event, info=self._event_info))

if self._closing or self._closed:
return

if event.type == "error":
await self._put_event(RealtimeError(info=self._event_info, error=event.error))
elif event.type == "function_call":
Expand Down Expand Up @@ -466,6 +474,8 @@ async def on_event(self, event: RealtimeModelEvent) -> None:

async def _put_event(self, event: RealtimeSessionEvent) -> None:
"""Put an event into the queue."""
if self._closing or self._closed:
return
await self._event_queue.put(event)

async def _function_needs_approval(
Expand Down Expand Up @@ -532,6 +542,8 @@ async def _maybe_request_tool_approval(
)

needs_approval = await self._function_needs_approval(function_tool, tool_call)
if self._closing or self._closed:
return None
if not needs_approval:
return True

Expand All @@ -552,6 +564,8 @@ async def _maybe_request_tool_approval(
tool_call=tool_call,
agent=agent,
)
if self._closing or self._closed:
return None
if rejected_message is not None:
return self._build_realtime_tool_output(
tool=function_tool,
Expand All @@ -560,6 +574,8 @@ async def _maybe_request_tool_approval(
output=rejected_message,
)

if self._closing or self._closed:
return None
self._pending_tool_calls[tool_call.call_id] = (
tool_call,
agent,
Expand Down Expand Up @@ -666,24 +682,36 @@ async def _send_tool_rejection(
)

async def _send_tool_output_completion(self, pending_output: _PendingToolOutput) -> None:
if self._closing or self._closed:
return

call_id = pending_output.tool_call.call_id
self._pending_tool_outputs[call_id] = pending_output
try:
await self._send_pending_tool_output(pending_output)
except Exception as exc:
if self._closing or self._closed:
self._pending_tool_outputs.pop(call_id, None)
return
raise _PendingToolOutputSendError(call_id, exc) from exc
self._pending_tool_outputs.pop(call_id, None)

async def _send_pending_tool_output(self, pending_output: _PendingToolOutput) -> None:
if self._closing or self._closed:
return
if pending_output.session_update is not None:
await self._model.send_event(pending_output.session_update)
if self._closing or self._closed:
return
await self._model.send_event(
RealtimeModelSendToolOutput(
tool_call=pending_output.tool_call,
output=pending_output.output,
start_response=pending_output.start_response,
)
)
if self._closing or self._closed:
return
if pending_output.tool_end_event is not None:
await self._put_event(pending_output.tool_end_event)

Expand Down Expand Up @@ -817,6 +845,9 @@ async def _handle_tool_call(
agent.get_all_tools(self._context_wrapper),
self._get_handoffs(agent, self._context_wrapper),
)
if self._closing or self._closed:
return

function_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}

Expand All @@ -825,6 +856,8 @@ async def _handle_tool_call(
approval_status = await self._maybe_request_tool_approval(
event, function_tool=func_tool, agent=agent
)
if self._closing or self._closed:
return
if isinstance(approval_status, _PendingToolOutput):
await self._send_tool_output_completion(approval_status)
mark_completed = True
Expand All @@ -841,6 +874,8 @@ async def _handle_tool_call(
tool_call=event,
agent=agent,
)
if self._closing or self._closed:
return
if rejected_message is not None:
await self._send_tool_output_completion(
self._build_realtime_tool_output(
Expand All @@ -861,6 +896,8 @@ async def _handle_tool_call(
arguments=event.arguments,
)
)
if self._closing or self._closed:
return

tool_context = ToolContext(
context=self._context_wrapper.context,
Expand All @@ -875,6 +912,8 @@ async def _handle_tool_call(
context=tool_context,
arguments=event.arguments,
)
if self._closing or self._closed:
return

await self._send_tool_output_completion(
_PendingToolOutput(
Expand Down Expand Up @@ -904,6 +943,8 @@ async def _handle_tool_call(

# Execute the handoff to get the new agent
result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments)
if self._closing or self._closed:
return
if not isinstance(result, RealtimeAgent):
raise UserError(
f"Handoff {handoff.tool_name} returned invalid result: {type(result)}"
Expand All @@ -920,6 +961,8 @@ async def _handle_tool_call(
starting_settings=None,
agent=self._current_agent,
)
if self._closing or self._closed:
return

# Send handoff event
await self._put_event(
Expand Down Expand Up @@ -963,6 +1006,8 @@ async def _handle_tool_call(
self._finish_tool_call(event.call_id, mark_completed=mark_completed)

def _begin_tool_call(self, call_id: str, *, from_pending_approval: bool) -> bool:
if self._closing or self._closed:
return False
if call_id in self._active_tool_call_ids or call_id in self._completed_tool_call_ids:
return False
if not from_pending_approval and call_id in self._pending_tool_calls:
Expand All @@ -972,7 +1017,7 @@ def _begin_tool_call(self, call_id: str, *, from_pending_approval: bool) -> bool

def _finish_tool_call(self, call_id: str, *, mark_completed: bool) -> None:
self._active_tool_call_ids.discard(call_id)
if mark_completed:
if mark_completed and not self._closing and not self._closed:
self._completed_tool_call_ids.add(call_id)

@classmethod
Expand Down Expand Up @@ -1188,7 +1233,7 @@ async def _run_output_guardrails(self, text: str, response_id: str) -> bool:

if triggered_results:
# Double-check: bail if already interrupted for this response
if response_id in self._interrupted_response_ids:
if response_id in self._interrupted_response_ids or self._closing or self._closed:
return False

# Mark as interrupted immediately (before any awaits) to minimize race window
Expand All @@ -1204,9 +1249,13 @@ async def _run_output_guardrails(self, text: str, response_id: str) -> bool:
)

# Interrupt the model
if self._closing or self._closed:
return False
await self._model.send_event(RealtimeModelSendInterrupt(force_response_cancel=True))

# Send guardrail triggered message
if self._closing or self._closed:
return False
guardrail_names = [result.guardrail.get_name() for result in triggered_results]
await self._model.send_event(
RealtimeModelSendUserInput(
Expand All @@ -1220,6 +1269,8 @@ async def _run_output_guardrails(self, text: str, response_id: str) -> bool:

def _enqueue_guardrail_task(self, text: str, response_id: str) -> None:
# Runs the guardrails in a separate task to avoid blocking the main loop
if self._closing or self._closed:
return

task = asyncio.create_task(self._run_output_guardrails(text, response_id))
self._guardrail_tasks.add(task)
Expand All @@ -1232,6 +1283,10 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None:
# Remove from tracking set
self._guardrail_tasks.discard(task)

if self._closing or self._closed:
self._retrieve_task_exception(task)
return

# Check for exceptions and propagate as events
if not task.cancelled():
exception = task.exception()
Expand All @@ -1246,11 +1301,88 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None:
)
)

def _cleanup_guardrail_tasks(self) -> None:
for task in self._guardrail_tasks:
if not task.done():
task.cancel()
self._guardrail_tasks.clear()
def _retrieve_task_exception(self, task: asyncio.Task[Any]) -> None:
if task.cancelled():
return
task.exception()

def _discard_task_from_sets(
self,
task: asyncio.Task[Any],
task_sets: Sequence[set[asyncio.Task[Any]]],
) -> None:
for task_set in task_sets:
task_set.discard(task)

def _cancel_and_detach_remaining_tasks(
self,
task_sets: Sequence[set[asyncio.Task[Any]]],
current_task: asyncio.Task[Any] | None,
) -> int:
detached = 0
for task_set in task_sets:
for task in list(task_set):
if task is current_task:
task_set.discard(task)
continue
if not task.done():
task.cancel()
task_set.discard(task)
detached += 1
return detached

async def _cancel_and_await_tasks(
self,
task_sets: Sequence[set[asyncio.Task[Any]]],
*,
timeout: float | None = None,
) -> None:
timeout = _CLEANUP_BACKGROUND_TASK_TIMEOUT if timeout is None else timeout
current_task = asyncio.current_task()
loop = asyncio.get_running_loop()
deadline = loop.time() + timeout

while any(task_sets):
tasks_to_await: list[asyncio.Task[Any]] = []
for task_set in task_sets:
for task in list(task_set):
if task is current_task:
task_set.discard(task)
continue
if not task.done():
task.cancel()
tasks_to_await.append(task)

if not tasks_to_await:
return

remaining_timeout = deadline - loop.time()
if remaining_timeout <= 0:
detached = self._cancel_and_detach_remaining_tasks(task_sets, current_task)
if detached:
logger.warning(
"Timed out waiting for %d realtime background task(s) to cancel.",
detached,
)
return

done, pending = await asyncio.wait(tasks_to_await, timeout=remaining_timeout)
for task in done:
self._discard_task_from_sets(task, task_sets)
self._retrieve_task_exception(task)

if pending:
for task in pending:
self._discard_task_from_sets(task, task_sets)
detached = self._cancel_and_detach_remaining_tasks(task_sets, current_task)
logger.warning(
"Timed out waiting for %d realtime background task(s) to cancel.",
len(pending) + detached,
)
return

async def _cleanup_guardrail_tasks(self) -> None:
await self._cancel_and_await_tasks((self._guardrail_tasks,))

def _enqueue_tool_call_task(
self,
Expand All @@ -1261,6 +1393,11 @@ def _enqueue_tool_call_task(
call_id_reserved: bool = False,
) -> None:
"""Run tool calls in the background to avoid blocking realtime transport."""
if self._closing or self._closed:
if call_id_reserved:
self._finish_tool_call(event.call_id, mark_completed=False)
return

handle_kwargs: dict[str, Any] = {"agent_snapshot": agent_snapshot}
if from_pending_approval:
handle_kwargs["from_pending_approval"] = True
Expand All @@ -1274,6 +1411,10 @@ def _enqueue_tool_call_task(
def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None:
self._tool_call_tasks.discard(task)

if self._closing or self._closed:
self._retrieve_task_exception(task)
return

if task.cancelled():
return

Expand Down Expand Up @@ -1316,11 +1457,8 @@ def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None:
)
)

def _cleanup_tool_call_tasks(self) -> None:
for task in self._tool_call_tasks:
if not task.done():
task.cancel()
self._tool_call_tasks.clear()
async def _cleanup_tool_call_tasks(self) -> None:
await self._cancel_and_await_tasks((self._tool_call_tasks,))

def _wake_event_iterators(self) -> None:
for _ in range(self._event_iterator_waiters):
Expand All @@ -1332,19 +1470,22 @@ async def _cleanup(self) -> None:
self._wake_event_iterators()
return

# Cancel and cleanup guardrail tasks
self._cleanup_guardrail_tasks()
self._cleanup_tool_call_tasks()
self._closing = True

# Remove ourselves as a listener
self._model.remove_listener(self)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Prevent in-flight events from enqueuing after cleanup

Removing the listener here does not stop model events that already copied this listener before close() began; fresh evidence is OpenAIRealtimeModel._emit_event in src/agents/realtime/openai_realtime.py lines 624-626, which iterates over list(self._listeners) and then awaits listener.on_event(event). When close() races with an already-dispatched function_call and _tool_call_tasks is still empty, _cleanup_tool_call_tasks() returns before that in-flight on_event reaches _enqueue_tool_call_task, so the newly added tool task is never cancelled or awaited and can continue after the session is closed.

Useful? React with 👍 / 👎.


# Cancel and cleanup guardrail/tool-call tasks together so one group cannot block the other.
await self._cancel_and_await_tasks((self._guardrail_tasks, self._tool_call_tasks))

# Close the model connection
await self._model.close()

# Clear pending approval tracking
self._pending_tool_calls.clear()
self._pending_tool_outputs.clear()
self._active_tool_call_ids.clear()
self._completed_tool_call_ids.clear()

# Mark as closed
self._closed = True
Expand Down
Loading