Skip to content

Commit 93a88fa

Browse files
committed
fix: drain terminal streamable HTTP responses
1 parent ac96f88 commit 93a88fa

3 files changed

Lines changed: 147 additions & 10 deletions

File tree

src/mcp/client/streamable_http.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -240,16 +240,19 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
240240
event_source.response.raise_for_status()
241241
logger.debug("Resumption GET SSE connection established")
242242

243+
response_complete = False
243244
async for sse in event_source.aiter_sse(): # pragma: no branch
245+
if response_complete:
246+
continue
247+
244248
is_complete = await self._handle_sse_event(
245249
sse,
246250
ctx.read_stream_writer,
247251
original_request_id,
248252
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
249253
)
250254
if is_complete:
251-
await event_source.response.aclose()
252-
break
255+
response_complete = True
253256

254257
async def _handle_post_request(self, ctx: RequestContext) -> None:
255258
"""Handle a POST request with response processing."""
@@ -340,9 +343,13 @@ async def _handle_sse_response(
340343
assert isinstance(ctx.session_message.message, JSONRPCRequest)
341344
original_request_id = ctx.session_message.message.id
342345

346+
response_complete = False
343347
try:
344348
event_source = EventSource(response)
345349
async for sse in event_source.aiter_sse(): # pragma: no branch
350+
if response_complete:
351+
continue
352+
346353
# Track last event ID for potential reconnection
347354
if sse.id:
348355
last_event_id = sse.id
@@ -359,13 +366,15 @@ async def _handle_sse_response(
359366
is_initialization=is_initialization,
360367
)
361368
# If the SSE event indicates completion, like returning response/error
362-
# break the loop
369+
# keep draining the response to EOF so the HTTP connection can be reused.
363370
if is_complete:
364-
await response.aclose()
365-
return # Normal completion, no reconnect needed
371+
response_complete = True
366372
except Exception:
367373
logger.debug("SSE stream ended", exc_info=True) # pragma: no cover
368374

375+
if response_complete:
376+
return # Normal completion, no reconnect needed
377+
369378
# Stream ended without response - reconnect if we received an event with ID
370379
if last_event_id is not None: # pragma: no branch
371380
logger.info("SSE stream disconnected, reconnecting...")
@@ -405,7 +414,11 @@ async def _handle_reconnection(
405414
reconnect_last_event_id: str = last_event_id
406415
reconnect_retry_ms = retry_interval_ms
407416

417+
response_complete = False
408418
async for sse in event_source.aiter_sse():
419+
if response_complete:
420+
continue
421+
409422
if sse.id: # pragma: no branch
410423
reconnect_last_event_id = sse.id
411424
if sse.retry is not None:
@@ -418,13 +431,15 @@ async def _handle_reconnection(
418431
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
419432
)
420433
if is_complete:
421-
await event_source.response.aclose()
422-
return
434+
response_complete = True
435+
436+
if response_complete:
437+
return
423438

424439
# Stream ended again without response - reconnect again (reset attempt counter)
425440
logger.info("SSE stream disconnected, reconnecting...")
426441
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0)
427-
except Exception as e: # pragma: no cover
442+
except Exception as e: # pragma: lax no cover
428443
logger.debug(f"Reconnection failed: {e}")
429444
# Try to reconnect again if we still have an event ID
430445
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)

tests/interaction/transports/test_hosting_resume.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None:
339339
capture = ClientMessageMetadata(on_resumption_token_update=on_token)
340340

