|
1 | 1 | """ |
2 | | -Authentication Middleware |
3 | | -Protects routes requiring authentication |
| 2 | +Authentication Middleware for CodeIntel API |
| 3 | +
|
| 4 | +Supports three auth modes: |
| 5 | + 1. JWT tokens (Supabase) - for web UI users |
| 6 | + 2. API keys (ci_xxx) - for MCP/programmatic access |
| 7 | + 3. Public access - for demo endpoints (no auth required) |
| 8 | +
|
| 9 | +Usage: |
| 10 | + from middleware.auth import require_auth, public_auth, AuthContext |
| 11 | +
|
| 12 | + @app.get("/api/repos") |
| 13 | + async def list_repos(auth: AuthContext = Depends(require_auth)): |
| 14 | + user_id = auth.user_id |
| 15 | + ... |
| 16 | +
|
| 17 | + @app.get("/api/demo/search") |
| 18 | + async def demo_search(auth: AuthContext = Depends(public_auth)): |
| 19 | + # Works with or without auth |
| 20 | + ... |
4 | 21 | """ |
| 22 | +from dataclasses import dataclass |
| 23 | +from typing import Optional |
| 24 | +import os |
| 25 | +import hashlib |
| 26 | + |
5 | 27 | from fastapi import Depends, HTTPException, status |
6 | 28 | from fastapi.security import HTTPBearer |
7 | 29 | from fastapi.security.http import HTTPAuthorizationCredentials |
8 | | -from typing import Dict, Any, Optional |
9 | | -from services.auth import get_auth_service |
10 | 30 |
|
11 | | -# HTTP Bearer token scheme |
12 | | -security = HTTPBearer() |
13 | 31 |
|
| 32 | +# --------------------------------------------------------------------------- |
| 33 | +# Auth Context - unified return type for all auth methods |
| 34 | +# --------------------------------------------------------------------------- |
14 | 35 |
|
15 | | -async def get_current_user( |
16 | | - credentials: HTTPAuthorizationCredentials = Depends(security) |
17 | | -) -> Dict[str, Any]: |
18 | | - """ |
19 | | - Dependency to get current authenticated user from JWT token |
| 36 | +@dataclass |
| 37 | +class AuthContext: |
| 38 | + """Authentication context passed to route handlers""" |
| 39 | + user_id: Optional[str] = None # Supabase user ID (JWT auth) |
| 40 | + email: Optional[str] = None # User email (JWT auth) |
| 41 | + api_key_name: Optional[str] = None # API key name (key auth) |
| 42 | + tier: str = "free" # Rate limit tier |
| 43 | + is_public: bool = False # True for unauthenticated demo access |
20 | 44 |
|
21 | | - Usage: |
22 | | - @app.get("/protected") |
23 | | - async def protected_route(user: Dict = Depends(get_current_user)): |
24 | | - return {"user_id": user["user_id"]} |
| 45 | + @property |
| 46 | + def is_authenticated(self) -> bool: |
| 47 | + return not self.is_public |
25 | 48 |
|
26 | | - Returns: |
27 | | - Dict with user_id, email, and metadata |
| 49 | + @property |
| 50 | + def identifier(self) -> str: |
| 51 | + """Unique ID for rate limiting""" |
| 52 | + return self.user_id or self.api_key_name or "anonymous" |
| 53 | + |
| 54 | + |
| 55 | +# --------------------------------------------------------------------------- |
| 56 | +# Bearer token scheme (auto_error=False allows optional auth) |
| 57 | +# --------------------------------------------------------------------------- |
| 58 | + |
| 59 | +_bearer = HTTPBearer(auto_error=False) |
| 60 | +_bearer_required = HTTPBearer(auto_error=True) |
| 61 | + |
| 62 | + |
| 63 | +# --------------------------------------------------------------------------- |
| 64 | +# Core validation functions |
| 65 | +# --------------------------------------------------------------------------- |
| 66 | + |
| 67 | +def _validate_jwt(token: str) -> Optional[AuthContext]: |
| 68 | + """Validate Supabase JWT token""" |
| 69 | + try: |
| 70 | + from services.auth import get_auth_service |
| 71 | + auth_service = get_auth_service() |
| 72 | + user = auth_service.verify_jwt(token) |
| 73 | + |
| 74 | + return AuthContext( |
| 75 | + user_id=user["user_id"], |
| 76 | + email=user.get("email"), |
| 77 | + tier=user.get("metadata", {}).get("tier", "free") |
| 78 | + ) |
| 79 | + except Exception: |
| 80 | + return None |
| 81 | + |
| 82 | + |
| 83 | +def _validate_api_key(token: str) -> Optional[AuthContext]: |
| 84 | + """Validate API key (ci_xxx format)""" |
| 85 | + # Dev key for local development |
| 86 | + dev_key = os.getenv("API_KEY", "dev-secret-key") |
| 87 | + if token == dev_key and os.getenv("DEBUG", "false").lower() == "true": |
| 88 | + return AuthContext( |
| 89 | + api_key_name="development", |
| 90 | + tier="enterprise" |
| 91 | + ) |
| 92 | + |
| 93 | + # Production API keys start with ci_ |
| 94 | + if not token.startswith("ci_"): |
| 95 | + return None |
| 96 | + |
| 97 | + try: |
| 98 | + from services.supabase_service import get_supabase_service |
| 99 | + db = get_supabase_service().client |
| 100 | + |
| 101 | + key_hash = hashlib.sha256(token.encode()).hexdigest() |
| 102 | + result = db.table("api_keys").select("*").eq("key_hash", key_hash).eq("active", True).execute() |
| 103 | + |
| 104 | + if not result.data: |
| 105 | + return None |
28 | 106 |
|
29 | | - Raises: |
30 | | - HTTPException: 401 if token invalid |
| 107 | + key_data = result.data[0] |
| 108 | + return AuthContext( |
| 109 | + api_key_name=key_data.get("name"), |
| 110 | + user_id=key_data.get("user_id"), |
| 111 | + tier=key_data.get("tier", "free") |
| 112 | + ) |
| 113 | + except Exception: |
| 114 | + return None |
| 115 | + |
| 116 | + |
| 117 | +def _authenticate(token: str) -> AuthContext: |
| 118 | + """Try JWT first, then API key""" |
| 119 | + # Try JWT (Supabase tokens) |
| 120 | + ctx = _validate_jwt(token) |
| 121 | + if ctx: |
| 122 | + return ctx |
| 123 | + |
| 124 | + # Try API key |
| 125 | + ctx = _validate_api_key(token) |
| 126 | + if ctx: |
| 127 | + return ctx |
| 128 | + |
| 129 | + # Neither worked |
| 130 | + raise HTTPException( |
| 131 | + status_code=status.HTTP_401_UNAUTHORIZED, |
| 132 | + detail="Invalid token or API key", |
| 133 | + headers={"WWW-Authenticate": "Bearer"} |
| 134 | + ) |
| 135 | + |
| 136 | + |
| 137 | +# --------------------------------------------------------------------------- |
| 138 | +# FastAPI Dependencies - use these in your routes |
| 139 | +# --------------------------------------------------------------------------- |
| 140 | + |
| 141 | +async def require_auth( |
| 142 | + credentials: HTTPAuthorizationCredentials = Depends(_bearer_required) |
| 143 | +) -> AuthContext: |
| 144 | + """ |
| 145 | + Require authentication (JWT or API key) |
| 146 | + |
| 147 | + Raises 401 if no valid credentials provided. |
| 148 | + """ |
| 149 | + return _authenticate(credentials.credentials) |
| 150 | + |
| 151 | + |
| 152 | +async def public_auth( |
| 153 | + credentials: Optional[HTTPAuthorizationCredentials] = Depends(_bearer) |
| 154 | +) -> AuthContext: |
| 155 | + """ |
| 156 | + Optional authentication for public/demo routes |
| 157 | + |
| 158 | + Returns authenticated context if valid token provided, |
| 159 | + otherwise returns public context (is_public=True). |
| 160 | + """ |
| 161 | + if not credentials: |
| 162 | + return AuthContext(is_public=True) |
| 163 | + |
| 164 | + try: |
| 165 | + return _authenticate(credentials.credentials) |
| 166 | + except HTTPException: |
| 167 | + # Invalid token on public route = treat as anonymous |
| 168 | + return AuthContext(is_public=True) |
| 169 | + |
| 170 | + |
| 171 | +# --------------------------------------------------------------------------- |
| 172 | +# Legacy functions - kept for backwards compatibility |
| 173 | +# --------------------------------------------------------------------------- |
| 174 | + |
| 175 | +async def get_current_user( |
| 176 | + credentials: HTTPAuthorizationCredentials = Depends(_bearer_required) |
| 177 | +) -> dict: |
| 178 | + """ |
| 179 | + [LEGACY] Get current user from JWT token |
| 180 | + |
| 181 | + Prefer using require_auth() for new code. |
31 | 182 | """ |
| 183 | + from services.auth import get_auth_service |
32 | 184 | auth_service = get_auth_service() |
33 | 185 | return auth_service.verify_jwt(credentials.credentials) |
34 | 186 |
|
35 | 187 |
|
36 | 188 | async def get_optional_user( |
37 | | - credentials: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer(auto_error=False)) |
38 | | -) -> Optional[Dict[str, Any]]: |
| 189 | + credentials: Optional[HTTPAuthorizationCredentials] = Depends(_bearer) |
| 190 | +) -> Optional[dict]: |
39 | 191 | """ |
40 | | - Optional authentication - returns None if no token provided |
| 192 | + [LEGACY] Optional JWT authentication |
41 | 193 | |
42 | | - Usage: |
43 | | - @app.get("/optional-auth") |
44 | | - async def route(user: Optional[Dict] = Depends(get_optional_user)): |
45 | | - if user: |
46 | | - return {"message": f"Hello {user['email']}"} |
47 | | - return {"message": "Hello guest"} |
| 194 | + Prefer using public_auth() for new code. |
48 | 195 | """ |
49 | 196 | if not credentials: |
50 | 197 | return None |
51 | 198 |
|
52 | 199 | try: |
| 200 | + from services.auth import get_auth_service |
53 | 201 | auth_service = get_auth_service() |
54 | 202 | return auth_service.verify_jwt(credentials.credentials) |
55 | | - except: |
| 203 | + except Exception: |
56 | 204 | return None |
0 commit comments