Skip to content

Commit 95fe1b8

Browse files
committed
fix: reject initialize protocol version conflicts
1 parent ac96f88 commit 95fe1b8

2 files changed

Lines changed: 83 additions & 18 deletions

File tree

src/mcp/server/streamable_http.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -476,24 +476,12 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
476476
await response(scope, receive, send)
477477
return
478478

479-
# Check if this is an initialization request
480-
is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize"
481-
482-
if is_initialization_request:
483-
# Check if the server already has an established session
484-
if self.mcp_session_id:
485-
# Check if request has a session ID
486-
request_session_id = self._get_session_id(request)
487-
488-
# If request has a session ID but doesn't match, return 404
489-
if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover
490-
response = self._create_error_response(
491-
"Not Found: Invalid or expired session ID",
492-
HTTPStatus.NOT_FOUND,
493-
)
494-
await response(scope, receive, send)
495-
return
496-
elif not await self._validate_request_headers(request, send):
479+
is_initialization_request = False
480+
if isinstance(message, JSONRPCRequest) and message.method == "initialize":
481+
is_initialization_request = True
482+
if not await self._validate_initialization_request(message, request, send):
483+
return
484+
elif not await self._validate_request_headers(request, send): # pragma: no cover
497485
return
498486

499487
# For notifications and responses only, return 202 Accepted
@@ -865,6 +853,44 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool
865853

866854
return True
867855

856+
async def _validate_initialization_request(self, message: JSONRPCRequest, request: Request, send: Send) -> bool:
857+
if not await self._validate_initialization_protocol_version(message, request, send):
858+
return False
859+
860+
if not self.mcp_session_id:
861+
return True
862+
863+
request_session_id = self._get_session_id(request)
864+
if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover
865+
response = self._create_error_response(
866+
"Not Found: Invalid or expired session ID",
867+
HTTPStatus.NOT_FOUND,
868+
)
869+
await response(request.scope, request.receive, send)
870+
return False
871+
872+
return True
873+
874+
async def _validate_initialization_protocol_version(
875+
self, message: JSONRPCRequest, request: Request, send: Send
876+
) -> bool:
877+
header_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
878+
body_protocol_version = str(message.params.get("protocolVersion")) if message.params else None
879+
if (
880+
header_protocol_version is not None
881+
and body_protocol_version is not None
882+
and header_protocol_version != body_protocol_version
883+
):
884+
response = self._create_error_response(
885+
f"Bad Request: {MCP_PROTOCOL_VERSION_HEADER} header does not match initialize.params.protocolVersion",
886+
HTTPStatus.BAD_REQUEST,
887+
INVALID_REQUEST,
888+
)
889+
await response(request.scope, request.receive, send)
890+
return False
891+
892+
return True
893+
868894
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
869895
"""Replays events that would have been sent after the specified event ID.
870896

tests/shared/test_streamable_http.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,6 +1489,45 @@ async def test_server_validates_protocol_version_header(basic_app: Starlette) ->
14891489
assert response.status_code == 200
14901490

14911491

1492+
@pytest.mark.anyio
1493+
@pytest.mark.parametrize(
1494+
("header_version", "body_version"),
1495+
[
1496+
("2025-03-26", "2025-06-18"),
1497+
("2025-06-18", "2025-03-26"),
1498+
],
1499+
)
1500+
async def test_server_rejects_initialize_protocol_version_mismatch(
1501+
basic_app: Starlette, header_version: str, body_version: str
1502+
) -> None:
1503+
"""Initialize rejects conflicting protocol versions in header and body."""
1504+
init_request: dict[str, Any] = {
1505+
"jsonrpc": "2.0",
1506+
"method": "initialize",
1507+
"params": {
1508+
"clientInfo": {"name": "test-client", "version": "1.0"},
1509+
"protocolVersion": body_version,
1510+
"capabilities": {},
1511+
},
1512+
"id": "init-1",
1513+
}
1514+
1515+
async with make_client(basic_app) as client:
1516+
response = await client.post(
1517+
"/mcp",
1518+
headers={
1519+
"Accept": "application/json, text/event-stream",
1520+
"Content-Type": "application/json",
1521+
MCP_PROTOCOL_VERSION_HEADER: header_version,
1522+
},
1523+
json=init_request,
1524+
)
1525+
1526+
assert response.status_code == 400
1527+
assert MCP_PROTOCOL_VERSION_HEADER in response.text
1528+
assert "protocolVersion" in response.text
1529+
1530+
14921531
@pytest.mark.anyio
14931532
async def test_server_backwards_compatibility_no_protocol_version(basic_app: Starlette) -> None:
14941533
"""A request without a protocol version header is accepted for backwards compatibility."""

0 commit comments

Comments
 (0)