Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import re
import sys
from abc import abstractmethod
from collections.abc import Callable, Collection, Sequence
from collections.abc import Callable, Collection, Coroutine, Sequence
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
from datetime import timedelta
from functools import partial
Expand Down Expand Up @@ -227,6 +227,7 @@ def __init__(
self.is_connected: bool = False
self._tools_loaded: bool = False
self._prompts_loaded: bool = False
self._pending_reload_tasks: set[asyncio.Task[None]] = set()

def __str__(self) -> str:
return f"MCPTool(name={self.name}, description={self.description})"
Expand Down Expand Up @@ -840,12 +841,46 @@ async def message_handler(
if isinstance(message, types.ServerNotification):
match message.root.method:
case "notifications/tools/list_changed":
await self.load_tools()
self._schedule_reload(self.load_tools())
case "notifications/prompts/list_changed":
await self.load_prompts()
self._schedule_reload(self.load_prompts())
case _:
logger.debug("Unhandled notification: %s", message.root.method)

def _schedule_reload(self, coro: Coroutine[Any, Any, None]) -> None:
"""Schedule a reload coroutine as a background task.

Reloads (load_tools / load_prompts) triggered by MCP server
notifications must NOT be awaited inside the message handler because
the handler runs on the MCP SDK's single-threaded receive loop.
Awaiting a session request (e.g. ``list_tools``) from within that loop
deadlocks: the receive loop cannot read the response while it is
blocked waiting for the handler to return.

Instead we fire the reload as an independent ``asyncio.Task`` and keep
a strong reference in ``_pending_reload_tasks`` so it is not garbage-
collected before completion. Only one reload per kind (tools / prompts)
is kept in flight; a new notification cancels the previous pending task
for the same coroutine name to avoid unbounded growth.
"""
# Cancel-and-replace: only one reload per kind should be in flight.
reload_name = f"mcp-reload:{self.name}:{coro.__qualname__}"
for existing in list(self._pending_reload_tasks):
if existing.get_name() == reload_name and not existing.done():
existing.cancel()

async def _safe_reload() -> None:
try:
await coro
except asyncio.CancelledError:
raise
except Exception:
logger.warning("Background MCP reload failed", exc_info=True)

task = asyncio.create_task(_safe_reload(), name=reload_name)
self._pending_reload_tasks.add(task)
task.add_done_callback(self._pending_reload_tasks.discard)

def _determine_approval_mode(
self,
*candidate_names: str,
Expand Down Expand Up @@ -971,6 +1006,14 @@ async def load_tools(self) -> None:
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)

async def _close_on_owner(self) -> None:
# Cancel any pending reload tasks before tearing down the session.
tasks = list(self._pending_reload_tasks)
for task in tasks:
task.cancel()
self._pending_reload_tasks.clear()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)

await self._safe_close_exit_stack()
self._exit_stack = AsyncExitStack()
self.session = None
Expand Down
113 changes: 112 additions & 1 deletion python/packages/core/tests/core/test_mcp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
# type: ignore[reportPrivateUsage]
import asyncio
import contextlib
import json
import logging
import os
Expand Down Expand Up @@ -1612,7 +1614,7 @@ async def call_tool_with_error(*args, **kwargs):

async def test_mcp_tool_message_handler_notification():
"""Test that message_handler correctly processes tools/list_changed and prompts/list_changed
notifications."""
notifications by scheduling reloads as background tasks."""
tool = MCPStdioTool(name="test_tool", command="python")

# Mock the load_tools and load_prompts methods
Expand All @@ -1626,6 +1628,8 @@ async def test_mcp_tool_message_handler_notification():

result = await tool.message_handler(tools_notification)
assert result is None
# The reload is scheduled as a background task; let it run.
await asyncio.sleep(0)
tool.load_tools.assert_called_once()

# Reset mock
Expand All @@ -1638,6 +1642,7 @@ async def test_mcp_tool_message_handler_notification():

result = await tool.message_handler(prompts_notification)
assert result is None
await asyncio.sleep(0)
tool.load_prompts.assert_called_once()

# Test unhandled notification
Expand All @@ -1661,6 +1666,112 @@ async def test_mcp_tool_message_handler_error():
assert result is None


async def test_mcp_tool_message_handler_does_not_block_receive_loop():
"""Test that message_handler does not deadlock the MCP receive loop.

Regression test for https://github.com/microsoft/agent-framework/issues/4828.
When the MCP server sends a ``notifications/tools/list_changed``
notification, the handler must NOT await ``load_tools()`` synchronously
because that would block the single-threaded MCP receive loop, preventing
it from delivering the ``list_tools`` response — a classic deadlock.
"""
tool = MCPStdioTool(name="test_tool", command="python")

# Use an event to make load_tools block until we release it.
# This simulates load_tools waiting for a session response that the
# receive loop would need to deliver.
release = asyncio.Event()

async def slow_load_tools():
await release.wait()

tool.load_tools = slow_load_tools # type: ignore[assignment]

tools_notification = Mock(spec=types.ServerNotification)
tools_notification.root = Mock()
tools_notification.root.method = "notifications/tools/list_changed"

# message_handler must return immediately even though load_tools blocks.
await tool.message_handler(tools_notification)

# If the handler had awaited load_tools synchronously, we would never
# reach this line (deadlock). Verify the reload task is pending.
assert len(tool._pending_reload_tasks) == 1

# Unblock the reload so the background task finishes cleanly.
release.set()
# Wait for the pending reload task(s) to complete so their done-callbacks
# have a chance to remove them from _pending_reload_tasks.
await asyncio.wait_for(asyncio.gather(*tool._pending_reload_tasks), timeout=1)
assert len(tool._pending_reload_tasks) == 0


async def test_mcp_tool_message_handler_reload_failure_is_logged(caplog: pytest.LogCaptureFixture):
"""Background reload errors are logged, not raised into the receive loop."""
tool = MCPStdioTool(name="test_tool", command="python")
tool.load_tools = AsyncMock(side_effect=RuntimeError("connection lost"))

tools_notification = Mock(spec=types.ServerNotification)
tools_notification.root = Mock()
tools_notification.root.method = "notifications/tools/list_changed"

await tool.message_handler(tools_notification)
# Let the background task run — it should not propagate the exception.
# Snapshot tasks and await them to ensure done-callbacks fire.
pending = list(tool._pending_reload_tasks)
if pending:
await asyncio.wait_for(asyncio.gather(*pending, return_exceptions=True), timeout=1)
tool.load_tools.assert_called_once()
assert len(tool._pending_reload_tasks) == 0

# Verify the warning was actually logged with exception info.
reload_warnings = [r for r in caplog.records if "Background MCP reload failed" in r.message]
assert len(reload_warnings) == 1
assert reload_warnings[0].levelname == "WARNING"
assert reload_warnings[0].exc_info is not None


async def test_mcp_tool_message_handler_cancel_and_replace():
"""Sending two notifications in quick succession cancels the first reload task."""
tool = MCPStdioTool(name="test_tool", command="python")

release = asyncio.Event()
call_count = 0

async def blocking_load_tools():
nonlocal call_count
call_count += 1
await release.wait()

tool.load_tools = blocking_load_tools # type: ignore[assignment]

notification = Mock(spec=types.ServerNotification)
notification.root = Mock()
notification.root.method = "notifications/tools/list_changed"

# First notification — starts a blocking reload task.
await tool.message_handler(notification)
assert len(tool._pending_reload_tasks) == 1
first_task = next(iter(tool._pending_reload_tasks))

# Second notification — should cancel the first and replace it.
await tool.message_handler(notification)
# Yield to the event loop so the cancellation is processed.
with contextlib.suppress(asyncio.CancelledError):
await first_task

assert first_task.cancelled()

assert len(tool._pending_reload_tasks) == 1
second_task = next(iter(tool._pending_reload_tasks))
assert second_task is not first_task

# Unblock and let the second task finish.
release.set()
await asyncio.wait_for(asyncio.gather(*tool._pending_reload_tasks), timeout=1)
assert len(tool._pending_reload_tasks) == 0


async def test_mcp_tool_sampling_callback_no_client():
"""Test sampling callback error path when no chat client is available."""
tool = MCPStdioTool(name="test_tool", command="python")
Expand Down
Loading