Skip to content

Commit ceebd33

Browse files
committed
feat: forward per-user ci_ keys at /mcp, scoped to the owner (#323)
The remote /mcp endpoint authenticated every caller as one shared identity (MCP_API_KEY). Now each request carries the caller's own ci_ key, resolved per-request from a task-local ContextVar and sent as a per-request header; the backend's existing validator scopes results to that owner. Chose forwarding over re-validating in the MCP service: keeps auth in one authority (auth.py untouched, off-limits), adds zero DB round-trips, zero schema change. Chose a ContextVar set in pure-ASGI middleware over mutating the shared httpx client's headers: the latter races under concurrency and leaks tenant data; the ContextVar is task-local so it cannot. Pure ASGI (not BaseHTTPMiddleware) because the latter breaks contextvar propagation to the handler task. Fails closed: missing/invalid credential -> 401, no fallback to a shared identity for data calls. MCP_AUTH_TOKEN retained as a non-widening admin path. Latency-neutral; same single backend validation as before.
1 parent 42d6b1c commit ceebd33

3 files changed

Lines changed: 334 additions & 34 deletions

File tree

mcp-server/api_client.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,65 +3,100 @@
33
Uses a module-level client to avoid creating new TCP connections per tool call.
44
The client is initialized lazily on first use and reused for all subsequent calls.
55
Concurrent access is serialized via asyncio.Lock to prevent duplicate clients.
6+
7+
Auth identity is per-request, not per-client. On the remote (http) transport each
8+
MCP request carries the caller's own ci_ key, so we resolve the key at call time
9+
from a ContextVar and send it as a per-request header. The shared client holds NO
10+
Authorization header: baking one in would let concurrent requests race on a single
11+
shared identity, which is a cross-tenant leak. A ContextVar is task-local, so two
12+
in-flight requests can never read each other's key.
613
"""
714
import asyncio
15+
import contextvars
816
from typing import Any, Optional
917

1018
import httpx
1119

1220
from config import BACKEND_API_URL, API_KEY
1321

1422

15-
# Persistent client reused across all tool calls
23+
# Persistent client reused across all tool calls. Identity-free by design.
1624
_client: Optional[httpx.AsyncClient] = None
1725
_client_lock: asyncio.Lock = asyncio.Lock()
1826

27+
# Per-request caller identity. The MCP server's auth middleware sets this (remote
28+
# transport) before the tool runs; it stays unset on stdio, where the configured
29+
# key is the identity. Task-local, so concurrent requests are isolated.
30+
_request_api_key: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
31+
"request_api_key", default=None
32+
)
33+
1934

20-
def _get_headers() -> dict[str, str]:
21-
"""Return Authorization header with the configured API_KEY.
35+
def set_request_api_key(key: Optional[str]) -> None:
36+
"""Record the caller's key for the current request context.
2237
23-
Raises ValueError if API_KEY is empty or unset.
38+
Called by the MCP server's auth middleware on the remote transport.
2439
"""
25-
if not API_KEY:
26-
raise ValueError(
27-
"No API_KEY configured. Set API_KEY in .env or environment."
28-
)
29-
return {"Authorization": f"Bearer {API_KEY}"}
40+
_request_api_key.set(key)
41+
42+
43+
def _resolve_key() -> str:
44+
"""Per-request key if present (remote), else the configured key (stdio/admin)."""
45+
key = _request_api_key.get()
46+
if key:
47+
return key
48+
if API_KEY:
49+
return API_KEY
50+
raise ValueError("No API key available for backend call")
51+
52+
53+
def _auth_header() -> dict[str, str]:
54+
"""Build the Authorization header. The space after 'Bearer' is mandatory (PR #292)."""
55+
return {"Authorization": f"Bearer {_resolve_key()}"}
56+
57+
58+
def _merge_headers(extra: Optional[dict[str, Any]]) -> dict[str, str]:
59+
headers = _auth_header()
60+
if extra:
61+
headers.update(extra)
62+
return headers
3063

3164

3265
async def get_client() -> httpx.AsyncClient:
33-
"""Get or create the persistent HTTP client."""
66+
"""Get or create the persistent, identity-free HTTP client."""
3467
global _client
3568
async with _client_lock:
3669
if _client is None or _client.is_closed:
3770
_client = httpx.AsyncClient(
3871
base_url=BACKEND_API_URL,
3972
timeout=120.0,
40-
headers=_get_headers(),
4173
)
4274
return _client
4375

4476

4577
async def api_get(path: str, **kwargs: Any) -> dict:
46-
"""Make a GET request to the backend API."""
78+
"""Make a GET request to the backend API with the caller's identity."""
4779
client = await get_client()
48-
response = await client.get(path, **kwargs)
80+
headers = _merge_headers(kwargs.pop("headers", None))
81+
response = await client.get(path, headers=headers, **kwargs)
4982
response.raise_for_status()
5083
return response.json()
5184

5285

5386
async def api_post(path: str, json: dict, **kwargs: Any) -> dict:
54-
"""Make a POST request to the backend API."""
87+
"""Make a POST request to the backend API with the caller's identity."""
5588
client = await get_client()
56-
response = await client.post(path, json=json, **kwargs)
89+
headers = _merge_headers(kwargs.pop("headers", None))
90+
response = await client.post(path, json=json, headers=headers, **kwargs)
5791
response.raise_for_status()
5892
return response.json()
5993

6094

6195
async def api_delete(path: str, **kwargs: Any) -> dict:
62-
"""Make a DELETE request to the backend API."""
96+
"""Make a DELETE request to the backend API with the caller's identity."""
6397
client = await get_client()
64-
response = await client.delete(path, **kwargs)
98+
headers = _merge_headers(kwargs.pop("headers", None))
99+
response = await client.delete(path, headers=headers, **kwargs)
65100
response.raise_for_status()
66101
if response.status_code == 204 or not response.content:
67102
return {}

mcp-server/server.py

Lines changed: 88 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,23 @@
99
python server.py # stdio (default)
1010
python server.py --transport http # streamable HTTP on $PORT
1111
"""
12+
import json
13+
import logging
1214
import sys
15+
from typing import Optional
1316

1417
from mcp.server.fastmcp import FastMCP
1518
import mcp.types as types
1619
from starlette.routing import Route
1720
from starlette.responses import JSONResponse
1821

22+
from api_client import set_request_api_key
1923
from config import SERVER_NAME, SERVER_VERSION, TRANSPORT, HOST, PORT, MCP_AUTH_TOKEN
2024
from tools import get_tool_schemas
2125
from handlers import call_tool
2226

27+
logger = logging.getLogger(__name__)
28+
2329
mcp = FastMCP(
2430
name=SERVER_NAME,
2531
instructions=(
@@ -62,27 +68,92 @@ async def _health(request):
6268
return JSONResponse({"status": "ok", "server": SERVER_NAME, "version": SERVER_VERSION})
6369

6470

65-
def _get_http_app():
66-
"""Build the Starlette app with health check + MCP endpoint."""
67-
from starlette.middleware.base import BaseHTTPMiddleware
68-
from starlette.requests import Request
71+
def _extract_bearer(scope) -> Optional[str]:
72+
"""Pull the Bearer token from an ASGI scope's headers, or None.
6973
70-
app = mcp.streamable_http_app()
71-
app.routes.insert(0, Route("/health", _health, methods=["GET"]))
74+
Requires the space after 'Bearer' -- a missing space is the PR #292 bug, and a
75+
request carrying it is treated as malformed (rejected), never silently parsed.
76+
"""
77+
for name, value in scope.get("headers", []):
78+
if name == b"authorization":
79+
raw = value.decode("latin-1")
80+
if raw.startswith("Bearer "):
81+
return raw[len("Bearer "):].strip()
82+
return None
83+
return None
84+
85+
86+
def _key_suffix(token: str) -> str:
87+
"""Last 8 chars for safe log correlation. Never log the full key (bug #7)."""
88+
return token[-8:] if len(token) >= 8 else "********"
89+
90+
91+
async def _send_401(send, detail: str) -> None:
92+
body = json.dumps({"error": detail}).encode()
93+
await send({
94+
"type": "http.response.start",
95+
"status": 401,
96+
"headers": [
97+
(b"content-type", b"application/json"),
98+
(b"www-authenticate", b"Bearer"),
99+
],
100+
})
101+
await send({"type": "http.response.body", "body": body})
102+
103+
104+
class MCPAuthMiddleware:
105+
"""Authenticate /mcp with a per-user ci_ key, forwarded to the backend.
106+
107+
Pure ASGI (not BaseHTTPMiddleware) on purpose: the ContextVar set here must
108+
propagate to the tool-handler task, and BaseHTTPMiddleware runs the downstream
109+
app in a separate task that breaks that propagation.
72110
73-
if MCP_AUTH_TOKEN:
74-
class MCPAuthMiddleware(BaseHTTPMiddleware):
75-
"""Require Bearer token on /mcp, leave /health public."""
76-
async def dispatch(self, request: Request, call_next):
77-
if request.url.path == "/health":
78-
return await call_next(request)
79-
auth = request.headers.get("authorization", "")
80-
if not auth.startswith("Bearer ") or auth[7:] != MCP_AUTH_TOKEN:
81-
return JSONResponse({"error": "Unauthorized"}, status_code=401)
82-
return await call_next(request)
111+
Fails closed -- a missing or invalid credential is a 401, never a fallback to a
112+
shared identity for data calls. /health stays public.
113+
"""
83114

84-
app.add_middleware(MCPAuthMiddleware)
115+
def __init__(self, app):
116+
self.app = app
85117

118+
async def __call__(self, scope, receive, send):
119+
if scope["type"] != "http":
120+
await self.app(scope, receive, send)
121+
return
122+
123+
if scope.get("path") == "/health":
124+
await self.app(scope, receive, send)
125+
return
126+
127+
token = _extract_bearer(scope)
128+
if not token:
129+
logger.warning("mcp auth: missing or malformed Authorization header")
130+
await _send_401(send, "Missing or malformed Authorization header")
131+
return
132+
133+
if token.startswith("ci_"):
134+
# Per-user key: carry the caller's own identity to the backend.
135+
set_request_api_key(token)
136+
logger.info("mcp auth: ci_ key accepted (suffix=%s)", _key_suffix(token))
137+
await self.app(scope, receive, send)
138+
return
139+
140+
if MCP_AUTH_TOKEN and token == MCP_AUTH_TOKEN:
141+
# Admin path: authenticates the endpoint only. No user key is set, so
142+
# api_client uses the configured key -- the data scope is not widened.
143+
logger.info("mcp auth: admin token accepted")
144+
await self.app(scope, receive, send)
145+
return
146+
147+
logger.warning("mcp auth: invalid token rejected (suffix=%s)", _key_suffix(token))
148+
await _send_401(send, "Invalid API key")
149+
150+
151+
def _get_http_app():
152+
"""Build the Starlette app: public /health + auth-required /mcp."""
153+
app = mcp.streamable_http_app()
154+
app.routes.insert(0, Route("/health", _health, methods=["GET"]))
155+
# Always enforce: remote /mcp requires a per-user ci_ key (or the admin token).
156+
app.add_middleware(MCPAuthMiddleware)
86157
return app
87158

88159

0 commit comments

Comments
 (0)