Skip to content

Commit a8efe77

Browse files
committed
fix(stdio): bound EOF drain wait
1 parent 5332b0e commit a8efe77

3 files changed

Lines changed: 82 additions & 1 deletion

File tree

src/mcp/server/lowlevel/server.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ async def main():
7272

7373
logger = logging.getLogger(__name__)
7474

75+
DEFAULT_READ_EOF_DRAIN_TIMEOUT_SECONDS = 1.0
76+
7577
LifespanResultT = TypeVar("LifespanResultT", default=Any)
7678

7779

@@ -351,6 +353,9 @@ async def run(
351353
# to drain their responses via the still-open write stream (e.g. stdio
352354
# with bash-redirected stdin).
353355
drain_on_read_close: bool = False,
356+
# Maximum time to wait for in-flight handlers to drain after read EOF.
357+
# None means wait indefinitely.
358+
read_eof_drain_timeout_seconds: float | None = DEFAULT_READ_EOF_DRAIN_TIMEOUT_SECONDS,
354359
):
355360
async with AsyncExitStack() as stack:
356361
lifespan_context = await stack.enter_async_context(self.lifespan(self))
@@ -383,7 +388,14 @@ async def run(
383388
raise_exceptions,
384389
)
385390
finally:
386-
if not drain_on_read_close:
391+
if drain_on_read_close:
392+
if read_eof_drain_timeout_seconds is not None:
393+
with anyio.move_on_after(read_eof_drain_timeout_seconds) as drain_scope:
394+
while session.has_in_flight_requests:
395+
await anyio.sleep(0.01)
396+
if drain_scope.cancelled_caught:
397+
tg.cancel_scope.cancel()
398+
else:
387399
# Transport closed: cancel in-flight handlers. Without this the
388400
# TG join waits for them, and when they eventually try to
389401
# respond they hit a closed write stream (the session's

src/mcp/shared/session.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,11 @@ def __init__(
209209
self._exit_stack.push_async_callback(self._read_stream.aclose)
210210
self._exit_stack.push_async_callback(self._write_stream.aclose)
211211

212+
@property
213+
def has_in_flight_requests(self) -> bool:
214+
"""Whether any received requests are still awaiting a response."""
215+
return bool(self._in_flight)
216+
212217
async def __aenter__(self) -> Self:
213218
self._task_group = anyio.create_task_group()
214219
await self._task_group.__aenter__()

tests/server/test_cancel_handling.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,70 @@ async def run_server():
166166
await server_run_returned.wait()
167167

168168

169+
@pytest.mark.anyio
170+
async def test_server_bounds_drain_on_read_eof_when_handler_never_finishes():
171+
handler_started = anyio.Event()
172+
handler_cancelled = anyio.Event()
173+
server_run_returned = anyio.Event()
174+
175+
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
176+
handler_started.set()
177+
try:
178+
await anyio.sleep_forever()
179+
finally:
180+
handler_cancelled.set()
181+
raise AssertionError # pragma: no cover
182+
183+
server = Server("test", on_call_tool=handle_call_tool)
184+
185+
to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
186+
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)
187+
188+
async def run_server():
189+
await server.run(
190+
server_read,
191+
server_write,
192+
server.create_initialization_options(),
193+
drain_on_read_close=True,
194+
read_eof_drain_timeout_seconds=0.05,
195+
)
196+
server_run_returned.set()
197+
198+
init_req = JSONRPCRequest(
199+
jsonrpc="2.0",
200+
id=1,
201+
method="initialize",
202+
params=InitializeRequestParams(
203+
protocol_version=LATEST_PROTOCOL_VERSION,
204+
capabilities=ClientCapabilities(),
205+
client_info=Implementation(name="test", version="1.0"),
206+
).model_dump(by_alias=True, mode="json", exclude_none=True),
207+
)
208+
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
209+
call_req = JSONRPCRequest(
210+
jsonrpc="2.0",
211+
id=2,
212+
method="tools/call",
213+
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
214+
)
215+
216+
with anyio.fail_after(2):
217+
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
218+
tg.start_soon(run_server)
219+
220+
await to_server.send(SessionMessage(init_req))
221+
await from_server.receive() # init response
222+
await to_server.send(SessionMessage(initialized))
223+
await to_server.send(SessionMessage(call_req))
224+
225+
await handler_started.wait()
226+
await to_server.aclose()
227+
228+
await server_run_returned.wait()
229+
230+
assert handler_cancelled.is_set()
231+
232+
169233
@pytest.mark.anyio
170234
async def test_server_reraises_handler_cancellation_when_server_is_cancelled():
171235
"""If the server task is cancelled (e.g. KeyboardInterrupt), in-flight

0 commit comments

Comments
 (0)