Skip to content

Commit 5710ff3

Browse files
committed
fix(streamable-http): avoid startup race after initialize (#1675)
1 parent ac96f88 commit 5710ff3

2 files changed

Lines changed: 45 additions & 2 deletions

File tree

src/mcp/client/streamable_http.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import anyio
1212
import httpx
13-
from anyio.abc import TaskGroup
13+
from anyio.abc import TaskGroup, TaskStatus
1414
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
1515
from pydantic import ValidationError
1616

@@ -437,10 +437,13 @@ async def post_writer(
437437
write_stream: ContextSendStream[SessionMessage],
438438
start_get_stream: Callable[[], None],
439439
tg: TaskGroup,
440+
*,
441+
task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED,
440442
) -> None:
441443
"""Handle writing requests to the server."""
442444
try:
443445
async with write_stream_reader, read_stream_writer, write_stream:
446+
task_status.started()
444447

445448
async def _handle_message(session_message: SessionMessage) -> None:
446449
message = session_message.message
@@ -570,7 +573,7 @@ async def streamable_http_client(
570573
def start_get_stream() -> None:
571574
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
572575

573-
tg.start_soon(
576+
await tg.start(
574577
transport.post_writer,
575578
client,
576579
write_stream_reader,

tests/shared/test_streamable_http.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,46 @@ async def test_streamable_http_client_basic_connection(basic_app: Starlette) ->
868868
assert result.server_info.name == SERVER_NAME
869869

870870

871+
@pytest.mark.anyio
872+
async def test_streamable_http_client_no_race_on_consecutive_requests(basic_app: Starlette) -> None:
873+
"""The first request after initialize can run repeatedly without racing startup."""
874+
for iteration in range(10): # pragma: no branch
875+
async with (
876+
make_client(basic_app) as http_client,
877+
streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream),
878+
ClientSession(read_stream, write_stream) as session,
879+
):
880+
await session.initialize()
881+
882+
tools = await session.list_tools()
883+
assert len(tools.tools) == 8, f"Iteration {iteration}: expected 8 tools, got {len(tools.tools)}"
884+
assert tools.tools[0].name == "test_tool"
885+
886+
tools2 = await session.list_tools()
887+
assert len(tools2.tools) == 8
888+
889+
resource = await session.read_resource(uri="foobar://test-iteration")
890+
assert len(resource.contents) == 1
891+
892+
893+
@pytest.mark.anyio
894+
async def test_streamable_http_client_rapid_request_sequence(basic_app: Starlette) -> None:
895+
"""A rapid sequence of requests reuses the initialized stream reliably."""
896+
async with (
897+
make_client(basic_app) as http_client,
898+
streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream),
899+
ClientSession(read_stream, write_stream) as session,
900+
):
901+
await session.initialize()
902+
903+
for i in range(20):
904+
tools = await session.list_tools()
905+
assert len(tools.tools) == 8, f"Request {i}: expected 8 tools, got {len(tools.tools)}"
906+
907+
resource = await session.read_resource(uri="foobar://final-test")
908+
assert len(resource.contents) == 1
909+
910+
871911
@pytest.mark.anyio
872912
async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession) -> None:
873913
"""A resource read round-trips its arguments and the handler's content."""

0 commit comments

Comments
 (0)