diff --git a/backend/main.py b/backend/main.py index ca72dce..5f220d9 100644 --- a/backend/main.py +++ b/backend/main.py @@ -2,7 +2,7 @@ CodeIntel Backend API FastAPI backend for codebase intelligence """ -from fastapi import FastAPI, HTTPException, Header, WebSocket, WebSocketDisconnect, Depends +from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional, List @@ -27,7 +27,7 @@ # Import routers from routes.auth import router as auth_router -from middleware.auth import get_current_user +from middleware.auth import require_auth, AuthContext app = FastAPI( title="CodeIntel API", @@ -83,34 +83,6 @@ async def dispatch(self, request: Request, call_next): api_key_manager = APIKeyManager(get_supabase_service().client) cost_controller = CostController(get_supabase_service().client) -# Development API Key (for local testing only) -DEV_API_KEY = os.getenv("API_KEY", "dev-secret-key") - - -def verify_api_key(authorization: str = Header(None)): - """Verify API key and check rate limits""" - if not authorization or not authorization.startswith("Bearer "): - raise HTTPException(status_code=401, detail="Invalid authorization header") - - token = authorization.replace("Bearer ", "") - - # Allow dev key for local development - if token == DEV_API_KEY and os.getenv("DEBUG", "false").lower() == "true": - return {"key": token, "tier": "enterprise", "user_id": None, "name": "Development"} - - # Verify production API key - key_data = api_key_manager.verify_key(token) - if not key_data: - raise HTTPException(status_code=401, detail="Invalid API key") - - # Check rate limits - allowed, error_msg = rate_limiter.check_rate_limit(token, key_data.get("tier", "free")) - if not allowed: - raise HTTPException(status_code=429, detail=error_msg) - - return key_data - - # Request/Response Models class SearchRequest(BaseModel): query: str @@ -144,9 +116,9 @@ async def health_check(): @app.get("/api/repos") -async def list_repositories(current_user: dict = Depends(get_current_user)): +async def list_repositories(auth: AuthContext = Depends(require_auth)): """List all repositories for authenticated user""" - user_id = current_user["user_id"] + user_id = auth.user_id # TODO: Filter repos by user_id once we add user_id column to repositories table # For now, return all repos (will fix in next section) @@ -157,10 +129,10 @@ async def list_repositories(current_user: dict = Depends(get_current_user)): @app.post("/api/repos") async def add_repository( request: AddRepoRequest, - current_user: dict = Depends(get_current_user) + auth: AuthContext = Depends(require_auth) ): """Add a new repository with validation and cost controls""" - user_id = current_user["user_id"] + user_id = auth.user_id or auth.identifier # Validate repository name valid_name, name_error = InputValidator.validate_repo_name(request.name) @@ -262,10 +234,9 @@ async def progress_callback(files_processed: int, functions_indexed: int, total_ async def index_repository( repo_id: str, incremental: bool = True, - api_key: str = Header(None, alias="Authorization") + auth: AuthContext = Depends(require_auth) ): """Trigger indexing for a repository - automatically uses incremental if possible""" - verify_api_key(api_key) import time import git @@ -322,10 +293,9 @@ async def index_repository( @app.post("/api/search") async def search_code( request: SearchRequest, - api_key: str = Header(None, alias="Authorization") + auth: AuthContext = Depends(require_auth) ): """Search code semantically with caching and validation""" - verify_api_key(api_key) # Validate search query valid_query, query_error = InputValidator.validate_search_query(request.query) @@ -368,10 +338,9 @@ async def search_code( @app.post("/api/explain") async def explain_code( request: ExplainRequest, - api_key: str = Header(None, alias="Authorization") + auth: AuthContext = Depends(require_auth) ): """Generate code explanation""" - verify_api_key(api_key) try: repo = repo_manager.get_repo(request.repo_id) @@ -400,10 +369,9 @@ class ImpactRequest(BaseModel): @app.get("/api/repos/{repo_id}/dependencies") async def get_dependency_graph( repo_id: str, - api_key: str = Header(None, alias="Authorization") + auth: AuthContext = Depends(require_auth) ): """Get dependency graph for repository with Supabase caching""" - verify_api_key(api_key) try: repo = repo_manager.get_repo(repo_id) @@ -434,10 +402,9 @@ async def get_dependency_graph( async def analyze_impact( repo_id: str, request: ImpactRequest, - api_key: str = Header(None, alias="Authorization") + auth: AuthContext = Depends(require_auth) ): """Analyze impact of changing a file with validation and caching""" - verify_api_key(api_key) try: repo = repo_manager.get_repo(repo_id) @@ -474,10 +441,9 @@ async def analyze_impact( @app.get("/api/repos/{repo_id}/insights") async def get_repository_insights( repo_id: str, - api_key: str = Header(None, alias="Authorization") + auth: AuthContext = Depends(require_auth) ): """Get comprehensive insights about repository with Supabase caching""" - verify_api_key(api_key) try: repo = repo_manager.get_repo(repo_id) @@ -517,10 +483,9 @@ class ImpactRequest(BaseModel): @app.get("/api/repos/{repo_id}/style-analysis") async def get_style_analysis( repo_id: str, - api_key: str = Header(None, alias="Authorization") + auth: AuthContext = Depends(require_auth) ): """Analyze code style and team patterns with Supabase caching""" - verify_api_key(api_key) try: repo = repo_manager.get_repo(repo_id) @@ -549,11 +514,9 @@ async def get_style_analysis( @app.get("/api/metrics") async def get_performance_metrics( - api_key: str = Header(None, alias="Authorization") + auth: AuthContext = Depends(require_auth) ): """Get performance metrics and monitoring data""" - verify_api_key(api_key) - return metrics.get_metrics() @@ -567,16 +530,14 @@ class CreateAPIKeyRequest(BaseModel): @app.post("/api/keys/generate") async def generate_api_key( request: CreateAPIKeyRequest, - api_key: str = Header(None, alias="Authorization") + auth: AuthContext = Depends(require_auth) ): """Generate a new API key (requires existing valid key or dev mode)""" - key_data = verify_api_key(api_key) - # Generate new key new_key = api_key_manager.generate_key( name=request.name, tier=request.tier, - user_id=key_data.get("user_id") + user_id=auth.user_id ) return { @@ -589,21 +550,18 @@ async def generate_api_key( @app.get("/api/keys/usage") async def get_api_usage( - api_key: str = Header(None, alias="Authorization") + auth: AuthContext = Depends(require_auth) ): """Get current API usage stats""" - key_data = verify_api_key(api_key) - token = api_key.replace("Bearer ", "") - - usage = rate_limiter.get_usage(token) + usage = rate_limiter.get_usage(auth.identifier) return { - "tier": key_data.get("tier", "free"), + "tier": auth.tier, "limits": { "free": {"minute": 20, "hour": 200, "day": 1000}, "pro": {"minute": 100, "hour": 2000, "day": 20000}, "enterprise": {"minute": 500, "hour": 10000, "day": 100000} - }[key_data.get("tier", "free")], + }[auth.tier], "usage": usage } diff --git a/backend/middleware/__init__.py b/backend/middleware/__init__.py index 743ca02..63a583b 100644 --- a/backend/middleware/__init__.py +++ b/backend/middleware/__init__.py @@ -1,4 +1,18 @@ """Authentication middleware package""" -from .auth import get_current_user, get_optional_user +from .auth import ( + # New unified auth (recommended) + AuthContext, + require_auth, + public_auth, + # Legacy (backwards compatibility) + get_current_user, + get_optional_user, +) -__all__ = ["get_current_user", "get_optional_user"] +__all__ = [ + "AuthContext", + "require_auth", + "public_auth", + "get_current_user", + "get_optional_user", +] diff --git a/backend/middleware/auth.py b/backend/middleware/auth.py index 79d9cb4..0bdb19d 100644 --- a/backend/middleware/auth.py +++ b/backend/middleware/auth.py @@ -1,56 +1,204 @@ """ -Authentication Middleware -Protects routes requiring authentication +Authentication Middleware for CodeIntel API + +Supports three auth modes: + 1. JWT tokens (Supabase) - for web UI users + 2. API keys (ci_xxx) - for MCP/programmatic access + 3. Public access - for demo endpoints (no auth required) + +Usage: + from middleware.auth import require_auth, public_auth, AuthContext + + @app.get("/api/repos") + async def list_repos(auth: AuthContext = Depends(require_auth)): + user_id = auth.user_id + ... + + @app.get("/api/demo/search") + async def demo_search(auth: AuthContext = Depends(public_auth)): + # Works with or without auth + ... """ +from dataclasses import dataclass +from typing import Optional +import os +import hashlib + from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBearer from fastapi.security.http import HTTPAuthorizationCredentials -from typing import Dict, Any, Optional -from services.auth import get_auth_service -# HTTP Bearer token scheme -security = HTTPBearer() +# --------------------------------------------------------------------------- +# Auth Context - unified return type for all auth methods +# --------------------------------------------------------------------------- -async def get_current_user( - credentials: HTTPAuthorizationCredentials = Depends(security) -) -> Dict[str, Any]: - """ - Dependency to get current authenticated user from JWT token +@dataclass +class AuthContext: + """Authentication context passed to route handlers""" + user_id: Optional[str] = None # Supabase user ID (JWT auth) + email: Optional[str] = None # User email (JWT auth) + api_key_name: Optional[str] = None # API key name (key auth) + tier: str = "free" # Rate limit tier + is_public: bool = False # True for unauthenticated demo access - Usage: - @app.get("/protected") - async def protected_route(user: Dict = Depends(get_current_user)): - return {"user_id": user["user_id"]} + @property + def is_authenticated(self) -> bool: + return not self.is_public - Returns: - Dict with user_id, email, and metadata + @property + def identifier(self) -> str: + """Unique ID for rate limiting""" + return self.user_id or self.api_key_name or "anonymous" + + +# --------------------------------------------------------------------------- +# Bearer token scheme (auto_error=False allows optional auth) +# --------------------------------------------------------------------------- + +_bearer = HTTPBearer(auto_error=False) +_bearer_required = HTTPBearer(auto_error=True) + + +# --------------------------------------------------------------------------- +# Core validation functions +# --------------------------------------------------------------------------- + +def _validate_jwt(token: str) -> Optional[AuthContext]: + """Validate Supabase JWT token""" + try: + from services.auth import get_auth_service + auth_service = get_auth_service() + user = auth_service.verify_jwt(token) + + return AuthContext( + user_id=user["user_id"], + email=user.get("email"), + tier=user.get("metadata", {}).get("tier", "free") + ) + except Exception: + return None + + +def _validate_api_key(token: str) -> Optional[AuthContext]: + """Validate API key (ci_xxx format)""" + # Dev key for local development + dev_key = os.getenv("API_KEY", "dev-secret-key") + if token == dev_key and os.getenv("DEBUG", "false").lower() == "true": + return AuthContext( + api_key_name="development", + tier="enterprise" + ) + + # Production API keys start with ci_ + if not token.startswith("ci_"): + return None + + try: + from services.supabase_service import get_supabase_service + db = get_supabase_service().client + + key_hash = hashlib.sha256(token.encode()).hexdigest() + result = db.table("api_keys").select("*").eq("key_hash", key_hash).eq("active", True).execute() + + if not result.data: + return None - Raises: - HTTPException: 401 if token invalid + key_data = result.data[0] + return AuthContext( + api_key_name=key_data.get("name"), + user_id=key_data.get("user_id"), + tier=key_data.get("tier", "free") + ) + except Exception: + return None + + +def _authenticate(token: str) -> AuthContext: + """Try JWT first, then API key""" + # Try JWT (Supabase tokens) + ctx = _validate_jwt(token) + if ctx: + return ctx + + # Try API key + ctx = _validate_api_key(token) + if ctx: + return ctx + + # Neither worked + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token or API key", + headers={"WWW-Authenticate": "Bearer"} + ) + + +# --------------------------------------------------------------------------- +# FastAPI Dependencies - use these in your routes +# --------------------------------------------------------------------------- + +async def require_auth( + credentials: HTTPAuthorizationCredentials = Depends(_bearer_required) +) -> AuthContext: + """ + Require authentication (JWT or API key) + + Raises 401 if no valid credentials provided. + """ + return _authenticate(credentials.credentials) + + +async def public_auth( + credentials: Optional[HTTPAuthorizationCredentials] = Depends(_bearer) +) -> AuthContext: + """ + Optional authentication for public/demo routes + + Returns authenticated context if valid token provided, + otherwise returns public context (is_public=True). + """ + if not credentials: + return AuthContext(is_public=True) + + try: + return _authenticate(credentials.credentials) + except HTTPException: + # Invalid token on public route = treat as anonymous + return AuthContext(is_public=True) + + +# --------------------------------------------------------------------------- +# Legacy functions - kept for backwards compatibility +# --------------------------------------------------------------------------- + +async def get_current_user( + credentials: HTTPAuthorizationCredentials = Depends(_bearer_required) +) -> dict: + """ + [LEGACY] Get current user from JWT token + + Prefer using require_auth() for new code. """ + from services.auth import get_auth_service auth_service = get_auth_service() return auth_service.verify_jwt(credentials.credentials) async def get_optional_user( - credentials: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer(auto_error=False)) -) -> Optional[Dict[str, Any]]: + credentials: Optional[HTTPAuthorizationCredentials] = Depends(_bearer) +) -> Optional[dict]: """ - Optional authentication - returns None if no token provided + [LEGACY] Optional JWT authentication - Usage: - @app.get("/optional-auth") - async def route(user: Optional[Dict] = Depends(get_optional_user)): - if user: - return {"message": f"Hello {user['email']}"} - return {"message": "Hello guest"} + Prefer using public_auth() for new code. """ if not credentials: return None try: + from services.auth import get_auth_service auth_service = get_auth_service() return auth_service.verify_jwt(credentials.credentials) - except: + except Exception: return None diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 91ead49..0c39329 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -68,7 +68,7 @@ def mock_supabase(): client = MagicMock() table = MagicMock() - # Mock the fluent interface + # Mock the fluent interface for tables execute_result = MagicMock() execute_result.data = [] execute_result.count = 0 @@ -84,6 +84,13 @@ def mock_supabase(): table.execute.return_value = execute_result client.table.return_value = table + + # Mock auth.get_user to reject invalid tokens + # By default, return response with user=None (invalid token) + auth_response = MagicMock() + auth_response.user = None + client.auth.get_user.return_value = auth_response + mock.return_value = client yield mock