Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 52 additions & 17 deletions mcp-server/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,65 +3,100 @@
Uses a module-level client to avoid creating new TCP connections per tool call.
The client is initialized lazily on first use and reused for all subsequent calls.
Concurrent access is serialized via asyncio.Lock to prevent duplicate clients.

Auth identity is per-request, not per-client. On the remote (http) transport each
MCP request carries the caller's own ci_ key, so we resolve the key at call time
from a ContextVar and send it as a per-request header. The shared client holds NO
Authorization header: baking one in would let concurrent requests race on a single
shared identity, which is a cross-tenant leak. A ContextVar is task-local, so two
in-flight requests can never read each other's key.
"""
import asyncio
import contextvars
from typing import Any, Optional

import httpx

from config import BACKEND_API_URL, API_KEY


# Persistent client reused across all tool calls
# Persistent client reused across all tool calls. Identity-free by design.
_client: Optional[httpx.AsyncClient] = None
_client_lock: asyncio.Lock = asyncio.Lock()

# Per-request caller identity. The MCP server's auth middleware sets this (remote
# transport) before the tool runs; it stays unset on stdio, where the configured
# key is the identity. Task-local, so concurrent requests are isolated.
_request_api_key: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"request_api_key", default=None
)


def _get_headers() -> dict[str, str]:
"""Return Authorization header with the configured API_KEY.
def set_request_api_key(key: Optional[str]) -> None:
"""Record the caller's key for the current request context.

Raises ValueError if API_KEY is empty or unset.
Called by the MCP server's auth middleware on the remote transport.
"""
if not API_KEY:
raise ValueError(
"No API_KEY configured. Set API_KEY in .env or environment."
)
return {"Authorization": f"Bearer {API_KEY}"}
_request_api_key.set(key)


def _resolve_key() -> str:
"""Per-request key if present (remote), else the configured key (stdio/admin)."""
key = _request_api_key.get()
if key:
return key
if API_KEY:
return API_KEY
raise ValueError("No API key available for backend call")


def _auth_header() -> dict[str, str]:
"""Build the Authorization header. The space after 'Bearer' is mandatory (PR #292)."""
return {"Authorization": f"Bearer {_resolve_key()}"}


def _merge_headers(extra: Optional[dict[str, Any]]) -> dict[str, str]:
headers = _auth_header()
if extra:
headers.update(extra)
return headers


async def get_client() -> httpx.AsyncClient:
"""Get or create the persistent HTTP client."""
"""Get or create the persistent, identity-free HTTP client."""
global _client
async with _client_lock:
if _client is None or _client.is_closed:
_client = httpx.AsyncClient(
base_url=BACKEND_API_URL,
timeout=120.0,
headers=_get_headers(),
)
return _client


async def api_get(path: str, **kwargs: Any) -> dict:
"""Make a GET request to the backend API."""
"""Make a GET request to the backend API with the caller's identity."""
client = await get_client()
response = await client.get(path, **kwargs)
headers = _merge_headers(kwargs.pop("headers", None))
response = await client.get(path, headers=headers, **kwargs)
response.raise_for_status()
return response.json()


async def api_post(path: str, json: dict, **kwargs: Any) -> dict:
"""Make a POST request to the backend API."""
"""Make a POST request to the backend API with the caller's identity."""
client = await get_client()
response = await client.post(path, json=json, **kwargs)
headers = _merge_headers(kwargs.pop("headers", None))
response = await client.post(path, json=json, headers=headers, **kwargs)
response.raise_for_status()
return response.json()


