Skip to content

Commit f20ccb6

Browse files
committed
fix: preserve auth endpoint query params
1 parent ac96f88 commit f20ccb6

2 files changed

Lines changed: 53 additions & 2 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from collections.abc import AsyncGenerator, Awaitable, Callable
1313
from dataclasses import dataclass, field
1414
from typing import Any, Protocol
15-
from urllib.parse import quote, urlencode, urljoin, urlparse
15+
from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlparse, urlsplit, urlunsplit
1616

1717
import anyio
1818
import httpx
@@ -353,7 +353,14 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
353353
if "offline_access" in self.context.client_metadata.scope.split():
354354
auth_params["prompt"] = "consent"
355355

356-
authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
356+
auth_endpoint_parts = urlsplit(auth_endpoint)
357+
authorization_query = urlencode(
358+
[
359+
*parse_qsl(auth_endpoint_parts.query, keep_blank_values=True),
360+
*auth_params.items(),
361+
]
362+
)
363+
authorization_url = urlunsplit(auth_endpoint_parts._replace(query=authorization_query))
357364
await self.context.redirect_handler(authorization_url)
358365

359366
# Wait for callback

tests/client/test_auth.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,50 @@ def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: O
263263
class TestOAuthFlow:
264264
"""Test OAuth flow methods."""
265265

266+
@pytest.mark.anyio
267+
async def test_authorization_endpoint_query_params_are_preserved(
268+
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
269+
):
270+
"""OAuth authorization endpoints may already carry provider-specific query params."""
271+
captured_state: str | None = None
272+
273+
async def redirect_handler(url: str) -> None:
274+
nonlocal captured_state
275+
parsed = urlparse(url)
276+
params = parse_qs(parsed.query)
277+
278+
assert params["prompt"] == ["select_account"]
279+
assert params["response_type"] == ["code"]
280+
assert params["client_id"] == ["test_client"]
281+
282+
captured_state = params.get("state", [None])[0]
283+
284+
async def callback_handler() -> tuple[str, str | None]:
285+
return "test_auth_code", captured_state
286+
287+
provider = OAuthClientProvider(
288+
server_url="https://api.example.com/v1/mcp",
289+
client_metadata=client_metadata,
290+
storage=mock_storage,
291+
redirect_handler=redirect_handler,
292+
callback_handler=callback_handler,
293+
)
294+
provider.context.oauth_metadata = OAuthMetadata(
295+
issuer=AnyHttpUrl("https://auth.example.com"),
296+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize?prompt=select_account"),
297+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
298+
)
299+
provider.context.client_info = OAuthClientInformationFull(
300+
client_id="test_client",
301+
client_secret="test_secret",
302+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
303+
)
304+
305+
auth_code, code_verifier = await provider._perform_authorization_code_grant()
306+
307+
assert auth_code == "test_auth_code"
308+
assert code_verifier
309+
266310
@pytest.mark.anyio
267311
async def test_build_protected_resource_discovery_urls(
268312
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage

0 commit comments

Comments
 (0)