Skip to content

Commit 5f14f6d

Browse files
Respect negotiated capabilities in ClientSessionGroup
1 parent ac96f88 commit 5f14f6d

2 files changed

Lines changed: 77 additions & 28 deletions

File tree

src/mcp/client/session_group.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -344,39 +344,45 @@ async def _aggregate_components(self, server_info: types.Implementation, session
344344
tools_temp: dict[str, types.Tool] = {}
345345
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
346346

347+
# Query capabilities negotiated during initialize().
348+
capabilities = (
349+
session.initialize_result.capabilities
350+
if session.initialize_result is not None
351+
else None
352+
)
347353
# Query the server for its prompts and aggregate to list.
348-
try:
349-
prompts = (await session.list_prompts()).prompts
350-
for prompt in prompts:
351-
name = self._component_name(prompt.name, server_info)
352-
prompts_temp[name] = prompt
353-
component_names.prompts.add(name)
354-
except MCPError as err: # pragma: no cover
355-
logging.warning(f"Could not fetch prompts: {err}")
354+
if capabilities is None or capabilities.prompts is not None:
355+
try:
356+
prompts = (await session.list_prompts()).prompts
357+
for prompt in prompts:
358+
name = self._component_name(prompt.name, server_info)
359+
prompts_temp[name] = prompt
360+
component_names.prompts.add(name)
361+
except MCPError as err: # pragma: no cover
362+
logging.warning(f"Could not fetch prompts: {err}")
356363

357364
# Query the server for its resources and aggregate to list.
358-
try:
359-
resources = (await session.list_resources()).resources
360-
for resource in resources:
361-
name = self._component_name(resource.name, server_info)
362-
resources_temp[name] = resource
363-
component_names.resources.add(name)
364-
except MCPError as err: # pragma: no cover
365-
logging.warning(f"Could not fetch resources: {err}")
365+
if capabilities is None or capabilities.resources is not None:
366+
try:
367+
resources = (await session.list_resources()).resources
368+
for resource in resources:
369+
name = self._component_name(resource.name, server_info)
370+
resources_temp[name] = resource
371+
component_names.resources.add(name)
372+
except MCPError as err: # pragma: no cover
373+
logging.warning(f"Could not fetch resources: {err}")
366374

367375
# Query the server for its tools and aggregate to list.
368-
try:
369-
tools = (await session.list_tools()).tools
370-
for tool in tools:
371-
name = self._component_name(tool.name, server_info)
372-
tools_temp[name] = tool
373-
tool_to_session_temp[name] = session
374-
component_names.tools.add(name)
375-
except MCPError as err: # pragma: no cover
376-
logging.warning(f"Could not fetch tools: {err}")
377-
378-
# Clean up exit stack for session if we couldn't retrieve anything
379-
# from the server.
376+
if capabilities is None or capabilities.tools is not None:
377+
try:
378+
tools = (await session.list_tools()).tools
379+
for tool in tools:
380+
name = self._component_name(tool.name, server_info)
381+
tools_temp[name] = tool
382+
tool_to_session_temp[name] = session
383+
component_names.tools.add(name)
384+
except MCPError as err: # pragma: no cover
385+
logging.warning(f"Could not fetch tools: {err}")
380386
if not any((prompts_temp, resources_temp, tools_temp)):
381387
del self._session_exit_stacks[session] # pragma: no cover
382388

tests/client/test_session_group.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,49 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli
125125
mock_session.list_prompts.assert_awaited_once()
126126

127127

128+
@pytest.mark.anyio
129+
130+
@pytest.mark.anyio
131+
async def test_client_session_group_skips_unsupported_capabilities(
132+
mock_exit_stack: contextlib.AsyncExitStack,
133+
):
134+
"""Only query capabilities advertised by the server."""
135+
136+
mock_server_info = mock.Mock(spec=types.Implementation)
137+
mock_server_info.name = "ToolsOnlyServer"
138+
139+
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
140+
141+
mock_tool = mock.Mock(spec=types.Tool)
142+
mock_tool.name = "ping"
143+
144+
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool])
145+
mock_session.list_resources.return_value = mock.AsyncMock(resources=[])
146+
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[])
147+
148+
capabilities = mock.Mock()
149+
capabilities.tools = object()
150+
capabilities.prompts = None
151+
capabilities.resources = None
152+
153+
initialize_result = mock.Mock()
154+
initialize_result.capabilities = capabilities
155+
156+
mock_session.initialize_result = initialize_result
157+
158+
group = ClientSessionGroup(exit_stack=mock_exit_stack)
159+
160+
await group._aggregate_components(
161+
mock_server_info,
162+
mock_session,
163+
)
164+
165+
mock_session.list_tools.assert_awaited_once()
166+
mock_session.list_prompts.assert_not_awaited()
167+
mock_session.list_resources.assert_not_awaited()
168+
169+
assert "ping" in group.tools
170+
128171
@pytest.mark.anyio
129172
async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack):
130173
"""Test connecting with a component name hook."""

0 commit comments

Comments
 (0)