Skip to content

Commit 9eb10e3

Browse files
committed
fix: drain stdio responses after stdin EOF
1 parent ac96f88 commit 9eb10e3

6 files changed

Lines changed: 162 additions & 9 deletions

File tree

src/mcp/server/lowlevel/server.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ async def main():
4545
from typing import Any, Generic, cast
4646

4747
import anyio
48+
from anyio.lowlevel import checkpoint
4849
from opentelemetry.trace import SpanKind, StatusCode
4950
from starlette.applications import Starlette
5051
from starlette.middleware import Middleware
@@ -74,6 +75,8 @@ async def main():
7475

7576
LifespanResultT = TypeVar("LifespanResultT", default=Any)
7677

78+
STDIO_READ_EOF_RESPONSE_DRAIN_TIMEOUT = 5.0
79+
7780

7881
class NotificationOptions:
7982
def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False):
@@ -347,6 +350,8 @@ async def run(
347350
# the initialization lifecycle, but can do so with any available node
348351
# rather than requiring initialization for each connection.
349352
stateless: bool = False,
353+
drain_in_flight_on_read_eof: bool = False,
354+
read_eof_response_drain_timeout: float = STDIO_READ_EOF_RESPONSE_DRAIN_TIMEOUT,
350355
):
351356
async with AsyncExitStack() as stack:
352357
lifespan_context = await stack.enter_async_context(self.lifespan(self))
@@ -356,6 +361,7 @@ async def run(
356361
write_stream,
357362
initialization_options,
358363
stateless=stateless,
364+
close_write_stream_on_read_eof=not drain_in_flight_on_read_eof,
359365
)
360366
)
361367

@@ -378,11 +384,19 @@ async def run(
378384
raise_exceptions,
379385
)
380386
finally:
381-
# Transport closed: cancel in-flight handlers. Without this the
382-
# TG join waits for them, and when they eventually try to
383-
# respond they hit a closed write stream (the session's
384-
# _receive_loop closed it when the read stream ended).
385-
tg.cancel_scope.cancel()
387+
cancel_in_flight = True
388+
if drain_in_flight_on_read_eof:
389+
with anyio.move_on_after(read_eof_response_drain_timeout) as drain_scope:
390+
while session.has_in_flight_requests:
391+
await checkpoint()
392+
cancel_in_flight = drain_scope.cancel_called
393+
394+
# Transport closed or drain timed out: cancel in-flight handlers.
395+
# Without this the TG join can wait indefinitely, or handlers can
396+
# eventually try to respond through a write stream that the session
397+
# closed when the read stream ended.
398+
if cancel_in_flight:
399+
tg.cancel_scope.cancel()
386400

387401
async def _handle_message(
388402
self,

src/mcp/server/mcpserver/server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,7 @@ async def run_stdio_async(self) -> None:
852852
read_stream,
853853
write_stream,
854854
self._lowlevel_server.create_initialization_options(),
855+
drain_in_flight_on_read_eof=True,
855856
)
856857

