|
16 | 16 | from starlette.routing import Route |
17 | 17 | from starlette.responses import JSONResponse |
18 | 18 |
|
19 | | -from config import SERVER_NAME, SERVER_VERSION, TRANSPORT, HOST, PORT |
| 19 | +from config import SERVER_NAME, SERVER_VERSION, TRANSPORT, HOST, PORT, MCP_AUTH_TOKEN |
20 | 20 | from tools import get_tool_schemas |
21 | 21 | from handlers import call_tool |
22 | 22 |
|
@@ -64,8 +64,25 @@ async def _health(request): |
64 | 64 |
|
65 | 65 | def _get_http_app(): |
66 | 66 | """Build the Starlette app with health check + MCP endpoint.""" |
| 67 | + from starlette.middleware.base import BaseHTTPMiddleware |
| 68 | + from starlette.requests import Request |
| 69 | + |
67 | 70 | app = mcp.streamable_http_app() |
68 | 71 | app.routes.insert(0, Route("/health", _health, methods=["GET"])) |
| 72 | + |
| 73 | + if MCP_AUTH_TOKEN: |
| 74 | + class MCPAuthMiddleware(BaseHTTPMiddleware): |
| 75 | + """Require Bearer token on /mcp, leave /health public.""" |
| 76 | + async def dispatch(self, request: Request, call_next): |
| 77 | + if request.url.path == "/health": |
| 78 | + return await call_next(request) |
| 79 | + auth = request.headers.get("authorization", "") |
| 80 | + if not auth.startswith("Bearer ") or auth[7:] != MCP_AUTH_TOKEN: |
| 81 | + return JSONResponse({"error": "Unauthorized"}, status_code=401) |
| 82 | + return await call_next(request) |
| 83 | + |
| 84 | + app.add_middleware(MCPAuthMiddleware) |
| 85 | + |
69 | 86 | return app |
70 | 87 |
|
71 | 88 |
|
|
0 commit comments