1616import logging
1717from collections .abc import Mapping
1818from contextlib import AsyncExitStack
19- from typing import Any
19+ from typing import Any , TypeVar , overload
2020
2121import anyio
22+ from pydantic import BaseModel
2223
23- from mcp .server ._typed_request import TypedServerRequestMixin
2424from mcp .shared .dispatcher import CallOptions , Outbound
2525from mcp .shared .exceptions import NoBackChannelError
2626from mcp .shared .peer import Meta , dump_params
27- from mcp .types import ClientCapabilities , InitializeRequestParams , LoggingLevel
27+ from mcp .types import (
28+ ClientCapabilities ,
29+ CreateMessageRequest ,
30+ CreateMessageResult ,
31+ ElicitRequest ,
32+ ElicitResult ,
33+ EmptyResult ,
34+ InitializeRequestParams ,
35+ ListRootsRequest ,
36+ ListRootsResult ,
37+ LoggingLevel ,
38+ PingRequest ,
39+ Request ,
40+ )
2841
2942__all__ = ["Connection" ]
3043
3144logger = logging .getLogger (__name__ )
3245
46+ ResultT = TypeVar ("ResultT" , bound = BaseModel )
47+
48+ # Result types for the spec's server-to-client request set, used by
49+ # `Connection.send_request` to infer the result type. If the spec's request
50+ # set grows substantially, consider declaring the result mapping on the
51+ # request types themselves (a `__mcp_result__` ClassVar read via a structural
52+ # protocol) so this table and the overload ladder don't need maintaining.
53+ _RESULT_FOR : dict [type [Request [Any , Any ]], type [BaseModel ]] = {
54+ CreateMessageRequest : CreateMessageResult ,
55+ ElicitRequest : ElicitResult ,
56+ ListRootsRequest : ListRootsResult ,
57+ PingRequest : EmptyResult ,
58+ }
59+
3360
3461def _notification_params (payload : dict [str , Any ] | None , meta : Meta | None ) -> dict [str , Any ] | None :
3562 if not meta :
@@ -39,7 +66,7 @@ def _notification_params(payload: dict[str, Any] | None, meta: Meta | None) -> d
3966 return out
4067
4168
42- class Connection ( TypedServerRequestMixin ) :
69+ class Connection :
4370 """Per-client connection state and standalone-stream `Outbound`.
4471
4572 Constructed by `ServerRunner` once per connection. The peer-info fields
@@ -98,10 +125,10 @@ async def send_raw_request(
98125 ) -> dict [str , Any ]:
99126 """Send a raw request on the standalone stream.
100127
101- Low-level `Outbound` channel. Prefer the typed `send_request` (from
102- `TypedServerRequestMixin`) or the convenience methods below; use this
103- directly only for off-spec messages. `opts` carries per-call `timeout`
104- / `on_progress` / resumption hints; see `CallOptions`.
128+ Low-level `Outbound` channel. Prefer the typed `send_request` or the
129+ convenience methods below; use this directly only for off-spec
130+ messages. `opts` carries per-call `timeout` / `on_progress` /
131+ resumption hints; see `CallOptions`.
105132
106133 Raises:
107134 MCPError: The peer responded with an error.
@@ -111,6 +138,42 @@ async def send_raw_request(
111138 raise NoBackChannelError (method )
112139 return await self ._outbound .send_raw_request (method , params , opts )
113140
141+ @overload
142+ async def send_request (
143+ self , req : CreateMessageRequest , * , opts : CallOptions | None = None
144+ ) -> CreateMessageResult : ...
145+ @overload
146+ async def send_request (self , req : ElicitRequest , * , opts : CallOptions | None = None ) -> ElicitResult : ...
147+ @overload
148+ async def send_request (self , req : ListRootsRequest , * , opts : CallOptions | None = None ) -> ListRootsResult : ...
149+ @overload
150+ async def send_request (self , req : PingRequest , * , opts : CallOptions | None = None ) -> EmptyResult : ...
151+ @overload
152+ async def send_request (
153+ self , req : Request [Any , Any ], * , result_type : type [ResultT ], opts : CallOptions | None = None
154+ ) -> ResultT : ...
155+ async def send_request (
156+ self ,
157+ req : Request [Any , Any ],
158+ * ,
159+ result_type : type [BaseModel ] | None = None ,
160+ opts : CallOptions | None = None ,
161+ ) -> BaseModel :
162+ """Send a typed server-to-client request and return its typed result.
163+
164+ For spec request types the result type is inferred. For custom requests
165+ pass `result_type=` explicitly.
166+
167+ Raises:
168+ MCPError: The peer responded with an error.
169+ NoBackChannelError: No back-channel for server-initiated requests.
170+ pydantic.ValidationError: The peer's result does not match the expected result type.
171+ KeyError: `result_type` omitted for a non-spec request type.
172+ """
173+ raw = await self .send_raw_request (req .method , dump_params (req .params ), opts )
174+ cls = result_type if result_type is not None else _RESULT_FOR [type (req )]
175+ return cls .model_validate (raw , by_name = False )
176+
114177 async def notify (self , method : str , params : Mapping [str , Any ] | None ) -> None :
115178 """Send a best-effort notification on the standalone stream.
116179
0 commit comments