857858
async def run_sse_async( # pragma: no cover

src/mcp/server/session.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,13 @@ def __init__(
8080
write_stream: WriteStream[SessionMessage],
8181
init_options: InitializationOptions,
8282
stateless: bool = False,
83+
close_write_stream_on_read_eof: bool = True,
8384
) -> None:
84-
super().__init__(read_stream, write_stream)
85+
super().__init__(
86+
read_stream,
87+
write_stream,
88+
close_write_stream_on_read_eof=close_write_stream_on_read_eof,
89+
)
8590
self._stateless = stateless
8691
self._initialization_state = (
8792
InitializationState.Initialized if stateless else InitializationState.NotInitialized

src/mcp/shared/session.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,14 @@ def __init__(
189189
write_stream: WriteStream[SessionMessage],
190190
# If none, reading will never time out
191191
read_timeout_seconds: float | None = None,
192+
close_write_stream_on_read_eof: bool = True,
192193
) -> None:
193194
self._read_stream = read_stream
194195
self._write_stream = write_stream
195196
self._response_streams = {}
196197
self._request_id = 0
197198
self._session_read_timeout_seconds = read_timeout_seconds
199+
self._close_write_stream_on_read_eof = close_write_stream_on_read_eof
198200
self._in_flight = {}
199201
self._progress_callbacks = {}
200202
self._exit_stack = AsyncExitStack()
@@ -216,7 +218,10 @@ async def __aexit__(
216218
# would be very surprising behavior), so make sure to cancel the tasks
217219
# in the task group.
218220
self._task_group.cancel_scope.cancel()
219-
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
221+
result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
222+
if not self._close_write_stream_on_read_eof:
223+
await self._write_stream.aclose()
224+
return result
220225

221226
async def send_request(
222227
self,
@@ -331,7 +336,10 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
331336
raise NotImplementedError
332337

333338
async def _receive_loop(self) -> None:
334-
async with self._read_stream, self._write_stream:
339+
async with AsyncExitStack() as stream_stack:
340+
await stream_stack.enter_async_context(self._read_stream)
341+
if self._close_write_stream_on_read_eof:
342+
await stream_stack.enter_async_context(self._write_stream)
335343
try:
336344

337345
async def _handle_session_message(message: SessionMessage) -> None:
@@ -438,6 +446,11 @@ async def _handle_session_message(message: SessionMessage) -> None:
438446
pass
439447
self._response_streams.clear()
440448

449+
@property
450+
def has_in_flight_requests(self) -> bool:
451+
"""Return whether client requests are still being handled."""
452+
return bool(self._in_flight)
453+
441454
def _normalize_request_id(self, response_id: RequestId) -> RequestId:
442455
"""Normalize a response ID to match how request IDs are stored.
443456

tests/server/test_cancel_handling.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,71 @@ async def run_server():
172172
assert handler_cancelled.is_set()
173173

174174

175+
@pytest.mark.anyio
176+
async def test_server_cancels_in_flight_handlers_when_read_eof_drain_times_out():
177+
"""A bounded read-EOF drain still cancels handlers that never finish."""
178+
handler_started = anyio.Event()
179+
handler_cancelled = anyio.Event()
180+
server_run_returned = anyio.Event()
181+
182+
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
183+
handler_started.set()
184+
try:
185+
await anyio.sleep_forever()
186+
finally:
187+
handler_cancelled.set()
188+
raise AssertionError # pragma: no cover
189+
190+
server = Server("test", on_call_tool=handle_call_tool)
191+
192+
to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
193+
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)
194+
195+
async def run_server():
196+
await server.run(
197+
server_read,
198+
server_write,
199+
server.create_initialization_options(),
200+
drain_in_flight_on_read_eof=True,
201+
read_eof_response_drain_timeout=0.01,
202+
)
203+
server_run_returned.set()
204+
205+
init_req = JSONRPCRequest(
206+
jsonrpc="2.0",
207+
id=1,
208+
method="initialize",
209+
params=InitializeRequestParams(
210+
protocol_version=LATEST_PROTOCOL_VERSION,
211+
capabilities=ClientCapabilities(),
212+
client_info=Implementation(name="test", version="1.0"),
213+
).model_dump(by_alias=True, mode="json", exclude_none=True),
214+
)
215+
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
216+
call_req = JSONRPCRequest(
217+
jsonrpc="2.0",
218+
id=2,
219+
method="tools/call",
220+
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
221+
)
222+
223+
with anyio.fail_after(5):
224+
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
225+
tg.start_soon(run_server)
226+
227+
await to_server.send(SessionMessage(init_req))
228+
await from_server.receive()
229+
await to_server.send(SessionMessage(initialized))
230+
await to_server.send(SessionMessage(call_req))
231+
232+
await handler_started.wait()
233+
await to_server.aclose()
234+
235+
await server_run_returned.wait()
236+
237+
assert handler_cancelled.is_set()
238+
239+
175240
@pytest.mark.anyio
176241
async def test_server_handles_transport_close_with_pending_server_to_client_requests():
177242
"""When the transport closes while handlers are blocked on server→client

tests/server/test_stdio.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import io
2+
import json
23
import sys
34
import threading
45
from collections.abc import AsyncIterator
@@ -7,11 +8,12 @@
78

89
import anyio
910
import pytest
11+
from anyio.lowlevel import checkpoint
1012

1113
from mcp.server.mcpserver import MCPServer
1214
from mcp.server.stdio import stdio_server
1315
from mcp.shared.message import SessionMessage
14-
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter
16+
from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter
1517

1618

1719
@pytest.mark.anyio
@@ -142,6 +144,59 @@ def test_mcpserver_run_stdio_serves_until_stdin_closes(monkeypatch: pytest.Monke
142144
assert response == JSONRPCResponse(jsonrpc="2.0", id=1, result={})
143145

144146

147+
def test_mcpserver_run_stdio_drains_in_flight_tool_responses_after_stdin_eof(
148+
monkeypatch: pytest.MonkeyPatch,
149+
) -> None:
150+
"""stdin EOF must not drop responses for requests the server already accepted."""
151+
server = MCPServer(name="DrainStdioServer")
152+
153+
@server.tool()
154+
async def slow_echo(text: str) -> str:
155+
await checkpoint()
156+
return text
157+
158+
payload_lines = [
159+
JSONRPCRequest(
160+
jsonrpc="2.0",
161+
id=0,
162+
method="initialize",
163+
params={
164+
"protocolVersion": "2024-11-05",
165+
"capabilities": {},
166+
"clientInfo": {"name": "stdio-replay", "version": "0.1"},
167+
},
168+
).model_dump_json(by_alias=True, exclude_none=True),
169+
JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized", params={}).model_dump_json(
170+
by_alias=True, exclude_none=True
171+
),
172+
JSONRPCRequest(
173+
jsonrpc="2.0",
174+
id=1,
175+
method="tools/call",
176+
params={"name": "slow_echo", "arguments": {"text": "first"}},
177+
).model_dump_json(by_alias=True, exclude_none=True),
178+
JSONRPCRequest(
179+
jsonrpc="2.0",
180+
id=2,
181+
method="tools/call",
182+
params={"name": "slow_echo", "arguments": {"text": "second"}},
183+
).model_dump_json(by_alias=True, exclude_none=True),
184+
]
185+
stdin_bytes = io.BytesIO(("\n".join(payload_lines) + "\n").encode())
186+
captured = _KeepOpenBytesIO()
187+
monkeypatch.setattr(sys, "stdin", TextIOWrapper(stdin_bytes, encoding="utf-8"))
188+
monkeypatch.setattr(sys, "stdout", TextIOWrapper(captured, encoding="utf-8"))
189+
190+
_run_stdio_bounded(server)
191+
192+
output = captured.getvalue().decode()
193+
responses = [json.loads(line) for line in output.splitlines() if line]
194+
195+
assert [response["id"] for response in responses] == [0, 1, 2]
196+
assert responses[1]["result"]["content"][0]["text"] == "first"
197+
assert responses[2]["result"]["content"][0]["text"] == "second"
198+
199+
145200
def test_mcpserver_run_stdio_runs_lifespan_cleanup_after_stdin_closes(monkeypatch: pytest.MonkeyPatch) -> None:
146201
"""Code after `yield` in a lifespan runs when stdin EOF ends `run("stdio")`.
147202

0 commit comments

Comments
 (0)