@@ -476,24 +476,12 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
476476 await response (scope , receive , send )
477477 return
478478
479- # Check if this is an initialization request
480- is_initialization_request = isinstance (message , JSONRPCRequest ) and message .method == "initialize"
481-
482- if is_initialization_request :
483- # Check if the server already has an established session
484- if self .mcp_session_id :
485- # Check if request has a session ID
486- request_session_id = self ._get_session_id (request )
487-
488- # If request has a session ID but doesn't match, return 404
489- if request_session_id and request_session_id != self .mcp_session_id : # pragma: no cover
490- response = self ._create_error_response (
491- "Not Found: Invalid or expired session ID" ,
492- HTTPStatus .NOT_FOUND ,
493- )
494- await response (scope , receive , send )
495- return
496- elif not await self ._validate_request_headers (request , send ):
479+ is_initialization_request = False
480+ if isinstance (message , JSONRPCRequest ) and message .method == "initialize" :
481+ is_initialization_request = True
482+ if not await self ._validate_initialization_request (message , request , send ):
483+ return
484+ elif not await self ._validate_request_headers (request , send ): # pragma: no cover
497485 return
498486
499487 # For notifications and responses only, return 202 Accepted
@@ -865,6 +853,44 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool
865853
866854 return True
867855
856+ async def _validate_initialization_request (self , message : JSONRPCRequest , request : Request , send : Send ) -> bool :
857+ if not await self ._validate_initialization_protocol_version (message , request , send ):
858+ return False
859+
860+ if not self .mcp_session_id :
861+ return True
862+
863+ request_session_id = self ._get_session_id (request )
864+ if request_session_id and request_session_id != self .mcp_session_id : # pragma: no cover
865+ response = self ._create_error_response (
866+ "Not Found: Invalid or expired session ID" ,
867+ HTTPStatus .NOT_FOUND ,
868+ )
869+ await response (request .scope , request .receive , send )
870+ return False
871+
872+ return True
873+
874+ async def _validate_initialization_protocol_version (
875+ self , message : JSONRPCRequest , request : Request , send : Send
876+ ) -> bool :
877+ header_protocol_version = request .headers .get (MCP_PROTOCOL_VERSION_HEADER )
878+ body_protocol_version = str (message .params .get ("protocolVersion" )) if message .params else None
879+ if (
880+ header_protocol_version is not None
881+ and body_protocol_version is not None
882+ and header_protocol_version != body_protocol_version
883+ ):
884+ response = self ._create_error_response (
885+ f"Bad Request: { MCP_PROTOCOL_VERSION_HEADER } header does not match initialize.params.protocolVersion" ,
886+ HTTPStatus .BAD_REQUEST ,
887+ INVALID_REQUEST ,
888+ )
889+ await response (request .scope , request .receive , send )
890+ return False
891+
892+ return True
893+
868894 async def _replay_events (self , last_event_id : str , request : Request , send : Send ) -> None :
869895 """Replays events that would have been sent after the specified event ID.
870896
0 commit comments