async def api_delete(path: str, **kwargs: Any) -> dict:
"""Make a DELETE request to the backend API."""
"""Make a DELETE request to the backend API with the caller's identity."""
client = await get_client()
response = await client.delete(path, **kwargs)
headers = _merge_headers(kwargs.pop("headers", None))
response = await client.delete(path, headers=headers, **kwargs)
response.raise_for_status()
if response.status_code == 204 or not response.content:
return {}
Expand Down
105 changes: 88 additions & 17 deletions mcp-server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,23 @@
python server.py # stdio (default)
python server.py --transport http # streamable HTTP on $PORT
"""
import json
import logging
import sys
from typing import Optional

from mcp.server.fastmcp import FastMCP
import mcp.types as types
from starlette.routing import Route
from starlette.responses import JSONResponse

from api_client import set_request_api_key
from config import SERVER_NAME, SERVER_VERSION, TRANSPORT, HOST, PORT, MCP_AUTH_TOKEN
from tools import get_tool_schemas
from handlers import call_tool

logger = logging.getLogger(__name__)

mcp = FastMCP(
name=SERVER_NAME,
instructions=(
Expand Down Expand Up @@ -62,27 +68,92 @@ async def _health(request):
return JSONResponse({"status": "ok", "server": SERVER_NAME, "version": SERVER_VERSION})


def _get_http_app():
"""Build the Starlette app with health check + MCP endpoint."""
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
def _extract_bearer(scope) -> Optional[str]:
"""Pull the Bearer token from an ASGI scope's headers, or None.

app = mcp.streamable_http_app()
app.routes.insert(0, Route("/health", _health, methods=["GET"]))
Requires the space after 'Bearer' -- a missing space is the PR #292 bug, and a
request carrying it is treated as malformed (rejected), never silently parsed.
"""
for name, value in scope.get("headers", []):
if name == b"authorization":
raw = value.decode("latin-1")
if raw.startswith("Bearer "):
return raw[len("Bearer "):].strip()
return None
return None


def _key_suffix(token: str) -> str:
"""Last 8 chars for safe log correlation. Never log the full key (bug #7)."""
return token[-8:] if len(token) >= 8 else "********"


async def _send_401(send, detail: str) -> None:
body = json.dumps({"error": detail}).encode()
await send({
"type": "http.response.start",
"status": 401,
"headers": [
(b"content-type", b"application/json"),
(b"www-authenticate", b"Bearer"),
],
})
await send({"type": "http.response.body", "body": body})


class MCPAuthMiddleware:
"""Authenticate /mcp with a per-user ci_ key, forwarded to the backend.

Pure ASGI (not BaseHTTPMiddleware) on purpose: the ContextVar set here must
propagate to the tool-handler task, and BaseHTTPMiddleware runs the downstream
app in a separate task that breaks that propagation.

if MCP_AUTH_TOKEN:
class MCPAuthMiddleware(BaseHTTPMiddleware):
"""Require Bearer token on /mcp, leave /health public."""
async def dispatch(self, request: Request, call_next):
if request.url.path == "/health":
return await call_next(request)
auth = request.headers.get("authorization", "")
if not auth.startswith("Bearer ") or auth[7:] != MCP_AUTH_TOKEN:
return JSONResponse({"error": "Unauthorized"}, status_code=401)
return await call_next(request)
Fails closed -- a missing or invalid credential is a 401, never a fallback to a
shared identity for data calls. /health stays public.
"""

app.add_middleware(MCPAuthMiddleware)
def __init__(self, app):
self.app = app

async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return

if scope.get("path") == "/health":
await self.app(scope, receive, send)
return

token = _extract_bearer(scope)
if not token:
logger.warning("mcp auth: missing or malformed Authorization header")
await _send_401(send, "Missing or malformed Authorization header")
return

if token.startswith("ci_"):
# Per-user key: carry the caller's own identity to the backend.
set_request_api_key(token)
logger.info("mcp auth: ci_ key accepted (suffix=%s)", _key_suffix(token))
await self.app(scope, receive, send)
return

if MCP_AUTH_TOKEN and token == MCP_AUTH_TOKEN:
# Admin path: authenticates the endpoint only. No user key is set, so
# api_client uses the configured key -- the data scope is not widened.
logger.info("mcp auth: admin token accepted")
await self.app(scope, receive, send)
return

logger.warning("mcp auth: invalid token rejected (suffix=%s)", _key_suffix(token))
await _send_401(send, "Invalid API key")


def _get_http_app():
"""Build the Starlette app: public /health + auth-required /mcp."""
app = mcp.streamable_http_app()
app.routes.insert(0, Route("/health", _health, methods=["GET"]))
# Always enforce: remote /mcp requires a per-user ci_ key (or the admin token).
app.add_middleware(MCPAuthMiddleware)
return app


Expand Down
Loading
Loading