|
23 | 23 | from starlette.requests import Request |
24 | 24 | from starlette.routing import Mount |
25 | 25 |
|
| 26 | +import mcp.client.streamable_http as streamable_http |
26 | 27 | from mcp import MCPError, types |
27 | 28 | from mcp.client.session import ClientSession |
28 | 29 | from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client |
@@ -123,6 +124,39 @@ async def replay_events_after( |
123 | 124 | return target_stream_id |
124 | 125 |
|
125 | 126 |
|
| 127 | +class FakeStreamResponse(httpx.Response): |
| 128 | + def __init__(self) -> None: |
| 129 | + super().__init__( |
| 130 | + 200, |
| 131 | + request=httpx.Request("POST", "http://localhost:8000/mcp"), |
| 132 | + ) |
| 133 | + self.close_count = 0 |
| 134 | + |
| 135 | + async def aclose(self) -> None: # pragma: no cover |
| 136 | + self.close_count += 1 |
| 137 | + |
| 138 | + |
| 139 | +class FakeEventSource: |
| 140 | + def __init__(self, events: list[ServerSentEvent]) -> None: |
| 141 | + self.response = FakeStreamResponse() |
| 142 | + self.events = events |
| 143 | + self.seen = 0 |
| 144 | + |
| 145 | + async def aiter_sse(self) -> AsyncIterator[ServerSentEvent]: |
| 146 | + for event in self.events: |
| 147 | + self.seen += 1 |
| 148 | + yield event |
| 149 | + |
| 150 | + |
| 151 | +def jsonrpc_response_event(request_id: str, event_id: str) -> ServerSentEvent: |
| 152 | + return ServerSentEvent( |
| 153 | + event="message", |
| 154 | + data=json.dumps({"jsonrpc": "2.0", "id": request_id, "result": {}}), |
| 155 | + id=event_id, |
| 156 | + retry=None, |
| 157 | + ) |
| 158 | + |
| 159 | + |
126 | 160 | @dataclass |
127 | 161 | class ServerState: |
128 | 162 | lock: anyio.Event = field(default_factory=anyio.Event) |
@@ -1583,6 +1617,94 @@ async def test_handle_sse_event_skips_empty_data() -> None: |
1583 | 1617 | await read_stream.aclose() |
1584 | 1618 |
|
1585 | 1619 |
|
| 1620 | +@pytest.mark.anyio |
| 1621 | +async def test_handle_sse_response_drains_after_terminal_event(monkeypatch: pytest.MonkeyPatch) -> None: |
| 1622 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 1623 | + response = FakeStreamResponse() |
| 1624 | + event_source = FakeEventSource( |
| 1625 | + [ |
| 1626 | + jsonrpc_response_event("request-1", "event-1"), |
| 1627 | + ServerSentEvent(event="message", data="", id="event-2", retry=None), |
| 1628 | + ] |
| 1629 | + ) |
| 1630 | + |
| 1631 | + def event_source_factory(_response: httpx.Response) -> FakeEventSource: |
| 1632 | + return event_source |
| 1633 | + |
| 1634 | + monkeypatch.setattr(streamable_http, "EventSource", event_source_factory) |
| 1635 | + |
| 1636 | + write_stream, read_stream = create_context_streams[SessionMessage | Exception](1) |
| 1637 | + try: |
| 1638 | + async with httpx.AsyncClient() as client: |
| 1639 | + ctx = streamable_http.RequestContext( |
| 1640 | + client=client, |
| 1641 | + session_id=None, |
| 1642 | + session_message=SessionMessage( |
| 1643 | + JSONRPCRequest(jsonrpc="2.0", id="request-1", method="tools/call", params={}) |
| 1644 | + ), |
| 1645 | + metadata=None, |
| 1646 | + read_stream_writer=write_stream, |
| 1647 | + ) |
| 1648 | + |
| 1649 | + await transport._handle_sse_response(response, ctx) |
| 1650 | + |
| 1651 | + received = await read_stream.receive() |
| 1652 | + assert isinstance(received, SessionMessage) |
| 1653 | + assert isinstance(received.message, types.JSONRPCResponse) |
| 1654 | + assert received.message.id == "request-1" |
| 1655 | + assert event_source.seen == 2 |
| 1656 | + assert response.close_count == 0 |
| 1657 | + finally: |
| 1658 | + await write_stream.aclose() |
| 1659 | + await read_stream.aclose() |
| 1660 | + |
| 1661 | + |
| 1662 | +@pytest.mark.anyio |
| 1663 | +async def test_reconnection_drains_after_terminal_event(monkeypatch: pytest.MonkeyPatch) -> None: |
| 1664 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 1665 | + event_source = FakeEventSource( |
| 1666 | + [ |
| 1667 | + jsonrpc_response_event("request-1", "event-2"), |
| 1668 | + ServerSentEvent(event="message", data="", id="event-3", retry=None), |
| 1669 | + ] |
| 1670 | + ) |
| 1671 | + |
| 1672 | + async def sleep_noop(_delay: float) -> None: |
| 1673 | + pass |
| 1674 | + |
| 1675 | + @asynccontextmanager |
| 1676 | + async def connect_sse(*args: Any, **kwargs: Any) -> AsyncIterator[FakeEventSource]: |
| 1677 | + yield event_source |
| 1678 | + |
| 1679 | + monkeypatch.setattr(streamable_http.anyio, "sleep", sleep_noop) |
| 1680 | + monkeypatch.setattr(streamable_http, "aconnect_sse", connect_sse) |
| 1681 | + |
| 1682 | + write_stream, read_stream = create_context_streams[SessionMessage | Exception](1) |
| 1683 | + try: |
| 1684 | + async with httpx.AsyncClient() as client: |
| 1685 | + ctx = streamable_http.RequestContext( |
| 1686 | + client=client, |
| 1687 | + session_id=None, |
| 1688 | + session_message=SessionMessage( |
| 1689 | + JSONRPCRequest(jsonrpc="2.0", id="request-1", method="tools/call", params={}) |
| 1690 | + ), |
| 1691 | + metadata=None, |
| 1692 | + read_stream_writer=write_stream, |
| 1693 | + ) |
| 1694 | + |
| 1695 | + await transport._handle_reconnection(ctx, last_event_id="event-1") |
| 1696 | + |
| 1697 | + received = await read_stream.receive() |
| 1698 | + assert isinstance(received, SessionMessage) |
| 1699 | + assert isinstance(received.message, types.JSONRPCResponse) |
| 1700 | + assert received.message.id == "request-1" |
| 1701 | + assert event_source.seen == 2 |
| 1702 | + assert event_source.response.close_count == 0 |
| 1703 | + finally: |
| 1704 | + await write_stream.aclose() |
| 1705 | + await read_stream.aclose() |
| 1706 | + |
| 1707 | + |
1586 | 1708 | @pytest.mark.anyio |
1587 | 1709 | async def test_priming_event_not_sent_for_old_protocol_version() -> None: |
1588 | 1710 | """_maybe_send_priming_event skips for old protocol versions (backwards compat).""" |
|
0 commit comments