diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index ddc61ef66..276ec518b 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -12,7 +12,7 @@ import time from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass, field -from typing import Any, Protocol +from typing import Any, NewType, Protocol from urllib.parse import quote, urlencode, urljoin, urlparse import anyio @@ -53,6 +53,8 @@ logger = logging.getLogger(__name__) +AuthorizationState = NewType("AuthorizationState", str) + class PKCEParameters(BaseModel): """PKCE (Proof Key for Code Exchange) parameters.""" @@ -305,14 +307,10 @@ async def _perform_authorization(self) -> httpx.Request: token_request = await self._exchange_token_authorization_code(auth_code, code_verifier) return token_request - async def _perform_authorization_code_grant(self) -> tuple[str, str]: - """Perform the authorization redirect and get auth code.""" + async def _build_authorization_url(self) -> tuple[str, AuthorizationState, PKCEParameters]: + """Build authorization URL and state.""" if self.context.client_metadata.redirect_uris is None: raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover - if not self.context.redirect_handler: - raise OAuthFlowError("No redirect handler provided for authorization code grant") # pragma: no cover - if not self.context.callback_handler: - raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint: auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) # pragma: no cover @@ -325,7 +323,7 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: # Generate PKCE parameters pkce_params = PKCEParameters.generate() - state = secrets.token_urlsafe(32) + state = AuthorizationState(secrets.token_urlsafe(32)) auth_params = { "response_type": "code", @@ -344,6 +342,17 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: auth_params["scope"] = self.context.client_metadata.scope authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" + + return authorization_url, state, pkce_params + + async def _perform_authorization_code_grant(self) -> tuple[str, str]: + """Perform the authorization redirect and get auth code.""" + if not self.context.redirect_handler: + raise OAuthFlowError("No redirect handler provided for authorization code grant") # pragma: no cover + if not self.context.callback_handler: + raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover + + authorization_url, state, pkce_params = await self._build_authorization_url() await self.context.redirect_handler(authorization_url) # Wait for callback