diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 267e176ee8..9f56df455f 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -9,7 +9,7 @@ import re import sys from abc import abstractmethod -from collections.abc import Callable, Collection, Sequence +from collections.abc import Callable, Collection, Coroutine, Sequence from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore from datetime import timedelta from functools import partial @@ -227,6 +227,7 @@ def __init__( self.is_connected: bool = False self._tools_loaded: bool = False self._prompts_loaded: bool = False + self._pending_reload_tasks: set[asyncio.Task[None]] = set() def __str__(self) -> str: return f"MCPTool(name={self.name}, description={self.description})" @@ -840,12 +841,46 @@ async def message_handler( if isinstance(message, types.ServerNotification): match message.root.method: case "notifications/tools/list_changed": - await self.load_tools() + self._schedule_reload(self.load_tools()) case "notifications/prompts/list_changed": - await self.load_prompts() + self._schedule_reload(self.load_prompts()) case _: logger.debug("Unhandled notification: %s", message.root.method) + def _schedule_reload(self, coro: Coroutine[Any, Any, None]) -> None: + """Schedule a reload coroutine as a background task. + + Reloads (load_tools / load_prompts) triggered by MCP server + notifications must NOT be awaited inside the message handler because + the handler runs on the MCP SDK's single-threaded receive loop. + Awaiting a session request (e.g. ``list_tools``) from within that loop + deadlocks: the receive loop cannot read the response while it is + blocked waiting for the handler to return. + + Instead we fire the reload as an independent ``asyncio.Task`` and keep + a strong reference in ``_pending_reload_tasks`` so it is not garbage- + collected before completion. Only one reload per kind (tools / prompts) + is kept in flight; a new notification cancels the previous pending task + for the same coroutine name to avoid unbounded growth. + """ + # Cancel-and-replace: only one reload per kind should be in flight. + reload_name = f"mcp-reload:{self.name}:{coro.__qualname__}" + for existing in list(self._pending_reload_tasks): + if existing.get_name() == reload_name and not existing.done(): + existing.cancel() + + async def _safe_reload() -> None: + try: + await coro + except asyncio.CancelledError: + raise + except Exception: + logger.warning("Background MCP reload failed", exc_info=True) + + task = asyncio.create_task(_safe_reload(), name=reload_name) + self._pending_reload_tasks.add(task) + task.add_done_callback(self._pending_reload_tasks.discard) + def _determine_approval_mode( self, *candidate_names: str, @@ -971,6 +1006,14 @@ async def load_tools(self) -> None: params = types.PaginatedRequestParams(cursor=tool_list.nextCursor) async def _close_on_owner(self) -> None: + # Cancel any pending reload tasks before tearing down the session. + tasks = list(self._pending_reload_tasks) + for task in tasks: + task.cancel() + self._pending_reload_tasks.clear() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + await self._safe_close_exit_stack() self._exit_stack = AsyncExitStack() self.session = None diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index eb233eea99..7a70b7b272 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. # type: ignore[reportPrivateUsage] +import asyncio +import contextlib import json import logging import os @@ -1612,7 +1614,7 @@ async def call_tool_with_error(*args, **kwargs): async def test_mcp_tool_message_handler_notification(): """Test that message_handler correctly processes tools/list_changed and prompts/list_changed - notifications.""" + notifications by scheduling reloads as background tasks.""" tool = MCPStdioTool(name="test_tool", command="python") # Mock the load_tools and load_prompts methods @@ -1626,6 +1628,8 @@ async def test_mcp_tool_message_handler_notification(): result = await tool.message_handler(tools_notification) assert result is None + # The reload is scheduled as a background task; let it run. + await asyncio.sleep(0) tool.load_tools.assert_called_once() # Reset mock @@ -1638,6 +1642,7 @@ async def test_mcp_tool_message_handler_notification(): result = await tool.message_handler(prompts_notification) assert result is None + await asyncio.sleep(0) tool.load_prompts.assert_called_once() # Test unhandled notification @@ -1661,6 +1666,112 @@ async def test_mcp_tool_message_handler_error(): assert result is None +async def test_mcp_tool_message_handler_does_not_block_receive_loop(): + """Test that message_handler does not deadlock the MCP receive loop. + + Regression test for https://github.com/microsoft/agent-framework/issues/4828. + When the MCP server sends a ``notifications/tools/list_changed`` + notification, the handler must NOT await ``load_tools()`` synchronously + because that would block the single-threaded MCP receive loop, preventing + it from delivering the ``list_tools`` response — a classic deadlock. + """ + tool = MCPStdioTool(name="test_tool", command="python") + + # Use an event to make load_tools block until we release it. + # This simulates load_tools waiting for a session response that the + # receive loop would need to deliver. + release = asyncio.Event() + + async def slow_load_tools(): + await release.wait() + + tool.load_tools = slow_load_tools # type: ignore[assignment] + + tools_notification = Mock(spec=types.ServerNotification) + tools_notification.root = Mock() + tools_notification.root.method = "notifications/tools/list_changed" + + # message_handler must return immediately even though load_tools blocks. + await tool.message_handler(tools_notification) + + # If the handler had awaited load_tools synchronously, we would never + # reach this line (deadlock). Verify the reload task is pending. + assert len(tool._pending_reload_tasks) == 1 + + # Unblock the reload so the background task finishes cleanly. + release.set() + # Wait for the pending reload task(s) to complete so their done-callbacks + # have a chance to remove them from _pending_reload_tasks. + await asyncio.wait_for(asyncio.gather(*tool._pending_reload_tasks), timeout=1) + assert len(tool._pending_reload_tasks) == 0 + + +async def test_mcp_tool_message_handler_reload_failure_is_logged(caplog: pytest.LogCaptureFixture): + """Background reload errors are logged, not raised into the receive loop.""" + tool = MCPStdioTool(name="test_tool", command="python") + tool.load_tools = AsyncMock(side_effect=RuntimeError("connection lost")) + + tools_notification = Mock(spec=types.ServerNotification) + tools_notification.root = Mock() + tools_notification.root.method = "notifications/tools/list_changed" + + await tool.message_handler(tools_notification) + # Let the background task run — it should not propagate the exception. + # Snapshot tasks and await them to ensure done-callbacks fire. + pending = list(tool._pending_reload_tasks) + if pending: + await asyncio.wait_for(asyncio.gather(*pending, return_exceptions=True), timeout=1) + tool.load_tools.assert_called_once() + assert len(tool._pending_reload_tasks) == 0 + + # Verify the warning was actually logged with exception info. + reload_warnings = [r for r in caplog.records if "Background MCP reload failed" in r.message] + assert len(reload_warnings) == 1 + assert reload_warnings[0].levelname == "WARNING" + assert reload_warnings[0].exc_info is not None + + +async def test_mcp_tool_message_handler_cancel_and_replace(): + """Sending two notifications in quick succession cancels the first reload task.""" + tool = MCPStdioTool(name="test_tool", command="python") + + release = asyncio.Event() + call_count = 0 + + async def blocking_load_tools(): + nonlocal call_count + call_count += 1 + await release.wait() + + tool.load_tools = blocking_load_tools # type: ignore[assignment] + + notification = Mock(spec=types.ServerNotification) + notification.root = Mock() + notification.root.method = "notifications/tools/list_changed" + + # First notification — starts a blocking reload task. + await tool.message_handler(notification) + assert len(tool._pending_reload_tasks) == 1 + first_task = next(iter(tool._pending_reload_tasks)) + + # Second notification — should cancel the first and replace it. + await tool.message_handler(notification) + # Yield to the event loop so the cancellation is processed. + with contextlib.suppress(asyncio.CancelledError): + await first_task + + assert first_task.cancelled() + + assert len(tool._pending_reload_tasks) == 1 + second_task = next(iter(tool._pending_reload_tasks)) + assert second_task is not first_task + + # Unblock and let the second task finish. + release.set() + await asyncio.wait_for(asyncio.gather(*tool._pending_reload_tasks), timeout=1) + assert len(tool._pending_reload_tasks) == 0 + + async def test_mcp_tool_sampling_callback_no_client(): """Test sampling callback error path when no chat client is available.""" tool = MCPStdioTool(name="test_tool", command="python")