341341
async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, manager):
342-
with anyio.fail_after(5): # pragma: no branch
342+
with anyio.fail_after(5): # pragma: lax no cover
343343
async with ( # pragma: no branch
344344
streamable_http_client(f"{BASE_URL}/mcp", http_client=http, terminate_on_close=False) as (r1, w1),
345345
ClientSession(r1, w1, logging_callback=collect) as first,
@@ -357,7 +357,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None:
357357
http.headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION
358358
tg.cancel_scope.cancel()
359359

360-
with anyio.fail_after(5): # pragma: no branch
360+
with anyio.fail_after(5): # pragma: lax no cover
361361
release.set() # pragma: lax no cover — python/cpython#106749: 3.11 drops this line event
362362
# init priming + init response + call priming + "first" + "second" + result = 6 stored events.
363363
await store.wait_until_stored(6)

tests/shared/test_streamable_http.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from starlette.requests import Request
2424
from starlette.routing import Mount
2525

26+
import mcp.client.streamable_http as streamable_http
2627
from mcp import MCPError, types
2728
from mcp.client.session import ClientSession
2829
from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client
@@ -123,6 +124,39 @@ async def replay_events_after(
123124
return target_stream_id
124125

125126

127+
class FakeStreamResponse(httpx.Response):
128+
def __init__(self) -> None:
129+
super().__init__(
130+
200,
131+
request=httpx.Request("POST", "http://localhost:8000/mcp"),
132+
)
133+
self.close_count = 0
134+
135+
async def aclose(self) -> None: # pragma: no cover
136+
self.close_count += 1
137+
138+
139+
class FakeEventSource:
140+
def __init__(self, events: list[ServerSentEvent]) -> None:
141+
self.response = FakeStreamResponse()
142+
self.events = events
143+
self.seen = 0
144+
145+
async def aiter_sse(self) -> AsyncIterator[ServerSentEvent]:
146+
for event in self.events:
147+
self.seen += 1
148+
yield event
149+
150+
151+
def jsonrpc_response_event(request_id: str, event_id: str) -> ServerSentEvent:
152+
return ServerSentEvent(
153+
event="message",
154+
data=json.dumps({"jsonrpc": "2.0", "id": request_id, "result": {}}),
155+
id=event_id,
156+
retry=None,
157+
)
158+
159+
126160
@dataclass
127161
class ServerState:
128162
lock: anyio.Event = field(default_factory=anyio.Event)
@@ -1583,6 +1617,94 @@ async def test_handle_sse_event_skips_empty_data() -> None:
15831617
await read_stream.aclose()
15841618

15851619

1620+
@pytest.mark.anyio
1621+
async def test_handle_sse_response_drains_after_terminal_event(monkeypatch: pytest.MonkeyPatch) -> None:
1622+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
1623+
response = FakeStreamResponse()
1624+
event_source = FakeEventSource(
1625+
[
1626+
jsonrpc_response_event("request-1", "event-1"),
1627+
ServerSentEvent(event="message", data="", id="event-2", retry=None),
1628+
]
1629+
)
1630+
1631+
def event_source_factory(_response: httpx.Response) -> FakeEventSource:
1632+
return event_source
1633+
1634+
monkeypatch.setattr(streamable_http, "EventSource", event_source_factory)
1635+
1636+
write_stream, read_stream = create_context_streams[SessionMessage | Exception](1)
1637+
try:
1638+
async with httpx.AsyncClient() as client:
1639+
ctx = streamable_http.RequestContext(
1640+
client=client,
1641+
session_id=None,
1642+
session_message=SessionMessage(
1643+
JSONRPCRequest(jsonrpc="2.0", id="request-1", method="tools/call", params={})
1644+
),
1645+
metadata=None,
1646+
read_stream_writer=write_stream,
1647+
)
1648+
1649+
await transport._handle_sse_response(response, ctx)
1650+
1651+
received = await read_stream.receive()
1652+
assert isinstance(received, SessionMessage)
1653+
assert isinstance(received.message, types.JSONRPCResponse)
1654+
assert received.message.id == "request-1"
1655+
assert event_source.seen == 2
1656+
assert response.close_count == 0
1657+
finally:
1658+
await write_stream.aclose()
1659+
await read_stream.aclose()
1660+
1661+
1662+
@pytest.mark.anyio
1663+
async def test_reconnection_drains_after_terminal_event(monkeypatch: pytest.MonkeyPatch) -> None:
1664+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
1665+
event_source = FakeEventSource(
1666+
[
1667+
jsonrpc_response_event("request-1", "event-2"),
1668+
ServerSentEvent(event="message", data="", id="event-3", retry=None),
1669+
]
1670+
)
1671+
1672+
async def sleep_noop(_delay: float) -> None:
1673+
pass
1674+
1675+
@asynccontextmanager
1676+
async def connect_sse(*args: Any, **kwargs: Any) -> AsyncIterator[FakeEventSource]:
1677+
yield event_source
1678+
1679+
monkeypatch.setattr(streamable_http.anyio, "sleep", sleep_noop)
1680+
monkeypatch.setattr(streamable_http, "aconnect_sse", connect_sse)
1681+
1682+
write_stream, read_stream = create_context_streams[SessionMessage | Exception](1)
1683+
try:
1684+
async with httpx.AsyncClient() as client:
1685+
ctx = streamable_http.RequestContext(
1686+
client=client,
1687+
session_id=None,
1688+
session_message=SessionMessage(
1689+
JSONRPCRequest(jsonrpc="2.0", id="request-1", method="tools/call", params={})
1690+
),
1691+
metadata=None,
1692+
read_stream_writer=write_stream,
1693+
)
1694+
1695+
await transport._handle_reconnection(ctx, last_event_id="event-1")
1696+
1697+
received = await read_stream.receive()
1698+
assert isinstance(received, SessionMessage)
1699+
assert isinstance(received.message, types.JSONRPCResponse)
1700+
assert received.message.id == "request-1"
1701+
assert event_source.seen == 2
1702+
assert event_source.response.close_count == 0
1703+
finally:
1704+
await write_stream.aclose()
1705+
await read_stream.aclose()
1706+
1707+
15861708
@pytest.mark.anyio
15871709
async def test_priming_event_not_sent_for_old_protocol_version() -> None:
15881710
"""_maybe_send_priming_event skips for old protocol versions (backwards compat)."""

0 commit comments

Comments
 (0)