diff --git a/mcp-server/api_client.py b/mcp-server/api_client.py index 95b74da..38c97ac 100644 --- a/mcp-server/api_client.py +++ b/mcp-server/api_client.py @@ -3,8 +3,16 @@ Uses a module-level client to avoid creating new TCP connections per tool call. The client is initialized lazily on first use and reused for all subsequent calls. Concurrent access is serialized via asyncio.Lock to prevent duplicate clients. + +Auth identity is per-request, not per-client. On the remote (http) transport each +MCP request carries the caller's own ci_ key, so we resolve the key at call time +from a ContextVar and send it as a per-request header. The shared client holds NO +Authorization header: baking one in would let concurrent requests race on a single +shared identity, which is a cross-tenant leak. A ContextVar is task-local, so two +in-flight requests can never read each other's key. """ import asyncio +import contextvars from typing import Any, Optional import httpx @@ -12,56 +20,83 @@ from config import BACKEND_API_URL, API_KEY -# Persistent client reused across all tool calls +# Persistent client reused across all tool calls. Identity-free by design. _client: Optional[httpx.AsyncClient] = None _client_lock: asyncio.Lock = asyncio.Lock() +# Per-request caller identity. The MCP server's auth middleware sets this (remote +# transport) before the tool runs; it stays unset on stdio, where the configured +# key is the identity. Task-local, so concurrent requests are isolated. +_request_api_key: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( + "request_api_key", default=None +) + -def _get_headers() -> dict[str, str]: - """Return Authorization header with the configured API_KEY. +def set_request_api_key(key: Optional[str]) -> None: + """Record the caller's key for the current request context. - Raises ValueError if API_KEY is empty or unset. + Called by the MCP server's auth middleware on the remote transport. """ - if not API_KEY: - raise ValueError( - "No API_KEY configured. Set API_KEY in .env or environment." - ) - return {"Authorization": f"Bearer {API_KEY}"} + _request_api_key.set(key) + + +def _resolve_key() -> str: + """Per-request key if present (remote), else the configured key (stdio/admin).""" + key = _request_api_key.get() + if key: + return key + if API_KEY: + return API_KEY + raise ValueError("No API key available for backend call") + + +def _auth_header() -> dict[str, str]: + """Build the Authorization header. The space after 'Bearer' is mandatory (PR #292).""" + return {"Authorization": f"Bearer {_resolve_key()}"} + + +def _merge_headers(extra: Optional[dict[str, Any]]) -> dict[str, str]: + headers = _auth_header() + if extra: + headers.update(extra) + return headers async def get_client() -> httpx.AsyncClient: - """Get or create the persistent HTTP client.""" + """Get or create the persistent, identity-free HTTP client.""" global _client async with _client_lock: if _client is None or _client.is_closed: _client = httpx.AsyncClient( base_url=BACKEND_API_URL, timeout=120.0, - headers=_get_headers(), ) return _client async def api_get(path: str, **kwargs: Any) -> dict: - """Make a GET request to the backend API.""" + """Make a GET request to the backend API with the caller's identity.""" client = await get_client() - response = await client.get(path, **kwargs) + headers = _merge_headers(kwargs.pop("headers", None)) + response = await client.get(path, headers=headers, **kwargs) response.raise_for_status() return response.json() async def api_post(path: str, json: dict, **kwargs: Any) -> dict: - """Make a POST request to the backend API.""" + """Make a POST request to the backend API with the caller's identity.""" client = await get_client() - response = await client.post(path, json=json, **kwargs) + headers = _merge_headers(kwargs.pop("headers", None)) + response = await client.post(path, json=json, headers=headers, **kwargs) response.raise_for_status() return response.json() async def api_delete(path: str, **kwargs: Any) -> dict: - """Make a DELETE request to the backend API.""" + """Make a DELETE request to the backend API with the caller's identity.""" client = await get_client() - response = await client.delete(path, **kwargs) + headers = _merge_headers(kwargs.pop("headers", None)) + response = await client.delete(path, headers=headers, **kwargs) response.raise_for_status() if response.status_code == 204 or not response.content: return {} diff --git a/mcp-server/server.py b/mcp-server/server.py index 6d08040..5919913 100644 --- a/mcp-server/server.py +++ b/mcp-server/server.py @@ -9,17 +9,23 @@ python server.py # stdio (default) python server.py --transport http # streamable HTTP on $PORT """ +import json +import logging import sys +from typing import Optional from mcp.server.fastmcp import FastMCP import mcp.types as types from starlette.routing import Route from starlette.responses import JSONResponse +from api_client import set_request_api_key from config import SERVER_NAME, SERVER_VERSION, TRANSPORT, HOST, PORT, MCP_AUTH_TOKEN from tools import get_tool_schemas from handlers import call_tool +logger = logging.getLogger(__name__) + mcp = FastMCP( name=SERVER_NAME, instructions=( @@ -62,27 +68,92 @@ async def _health(request): return JSONResponse({"status": "ok", "server": SERVER_NAME, "version": SERVER_VERSION}) -def _get_http_app(): - """Build the Starlette app with health check + MCP endpoint.""" - from starlette.middleware.base import BaseHTTPMiddleware - from starlette.requests import Request +def _extract_bearer(scope) -> Optional[str]: + """Pull the Bearer token from an ASGI scope's headers, or None. - app = mcp.streamable_http_app() - app.routes.insert(0, Route("/health", _health, methods=["GET"])) + Requires the space after 'Bearer' -- a missing space is the PR #292 bug, and a + request carrying it is treated as malformed (rejected), never silently parsed. + """ + for name, value in scope.get("headers", []): + if name == b"authorization": + raw = value.decode("latin-1") + if raw.startswith("Bearer "): + return raw[len("Bearer "):].strip() + return None + return None + + +def _key_suffix(token: str) -> str: + """Last 8 chars for safe log correlation. Never log the full key (bug #7).""" + return token[-8:] if len(token) >= 8 else "********" + + +async def _send_401(send, detail: str) -> None: + body = json.dumps({"error": detail}).encode() + await send({ + "type": "http.response.start", + "status": 401, + "headers": [ + (b"content-type", b"application/json"), + (b"www-authenticate", b"Bearer"), + ], + }) + await send({"type": "http.response.body", "body": body}) + + +class MCPAuthMiddleware: + """Authenticate /mcp with a per-user ci_ key, forwarded to the backend. + + Pure ASGI (not BaseHTTPMiddleware) on purpose: the ContextVar set here must + propagate to the tool-handler task, and BaseHTTPMiddleware runs the downstream + app in a separate task that breaks that propagation. - if MCP_AUTH_TOKEN: - class MCPAuthMiddleware(BaseHTTPMiddleware): - """Require Bearer token on /mcp, leave /health public.""" - async def dispatch(self, request: Request, call_next): - if request.url.path == "/health": - return await call_next(request) - auth = request.headers.get("authorization", "") - if not auth.startswith("Bearer ") or auth[7:] != MCP_AUTH_TOKEN: - return JSONResponse({"error": "Unauthorized"}, status_code=401) - return await call_next(request) + Fails closed -- a missing or invalid credential is a 401, never a fallback to a + shared identity for data calls. /health stays public. + """ - app.add_middleware(MCPAuthMiddleware) + def __init__(self, app): + self.app = app + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + if scope.get("path") == "/health": + await self.app(scope, receive, send) + return + + token = _extract_bearer(scope) + if not token: + logger.warning("mcp auth: missing or malformed Authorization header") + await _send_401(send, "Missing or malformed Authorization header") + return + + if token.startswith("ci_"): + # Per-user key: carry the caller's own identity to the backend. + set_request_api_key(token) + logger.info("mcp auth: ci_ key accepted (suffix=%s)", _key_suffix(token)) + await self.app(scope, receive, send) + return + + if MCP_AUTH_TOKEN and token == MCP_AUTH_TOKEN: + # Admin path: authenticates the endpoint only. No user key is set, so + # api_client uses the configured key -- the data scope is not widened. + logger.info("mcp auth: admin token accepted") + await self.app(scope, receive, send) + return + + logger.warning("mcp auth: invalid token rejected (suffix=%s)", _key_suffix(token)) + await _send_401(send, "Invalid API key") + + +def _get_http_app(): + """Build the Starlette app: public /health + auth-required /mcp.""" + app = mcp.streamable_http_app() + app.routes.insert(0, Route("/health", _health, methods=["GET"])) + # Always enforce: remote /mcp requires a per-user ci_ key (or the admin token). + app.add_middleware(MCPAuthMiddleware) return app diff --git a/mcp-server/tests/test_auth_forwarding.py b/mcp-server/tests/test_auth_forwarding.py new file mode 100644 index 0000000..c171ebe --- /dev/null +++ b/mcp-server/tests/test_auth_forwarding.py @@ -0,0 +1,194 @@ +"""Tests for per-user ci_ key forwarding at /mcp (issue #323). + +Two layers under test: + 1. api_client: the per-request Authorization header is resolved from a task-local + ContextVar, not baked into the shared client (the cross-tenant-leak defense). + 2. MCPAuthMiddleware: fail-closed auth at the ASGI edge; /health stays public; + ci_ keys are forwarded, the admin token authenticates without widening scope. +""" +import asyncio + +import pytest + +import api_client +from api_client import ( + _auth_header, + _resolve_key, + _request_api_key, + set_request_api_key, + get_client, + close_client, +) + + +# -- Per-request header resolution -- + +class TestAuthHeaderResolution: + def setup_method(self): + _request_api_key.set(None) + + def test_uses_per_request_key(self): + set_request_api_key("ci_userA") + assert _auth_header() == {"Authorization": "Bearer ci_userA"} + + def test_bearer_has_exactly_one_space(self): + # PR #292 / commit df958de regression guard: f"Bearer{key}" must never recur. + set_request_api_key("ci_xyz") + header = _auth_header()["Authorization"] + assert header.startswith("Bearer ") + assert header.split(" ", 1) == ["Bearer", "ci_xyz"] + + def test_falls_back_to_configured_key_when_no_request_key(self, monkeypatch): + # stdio / admin path: no per-request key, use the configured one. + monkeypatch.setattr(api_client, "API_KEY", "ci_configured") + _request_api_key.set(None) + assert _resolve_key() == "ci_configured" + + def test_per_request_key_overrides_configured(self, monkeypatch): + monkeypatch.setattr(api_client, "API_KEY", "ci_configured") + set_request_api_key("ci_caller") + assert _resolve_key() == "ci_caller" + + def test_raises_when_no_key_anywhere(self, monkeypatch): + monkeypatch.setattr(api_client, "API_KEY", "") + _request_api_key.set(None) + with pytest.raises(ValueError): + _resolve_key() + + +# -- Concurrency isolation (the cross-tenant-leak defense) -- + +class TestConcurrentIsolation: + @pytest.mark.asyncio + async def test_concurrent_requests_do_not_share_identity(self): + """Two interleaved tasks with different keys must each see only their own. + + asyncio tasks copy the context at creation, so each gather'd coroutine has + an independent ContextVar. If identity lived in a module global instead, + one task's set would clobber the other and this test would fail -- that is + precisely the production race we are guarding against. + """ + seen: dict[str, str] = {} + + async def request(name: str, key: str) -> None: + set_request_api_key(key) + await asyncio.sleep(0) # force interleaving with the other task + header = _auth_header()["Authorization"] + await asyncio.sleep(0) + seen[name] = header + + await asyncio.gather( + request("A", "ci_userA"), + request("B", "ci_userB"), + ) + + assert seen["A"] == "Bearer ci_userA" + assert seen["B"] == "Bearer ci_userB" + + +# -- Identity-free shared client -- + +class TestSharedClient: + @pytest.mark.asyncio + async def test_shared_client_has_no_default_auth_header(self): + await close_client() + try: + client = await get_client() + header_names = {k.lower() for k in client.headers.keys()} + assert "authorization" not in header_names + finally: + await close_client() + + +# -- ASGI auth middleware -- + +import server # noqa: E402 (constructs FastMCP; no network, safe to import) +from server import MCPAuthMiddleware # noqa: E402 + + +class _Recorder: + """Minimal inner ASGI app that records whether it ran and the key it saw.""" + + def __init__(self): + self.called = False + self.captured_key = "UNSET" + + async def __call__(self, scope, receive, send): + self.called = True + self.captured_key = _request_api_key.get() + + +def _http_scope(path="/mcp", headers=None): + return {"type": "http", "path": path, "headers": headers or []} + + +async def _run(mw, scope): + sent = [] + + async def receive(): + return {"type": "http.request", "body": b"", "more_body": False} + + async def send(message): + sent.append(message) + + await mw(scope, receive, send) + return sent + + +class TestMCPAuthMiddleware: + def setup_method(self): + _request_api_key.set(None) + + @pytest.mark.asyncio + async def test_health_is_public(self): + inner = _Recorder() + await _run(MCPAuthMiddleware(inner), _http_scope(path="/health")) + assert inner.called is True + + @pytest.mark.asyncio + async def test_missing_auth_returns_401(self): + inner = _Recorder() + sent = await _run(MCPAuthMiddleware(inner), _http_scope(headers=[])) + assert inner.called is False + assert sent[0]["status"] == 401 + + @pytest.mark.asyncio + async def test_non_ci_token_returns_401(self): + inner = _Recorder() + headers = [(b"authorization", b"Bearer notakey")] + sent = await _run(MCPAuthMiddleware(inner), _http_scope(headers=headers)) + assert inner.called is False + assert sent[0]["status"] == 401 + + @pytest.mark.asyncio + async def test_malformed_bearer_no_space_returns_401(self): + # "Bearerci_x" (missing space) must be rejected, not silently parsed. + inner = _Recorder() + headers = [(b"authorization", b"Bearerci_userA")] + sent = await _run(MCPAuthMiddleware(inner), _http_scope(headers=headers)) + assert inner.called is False + assert sent[0]["status"] == 401 + + @pytest.mark.asyncio + async def test_ci_key_is_forwarded_to_context(self): + inner = _Recorder() + headers = [(b"authorization", b"Bearer ci_userA")] + await _run(MCPAuthMiddleware(inner), _http_scope(headers=headers)) + assert inner.called is True + assert inner.captured_key == "ci_userA" + + @pytest.mark.asyncio + async def test_admin_token_passes_without_setting_user_key(self, monkeypatch): + monkeypatch.setattr(server, "MCP_AUTH_TOKEN", "admin-secret") + inner = _Recorder() + headers = [(b"authorization", b"Bearer admin-secret")] + await _run(MCPAuthMiddleware(inner), _http_scope(headers=headers)) + assert inner.called is True + # Non-widening: admin authenticates the endpoint but sets no user identity. + assert inner.captured_key is None + + @pytest.mark.asyncio + async def test_lifespan_scope_passes_through(self): + inner = _Recorder() + await _run(MCPAuthMiddleware(inner), {"type": "lifespan"}) + assert inner.called is True