|
1 | 1 | import io |
| 2 | +import json |
2 | 3 | import sys |
3 | 4 | import threading |
4 | 5 | from collections.abc import AsyncIterator |
|
7 | 8 |
|
8 | 9 | import anyio |
9 | 10 | import pytest |
| 11 | +from anyio.lowlevel import checkpoint |
10 | 12 |
|
11 | 13 | from mcp.server.mcpserver import MCPServer |
12 | 14 | from mcp.server.stdio import stdio_server |
13 | 15 | from mcp.shared.message import SessionMessage |
14 | | -from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter |
| 16 | +from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter |
15 | 17 |
|
16 | 18 |
|
17 | 19 | @pytest.mark.anyio |
@@ -142,6 +144,59 @@ def test_mcpserver_run_stdio_serves_until_stdin_closes(monkeypatch: pytest.Monke |
142 | 144 | assert response == JSONRPCResponse(jsonrpc="2.0", id=1, result={}) |
143 | 145 |
|
144 | 146 |
|
| 147 | +def test_mcpserver_run_stdio_drains_in_flight_tool_responses_after_stdin_eof( |
| 148 | + monkeypatch: pytest.MonkeyPatch, |
| 149 | +) -> None: |
| 150 | + """stdin EOF must not drop responses for requests the server already accepted.""" |
| 151 | + server = MCPServer(name="DrainStdioServer") |
| 152 | + |
| 153 | + @server.tool() |
| 154 | + async def slow_echo(text: str) -> str: |
| 155 | + await checkpoint() |
| 156 | + return text |
| 157 | + |
| 158 | + payload_lines = [ |
| 159 | + JSONRPCRequest( |
| 160 | + jsonrpc="2.0", |
| 161 | + id=0, |
| 162 | + method="initialize", |
| 163 | + params={ |
| 164 | + "protocolVersion": "2024-11-05", |
| 165 | + "capabilities": {}, |
| 166 | + "clientInfo": {"name": "stdio-replay", "version": "0.1"}, |
| 167 | + }, |
| 168 | + ).model_dump_json(by_alias=True, exclude_none=True), |
| 169 | + JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized", params={}).model_dump_json( |
| 170 | + by_alias=True, exclude_none=True |
| 171 | + ), |
| 172 | + JSONRPCRequest( |
| 173 | + jsonrpc="2.0", |
| 174 | + id=1, |
| 175 | + method="tools/call", |
| 176 | + params={"name": "slow_echo", "arguments": {"text": "first"}}, |
| 177 | + ).model_dump_json(by_alias=True, exclude_none=True), |
| 178 | + JSONRPCRequest( |
| 179 | + jsonrpc="2.0", |
| 180 | + id=2, |
| 181 | + method="tools/call", |
| 182 | + params={"name": "slow_echo", "arguments": {"text": "second"}}, |
| 183 | + ).model_dump_json(by_alias=True, exclude_none=True), |
| 184 | + ] |
| 185 | + stdin_bytes = io.BytesIO(("\n".join(payload_lines) + "\n").encode()) |
| 186 | + captured = _KeepOpenBytesIO() |
| 187 | + monkeypatch.setattr(sys, "stdin", TextIOWrapper(stdin_bytes, encoding="utf-8")) |
| 188 | + monkeypatch.setattr(sys, "stdout", TextIOWrapper(captured, encoding="utf-8")) |
| 189 | + |
| 190 | + _run_stdio_bounded(server) |
| 191 | + |
| 192 | + output = captured.getvalue().decode() |
| 193 | + responses = [json.loads(line) for line in output.splitlines() if line] |
| 194 | + |
| 195 | + assert [response["id"] for response in responses] == [0, 1, 2] |
| 196 | + assert responses[1]["result"]["content"][0]["text"] == "first" |
| 197 | + assert responses[2]["result"]["content"][0]["text"] == "second" |
| 198 | + |
| 199 | + |
145 | 200 | def test_mcpserver_run_stdio_runs_lifespan_cleanup_after_stdin_closes(monkeypatch: pytest.MonkeyPatch) -> None: |
146 | 201 | """Code after `yield` in a lifespan runs when stdin EOF ends `run("stdio")`. |
147 | 202 |
|
|
0 commit comments